CUTLASS 2.1 (#83)

CUTLASS 2.1 contributes:
- BLAS-style host-side API added to CUTLASS Library
- Planar Complex GEMM kernels targeting Volta and Turing Tensor Cores
- Minor enhancements and bug fixes
This commit is contained in:
Andrew Kerr
2020-04-07 13:51:25 -07:00
committed by GitHub
parent 7c0cd26d13
commit 96dab34ad9
196 changed files with 20653 additions and 1995 deletions

View File

@ -128,13 +128,13 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
int stride_;
/// amount (in byte) to increment pointer to move to next access along
/// strided dimension
int inc_strided_;
LongIndex inc_strided_;
/// amount (in byte) to increment pointer from last access to first access
/// of next tile
int inc_next_;
LongIndex inc_next_;
/// amount (in byte) to increment pointer from first access of current tile
/// to first access of next tile
int inc_advance_;
LongIndex inc_advance_;
public:
@ -145,20 +145,20 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
/// Construct the Params object given a pitch-linear tensor's layout
CUTLASS_HOST_DEVICE
Params(Layout const &layout) : stride_(layout.stride(0)) {
inc_strided_ = (stride_ * ThreadMap::Delta::kStrided) *
inc_strided_ = (LongIndex(stride_) * ThreadMap::Delta::kStrided) *
sizeof_bits<Element>::value / 8;
if (kAdvanceRank) {
// advance along strided dimension
inc_advance_ =
Shape::kStrided * stride_ * sizeof_bits<Element>::value / 8;
Shape::kStrided * LongIndex(stride_) * sizeof_bits<Element>::value / 8;
} else {
// advance along contiguous dimension
inc_advance_ = Shape::kContiguous * sizeof_bits<Element>::value / 8;
}
inc_next_ = inc_advance_ - (ThreadMap::Iterations::kStrided - 1) *
ThreadMap::Delta::kStrided * stride_ *
inc_next_ = inc_advance_ - LongIndex(ThreadMap::Iterations::kStrided - 1) *
ThreadMap::Delta::kStrided * LongIndex(stride_) *
sizeof_bits<Element>::value / 8;
};
};
@ -280,7 +280,7 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
TensorCoord residue_extent;
if (kAdvanceRank) {
Index residue_size = (extent_[kAdvanceRank] % Shape::kStrided);
Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.strided()) % Shape::kStrided;
if (!residue_size) {
residue_size = Shape::kStrided;
}
@ -288,18 +288,19 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
residue_offset_ = make_Coord(0, residue_size);
residue_extent = make_Coord(
extent_.contiguous(),
min(threadblock_offset.strided() + residue_offset_.strided(), extent_.strided())
min(threadblock_offset.strided() + residue_size, extent_.strided())
);
} else {
Index residue_size = (extent_[kAdvanceRank] % Shape::kContiguous);
Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.contiguous()) % Shape::kContiguous;
if (!residue_size) {
residue_size = Shape::kContiguous;
}
residue_offset_ = make_Coord(residue_size, 0);
residue_extent = make_Coord(
min(extent_.contiguous(), threadblock_offset.contiguous() + residue_offset_.contiguous()),
min(extent_.contiguous(), threadblock_offset.contiguous() + residue_size),
extent_.strided()
);
}
@ -362,18 +363,18 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
compute_predicates_(extent_, true);
if (kAdvanceRank) {
pointer_ += params_.inc_advance_ * (tile_offset.strided() - 1);
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided() - 1);
pointer_ += Shape::kContiguous * tile_offset.contiguous();
} else {
pointer_ += params_.inc_advance_ * (tile_offset.contiguous() - 1);
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous() - 1);
pointer_ += Shape::kStrided * tile_offset.strided();
}
} else {
if (kAdvanceRank) {
pointer_ += params_.inc_advance_ * tile_offset.strided();
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided());
pointer_ += Shape::kContiguous * tile_offset.contiguous();
} else {
pointer_ += params_.inc_advance_ * tile_offset.contiguous();
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous());
pointer_ += Shape::kStrided * tile_offset.strided();
}
}

View File

@ -296,7 +296,12 @@ class PredicatedTileIterator<Shape_, Element_, layout::PitchLinear, AdvanceRank,
CUTLASS_DEVICE
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
load_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
}
CUTLASS_DEVICE
void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
CUTLASS_PRAGMA_UNROLL
@ -310,10 +315,12 @@ class PredicatedTileIterator<Shape_, Element_, layout::PitchLinear, AdvanceRank,
int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
address_iterator_.set_iteration_index(idx);
auto ptr = (address_iterator_.get() + pointer_offset);
char const *byte_ptr = reinterpret_cast<char const *>(address_iterator_.get()) + byte_offset;
AccessType const *access_ptr = reinterpret_cast<AccessType const *>(byte_ptr);
if (address_iterator_.valid()) {
frag_ptr[idx] = *ptr;
frag_ptr[idx] = *access_ptr;
}
++address_iterator_;
}
@ -323,11 +330,17 @@ class PredicatedTileIterator<Shape_, Element_, layout::PitchLinear, AdvanceRank,
/// Loads a fragment from memory
CUTLASS_DEVICE
void load(Fragment &frag) { load_with_pointer_offset(frag, 0); }
void load(Fragment &frag) { load_with_byte_offset(frag, 0); }
/// Store a fragment to memory
CUTLASS_DEVICE
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
store_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
}
/// Store a fragment to memory
CUTLASS_DEVICE
void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
address_iterator_.set_iteration_index(0);
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
@ -340,8 +353,11 @@ class PredicatedTileIterator<Shape_, Element_, layout::PitchLinear, AdvanceRank,
int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
char *byte_ptr = reinterpret_cast<char *>(address_iterator_.get()) + byte_offset;
AccessType *access_ptr = reinterpret_cast<AccessType *>(byte_ptr);
if (address_iterator_.valid()) {
*(address_iterator_.get() + pointer_offset) = frag_ptr[idx];
*access_ptr = frag_ptr[idx];
}
++address_iterator_;
}
@ -351,7 +367,7 @@ class PredicatedTileIterator<Shape_, Element_, layout::PitchLinear, AdvanceRank,
/// Store a fragment to memory
CUTLASS_DEVICE
void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); }
void store(Fragment const &frag) { store_with_byte_offset(frag, 0); }
};
////////////////////////////////////////////////////////////////////////////////
@ -528,6 +544,12 @@ public:
iterator_.load_with_pointer_offset(frag, pointer_offset);
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
iterator_.load_with_byte_offset(frag, byte_offset);
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load(Fragment &frag) {
@ -540,6 +562,12 @@ public:
iterator_.store_with_pointer_offset(frag, pointer_offset);
}
/// Store a fragment to memory
CUTLASS_DEVICE
void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
iterator_.store_with_byte_offset(frag, byte_offset);
}
/// Store a fragment to memory
CUTLASS_DEVICE
void store(Fragment const &frag) {
@ -721,6 +749,12 @@ public:
iterator_.load_with_pointer_offset(frag, pointer_offset);
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
iterator_.load_with_byte_offset(frag, byte_offset);
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load(Fragment &frag) {
@ -732,6 +766,12 @@ public:
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
iterator_.store_with_pointer_offset(frag, pointer_offset);
}
/// Store a fragment to memory
CUTLASS_DEVICE
void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
iterator_.store_with_byte_offset(frag, byte_offset);
}
/// Store a fragment to memory
CUTLASS_DEVICE

View File

@ -149,6 +149,12 @@ public:
/// Loads a fragment from memory
CUTLASS_DEVICE
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
load_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load_with_byte_offset(Fragment &frag, Index byte_offset) {
address_iterator_.set_iteration_index(0);
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
@ -157,7 +163,11 @@ public:
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
int access_idx = c + s * ThreadMap::Iterations::kContiguous;
frag_ptr[access_idx] = *(address_iterator_.get() + pointer_offset);
char const *byte_ptr = reinterpret_cast<char const *>(address_iterator_.get()) + byte_offset;
AccessType const *access_ptr = reinterpret_cast<AccessType const *>(byte_ptr);
frag_ptr[access_idx] = *access_ptr;
++address_iterator_;
}
}
@ -172,6 +182,11 @@ public:
/// Store a fragment to memory
CUTLASS_DEVICE
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
store_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
}
CUTLASS_DEVICE
void store_with_byte_offset(Fragment const &frag, Index byte_offset) {
address_iterator_.set_iteration_index(0);
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
@ -180,7 +195,11 @@ public:
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
int access_idx = c + s * ThreadMap::Iterations::kContiguous;
*(address_iterator_.get() + pointer_offset) = frag_ptr[access_idx];
char *byte_ptr = reinterpret_cast<char *>(address_iterator_.get()) + byte_offset;
AccessType *access_ptr = reinterpret_cast<AccessType *>(byte_ptr);
*access_ptr = frag_ptr[access_idx];
++address_iterator_;
}
}
@ -189,7 +208,7 @@ public:
/// Store a fragment to memory
CUTLASS_DEVICE
void store(Fragment const &frag) {
store_with_pointer_offset(frag, 0);
store_with_byte_offset(frag, 0);
}
};
@ -567,6 +586,11 @@ class RegularTileIterator<Shape_, Element_,
/// Store a fragment to memory
CUTLASS_DEVICE
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
store_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
}
CUTLASS_DEVICE
void store_with_byte_offset(Fragment const &frag, Index byte_offset) {
address_iterator_.set_iteration_index(0);
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
@ -575,7 +599,11 @@ class RegularTileIterator<Shape_, Element_,
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
int access_idx = c + s * ThreadMap::Iterations::kContiguous;
*(address_iterator_.get() + pointer_offset) = frag_ptr[access_idx];
char *byte_ptr = reinterpret_cast<char *>(address_iterator_.get()) + byte_offset;
AccessType *access_ptr = reinterpret_cast<AccessType *>(byte_ptr);
*access_ptr = frag_ptr[access_idx];
++address_iterator_;
}
}
@ -806,3 +834,5 @@ class RegularTileIterator<Shape_, Element_,
} // namespace threadblock
} // namespace transform
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1146,7 +1146,7 @@ class RegularTileIterator<
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
Index vec_pointer_offset = pointer_offset / ThreadMap::kElementsPerAccess;
Index vec_pointer_offset = pointer_offset / Layout::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
@ -1185,13 +1185,14 @@ class RegularTileIterator<
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
Index vec_pointer_offset = pointer_offset / ThreadMap::kElementsPerAccess;
Index vec_pointer_offset = pointer_offset / Layout::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
AccessType *access_ptr = pointer_[(s & 1) ^ ((s >> 1) & 1)];
access_ptr += 16 * (s / 2);
access_ptr += 16 * (s / 2) + vec_pointer_offset;
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
@ -1199,8 +1200,7 @@ class RegularTileIterator<
for(int i = 0; i < Detail::kIterarionsPerAccess; ++i) {
int access_offset =
c * ThreadMap::Delta::kContiguous / Detail::kContiguousElementsPerLine * line_size +
vec_pointer_offset + i * line_size;
c * ThreadMap::Delta::kContiguous / Detail::kContiguousElementsPerLine * line_size + i * line_size;
int access_idx = (c + s * ThreadMap::Iterations::kContiguous) *
Detail::kIterarionsPerAccess + i;