Add residual support for shmem staging iterator used in back-to-back GEMM fusion. This allows support of problem_size_0_n that is not multiple of 32. (#590)
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
362
include/cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h
Normal file
362
include/cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h
Normal file
@ -0,0 +1,362 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor_op_multiplicand_sm80.h"
|
||||
|
||||
#include "cutlass/platform/platform.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace warp {
|
||||
|
||||
|
||||
/// Tile access iterator
|
||||
/// Each iteration acess in the tile is
|
||||
/// used as multiplicand for one
|
||||
/// warp-level matrix multiplication
|
||||
template <
|
||||
/// Size of the tile (concept: MatrixShape)
|
||||
typename Shape_,
|
||||
/// Operand identity
|
||||
Operand Operand_,
|
||||
/// Data type of A elements
|
||||
typename Element_,
|
||||
/// Layout of operand
|
||||
typename Layout_,
|
||||
/// Shape of one matrix production operation (concept: MatrixShape)
|
||||
typename InstructionShape_,
|
||||
/// Delta between *MMA operations (in units of *MMA operations, concept:
|
||||
/// MatrixShape)
|
||||
int OpDelta_,
|
||||
/// Number of threads participating in one matrix operation
|
||||
int Threads = 32,
|
||||
/// Enable Residual Support
|
||||
bool EnableResidual = false,
|
||||
/// Number of partitions along K dimension
|
||||
int PartitionsK_ = 1
|
||||
>
|
||||
class MmaTensorOpMultiplicandTileAccessIterator {
|
||||
public:
|
||||
|
||||
/// Shape of tile to load (concept: MatrixShape)
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Operand tag
|
||||
static Operand const kOperand = Operand_;
|
||||
|
||||
/// Basic check
|
||||
static_assert(kOperand == Operand::kA || kOperand== Operand::kB,
|
||||
"MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma.");
|
||||
|
||||
/// Element type
|
||||
using Element = Element_;
|
||||
|
||||
/// Layout of source tile
|
||||
using Layout = Layout_;
|
||||
|
||||
/// Shape of one matrix product operation (concept: MatrixShape)
|
||||
using InstructionShape = InstructionShape_;
|
||||
|
||||
/// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape)
|
||||
static int const kOpDelta = OpDelta_;
|
||||
|
||||
/// Number of participating threads
|
||||
static int const kThreads = 32;
|
||||
|
||||
/// TensorRef type for loading element from a tensor
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
|
||||
/// Index type
|
||||
using Index = typename TensorRef::Index;
|
||||
|
||||
/// Long Index type
|
||||
using LongIndex = typename TensorRef::LongIndex;
|
||||
|
||||
/// Coordinate for an element in the tensor
|
||||
using TensorCoord = typename TensorRef::TensorCoord;
|
||||
|
||||
/// Number of elements accessed per Shared Memory load
|
||||
static int const kElementsPerAccess =
|
||||
(sizeof_bits<Element>::value >= 32 ? 1 : 32 / sizeof_bits<Element>::value);
|
||||
|
||||
using InstructionCount = MatrixShape<
|
||||
Shape::kRow / InstructionShape::kRow,
|
||||
Shape::kColumn / InstructionShape::kColumn
|
||||
>;
|
||||
|
||||
static int const kIterations = (kOperand == Operand::kA) ?
|
||||
InstructionCount::kColumn : InstructionCount::kRow;
|
||||
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Derived quantities
|
||||
//
|
||||
|
||||
/// Fragment object holding a thread's part of a tile
|
||||
using Fragment = Array<
|
||||
Element,
|
||||
(kOperand == Operand::kA) ?
|
||||
(Shape::kRow * InstructionShape::kColumn / kThreads) :
|
||||
(Shape::kColumn * InstructionShape::kRow / kThreads)
|
||||
>;
|
||||
|
||||
/// Memory access type
|
||||
using AccessType = AlignedArray<Element, kElementsPerAccess>;
|
||||
|
||||
private:
|
||||
|
||||
/// Underlying tensor reference
|
||||
TensorRef ref_;
|
||||
|
||||
/// Extent of tensor
|
||||
MatrixCoord extent_;
|
||||
|
||||
/// Origin
|
||||
MatrixCoord origin_;
|
||||
|
||||
/// Used to load residual tile
|
||||
bool is_residual_;
|
||||
|
||||
/// residual offset of each thread
|
||||
TensorCoord residual_offset_;
|
||||
|
||||
/// Iterations in a tile
|
||||
int iterations_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor from TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpMultiplicandTileAccessIterator(
|
||||
TensorRef const &ref,
|
||||
TensorCoord extent,
|
||||
int lane_id
|
||||
): ref_(ref), extent_(extent), is_residual_(false), iterations_(0) {
|
||||
|
||||
if (kOperand == Operand::kA) {
|
||||
origin_ = MatrixCoord(lane_id / 4, (lane_id % 4) * kElementsPerAccess);
|
||||
}
|
||||
else {
|
||||
origin_ = MatrixCoord((lane_id % 4) * kElementsPerAccess, lane_id / 4);
|
||||
}
|
||||
|
||||
ref_.add_coord_offset(origin_);
|
||||
|
||||
if(EnableResidual) {
|
||||
// compute residual offset
|
||||
if (kOperand == Operand::kA) {
|
||||
typename TensorCoord::Index residual_size =
|
||||
extent_.column() % Shape::kColumn;
|
||||
if(residual_size) {
|
||||
is_residual_ = true;
|
||||
residual_offset_ = make_Coord(0, residual_size);
|
||||
}
|
||||
}
|
||||
else {
|
||||
typename TensorCoord::Index residual_size =
|
||||
extent_.row() % Shape::kRow;
|
||||
if(residual_size) {
|
||||
is_residual_ = true;
|
||||
residual_offset_ = make_Coord(residual_size, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Constructor from TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpMultiplicandTileAccessIterator(
|
||||
TensorRef const &ref,
|
||||
int lane_id
|
||||
): MmaTensorOpMultiplicandTileAccessIterator(ref,
|
||||
{Shape::kRow, Shape::kColumn}, lane_id) {
|
||||
}
|
||||
|
||||
/// Advances an iterator along logical dimensions of matrix in units of whole tiles
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpMultiplicandTileAccessIterator &add_tile_offset(TensorCoord const &tile_offset) {
|
||||
|
||||
TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn);
|
||||
origin_ += coord_offset;
|
||||
|
||||
ref_.add_coord_offset(coord_offset);
|
||||
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances the iterator along the advance dimension
|
||||
CUTLASS_DEVICE
|
||||
void advance() {
|
||||
|
||||
if(EnableResidual && is_residual_) {
|
||||
is_residual_ = false;
|
||||
|
||||
origin_ += residual_offset_;
|
||||
ref_.add_coord_offset(residual_offset_);
|
||||
|
||||
}
|
||||
|
||||
else {
|
||||
if (kOperand == Operand::kA) {
|
||||
add_tile_offset({0, 1});
|
||||
}
|
||||
else {
|
||||
add_tile_offset({1, 0});
|
||||
}
|
||||
}
|
||||
|
||||
iterations_ = 0;
|
||||
}
|
||||
|
||||
/// increase iterations in a tile
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpMultiplicandTileAccessIterator & operator++() {
|
||||
|
||||
iterations_++;
|
||||
|
||||
if(iterations_ >= kIterations)
|
||||
advance();
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory at the location pointed to by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
void load(Fragment &frag) const {
|
||||
|
||||
int const kWarpShapeDivisibleInner =
|
||||
(kOperand == Operand::kA ? InstructionShape::kColumn : InstructionShape::kRow);
|
||||
|
||||
// Take advantage of Tensor Op's 8 x 4T access pattern
|
||||
int const kAccessesInner = (kWarpShapeDivisibleInner / kElementsPerAccess) / 4;
|
||||
|
||||
AccessType *access_ptr = reinterpret_cast<AccessType *>(&frag);
|
||||
|
||||
if (kOperand == Operand::kA) {
|
||||
int const kTilesPerInstruction = InstructionShape::kRow / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int inst_m_idx = 0; inst_m_idx < InstructionCount::kRow; ++inst_m_idx) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int access_m_idx = 0; access_m_idx < kTilesPerInstruction; ++access_m_idx) {
|
||||
int access_idx =
|
||||
access_m_idx + kTilesPerInstruction * (inner_idx + kAccessesInner * inst_m_idx);
|
||||
|
||||
MatrixCoord offset(
|
||||
access_m_idx * 8 + inst_m_idx * InstructionShape::kRow,
|
||||
inner_idx * 4 * kElementsPerAccess + iterations_ * InstructionShape::kColumn);
|
||||
|
||||
MatrixCoord access_coord = origin_ + offset;
|
||||
|
||||
// if(access_coord.row() < extent_.row() && access_coord.column() < extent_.column()) {
|
||||
|
||||
access_ptr[access_idx] = *reinterpret_cast<AccessType const *>(
|
||||
ref_.data() + ref_.offset(offset));
|
||||
// }
|
||||
// else {
|
||||
// AccessType zero;
|
||||
// zero.clear();
|
||||
// access_ptr[access_idx] = zero;
|
||||
// }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn; ++inst_n_idx) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) {
|
||||
int access_idx = inner_idx + kAccessesInner * inst_n_idx;
|
||||
|
||||
MatrixCoord offset(
|
||||
inner_idx * 4 * kElementsPerAccess + iterations_ * InstructionShape::kRow,
|
||||
inst_n_idx * 8);
|
||||
|
||||
MatrixCoord access_coord = origin_ + offset;
|
||||
|
||||
// if(access_coord.row() < extent_.row() && access_coord.column() < extent_.column()) {
|
||||
|
||||
access_ptr[access_idx] = *reinterpret_cast<AccessType const *>(
|
||||
ref_.data() + ref_.offset(offset));
|
||||
// }
|
||||
// else {
|
||||
// AccessType zero;
|
||||
// zero.clear();
|
||||
// access_ptr[access_idx] = zero;
|
||||
// }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace warp
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -56,8 +56,20 @@ namespace threadblock {
|
||||
|
||||
/// PredicatedVectorAccessIterator
|
||||
///
|
||||
template <typename Shape, typename WarpShape,
|
||||
typename Element, typename Layout, int ElementsPerAccess>
|
||||
template <
|
||||
/// Shape of the vector accessed by the entire threadblock
|
||||
typename Shape,
|
||||
/// Shape of the vector accessed by the warp
|
||||
typename WarpShape,
|
||||
/// Type of Element
|
||||
typename Element,
|
||||
/// Layout of the vector
|
||||
typename Layout,
|
||||
/// Number of elements for each access
|
||||
int ElementsPerAccess,
|
||||
/// Support residual tile
|
||||
bool EnableResidualAccess = false
|
||||
>
|
||||
class PredicatedVectorAccessIterator;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -65,8 +77,21 @@ class PredicatedVectorAccessIterator;
|
||||
/// Vector access iterator specialized for vectors, e.g. scale and bias
|
||||
/// Thread arrangements are for TensorOps
|
||||
///
|
||||
template <typename Shape_, typename WarpShape_, typename Element_, int ElementsPerAccess>
|
||||
class PredicatedVectorAccessIterator<Shape_, WarpShape_, Element_, layout::PitchLinear, ElementsPerAccess> {
|
||||
template <
|
||||
typename Shape_,
|
||||
typename WarpShape_,
|
||||
typename Element_,
|
||||
int ElementsPerAccess,
|
||||
bool EnableResidualAccess
|
||||
>
|
||||
class PredicatedVectorAccessIterator <
|
||||
Shape_,
|
||||
WarpShape_,
|
||||
Element_,
|
||||
layout::PitchLinear,
|
||||
ElementsPerAccess,
|
||||
EnableResidualAccess
|
||||
> {
|
||||
public:
|
||||
|
||||
using Shape = Shape_;
|
||||
@ -116,6 +141,12 @@ class PredicatedVectorAccessIterator<Shape_, WarpShape_, Element_, layout::Pitch
|
||||
/// iteration index
|
||||
LongIndex iteration_;
|
||||
|
||||
/// residual access
|
||||
bool is_residual_;
|
||||
|
||||
/// residual offset of each thread
|
||||
TensorCoord residual_offset_;
|
||||
|
||||
public:
|
||||
/// Constructs a vector access iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -132,7 +163,8 @@ class PredicatedVectorAccessIterator<Shape_, WarpShape_, Element_, layout::Pitch
|
||||
TensorCoord const &threadblock_offset)
|
||||
: pointer_(reinterpret_cast<BytePointer>(
|
||||
const_cast<NonConstPointer>(pointer))),
|
||||
extent_(extent) {
|
||||
extent_(extent),
|
||||
is_residual_(false) {
|
||||
|
||||
|
||||
int warp_offset = (warp_id / kWarpCountStrided) * WarpShape::kContiguous;
|
||||
@ -143,6 +175,15 @@ class PredicatedVectorAccessIterator<Shape_, WarpShape_, Element_, layout::Pitch
|
||||
TensorCoord((thread_id & kThreadsPerRowMask) * kElementsPerAccess, 0);
|
||||
|
||||
set_iteration_index(0);
|
||||
|
||||
if(EnableResidualAccess) {
|
||||
// compute residual offset
|
||||
typename TensorCoord::Index residual_size = extent_.contiguous() % WarpShape::kContiguous;
|
||||
if (residual_size) {
|
||||
is_residual_ = true;
|
||||
residual_offset_ = make_Coord(residual_size, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a PredicatedVectorAccessIterator with zero threadblock offset
|
||||
@ -170,6 +211,7 @@ class PredicatedVectorAccessIterator<Shape_, WarpShape_, Element_, layout::Pitch
|
||||
CUTLASS_DEVICE
|
||||
void add_tile_offset(
|
||||
TensorCoord const &tile_offset) {
|
||||
|
||||
thread_offset_ =
|
||||
thread_offset_ +
|
||||
TensorCoord(WarpShape::kContiguous * tile_offset.contiguous(), 0);
|
||||
@ -198,7 +240,12 @@ class PredicatedVectorAccessIterator<Shape_, WarpShape_, Element_, layout::Pitch
|
||||
/// Increment and return an instance to self.
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
add_tile_offset(TensorCoord(1, 0));
|
||||
if(EnableResidualAccess && is_residual_) {
|
||||
is_residual_ = false;
|
||||
thread_offset_ += residual_offset_;
|
||||
}
|
||||
else
|
||||
add_tile_offset(TensorCoord(1, 0));
|
||||
}
|
||||
|
||||
/// Increment and return an instance to self.
|
||||
@ -221,8 +268,21 @@ class PredicatedVectorAccessIterator<Shape_, WarpShape_, Element_, layout::Pitch
|
||||
|
||||
/// Specialization of PredicatedVectorAccessIterator for row-major data.
|
||||
///
|
||||
template <typename Shape_, typename WarpShape_, typename Element_, int ElementsPerAccess>
|
||||
class PredicatedVectorAccessIterator<Shape_, WarpShape_, Element_, layout::RowMajor, ElementsPerAccess> {
|
||||
template <
|
||||
typename Shape_,
|
||||
typename WarpShape_,
|
||||
typename Element_,
|
||||
int ElementsPerAccess,
|
||||
bool EnableResidualAccess
|
||||
>
|
||||
class PredicatedVectorAccessIterator<
|
||||
Shape_,
|
||||
WarpShape_,
|
||||
Element_,
|
||||
layout::RowMajor,
|
||||
ElementsPerAccess,
|
||||
EnableResidualAccess
|
||||
> {
|
||||
public:
|
||||
|
||||
using Shape = Shape_;
|
||||
@ -245,7 +305,8 @@ class PredicatedVectorAccessIterator<Shape_, WarpShape_, Element_, layout::RowMa
|
||||
layout::PitchLinearShape<WarpShape::kColumn, WarpShape::kRow>,
|
||||
Element,
|
||||
layout::PitchLinear,
|
||||
ElementsPerAccess>;
|
||||
ElementsPerAccess,
|
||||
EnableResidualAccess>;
|
||||
|
||||
using AccessType = typename UnderlyingIterator::AccessType;
|
||||
static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess;
|
||||
|
||||
Reference in New Issue
Block a user