* New updates. * Minor profiler updates Co-authored-by: Aniket Shivam <ashivam@nvidia.com>
1881 lines
61 KiB
C++
1881 lines
61 KiB
C++
/***************************************************************************************************
|
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
/*! \file
|
|
\brief Templates implementing loading of tiles from pitch-linear rank=2 tensors.
|
|
|
|
This iterator uses masks to guard out-of-bounds accesses. The first tile this
|
|
iterator visits maybe partial, then the remaining tiles are complete. So, we
|
|
only need to compute the predicates twice, once before the first tile and
|
|
once for the remaining full tiles which can share the same predicates.
|
|
|
|
A precomputed "Params" object minimizes the amount of state that must be stored in registers,
|
|
and integer addition is used to advance the pointer through memory.
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include "cutlass/arch/memory.h"
|
|
#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h"
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace cutlass {
|
|
namespace transform {
|
|
namespace threadblock {
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/// PredicatedTileIterator
|
|
///
|
|
/// Satisfies: ForwardTileIteratorConcept |
|
|
/// ReadableContiguousTileIteratorConcept |
|
|
/// WriteableContiguousTileIteratorConcept |
|
|
/// MaskedTileIteratorConcept
|
|
///
|
|
/// Regular tile iterator using a precomputed control structure to minimize register liveness
|
|
/// and integer arithmetic.
|
|
///
|
|
/// Layout is assumed to be invariant at the time the precomputed "Params" object is constructed.
|
|
///
|
|
/// Base pointer and tensor extents may be specified at the time the iterator is constructed.
|
|
/// Subsequently, they are assumed to be immutable.
|
|
///
|
|
/// Adding a logical coordinate offset may be performed at the time the iterator is constructed.
|
|
/// Subsequent additions to logical coordinate offset may be performed but are relatively expensive.
|
|
///
|
|
/// Visitation order is intended to first visit a "residual" tile that may be partially full in
|
|
/// both the advance dimension and the steady-state dimension. This is assumed to be the last
|
|
/// tile in the iteration sequence. Advancing an iterator that has just been constructed moves to
|
|
/// the first tile that is full in the advance dimension and recomputes predicates. Subsequent
|
|
/// accesses may be performed without updating internal predicates and are efficient in terms of
|
|
/// live register state and pointer arithmetic instructions.
|
|
///
|
|
/// To be efficient, this assumes the iterator will be dereferenced and advanced at least once
|
|
/// outside any looping structure to minimize integer arithmetic.
|
|
///
|
|
/// Acceses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing
|
|
/// the iterator.
|
|
///
|
|
///
|
|
/// Example:
|
|
///
|
|
/// An efficient pipeline structure may be constructed as follows:
|
|
///
|
|
// template <typename Iterator>
|
|
// __global__ void kernel(
|
|
// typename Iterator::Params params,
|
|
// typename Iterator::Element *ptr,
|
|
// TensorCoord extent) {
|
|
//
|
|
// typename Iterator::Fragment fragment;
|
|
//
|
|
// TensorCoord threadblock_offset(0, 0);
|
|
//
|
|
// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets);
|
|
//
|
|
//
|
|
// fragment = *iter; // load "residue" tile first
|
|
// ++iter; // advance to first "steady state" tile and update internal masks
|
|
//
|
|
//
|
|
// #pragma unroll
|
|
// for (int i = Remaining - 1; i >= 0; --i) {
|
|
//
|
|
// f(fragment);
|
|
//
|
|
// if (!i) {
|
|
// iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs.
|
|
// }
|
|
//
|
|
// fragment = *iter; // load tile during "steady state" phase
|
|
// ++iter; // advance to next tile - lightweight due to steady-state masks
|
|
// }
|
|
// }
|
|
//
|
|
// void host(TensorView<Element, 2, layout::PitchLinear> view) {
|
|
//
|
|
// using Iterator = transform::threadblock::PredicatedTileIterator;
|
|
//
|
|
// typename Iterator::Params params(view.layout());
|
|
//
|
|
// kernel<Iterator>(params, view.data());
|
|
// }
|
|
///
|
|
///
|
|
template <
|
|
typename Shape,
|
|
typename Element,
|
|
typename Layout,
|
|
int AdvanceRank,
|
|
typename ThreadMap,
|
|
int AccessSize = ThreadMap::kElementsPerAccess,
|
|
bool Gather = false
|
|
>
|
|
class PredicatedTileIterator;
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/// Specialization of PredicatedTileIterator for pitch-linear data.
|
|
///
|
|
/// Satisfies: ForwardTileIteratorConcept |
|
|
/// ReadableContiguousTileIteratorConcept |
|
|
/// WriteableContiguousTileIteratorConcept |
|
|
/// MaskedTileIteratorConcept
|
|
///
|
|
template <typename Shape_, typename Element_, int AdvanceRank,
|
|
typename ThreadMap_, int AccessSize, bool Gather>
|
|
class PredicatedTileIterator<Shape_, Element_, layout::PitchLinear, AdvanceRank,
|
|
ThreadMap_, AccessSize, Gather> {
|
|
public:
|
|
static_assert(
|
|
AdvanceRank == 0 || AdvanceRank == 1,
|
|
"Specialization for pitch-linear iterator may advance along the "
|
|
"contiguous(rank=0) or strided(rank=1) dimension.");
|
|
|
|
using Shape = Shape_;
|
|
using Element = Element_;
|
|
using Layout = layout::PitchLinear;
|
|
static int const kAdvanceRank = AdvanceRank;
|
|
using ThreadMap = ThreadMap_;
|
|
|
|
using Index = typename Layout::Index;
|
|
using LongIndex = typename Layout::LongIndex;
|
|
|
|
using TensorRef = TensorRef<Element, Layout>;
|
|
using TensorView = TensorView<Element, Layout>;
|
|
using TensorCoord = typename Layout::TensorCoord;
|
|
|
|
using Pointer = Element *;
|
|
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
|
|
|
/// Type used for internal memory accesses
|
|
using AccessType = AlignedArray<Element, AccessSize, (AccessSize * sizeof_bits<Element>::value / 8)>;
|
|
|
|
/// Underlying iterator to compute the addresses
|
|
using TileAccessIterator =
|
|
PredicatedTileAccessIterator<Shape, Element, Layout, kAdvanceRank,
|
|
ThreadMap, AccessType, Gather>;
|
|
|
|
static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector;
|
|
|
|
/// Fragment object to be loaded or stored
|
|
using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount *
|
|
ThreadMap::kElementsPerAccess>;
|
|
|
|
/// Predicate vector stores mask to guard accesses
|
|
using Mask = typename TileAccessIterator::Mask;
|
|
|
|
/// Parameters object is precomputed state and is host-constructible
|
|
class Params {
|
|
public:
|
|
using Base = typename TileAccessIterator::Params::Base;
|
|
|
|
friend PredicatedTileIterator;
|
|
|
|
private:
|
|
/// Parameters object
|
|
typename TileAccessIterator::Params params_;
|
|
|
|
public:
|
|
/// Construct the Params object given a pitch-linear tensor's layout
|
|
CUTLASS_HOST_DEVICE
|
|
Params(Layout const &layout) : params_(layout) {}
|
|
|
|
/// Default constructor
|
|
Params() = default;
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
Params(Base const &base)
|
|
: params_(base) {}
|
|
};
|
|
|
|
private:
|
|
/// Internal pointer type permits fast address arithmetic
|
|
using BytePointer = char *;
|
|
|
|
private:
|
|
//
|
|
// Data members
|
|
//
|
|
|
|
/// Data member to the tile access iterator
|
|
TileAccessIterator address_iterator_;
|
|
|
|
public:
|
|
|
|
/// Default constructor
|
|
PredicatedTileIterator() = default;
|
|
|
|
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
|
/// and thread ID
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator(
|
|
/// Precomputed parameters object
|
|
Params const ¶ms,
|
|
/// Pointer to start of tensor
|
|
Pointer pointer,
|
|
/// Extent of tensor
|
|
TensorCoord extent,
|
|
/// ID of each participating thread
|
|
int thread_id,
|
|
/// Initial offset of threadblock
|
|
TensorCoord const &threadblock_offset,
|
|
/// Gather indices
|
|
int const *indices = nullptr)
|
|
: address_iterator_(params.params_, pointer, extent, thread_id,
|
|
threadblock_offset, indices) {}
|
|
|
|
/// Construct a PredicatedTileIterator with zero threadblock offset
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator(
|
|
Params const ¶ms, ///< Precomputed parameters object
|
|
Pointer pointer, ///< Pointer to start of tensor
|
|
TensorCoord extent, ///< Extent of tensor
|
|
int thread_id ///< ID of each participating thread
|
|
)
|
|
: PredicatedTileIterator(params, pointer, extent, thread_id,
|
|
make_Coord(0, 0)) {}
|
|
|
|
/// Adds a pointer offset in units of Element
|
|
CUTLASS_HOST_DEVICE
|
|
void add_pointer_offset(LongIndex pointer_offset) {
|
|
address_iterator_.add_pointer_offset(pointer_offset);
|
|
}
|
|
|
|
/// Advances to the next tile in memory.
|
|
///
|
|
/// The first time this method is called, predicates are updated, and the
|
|
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
|
/// Subsequent calls are lightweight and must only update the internal
|
|
/// pointer.
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator &operator++() {
|
|
if (kAdvanceRank)
|
|
address_iterator_.add_tile_offset({0, 1});
|
|
else
|
|
address_iterator_.add_tile_offset({1, 0});
|
|
|
|
return *this;
|
|
}
|
|
|
|
/// Advances to the next tile in memory.
|
|
///
|
|
/// The first time this method is called, predicates are updated, and the
|
|
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
|
/// Subsequent calls are lightweight and must only update the internal
|
|
/// pointer.
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator operator++(int) {
|
|
PredicatedTileIterator self(*this);
|
|
operator++();
|
|
return self;
|
|
}
|
|
|
|
/// Clears the predicate set efficiently
|
|
CUTLASS_HOST_DEVICE
|
|
void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); }
|
|
|
|
/// Clears the predicate set efficiently
|
|
CUTLASS_HOST_DEVICE
|
|
void enable_mask() { address_iterator_.enable_mask(); }
|
|
|
|
/// Sets the predicate mask, overriding value stored in predicate iterator
|
|
CUTLASS_HOST_DEVICE
|
|
void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); }
|
|
|
|
/// Gets the mask
|
|
CUTLASS_HOST_DEVICE
|
|
void get_mask(Mask &mask) { address_iterator_.get_mask(mask); }
|
|
|
|
CUTLASS_DEVICE
|
|
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
|
load_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
|
|
}
|
|
|
|
CUTLASS_DEVICE
|
|
void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
|
|
|
|
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
|
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
|
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int v = 0; v < kAccessesPerVector; ++v) {
|
|
|
|
int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
|
|
|
|
address_iterator_.set_iteration_index(idx);
|
|
char const *byte_ptr = reinterpret_cast<char const *>(address_iterator_.get()) + byte_offset;
|
|
|
|
AccessType const *access_ptr = reinterpret_cast<AccessType const *>(byte_ptr);
|
|
|
|
cutlass::arch::global_load<AccessType,
|
|
sizeof(AccessType)
|
|
>(
|
|
frag_ptr[idx], access_ptr, address_iterator_.valid());
|
|
|
|
++address_iterator_;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Loads a fragment from memory
|
|
CUTLASS_DEVICE
|
|
void load(Fragment &frag) { load_with_byte_offset(frag, 0); }
|
|
|
|
/// Store a fragment to memory
|
|
CUTLASS_DEVICE
|
|
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
|
store_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
|
|
}
|
|
|
|
/// Store a fragment to memory
|
|
CUTLASS_DEVICE
|
|
void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
|
|
address_iterator_.set_iteration_index(0);
|
|
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
|
|
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int v = 0; v < kAccessesPerVector; ++v) {
|
|
|
|
int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
|
|
|
|
char *byte_ptr = reinterpret_cast<char *>(address_iterator_.get()) + byte_offset;
|
|
AccessType *access_ptr = reinterpret_cast<AccessType *>(byte_ptr);
|
|
|
|
if (address_iterator_.valid()) {
|
|
*access_ptr = frag_ptr[idx];
|
|
}
|
|
++address_iterator_;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Store a fragment to memory
|
|
CUTLASS_DEVICE
|
|
void store(Fragment const &frag) { store_with_byte_offset(frag, 0); }
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/// Specialization of PredicatedTileIterator for pitch-linear data.
|
|
///
|
|
/// Satisfies: ForwardTileIteratorConcept |
|
|
/// ReadableContiguousTileIteratorConcept |
|
|
/// WriteableContiguousTileIteratorConcept |
|
|
/// MaskedTileIteratorConcept
|
|
///
|
|
template <
|
|
typename Shape_,
|
|
typename Element_,
|
|
int AdvanceRank,
|
|
typename ThreadMap_,
|
|
int AccessSize,
|
|
bool Gather
|
|
>
|
|
class PredicatedTileIterator<Shape_, Element_, layout::ColumnMajor, AdvanceRank, ThreadMap_, AccessSize, Gather> {
|
|
public:
|
|
|
|
static_assert(AdvanceRank == 0 || AdvanceRank == 1,
|
|
"Specialization for pitch-linear iterator may along advance along the "
|
|
"contiguous(rank=0) or strided(rank=1) dimension.");
|
|
|
|
using Shape = Shape_;
|
|
using Element = Element_;
|
|
using Layout = layout::ColumnMajor;
|
|
static int const kAdvanceRank = AdvanceRank;
|
|
using ThreadMap = ThreadMap_;
|
|
|
|
using Index = typename Layout::Index;
|
|
using LongIndex = typename Layout::LongIndex;
|
|
|
|
using TensorRef = TensorRef<Element, Layout>;
|
|
using TensorView = TensorView<Element, Layout>;
|
|
using TensorCoord = typename Layout::TensorCoord;
|
|
|
|
using Pointer = Element *;
|
|
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
|
|
|
using UnderlyingIterator = PredicatedTileIterator<
|
|
layout::PitchLinearShape<Shape::kRow, Shape::kColumn>,
|
|
Element,
|
|
layout::PitchLinear,
|
|
(kAdvanceRank == 0 ? 0 : 1),
|
|
ThreadMap,
|
|
AccessSize,
|
|
Gather
|
|
>;
|
|
|
|
using AccessType = typename UnderlyingIterator::AccessType;
|
|
|
|
/// Fragment object to be loaded or stored
|
|
using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
|
|
|
|
/// Predicate vector stores mask to guard accesses
|
|
using Mask = typename UnderlyingIterator::Mask;
|
|
|
|
/// Parameters object is precomputed state and is host-constructible
|
|
class Params {
|
|
private:
|
|
|
|
friend PredicatedTileIterator;
|
|
|
|
/// Parameters object
|
|
typename UnderlyingIterator::Params params_;
|
|
|
|
public:
|
|
|
|
/// Default constructor
|
|
Params() = default;
|
|
|
|
/// Construct the Params object given a pitch-linear tensor's layout
|
|
CUTLASS_HOST_DEVICE
|
|
Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0)))
|
|
{}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
Params(typename UnderlyingIterator::Params::Base const &base)
|
|
: params_(base) {}
|
|
};
|
|
|
|
|
|
private:
|
|
|
|
//
|
|
// Data members
|
|
//
|
|
|
|
/// Underlying pitch-linear tile iterator
|
|
UnderlyingIterator iterator_;
|
|
|
|
public:
|
|
|
|
/// Default constructor
|
|
PredicatedTileIterator() = default;
|
|
|
|
/// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator(
|
|
Params const ¶ms, ///< Precomputed parameters object
|
|
Pointer pointer, ///< Pointer to start of tensor
|
|
TensorCoord extent, ///< Extent of tensor
|
|
int thread_id, ///< ID of each participating thread
|
|
TensorCoord const &threadblock_offset, ///< Initial offset of threadblock
|
|
int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
|
|
):
|
|
iterator_(
|
|
params.params_,
|
|
pointer,
|
|
layout::PitchLinearCoord(extent.row(), extent.column()),
|
|
thread_id,
|
|
layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()),
|
|
indices)
|
|
{ }
|
|
|
|
/// Construct a PredicatedTileIterator with zero threadblock offset
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator(
|
|
Params const ¶ms, ///< Precomputed parameters object
|
|
Pointer pointer, ///< Pointer to start of tensor
|
|
TensorCoord extent, ///< Extent of tensor
|
|
int thread_id ///< ID of each participating thread
|
|
): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { }
|
|
|
|
/// Adds a pointer offset in units of Element
|
|
CUTLASS_HOST_DEVICE
|
|
void add_pointer_offset(LongIndex pointer_offset) {
|
|
iterator_.add_pointer_offset(pointer_offset);
|
|
}
|
|
|
|
/// Advances to the next tile in memory.
|
|
///
|
|
/// The first time this method is called, predicates are updated, and the iterator's
|
|
/// internal pointer is reverted to the first "steady state" tile. Subsequent calls
|
|
/// are lightweight and must only update the internal pointer.
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator &operator++() {
|
|
++iterator_;
|
|
return *this;
|
|
}
|
|
|
|
/// Advances to the next tile in memory.
|
|
///
|
|
/// The first time this method is called, predicates are updated, and the iterator's
|
|
/// internal pointer is reverted to the first "steady state" tile. Subsequent calls
|
|
/// are lightweight and must only update the internal pointer.
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator operator++(int) {
|
|
PredicatedTileIterator self(*this);
|
|
operator++();
|
|
return self;
|
|
}
|
|
|
|
/// Clears the predicate set efficiently
|
|
CUTLASS_HOST_DEVICE
|
|
void clear_mask(bool enable = true) {
|
|
iterator_.clear_mask(enable);
|
|
}
|
|
|
|
/// Clears the predicate set efficiently
|
|
CUTLASS_HOST_DEVICE
|
|
void enable_mask() {
|
|
iterator_.enable_mask();
|
|
}
|
|
|
|
/// Sets the predicate mask, overriding value stored in predicate iterator
|
|
CUTLASS_HOST_DEVICE
|
|
void set_mask(Mask const &mask) {
|
|
iterator_.set_mask(mask);
|
|
}
|
|
|
|
/// Gets the mask
|
|
CUTLASS_HOST_DEVICE
|
|
void get_mask(Mask &mask) {
|
|
iterator_.get_mask(mask);
|
|
}
|
|
|
|
/// Loads a fragment from memory
|
|
CUTLASS_DEVICE
|
|
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
|
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
|
}
|
|
|
|
/// Loads a fragment from memory
|
|
CUTLASS_DEVICE
|
|
void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
|
|
iterator_.load_with_byte_offset(frag, byte_offset);
|
|
}
|
|
|
|
/// Loads a fragment from memory
|
|
CUTLASS_DEVICE
|
|
void load(Fragment &frag) {
|
|
load_with_pointer_offset(frag, 0);
|
|
}
|
|
|
|
/// Store a fragment to memory
|
|
CUTLASS_DEVICE
|
|
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
|
iterator_.store_with_pointer_offset(frag, pointer_offset);
|
|
}
|
|
|
|
/// Store a fragment to memory
|
|
CUTLASS_DEVICE
|
|
void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
|
|
iterator_.store_with_byte_offset(frag, byte_offset);
|
|
}
|
|
|
|
/// Store a fragment to memory
|
|
CUTLASS_DEVICE
|
|
void store(Fragment const &frag) {
|
|
store_with_pointer_offset(frag, 0);
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/// Specialization of PredicatedTileIterator for pitch-linear data.
|
|
///
|
|
/// Satisfies: ForwardTileIteratorConcept |
|
|
/// ReadableContiguousTileIteratorConcept |
|
|
/// WriteableContiguousTileIteratorConcept |
|
|
/// MaskedTileIteratorConcept
|
|
///
|
|
template <
|
|
typename Shape_,
|
|
typename Element_,
|
|
int AdvanceRank,
|
|
typename ThreadMap_,
|
|
int AccessSize,
|
|
bool Gather
|
|
>
|
|
class PredicatedTileIterator<Shape_, Element_, layout::RowMajor, AdvanceRank, ThreadMap_, AccessSize, Gather> {
|
|
public:
|
|
|
|
static_assert(AdvanceRank == 0 || AdvanceRank == 1,
|
|
"Specialization for pitch-linear iterator may along advance along the "
|
|
"contiguous(rank=0) or strided(rank=1) dimension.");
|
|
|
|
using Shape = Shape_;
|
|
using Element = Element_;
|
|
using Layout = layout::RowMajor;
|
|
static int const kAdvanceRank = AdvanceRank;
|
|
using ThreadMap = ThreadMap_;
|
|
|
|
using Index = typename Layout::Index;
|
|
using LongIndex = typename Layout::LongIndex;
|
|
|
|
using TensorRef = TensorRef<Element, Layout>;
|
|
using TensorView = TensorView<Element, Layout>;
|
|
using TensorCoord = typename Layout::TensorCoord;
|
|
|
|
using Pointer = Element *;
|
|
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
|
|
|
using UnderlyingIterator = PredicatedTileIterator<
|
|
layout::PitchLinearShape<Shape::kColumn, Shape::kRow>,
|
|
Element,
|
|
layout::PitchLinear,
|
|
(kAdvanceRank == 0 ? 1 : 0),
|
|
ThreadMap,
|
|
AccessSize,
|
|
Gather
|
|
>;
|
|
|
|
using AccessType = typename UnderlyingIterator::AccessType;
|
|
|
|
/// Fragment object to be loaded or stored
|
|
using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
|
|
|
|
/// Predicate vector stores mask to guard accesses
|
|
using Mask = typename UnderlyingIterator::Mask;
|
|
|
|
/// Parameters object is precomputed state and is host-constructible
|
|
class Params {
|
|
private:
|
|
|
|
friend PredicatedTileIterator;
|
|
|
|
/// Parameters object
|
|
typename UnderlyingIterator::Params params_;
|
|
|
|
public:
|
|
|
|
/// Default constructor
|
|
Params() = default;
|
|
|
|
/// Construct the Params object given a pitch-linear tensor's layout
|
|
CUTLASS_HOST_DEVICE
|
|
Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) {}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
Params(typename UnderlyingIterator::Params::Base const &base)
|
|
: params_(base) {}
|
|
|
|
};
|
|
|
|
private:
|
|
|
|
//
|
|
// Data members
|
|
//
|
|
|
|
/// Underlying pitch-linear tile iterator
|
|
UnderlyingIterator iterator_;
|
|
|
|
public:
|
|
|
|
/// Default constructor
|
|
PredicatedTileIterator() = default;
|
|
|
|
/// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator(
|
|
Params const ¶ms, ///< Precomputed parameters object
|
|
Pointer pointer, ///< Pointer to start of tensor
|
|
TensorCoord extent, ///< Extent of tensor
|
|
int thread_id, ///< ID of each participating thread
|
|
TensorCoord const &threadblock_offset, ///< Initial offset of threadblock
|
|
int const *indices = nullptr ///< Gather indices
|
|
):
|
|
iterator_(
|
|
params.params_,
|
|
pointer,
|
|
layout::PitchLinearCoord(extent.column(), extent.row()),
|
|
thread_id,
|
|
layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()),
|
|
indices
|
|
) { }
|
|
|
|
/// Construct a PredicatedTileIterator with zero threadblock offset
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator(
|
|
Params const ¶ms, ///< Precomputed parameters object
|
|
Pointer pointer, ///< Pointer to start of tensor
|
|
TensorCoord extent, ///< Extent of tensor
|
|
int thread_id ///< ID of each participating thread
|
|
): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { }
|
|
|
|
/// Adds a pointer offset in units of Element
|
|
CUTLASS_HOST_DEVICE
|
|
void add_pointer_offset(LongIndex pointer_offset) {
|
|
iterator_.add_pointer_offset(pointer_offset);
|
|
}
|
|
|
|
/// Advances to the next tile in memory.
|
|
///
|
|
/// The first time this method is called, predicates are updated, and the iterator's
|
|
/// internal pointer is reverted to the first "steady state" tile. Subsequent calls
|
|
/// are lightweight and must only update the internal pointer.
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator &operator++() {
|
|
++iterator_;
|
|
return *this;
|
|
}
|
|
|
|
/// Advances to the next tile in memory.
|
|
///
|
|
/// The first time this method is called, predicates are updated, and the iterator's
|
|
/// internal pointer is reverted to the first "steady state" tile. Subsequent calls
|
|
/// are lightweight and must only update the internal pointer.
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator operator++(int) {
|
|
PredicatedTileIterator self(*this);
|
|
operator++();
|
|
return self;
|
|
}
|
|
|
|
/// Clears the predicate set efficiently
|
|
CUTLASS_HOST_DEVICE
|
|
void clear_mask(bool enable = true) {
|
|
iterator_.clear_mask(enable);
|
|
}
|
|
|
|
/// Clears the predicate set efficiently
|
|
CUTLASS_HOST_DEVICE
|
|
void enable_mask() {
|
|
iterator_.enable_mask();
|
|
}
|
|
|
|
/// Sets the predicate mask, overriding value stored in predicate iterator
|
|
CUTLASS_HOST_DEVICE
|
|
void set_mask(Mask const &mask) {
|
|
iterator_.set_mask(mask);
|
|
}
|
|
|
|
/// Gets the mask
|
|
CUTLASS_HOST_DEVICE
|
|
void get_mask(Mask &mask) {
|
|
iterator_.get_mask(mask);
|
|
}
|
|
|
|
/// Loads a fragment from memory
|
|
CUTLASS_DEVICE
|
|
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
|
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
|
}
|
|
|
|
/// Loads a fragment from memory
|
|
CUTLASS_DEVICE
|
|
void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
|
|
iterator_.load_with_byte_offset(frag, byte_offset);
|
|
}
|
|
|
|
/// Loads a fragment from memory
|
|
CUTLASS_DEVICE
|
|
void load(Fragment &frag) {
|
|
load_with_pointer_offset(frag, 0);
|
|
}
|
|
|
|
/// Store a fragment to memory
|
|
CUTLASS_DEVICE
|
|
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
|
iterator_.store_with_pointer_offset(frag, pointer_offset);
|
|
}
|
|
|
|
/// Store a fragment to memory
|
|
CUTLASS_DEVICE
|
|
void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
|
|
iterator_.store_with_byte_offset(frag, byte_offset);
|
|
}
|
|
|
|
/// Store a fragment to memory
|
|
CUTLASS_DEVICE
|
|
void store(Fragment const &frag) {
|
|
store_with_pointer_offset(frag, 0);
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/// Specialization of PredicatedTileIterator for affine rank-2 data.
|
|
///
|
|
/// Satisfies: ForwardTileIteratorConcept |
|
|
/// ReadableContiguousTileIteratorConcept |
|
|
/// WriteableContiguousTileIteratorConcept |
|
|
/// MaskedTileIteratorConcept
|
|
///
|
|
template <typename Shape_, typename Element_, int AdvanceRank,
|
|
typename ThreadMap_, int AccessSize>
|
|
class PredicatedTileIterator<Shape_, Element_, layout::AffineRankN<2>, AdvanceRank,
|
|
ThreadMap_, AccessSize, false> {
|
|
public:
|
|
static_assert(
|
|
AdvanceRank == 0 || AdvanceRank == 1,
|
|
"Specialization for pitch-linear iterator may advance along the "
|
|
"contiguous(rank=0) or strided(rank=1) dimension.");
|
|
|
|
using Shape = Shape_;
|
|
using Element = Element_;
|
|
using Layout = layout::AffineRankN<2>;
|
|
static int const kAdvanceRank = AdvanceRank;
|
|
using ThreadMap = ThreadMap_;
|
|
|
|
using Index = typename Layout::Index;
|
|
using LongIndex = typename Layout::LongIndex;
|
|
|
|
using TensorRef = TensorRef<Element, Layout>;
|
|
using TensorView = TensorView<Element, Layout>;
|
|
using TensorCoord = typename Layout::TensorCoord;
|
|
|
|
using Pointer = Element *;
|
|
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
|
|
|
/// Type used for internal memory accesses
|
|
using AccessType = AlignedArray<Element, AccessSize, (AccessSize * sizeof_bits<Element>::value / 8)>;
|
|
|
|
/// Underlying iterator to compute the addresses
|
|
using TileAccessIterator =
|
|
PredicatedTileAccessIterator<Shape, Element, Layout, kAdvanceRank,
|
|
ThreadMap, AccessType>;
|
|
|
|
static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector;
|
|
|
|
/// Fragment object to be loaded or stored
|
|
using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount *
|
|
ThreadMap::kElementsPerAccess>;
|
|
|
|
/// Predicate vector stores mask to guard accesses
|
|
using Mask = typename TileAccessIterator::Mask;
|
|
|
|
/// Parameters object is precomputed state and is host-constructible
|
|
class Params {
|
|
public:
|
|
|
|
friend PredicatedTileIterator;
|
|
|
|
private:
|
|
/// Parameters object
|
|
typename TileAccessIterator::Params params_;
|
|
|
|
public:
|
|
/// Construct the Params object given a pitch-linear tensor's layout
|
|
CUTLASS_HOST_DEVICE
|
|
Params(Layout const &layout) : params_(layout) {}
|
|
|
|
/// Default constructor
|
|
Params() = default;
|
|
};
|
|
|
|
private:
|
|
/// Internal pointer type permits fast address arithmetic
|
|
using BytePointer = char *;
|
|
|
|
private:
|
|
//
|
|
// Data members
|
|
//
|
|
|
|
/// Data member to the tile access iterator
|
|
TileAccessIterator address_iterator_;
|
|
|
|
public:
|
|
|
|
/// Default constructor
|
|
PredicatedTileIterator() = default;
|
|
|
|
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
|
/// and thread ID
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator(
|
|
/// Precomputed parameters object
|
|
Params const ¶ms,
|
|
/// Pointer to start of tensor
|
|
Pointer pointer,
|
|
/// Extent of tensor
|
|
TensorCoord extent,
|
|
/// ID of each participating thread
|
|
int thread_id,
|
|
/// Initial offset of threadblock
|
|
TensorCoord const &threadblock_offset,
|
|
int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
|
|
)
|
|
: address_iterator_(params.params_, pointer, extent, thread_id,
|
|
threadblock_offset) {}
|
|
|
|
/// Construct a PredicatedTileIterator with zero threadblock offset
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator(
|
|
Params const ¶ms, ///< Precomputed parameters object
|
|
Pointer pointer, ///< Pointer to start of tensor
|
|
TensorCoord extent, ///< Extent of tensor
|
|
int thread_id ///< ID of each participating thread
|
|
)
|
|
: PredicatedTileIterator(params, pointer, extent, thread_id,
|
|
make_Coord(0, 0)) {}
|
|
|
|
/// Adds a pointer offset in units of Element
|
|
CUTLASS_HOST_DEVICE
|
|
void add_pointer_offset(LongIndex pointer_offset) {
|
|
address_iterator_.add_pointer_offset(pointer_offset);
|
|
}
|
|
|
|
/// Advances to the next tile in memory.
|
|
///
|
|
/// The first time this method is called, predicates are updated, and the
|
|
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
|
/// Subsequent calls are lightweight and must only update the internal
|
|
/// pointer.
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator &operator++() {
|
|
if (kAdvanceRank)
|
|
address_iterator_.add_tile_offset(make_Coord(0, 1));
|
|
else
|
|
address_iterator_.add_tile_offset(make_Coord(1, 0));
|
|
|
|
return *this;
|
|
}
|
|
|
|
/// Advances to the next tile in memory.
|
|
///
|
|
/// The first time this method is called, predicates are updated, and the
|
|
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
|
/// Subsequent calls are lightweight and must only update the internal
|
|
/// pointer.
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator operator++(int) {
|
|
PredicatedTileIterator self(*this);
|
|
operator++();
|
|
return self;
|
|
}
|
|
|
|
/// Clears the predicate set efficiently
|
|
CUTLASS_HOST_DEVICE
|
|
void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); }
|
|
|
|
/// Clears the predicate set efficiently
|
|
CUTLASS_HOST_DEVICE
|
|
void enable_mask() { address_iterator_.enable_mask(); }
|
|
|
|
/// Sets the predicate mask, overriding value stored in predicate iterator
|
|
CUTLASS_HOST_DEVICE
|
|
void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); }
|
|
|
|
/// Gets the mask
|
|
CUTLASS_HOST_DEVICE
|
|
void get_mask(Mask &mask) { address_iterator_.get_mask(mask); }
|
|
|
|
CUTLASS_DEVICE
|
|
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
|
load_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
|
|
}
|
|
|
|
CUTLASS_DEVICE
|
|
void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
|
|
|
|
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
|
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
|
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int v = 0; v < kAccessesPerVector; ++v) {
|
|
|
|
int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
|
|
|
|
address_iterator_.set_iteration_index(idx);
|
|
char const *byte_ptr = reinterpret_cast<char const *>(address_iterator_.get()) + byte_offset;
|
|
|
|
AccessType const *access_ptr = reinterpret_cast<AccessType const *>(byte_ptr);
|
|
|
|
cutlass::arch::global_load<AccessType,
|
|
sizeof(AccessType)
|
|
>(
|
|
frag_ptr[idx], access_ptr, address_iterator_.valid());
|
|
|
|
++address_iterator_;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Loads a fragment from memory
|
|
CUTLASS_DEVICE
|
|
void load(Fragment &frag) { load_with_byte_offset(frag, 0); }
|
|
|
|
/// Store a fragment to memory
|
|
CUTLASS_DEVICE
|
|
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
|
store_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
|
|
}
|
|
|
|
/// Store a fragment to memory
|
|
CUTLASS_DEVICE
|
|
void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
|
|
address_iterator_.set_iteration_index(0);
|
|
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
|
|
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int v = 0; v < kAccessesPerVector; ++v) {
|
|
|
|
int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
|
|
|
|
char *byte_ptr = reinterpret_cast<char *>(address_iterator_.get()) + byte_offset;
|
|
AccessType *access_ptr = reinterpret_cast<AccessType *>(byte_ptr);
|
|
|
|
if (address_iterator_.valid()) {
|
|
*access_ptr = frag_ptr[idx];
|
|
}
|
|
++address_iterator_;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Store a fragment to memory
|
|
CUTLASS_DEVICE
|
|
void store(Fragment const &frag) { store_with_byte_offset(frag, 0); }
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/// Specialization of PredicatedTileIterator for affine rank 2 column-major data.
|
|
///
|
|
/// Satisfies: ForwardTileIteratorConcept |
|
|
/// ReadableContiguousTileIteratorConcept |
|
|
/// WriteableContiguousTileIteratorConcept |
|
|
/// MaskedTileIteratorConcept
|
|
///
|
|
template <
|
|
typename Shape_,
|
|
typename Element_,
|
|
int AdvanceRank,
|
|
typename ThreadMap_,
|
|
int AccessSize
|
|
>
|
|
class PredicatedTileIterator<Shape_, Element_, layout::AffineRank2ColumnMajor, AdvanceRank, ThreadMap_, AccessSize, false> {
|
|
public:
|
|
|
|
static_assert(AdvanceRank == 0 || AdvanceRank == 1,
|
|
"Specialization for pitch-linear iterator may along advance along the "
|
|
"contiguous(rank=0) or strided(rank=1) dimension.");
|
|
|
|
using Shape = Shape_;
|
|
using Element = Element_;
|
|
using Layout = layout::AffineRank2ColumnMajor;
|
|
static int const kAdvanceRank = AdvanceRank;
|
|
using ThreadMap = ThreadMap_;
|
|
|
|
using Index = typename Layout::Index;
|
|
using LongIndex = typename Layout::LongIndex;
|
|
|
|
using TensorRef = TensorRef<Element, Layout>;
|
|
using TensorView = TensorView<Element, Layout>;
|
|
using TensorCoord = typename Layout::TensorCoord;
|
|
|
|
using Pointer = Element *;
|
|
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
|
|
|
// Map to the underlying AffineRankN<2> layout
|
|
using UnderlyingIterator = PredicatedTileIterator<
|
|
layout::PitchLinearShape<Shape::kRow, Shape::kColumn>,
|
|
Element,
|
|
layout::AffineRankN<2>,
|
|
(kAdvanceRank == 0 ? 0 : 1),
|
|
ThreadMap,
|
|
AccessSize
|
|
>;
|
|
|
|
using AccessType = typename UnderlyingIterator::AccessType;
|
|
|
|
/// Fragment object to be loaded or stored
|
|
using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
|
|
|
|
/// Predicate vector stores mask to guard accesses
|
|
using Mask = typename UnderlyingIterator::Mask;
|
|
|
|
/// Parameters object is precomputed state and is host-constructible
|
|
class Params {
|
|
private:
|
|
|
|
friend PredicatedTileIterator;
|
|
|
|
/// Parameters object
|
|
typename UnderlyingIterator::Params params_;
|
|
|
|
public:
|
|
|
|
/// Default constructor
|
|
Params() = default;
|
|
|
|
/// Construct the Params object given an AffineRankN<2> tensor's layout
|
|
CUTLASS_HOST_DEVICE
|
|
Params(Layout const &layout): params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1)))
|
|
{}
|
|
};
|
|
|
|
private:
|
|
|
|
//
|
|
// Data members
|
|
//
|
|
|
|
/// Underlying AffineRankN<2> tile iterator
|
|
UnderlyingIterator iterator_;
|
|
|
|
public:
|
|
|
|
/// Default constructor
|
|
PredicatedTileIterator() = default;
|
|
|
|
/// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator(
|
|
Params const ¶ms, ///< Precomputed parameters object
|
|
Pointer pointer, ///< Pointer to start of tensor
|
|
TensorCoord extent, ///< Extent of tensor
|
|
int thread_id, ///< ID of each participating thread
|
|
TensorCoord const &threadblock_offset, ///< Initial offset of threadblock
|
|
int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
|
|
):
|
|
iterator_(
|
|
params.params_,
|
|
pointer,
|
|
layout::PitchLinearCoord(extent.row(), extent.column()),
|
|
thread_id,
|
|
layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column())
|
|
) { }
|
|
|
|
/// Construct a PredicatedTileIterator with zero threadblock offset
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator(
|
|
Params const ¶ms, ///< Precomputed parameters object
|
|
Pointer pointer, ///< Pointer to start of tensor
|
|
TensorCoord extent, ///< Extent of tensor
|
|
int thread_id ///< ID of each participating thread
|
|
): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { }
|
|
|
|
/// Adds a pointer offset in units of Element
|
|
CUTLASS_HOST_DEVICE
|
|
void add_pointer_offset(LongIndex pointer_offset) {
|
|
iterator_.add_pointer_offset(pointer_offset);
|
|
}
|
|
|
|
/// Advances to the next tile in memory.
|
|
///
|
|
/// The first time this method is called, predicates are updated, and the iterator's
|
|
/// internal pointer is reverted to the first "steady state" tile. Subsequent calls
|
|
/// are lightweight and must only update the internal pointer.
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator &operator++() {
|
|
++iterator_;
|
|
return *this;
|
|
}
|
|
|
|
/// Advances to the next tile in memory.
|
|
///
|
|
/// The first time this method is called, predicates are updated, and the iterator's
|
|
/// internal pointer is reverted to the first "steady state" tile. Subsequent calls
|
|
/// are lightweight and must only update the internal pointer.
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator operator++(int) {
|
|
PredicatedTileIterator self(*this);
|
|
operator++();
|
|
return self;
|
|
}
|
|
|
|
/// Clears the predicate set efficiently
|
|
CUTLASS_HOST_DEVICE
|
|
void clear_mask(bool enable = true) {
|
|
iterator_.clear_mask(enable);
|
|
}
|
|
|
|
/// Clears the predicate set efficiently
|
|
CUTLASS_HOST_DEVICE
|
|
void enable_mask() {
|
|
iterator_.enable_mask();
|
|
}
|
|
|
|
/// Sets the predicate mask, overriding value stored in predicate iterator
|
|
CUTLASS_HOST_DEVICE
|
|
void set_mask(Mask const &mask) {
|
|
iterator_.set_mask(mask);
|
|
}
|
|
|
|
/// Gets the mask
|
|
CUTLASS_HOST_DEVICE
|
|
void get_mask(Mask &mask) {
|
|
iterator_.get_mask(mask);
|
|
}
|
|
|
|
/// Loads a fragment from memory
|
|
CUTLASS_DEVICE
|
|
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
|
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
|
}
|
|
|
|
/// Loads a fragment from memory
|
|
CUTLASS_DEVICE
|
|
void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
|
|
iterator_.load_with_byte_offset(frag, byte_offset);
|
|
}
|
|
|
|
/// Loads a fragment from memory
|
|
CUTLASS_DEVICE
|
|
void load(Fragment &frag) {
|
|
load_with_pointer_offset(frag, 0);
|
|
}
|
|
|
|
/// Store a fragment to memory
|
|
CUTLASS_DEVICE
|
|
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
|
iterator_.store_with_pointer_offset(frag, pointer_offset);
|
|
}
|
|
|
|
/// Store a fragment to memory
|
|
CUTLASS_DEVICE
|
|
void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
|
|
iterator_.store_with_byte_offset(frag, byte_offset);
|
|
}
|
|
|
|
/// Store a fragment to memory
|
|
CUTLASS_DEVICE
|
|
void store(Fragment const &frag) {
|
|
store_with_pointer_offset(frag, 0);
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/// Specialization of PredicatedTileIterator for affine rank 2 row-major data.
|
|
///
|
|
/// Satisfies: ForwardTileIteratorConcept |
|
|
/// ReadableContiguousTileIteratorConcept |
|
|
/// WriteableContiguousTileIteratorConcept |
|
|
/// MaskedTileIteratorConcept
|
|
///
|
|
template <
|
|
typename Shape_,
|
|
typename Element_,
|
|
int AdvanceRank,
|
|
typename ThreadMap_,
|
|
int AccessSize
|
|
>
|
|
class PredicatedTileIterator<Shape_, Element_, layout::AffineRank2RowMajor, AdvanceRank, ThreadMap_, AccessSize, false> {
|
|
public:
|
|
|
|
static_assert(AdvanceRank == 0 || AdvanceRank == 1,
|
|
"Specialization for pitch-linear iterator may along advance along the "
|
|
"contiguous(rank=0) or strided(rank=1) dimension.");
|
|
|
|
using Shape = Shape_;
|
|
using Element = Element_;
|
|
using Layout = layout::AffineRank2RowMajor;
|
|
static int const kAdvanceRank = AdvanceRank;
|
|
using ThreadMap = ThreadMap_;
|
|
|
|
using Index = typename Layout::Index;
|
|
using LongIndex = typename Layout::LongIndex;
|
|
|
|
using TensorRef = TensorRef<Element, Layout>;
|
|
using TensorView = TensorView<Element, Layout>;
|
|
using TensorCoord = typename Layout::TensorCoord;
|
|
|
|
using Pointer = Element *;
|
|
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
|
|
|
// Map to the underlying AffineRankN<2> layout
|
|
using UnderlyingIterator = PredicatedTileIterator<
|
|
layout::PitchLinearShape<Shape::kColumn, Shape::kRow>,
|
|
Element,
|
|
layout::AffineRankN<2>,
|
|
(kAdvanceRank == 0 ? 1 : 0),
|
|
ThreadMap,
|
|
AccessSize
|
|
>;
|
|
|
|
using AccessType = typename UnderlyingIterator::AccessType;
|
|
|
|
/// Fragment object to be loaded or stored
|
|
using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
|
|
|
|
/// Predicate vector stores mask to guard accesses
|
|
using Mask = typename UnderlyingIterator::Mask;
|
|
|
|
/// Parameters object is precomputed state and is host-constructible
|
|
class Params {
|
|
private:
|
|
|
|
friend PredicatedTileIterator;
|
|
|
|
/// Parameters object
|
|
typename UnderlyingIterator::Params params_;
|
|
|
|
public:
|
|
|
|
/// Default constructor
|
|
Params() = default;
|
|
|
|
/// Construct the Params object given an AffineRankN<2> tensor's layout
|
|
CUTLASS_HOST_DEVICE
|
|
Params(Layout const &layout): params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))) {}
|
|
};
|
|
|
|
|
|
private:
|
|
|
|
//
|
|
// Data members
|
|
//
|
|
|
|
/// Underlying AffineRankN<2> tile iterator
|
|
UnderlyingIterator iterator_;
|
|
|
|
public:
|
|
|
|
/// Default constructor
|
|
PredicatedTileIterator() = default;
|
|
|
|
/// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator(
|
|
Params const ¶ms, ///< Precomputed parameters object
|
|
Pointer pointer, ///< Pointer to start of tensor
|
|
TensorCoord extent, ///< Extent of tensor
|
|
int thread_id, ///< ID of each participating thread
|
|
TensorCoord const &threadblock_offset, ///< Initial offset of threadblock
|
|
int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
|
|
):
|
|
iterator_(
|
|
params.params_,
|
|
pointer,
|
|
layout::PitchLinearCoord(extent.column(), extent.row()),
|
|
thread_id,
|
|
layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row())
|
|
) { }
|
|
|
|
/// Construct a PredicatedTileIterator with zero threadblock offset
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator(
|
|
Params const ¶ms, ///< Precomputed parameters object
|
|
Pointer pointer, ///< Pointer to start of tensor
|
|
TensorCoord extent, ///< Extent of tensor
|
|
int thread_id ///< ID of each participating thread
|
|
): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { }
|
|
|
|
/// Adds a pointer offset in units of Element
|
|
CUTLASS_HOST_DEVICE
|
|
void add_pointer_offset(LongIndex pointer_offset) {
|
|
iterator_.add_pointer_offset(pointer_offset);
|
|
}
|
|
|
|
/// Advances to the next tile in memory.
|
|
///
|
|
/// The first time this method is called, predicates are updated, and the iterator's
|
|
/// internal pointer is reverted to the first "steady state" tile. Subsequent calls
|
|
/// are lightweight and must only update the internal pointer.
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator &operator++() {
|
|
++iterator_;
|
|
return *this;
|
|
}
|
|
|
|
/// Advances to the next tile in memory.
|
|
///
|
|
/// The first time this method is called, predicates are updated, and the iterator's
|
|
/// internal pointer is reverted to the first "steady state" tile. Subsequent calls
|
|
/// are lightweight and must only update the internal pointer.
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator operator++(int) {
|
|
PredicatedTileIterator self(*this);
|
|
operator++();
|
|
return self;
|
|
}
|
|
|
|
/// Clears the predicate set efficiently
|
|
CUTLASS_HOST_DEVICE
|
|
void clear_mask(bool enable = true) {
|
|
iterator_.clear_mask(enable);
|
|
}
|
|
|
|
/// Clears the predicate set efficiently
|
|
CUTLASS_HOST_DEVICE
|
|
void enable_mask() {
|
|
iterator_.enable_mask();
|
|
}
|
|
|
|
/// Sets the predicate mask, overriding value stored in predicate iterator
|
|
CUTLASS_HOST_DEVICE
|
|
void set_mask(Mask const &mask) {
|
|
iterator_.set_mask(mask);
|
|
}
|
|
|
|
/// Gets the mask
|
|
CUTLASS_HOST_DEVICE
|
|
void get_mask(Mask &mask) {
|
|
iterator_.get_mask(mask);
|
|
}
|
|
|
|
/// Loads a fragment from memory
|
|
CUTLASS_DEVICE
|
|
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
|
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
|
}
|
|
|
|
/// Loads a fragment from memory
|
|
CUTLASS_DEVICE
|
|
void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
|
|
iterator_.load_with_byte_offset(frag, byte_offset);
|
|
}
|
|
|
|
/// Loads a fragment from memory
|
|
CUTLASS_DEVICE
|
|
void load(Fragment &frag) {
|
|
load_with_pointer_offset(frag, 0);
|
|
}
|
|
|
|
/// Store a fragment to memory
|
|
CUTLASS_DEVICE
|
|
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
|
iterator_.store_with_pointer_offset(frag, pointer_offset);
|
|
}
|
|
|
|
/// Store a fragment to memory
|
|
CUTLASS_DEVICE
|
|
void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
|
|
iterator_.store_with_byte_offset(frag, byte_offset);
|
|
}
|
|
|
|
/// Store a fragment to memory
|
|
CUTLASS_DEVICE
|
|
void store(Fragment const &frag) {
|
|
store_with_pointer_offset(frag, 0);
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/// Specialization of PredicatedTileIterator for interleaved data. It is mapped
|
|
/// to the congruous layout.
|
|
///
|
|
/// Satisfies: ForwardTileIteratorConcept |
|
|
/// ReadableContiguousTileIteratorConcept |
|
|
/// WriteableContiguousTileIteratorConcept |
|
|
/// MaskedTileIteratorConcept
|
|
///
|
|
|
|
template <typename Shape_, typename Element_, int AdvanceRank,
|
|
typename ThreadMap_, int AccessSize, int InterleavedK>
|
|
class PredicatedTileIterator<Shape_, Element_,
|
|
layout::ColumnMajorInterleaved<InterleavedK>,
|
|
AdvanceRank, ThreadMap_, AccessSize, false> {
|
|
public:
|
|
static_assert(
|
|
AdvanceRank == 0 || AdvanceRank == 1,
|
|
"Specialization for pitch-linear iterator may along advance along the "
|
|
"contiguous(rank=0) or strided(rank=1) dimension.");
|
|
|
|
using Shape = Shape_;
|
|
using Element = Element_;
|
|
static int const kInterleavedK = InterleavedK;
|
|
using Layout = layout::ColumnMajorInterleaved<kInterleavedK>;
|
|
static int const kAdvanceRank = AdvanceRank;
|
|
using ThreadMap = ThreadMap_;
|
|
|
|
using Index = typename Layout::Index;
|
|
using LongIndex = typename Layout::LongIndex;
|
|
|
|
using TensorRef = TensorRef<Element, Layout>;
|
|
using TensorView = TensorView<Element, Layout>;
|
|
using TensorCoord = typename Layout::TensorCoord;
|
|
|
|
using Pointer = Element *;
|
|
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
|
|
|
using UnderlyingIterator = PredicatedTileIterator<
|
|
layout::PitchLinearShape<Shape::kRow * kInterleavedK,
|
|
Shape::kColumn / kInterleavedK>,
|
|
Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessSize>;
|
|
|
|
|
|
using AccessType = typename UnderlyingIterator::AccessType;
|
|
|
|
/// Fragment object to be loaded or stored
|
|
using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount *
|
|
ThreadMap::kElementsPerAccess>;
|
|
|
|
/// Predicate vector stores mask to guard accesses
|
|
using Mask = typename UnderlyingIterator::Mask;
|
|
|
|
/// Parameters object is precomputed state and is host-constructible
|
|
class Params {
|
|
private:
|
|
friend PredicatedTileIterator;
|
|
|
|
/// Parameters object
|
|
typename UnderlyingIterator::Params params_;
|
|
|
|
public:
|
|
|
|
/// Default constructor
|
|
Params() = default;
|
|
|
|
/// Construct the Params object given a pitch-linear tensor's layout
|
|
CUTLASS_HOST_DEVICE
|
|
Params(Layout const &layout)
|
|
: params_(layout::PitchLinear(layout.stride(0))) {}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
Params(typename UnderlyingIterator::Params::Base const &base)
|
|
: params_(base) {}
|
|
|
|
};
|
|
|
|
private:
|
|
//
|
|
// Data members
|
|
//
|
|
|
|
/// Underlying pitch-linear tile iterator
|
|
UnderlyingIterator iterator_;
|
|
|
|
public:
|
|
|
|
/// Default constructor
|
|
PredicatedTileIterator() = default;
|
|
|
|
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
|
/// and thread ID
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator(
|
|
/// Precomputed parameters object
|
|
Params const ¶ms,
|
|
/// Pointer to start of tensor
|
|
Pointer pointer,
|
|
/// Extent of tensor
|
|
TensorCoord extent,
|
|
/// ID of each participating thread
|
|
int thread_id,
|
|
/// Initial offset of threadblock
|
|
TensorCoord const &threadblock_offset,
|
|
int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
|
|
)
|
|
: iterator_(params.params_, pointer,
|
|
layout::PitchLinearCoord(extent.row() * kInterleavedK,
|
|
extent.column() / kInterleavedK),
|
|
thread_id,
|
|
layout::PitchLinearCoord(
|
|
threadblock_offset.row() * kInterleavedK,
|
|
threadblock_offset.column() / kInterleavedK)) {}
|
|
|
|
/// Construct a PredicatedTileIterator with zero threadblock offset
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator(
|
|
Params const ¶ms, ///< Precomputed parameters object
|
|
Pointer pointer, ///< Pointer to start of tensor
|
|
TensorCoord extent, ///< Extent of tensor
|
|
int thread_id ///< ID of each participating thread
|
|
)
|
|
: PredicatedTileIterator(params, pointer, extent, thread_id,
|
|
make_Coord(0, 0)) {}
|
|
|
|
/// Adds a pointer offset in units of Element
|
|
CUTLASS_HOST_DEVICE
|
|
void add_pointer_offset(LongIndex pointer_offset) {
|
|
iterator_.add_pointer_offset(pointer_offset);
|
|
}
|
|
|
|
/// Advances to the next tile in memory.
|
|
///
|
|
/// The first time this method is called, predicates are updated, and the
|
|
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
|
/// Subsequent calls are lightweight and must only update the internal
|
|
/// pointer.
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator &operator++() {
|
|
++iterator_;
|
|
return *this;
|
|
}
|
|
|
|
/// Advances to the next tile in memory.
|
|
///
|
|
/// The first time this method is called, predicates are updated, and the
|
|
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
|
/// Subsequent calls are lightweight and must only update the internal
|
|
/// pointer.
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator operator++(int) {
|
|
PredicatedTileIterator self(*this);
|
|
operator++();
|
|
return self;
|
|
}
|
|
|
|
/// Clears the predicate set efficiently
|
|
CUTLASS_HOST_DEVICE
|
|
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
|
|
|
/// Clears the predicate set efficiently
|
|
CUTLASS_HOST_DEVICE
|
|
void enable_mask() { iterator_.enable_mask(); }
|
|
|
|
/// Sets the predicate mask, overriding value stored in predicate iterator
|
|
CUTLASS_HOST_DEVICE
|
|
void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
|
|
|
|
/// Gets the mask
|
|
CUTLASS_HOST_DEVICE
|
|
void get_mask(Mask &mask) { iterator_.get_mask(mask); }
|
|
|
|
/// Loads a fragment from memory
|
|
CUTLASS_DEVICE
|
|
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
|
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
|
}
|
|
|
|
/// Loads a fragment from memory
|
|
CUTLASS_DEVICE
|
|
void load(Fragment &frag) { load_with_pointer_offset(frag, 0); }
|
|
|
|
/// Store a fragment to memory
|
|
CUTLASS_DEVICE
|
|
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
|
iterator_.store_with_pointer_offset(frag, pointer_offset);
|
|
}
|
|
|
|
/// Store a fragment to memory
|
|
CUTLASS_DEVICE
|
|
void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); }
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/// Specialization of PredicatedTileIterator for interleaved-32 data. It is
|
|
/// mapped to the congruous layout.
|
|
///
|
|
/// Satisfies: ForwardTileIteratorConcept |
|
|
/// ReadableContiguousTileIteratorConcept |
|
|
/// WriteableContiguousTileIteratorConcept |
|
|
/// MaskedTileIteratorConcept
|
|
///
|
|
template <typename Shape_, typename Element_, int AdvanceRank,
|
|
typename ThreadMap_, int AccessSize, int InterleavedK>
|
|
class PredicatedTileIterator<Shape_, Element_,
|
|
layout::RowMajorInterleaved<InterleavedK>,
|
|
AdvanceRank, ThreadMap_, AccessSize, false> {
|
|
public:
|
|
static_assert(
|
|
AdvanceRank == 0 || AdvanceRank == 1,
|
|
"Specialization for pitch-linear iterator may along advance along the "
|
|
"contiguous(rank=0) or strided(rank=1) dimension.");
|
|
|
|
using Shape = Shape_;
|
|
using Element = Element_;
|
|
static int const kInterleavedK = InterleavedK;
|
|
using Layout = layout::RowMajorInterleaved<kInterleavedK>;
|
|
static int const kAdvanceRank = AdvanceRank;
|
|
using ThreadMap = ThreadMap_;
|
|
|
|
using Index = typename Layout::Index;
|
|
using LongIndex = typename Layout::LongIndex;
|
|
|
|
using TensorRef = TensorRef<Element, Layout>;
|
|
using TensorView = TensorView<Element, Layout>;
|
|
using TensorCoord = typename Layout::TensorCoord;
|
|
|
|
using Pointer = Element *;
|
|
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
|
|
|
using UnderlyingIterator = PredicatedTileIterator<
|
|
layout::PitchLinearShape<Shape::kColumn * kInterleavedK,
|
|
Shape::kRow / kInterleavedK>,
|
|
Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessSize>;
|
|
|
|
|
|
using AccessType = typename UnderlyingIterator::AccessType;
|
|
|
|
/// Fragment object to be loaded or stored
|
|
using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount *
|
|
ThreadMap::kElementsPerAccess>;
|
|
|
|
/// Predicate vector stores mask to guard accesses
|
|
using Mask = typename UnderlyingIterator::Mask;
|
|
|
|
/// Parameters object is precomputed state and is host-constructible
|
|
class Params {
|
|
private:
|
|
friend PredicatedTileIterator;
|
|
|
|
/// Parameters object
|
|
typename UnderlyingIterator::Params params_;
|
|
|
|
public:
|
|
|
|
/// Default constructor
|
|
Params() = default;
|
|
|
|
/// Construct the Params object given a pitch-linear tensor's layout
|
|
CUTLASS_HOST_DEVICE
|
|
Params(Layout const &layout)
|
|
: params_(layout::PitchLinear(layout.stride(0))) {}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
Params(typename UnderlyingIterator::Params::Base const &base)
|
|
: params_(base) {}
|
|
};
|
|
|
|
private:
|
|
//
|
|
// Data members
|
|
//
|
|
|
|
/// Underlying pitch-linear tile iterator
|
|
UnderlyingIterator iterator_;
|
|
|
|
public:
|
|
|
|
/// Default constructor
|
|
PredicatedTileIterator() = default;
|
|
|
|
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
|
/// and thread ID
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator(
|
|
/// Precomputed parameters object
|
|
Params const ¶ms,
|
|
/// Pointer to start of tensor
|
|
Pointer pointer,
|
|
/// Extent of tensor
|
|
TensorCoord extent,
|
|
/// ID of each participating thread
|
|
int thread_id,
|
|
/// Initial offset of threadblock
|
|
TensorCoord const &threadblock_offset,
|
|
int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
|
|
)
|
|
: iterator_(params.params_, pointer,
|
|
layout::PitchLinearCoord(extent.column() * kInterleavedK,
|
|
extent.row() / kInterleavedK),
|
|
thread_id,
|
|
layout::PitchLinearCoord(
|
|
threadblock_offset.column() * kInterleavedK,
|
|
threadblock_offset.row() / kInterleavedK)) {}
|
|
|
|
/// Construct a PredicatedTileIterator with zero threadblock offset
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator(
|
|
Params const ¶ms, ///< Precomputed parameters object
|
|
Pointer pointer, ///< Pointer to start of tensor
|
|
TensorCoord extent, ///< Extent of tensor
|
|
int thread_id ///< ID of each participating thread
|
|
)
|
|
: PredicatedTileIterator(params, pointer, extent, thread_id,
|
|
make_Coord(0, 0)) {}
|
|
|
|
/// Adds a pointer offset in units of Element
|
|
CUTLASS_HOST_DEVICE
|
|
void add_pointer_offset(LongIndex pointer_offset) {
|
|
iterator_.add_pointer_offset(pointer_offset);
|
|
}
|
|
|
|
/// Advances to the next tile in memory.
|
|
///
|
|
/// The first time this method is called, predicates are updated, and the
|
|
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
|
/// Subsequent calls are lightweight and must only update the internal
|
|
/// pointer.
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator &operator++() {
|
|
++iterator_;
|
|
return *this;
|
|
}
|
|
|
|
/// Advances to the next tile in memory.
|
|
///
|
|
/// The first time this method is called, predicates are updated, and the
|
|
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
|
/// Subsequent calls are lightweight and must only update the internal
|
|
/// pointer.
|
|
CUTLASS_HOST_DEVICE
|
|
PredicatedTileIterator operator++(int) {
|
|
PredicatedTileIterator self(*this);
|
|
operator++();
|
|
return self;
|
|
}
|
|
|
|
/// Clears the predicate set efficiently
|
|
CUTLASS_HOST_DEVICE
|
|
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
|
|
|
/// Clears the predicate set efficiently
|
|
CUTLASS_HOST_DEVICE
|
|
void enable_mask() { iterator_.enable_mask(); }
|
|
|
|
/// Sets the predicate mask, overriding value stored in predicate iterator
|
|
CUTLASS_HOST_DEVICE
|
|
void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
|
|
|
|
/// Gets the mask
|
|
CUTLASS_HOST_DEVICE
|
|
void get_mask(Mask &mask) { iterator_.get_mask(mask); }
|
|
|
|
/// Loads a fragment from memory
|
|
CUTLASS_DEVICE
|
|
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
|
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
|
}
|
|
|
|
/// Loads a fragment from memory
|
|
CUTLASS_DEVICE
|
|
void load(Fragment &frag) { load_with_pointer_offset(frag, 0); }
|
|
|
|
/// Store a fragment to memory
|
|
CUTLASS_DEVICE
|
|
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
|
iterator_.store_with_pointer_offset(frag, pointer_offset);
|
|
}
|
|
|
|
/// Store a fragment to memory
|
|
CUTLASS_DEVICE
|
|
void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); }
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
} // namespace threadblock
|
|
} // namespace transform
|
|
} // namespace cutlass
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|