Add residual support for shmem staging iterator used in back-to-back GEMM fusion. This allows support of problem_size_0_n that is not multiple of 32. (#590)

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Haicheng Wu
2022-08-15 11:19:24 -04:00
committed by GitHub
parent e66bfcb1f8
commit 497b499d9d
8 changed files with 462 additions and 35 deletions

View File

@ -0,0 +1,362 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/layout/tensor_op_multiplicand_sm80.h"
#include "cutlass/platform/platform.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace warp {
/// Tile access iterator
/// Each iteration acess in the tile is
/// used as multiplicand for one
/// warp-level matrix multiplication
template <
/// Size of the tile (concept: MatrixShape)
typename Shape_,
/// Operand identity
Operand Operand_,
/// Data type of A elements
typename Element_,
/// Layout of operand
typename Layout_,
/// Shape of one matrix production operation (concept: MatrixShape)
typename InstructionShape_,
/// Delta between *MMA operations (in units of *MMA operations, concept:
/// MatrixShape)
int OpDelta_,
/// Number of threads participating in one matrix operation
int Threads = 32,
/// Enable Residual Support
bool EnableResidual = false,
/// Number of partitions along K dimension
int PartitionsK_ = 1
>
class MmaTensorOpMultiplicandTileAccessIterator {
public:
/// Shape of tile to load (concept: MatrixShape)
using Shape = Shape_;
/// Operand tag
static Operand const kOperand = Operand_;
/// Basic check
static_assert(kOperand == Operand::kA || kOperand== Operand::kB,
"MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma.");
/// Element type
using Element = Element_;
/// Layout of source tile
using Layout = Layout_;
/// Shape of one matrix product operation (concept: MatrixShape)
using InstructionShape = InstructionShape_;
/// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape)
static int const kOpDelta = OpDelta_;
/// Number of participating threads
static int const kThreads = 32;
/// TensorRef type for loading element from a tensor
using TensorRef = TensorRef<Element, Layout>;
/// Index type
using Index = typename TensorRef::Index;
/// Long Index type
using LongIndex = typename TensorRef::LongIndex;
/// Coordinate for an element in the tensor
using TensorCoord = typename TensorRef::TensorCoord;
/// Number of elements accessed per Shared Memory load
static int const kElementsPerAccess =
(sizeof_bits<Element>::value >= 32 ? 1 : 32 / sizeof_bits<Element>::value);
using InstructionCount = MatrixShape<
Shape::kRow / InstructionShape::kRow,
Shape::kColumn / InstructionShape::kColumn
>;
static int const kIterations = (kOperand == Operand::kA) ?
InstructionCount::kColumn : InstructionCount::kRow;
public:
//
// Derived quantities
//
/// Fragment object holding a thread's part of a tile
using Fragment = Array<
Element,
(kOperand == Operand::kA) ?
(Shape::kRow * InstructionShape::kColumn / kThreads) :
(Shape::kColumn * InstructionShape::kRow / kThreads)
>;
/// Memory access type
using AccessType = AlignedArray<Element, kElementsPerAccess>;
private:
/// Underlying tensor reference
TensorRef ref_;
/// Extent of tensor
MatrixCoord extent_;
/// Origin
MatrixCoord origin_;
/// Used to load residual tile
bool is_residual_;
/// residual offset of each thread
TensorCoord residual_offset_;
/// Iterations in a tile
int iterations_;
public:
/// Constructor from TensorRef
CUTLASS_HOST_DEVICE
MmaTensorOpMultiplicandTileAccessIterator(
TensorRef const &ref,
TensorCoord extent,
int lane_id
): ref_(ref), extent_(extent), is_residual_(false), iterations_(0) {
if (kOperand == Operand::kA) {
origin_ = MatrixCoord(lane_id / 4, (lane_id % 4) * kElementsPerAccess);
}
else {
origin_ = MatrixCoord((lane_id % 4) * kElementsPerAccess, lane_id / 4);
}
ref_.add_coord_offset(origin_);
if(EnableResidual) {
// compute residual offset
if (kOperand == Operand::kA) {
typename TensorCoord::Index residual_size =
extent_.column() % Shape::kColumn;
if(residual_size) {
is_residual_ = true;
residual_offset_ = make_Coord(0, residual_size);
}
}
else {
typename TensorCoord::Index residual_size =
extent_.row() % Shape::kRow;
if(residual_size) {
is_residual_ = true;
residual_offset_ = make_Coord(residual_size, 0);
}
}
}
}
/// Constructor from TensorRef
CUTLASS_HOST_DEVICE
MmaTensorOpMultiplicandTileAccessIterator(
TensorRef const &ref,
int lane_id
): MmaTensorOpMultiplicandTileAccessIterator(ref,
{Shape::kRow, Shape::kColumn}, lane_id) {
}
/// Advances an iterator along logical dimensions of matrix in units of whole tiles
CUTLASS_HOST_DEVICE
MmaTensorOpMultiplicandTileAccessIterator &add_tile_offset(TensorCoord const &tile_offset) {
TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn);
origin_ += coord_offset;
ref_.add_coord_offset(coord_offset);
return *this;
}
/// Advances the iterator along the advance dimension
CUTLASS_DEVICE
void advance() {
if(EnableResidual && is_residual_) {
is_residual_ = false;
origin_ += residual_offset_;
ref_.add_coord_offset(residual_offset_);
}
else {
if (kOperand == Operand::kA) {
add_tile_offset({0, 1});
}
else {
add_tile_offset({1, 0});
}
}
iterations_ = 0;
}
/// increase iterations in a tile
CUTLASS_HOST_DEVICE
MmaTensorOpMultiplicandTileAccessIterator & operator++() {
iterations_++;
if(iterations_ >= kIterations)
advance();
return *this;
}
/// Loads a fragment from memory at the location pointed to by the iterator.
CUTLASS_HOST_DEVICE
void load(Fragment &frag) const {
int const kWarpShapeDivisibleInner =
(kOperand == Operand::kA ? InstructionShape::kColumn : InstructionShape::kRow);
// Take advantage of Tensor Op's 8 x 4T access pattern
int const kAccessesInner = (kWarpShapeDivisibleInner / kElementsPerAccess) / 4;
AccessType *access_ptr = reinterpret_cast<AccessType *>(&frag);
if (kOperand == Operand::kA) {
int const kTilesPerInstruction = InstructionShape::kRow / 8;
CUTLASS_PRAGMA_UNROLL
for (int inst_m_idx = 0; inst_m_idx < InstructionCount::kRow; ++inst_m_idx) {
CUTLASS_PRAGMA_UNROLL
for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) {
CUTLASS_PRAGMA_UNROLL
for (int access_m_idx = 0; access_m_idx < kTilesPerInstruction; ++access_m_idx) {
int access_idx =
access_m_idx + kTilesPerInstruction * (inner_idx + kAccessesInner * inst_m_idx);
MatrixCoord offset(
access_m_idx * 8 + inst_m_idx * InstructionShape::kRow,
inner_idx * 4 * kElementsPerAccess + iterations_ * InstructionShape::kColumn);
MatrixCoord access_coord = origin_ + offset;
// if(access_coord.row() < extent_.row() && access_coord.column() < extent_.column()) {
access_ptr[access_idx] = *reinterpret_cast<AccessType const *>(
ref_.data() + ref_.offset(offset));
// }
// else {
// AccessType zero;
// zero.clear();
// access_ptr[access_idx] = zero;
// }
}
}
}
}
else {
CUTLASS_PRAGMA_UNROLL
for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn; ++inst_n_idx) {
CUTLASS_PRAGMA_UNROLL
for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) {
int access_idx = inner_idx + kAccessesInner * inst_n_idx;
MatrixCoord offset(
inner_idx * 4 * kElementsPerAccess + iterations_ * InstructionShape::kRow,
inst_n_idx * 8);
MatrixCoord access_coord = origin_ + offset;
// if(access_coord.row() < extent_.row() && access_coord.column() < extent_.column()) {
access_ptr[access_idx] = *reinterpret_cast<AccessType const *>(
ref_.data() + ref_.offset(offset));
// }
// else {
// AccessType zero;
// zero.clear();
// access_ptr[access_idx] = zero;
// }
}
}
}
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace warp
} // namespace gemm
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -56,8 +56,20 @@ namespace threadblock {
/// PredicatedVectorAccessIterator
///
template <typename Shape, typename WarpShape,
typename Element, typename Layout, int ElementsPerAccess>
template <
/// Shape of the vector accessed by the entire threadblock
typename Shape,
/// Shape of the vector accessed by the warp
typename WarpShape,
/// Type of Element
typename Element,
/// Layout of the vector
typename Layout,
/// Number of elements for each access
int ElementsPerAccess,
/// Support residual tile
bool EnableResidualAccess = false
>
class PredicatedVectorAccessIterator;
////////////////////////////////////////////////////////////////////////////////
@ -65,8 +77,21 @@ class PredicatedVectorAccessIterator;
/// Vector access iterator specialized for vectors, e.g. scale and bias
/// Thread arrangements are for TensorOps
///
template <typename Shape_, typename WarpShape_, typename Element_, int ElementsPerAccess>
class PredicatedVectorAccessIterator<Shape_, WarpShape_, Element_, layout::PitchLinear, ElementsPerAccess> {
template <
typename Shape_,
typename WarpShape_,
typename Element_,
int ElementsPerAccess,
bool EnableResidualAccess
>
class PredicatedVectorAccessIterator <
Shape_,
WarpShape_,
Element_,
layout::PitchLinear,
ElementsPerAccess,
EnableResidualAccess
> {
public:
using Shape = Shape_;
@ -116,6 +141,12 @@ class PredicatedVectorAccessIterator<Shape_, WarpShape_, Element_, layout::Pitch
/// iteration index
LongIndex iteration_;
/// residual access
bool is_residual_;
/// residual offset of each thread
TensorCoord residual_offset_;
public:
/// Constructs a vector access iterator
CUTLASS_HOST_DEVICE
@ -132,7 +163,8 @@ class PredicatedVectorAccessIterator<Shape_, WarpShape_, Element_, layout::Pitch
TensorCoord const &threadblock_offset)
: pointer_(reinterpret_cast<BytePointer>(
const_cast<NonConstPointer>(pointer))),
extent_(extent) {
extent_(extent),
is_residual_(false) {
int warp_offset = (warp_id / kWarpCountStrided) * WarpShape::kContiguous;
@ -143,6 +175,15 @@ class PredicatedVectorAccessIterator<Shape_, WarpShape_, Element_, layout::Pitch
TensorCoord((thread_id & kThreadsPerRowMask) * kElementsPerAccess, 0);
set_iteration_index(0);
if(EnableResidualAccess) {
// compute residual offset
typename TensorCoord::Index residual_size = extent_.contiguous() % WarpShape::kContiguous;
if (residual_size) {
is_residual_ = true;
residual_offset_ = make_Coord(residual_size, 0);
}
}
}
/// Construct a PredicatedVectorAccessIterator with zero threadblock offset
@ -170,6 +211,7 @@ class PredicatedVectorAccessIterator<Shape_, WarpShape_, Element_, layout::Pitch
CUTLASS_DEVICE
void add_tile_offset(
TensorCoord const &tile_offset) {
thread_offset_ =
thread_offset_ +
TensorCoord(WarpShape::kContiguous * tile_offset.contiguous(), 0);
@ -198,7 +240,12 @@ class PredicatedVectorAccessIterator<Shape_, WarpShape_, Element_, layout::Pitch
/// Increment and return an instance to self.
CUTLASS_HOST_DEVICE
void advance() {
add_tile_offset(TensorCoord(1, 0));
if(EnableResidualAccess && is_residual_) {
is_residual_ = false;
thread_offset_ += residual_offset_;
}
else
add_tile_offset(TensorCoord(1, 0));
}
/// Increment and return an instance to self.
@ -221,8 +268,21 @@ class PredicatedVectorAccessIterator<Shape_, WarpShape_, Element_, layout::Pitch
/// Specialization of PredicatedVectorAccessIterator for row-major data.
///
template <typename Shape_, typename WarpShape_, typename Element_, int ElementsPerAccess>
class PredicatedVectorAccessIterator<Shape_, WarpShape_, Element_, layout::RowMajor, ElementsPerAccess> {
template <
typename Shape_,
typename WarpShape_,
typename Element_,
int ElementsPerAccess,
bool EnableResidualAccess
>
class PredicatedVectorAccessIterator<
Shape_,
WarpShape_,
Element_,
layout::RowMajor,
ElementsPerAccess,
EnableResidualAccess
> {
public:
using Shape = Shape_;
@ -245,7 +305,8 @@ class PredicatedVectorAccessIterator<Shape_, WarpShape_, Element_, layout::RowMa
layout::PitchLinearShape<WarpShape::kColumn, WarpShape::kRow>,
Element,
layout::PitchLinear,
ElementsPerAccess>;
ElementsPerAccess,
EnableResidualAccess>;
using AccessType = typename UnderlyingIterator::AccessType;
static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess;