CUTLASS 2.1 (#83)
CUTLASS 2.1 contributes: - BLAS-style host-side API added to CUTLASS Library - Planar Complex GEMM kernels targeting Volta and Turing Tensor Cores - Minor enhancements and bug fixes
This commit is contained in:
@ -128,13 +128,13 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
|
||||
int stride_;
|
||||
/// amount (in byte) to increment pointer to move to next access along
|
||||
/// strided dimension
|
||||
int inc_strided_;
|
||||
LongIndex inc_strided_;
|
||||
/// amount (in byte) to increment pointer from last access to first access
|
||||
/// of next tile
|
||||
int inc_next_;
|
||||
LongIndex inc_next_;
|
||||
/// amount (in byte) to increment pointer from first access of current tile
|
||||
/// to first access of next tile
|
||||
int inc_advance_;
|
||||
LongIndex inc_advance_;
|
||||
|
||||
public:
|
||||
|
||||
@ -145,20 +145,20 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
|
||||
/// Construct the Params object given a pitch-linear tensor's layout
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Layout const &layout) : stride_(layout.stride(0)) {
|
||||
inc_strided_ = (stride_ * ThreadMap::Delta::kStrided) *
|
||||
inc_strided_ = (LongIndex(stride_) * ThreadMap::Delta::kStrided) *
|
||||
sizeof_bits<Element>::value / 8;
|
||||
|
||||
if (kAdvanceRank) {
|
||||
// advance along strided dimension
|
||||
inc_advance_ =
|
||||
Shape::kStrided * stride_ * sizeof_bits<Element>::value / 8;
|
||||
Shape::kStrided * LongIndex(stride_) * sizeof_bits<Element>::value / 8;
|
||||
} else {
|
||||
// advance along contiguous dimension
|
||||
inc_advance_ = Shape::kContiguous * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
inc_next_ = inc_advance_ - (ThreadMap::Iterations::kStrided - 1) *
|
||||
ThreadMap::Delta::kStrided * stride_ *
|
||||
inc_next_ = inc_advance_ - LongIndex(ThreadMap::Iterations::kStrided - 1) *
|
||||
ThreadMap::Delta::kStrided * LongIndex(stride_) *
|
||||
sizeof_bits<Element>::value / 8;
|
||||
};
|
||||
};
|
||||
@ -280,7 +280,7 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
|
||||
TensorCoord residue_extent;
|
||||
if (kAdvanceRank) {
|
||||
|
||||
Index residue_size = (extent_[kAdvanceRank] % Shape::kStrided);
|
||||
Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.strided()) % Shape::kStrided;
|
||||
if (!residue_size) {
|
||||
residue_size = Shape::kStrided;
|
||||
}
|
||||
@ -288,18 +288,19 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
|
||||
residue_offset_ = make_Coord(0, residue_size);
|
||||
residue_extent = make_Coord(
|
||||
extent_.contiguous(),
|
||||
min(threadblock_offset.strided() + residue_offset_.strided(), extent_.strided())
|
||||
min(threadblock_offset.strided() + residue_size, extent_.strided())
|
||||
);
|
||||
|
||||
} else {
|
||||
|
||||
Index residue_size = (extent_[kAdvanceRank] % Shape::kContiguous);
|
||||
Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.contiguous()) % Shape::kContiguous;
|
||||
if (!residue_size) {
|
||||
residue_size = Shape::kContiguous;
|
||||
}
|
||||
|
||||
residue_offset_ = make_Coord(residue_size, 0);
|
||||
|
||||
residue_extent = make_Coord(
|
||||
min(extent_.contiguous(), threadblock_offset.contiguous() + residue_offset_.contiguous()),
|
||||
min(extent_.contiguous(), threadblock_offset.contiguous() + residue_size),
|
||||
extent_.strided()
|
||||
);
|
||||
}
|
||||
@ -362,18 +363,18 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
|
||||
compute_predicates_(extent_, true);
|
||||
|
||||
if (kAdvanceRank) {
|
||||
pointer_ += params_.inc_advance_ * (tile_offset.strided() - 1);
|
||||
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided() - 1);
|
||||
pointer_ += Shape::kContiguous * tile_offset.contiguous();
|
||||
} else {
|
||||
pointer_ += params_.inc_advance_ * (tile_offset.contiguous() - 1);
|
||||
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous() - 1);
|
||||
pointer_ += Shape::kStrided * tile_offset.strided();
|
||||
}
|
||||
} else {
|
||||
if (kAdvanceRank) {
|
||||
pointer_ += params_.inc_advance_ * tile_offset.strided();
|
||||
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided());
|
||||
pointer_ += Shape::kContiguous * tile_offset.contiguous();
|
||||
} else {
|
||||
pointer_ += params_.inc_advance_ * tile_offset.contiguous();
|
||||
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous());
|
||||
pointer_ += Shape::kStrided * tile_offset.strided();
|
||||
}
|
||||
}
|
||||
|
||||
@ -296,7 +296,12 @@ class PredicatedTileIterator<Shape_, Element_, layout::PitchLinear, AdvanceRank,
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
||||
|
||||
load_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
|
||||
|
||||
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
@ -310,10 +315,12 @@ class PredicatedTileIterator<Shape_, Element_, layout::PitchLinear, AdvanceRank,
|
||||
int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
|
||||
|
||||
address_iterator_.set_iteration_index(idx);
|
||||
auto ptr = (address_iterator_.get() + pointer_offset);
|
||||
char const *byte_ptr = reinterpret_cast<char const *>(address_iterator_.get()) + byte_offset;
|
||||
|
||||
AccessType const *access_ptr = reinterpret_cast<AccessType const *>(byte_ptr);
|
||||
|
||||
if (address_iterator_.valid()) {
|
||||
frag_ptr[idx] = *ptr;
|
||||
frag_ptr[idx] = *access_ptr;
|
||||
}
|
||||
++address_iterator_;
|
||||
}
|
||||
@ -323,11 +330,17 @@ class PredicatedTileIterator<Shape_, Element_, layout::PitchLinear, AdvanceRank,
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment &frag) { load_with_pointer_offset(frag, 0); }
|
||||
void load(Fragment &frag) { load_with_byte_offset(frag, 0); }
|
||||
|
||||
/// Store a fragment to memory
|
||||
CUTLASS_DEVICE
|
||||
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
||||
store_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
|
||||
}
|
||||
|
||||
/// Store a fragment to memory
|
||||
CUTLASS_DEVICE
|
||||
void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
|
||||
address_iterator_.set_iteration_index(0);
|
||||
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
|
||||
|
||||
@ -340,8 +353,11 @@ class PredicatedTileIterator<Shape_, Element_, layout::PitchLinear, AdvanceRank,
|
||||
|
||||
int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
|
||||
|
||||
char *byte_ptr = reinterpret_cast<char *>(address_iterator_.get()) + byte_offset;
|
||||
AccessType *access_ptr = reinterpret_cast<AccessType *>(byte_ptr);
|
||||
|
||||
if (address_iterator_.valid()) {
|
||||
*(address_iterator_.get() + pointer_offset) = frag_ptr[idx];
|
||||
*access_ptr = frag_ptr[idx];
|
||||
}
|
||||
++address_iterator_;
|
||||
}
|
||||
@ -351,7 +367,7 @@ class PredicatedTileIterator<Shape_, Element_, layout::PitchLinear, AdvanceRank,
|
||||
|
||||
/// Store a fragment to memory
|
||||
CUTLASS_DEVICE
|
||||
void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); }
|
||||
void store(Fragment const &frag) { store_with_byte_offset(frag, 0); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -528,6 +544,12 @@ public:
|
||||
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
|
||||
iterator_.load_with_byte_offset(frag, byte_offset);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment &frag) {
|
||||
@ -540,6 +562,12 @@ public:
|
||||
iterator_.store_with_pointer_offset(frag, pointer_offset);
|
||||
}
|
||||
|
||||
/// Store a fragment to memory
|
||||
CUTLASS_DEVICE
|
||||
void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
|
||||
iterator_.store_with_byte_offset(frag, byte_offset);
|
||||
}
|
||||
|
||||
/// Store a fragment to memory
|
||||
CUTLASS_DEVICE
|
||||
void store(Fragment const &frag) {
|
||||
@ -721,6 +749,12 @@ public:
|
||||
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
|
||||
iterator_.load_with_byte_offset(frag, byte_offset);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment &frag) {
|
||||
@ -732,6 +766,12 @@ public:
|
||||
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
||||
iterator_.store_with_pointer_offset(frag, pointer_offset);
|
||||
}
|
||||
|
||||
/// Store a fragment to memory
|
||||
CUTLASS_DEVICE
|
||||
void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
|
||||
iterator_.store_with_byte_offset(frag, byte_offset);
|
||||
}
|
||||
|
||||
/// Store a fragment to memory
|
||||
CUTLASS_DEVICE
|
||||
|
||||
@ -149,6 +149,12 @@ public:
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
||||
load_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load_with_byte_offset(Fragment &frag, Index byte_offset) {
|
||||
address_iterator_.set_iteration_index(0);
|
||||
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
||||
|
||||
@ -157,7 +163,11 @@ public:
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
||||
int access_idx = c + s * ThreadMap::Iterations::kContiguous;
|
||||
frag_ptr[access_idx] = *(address_iterator_.get() + pointer_offset);
|
||||
|
||||
char const *byte_ptr = reinterpret_cast<char const *>(address_iterator_.get()) + byte_offset;
|
||||
AccessType const *access_ptr = reinterpret_cast<AccessType const *>(byte_ptr);
|
||||
|
||||
frag_ptr[access_idx] = *access_ptr;
|
||||
++address_iterator_;
|
||||
}
|
||||
}
|
||||
@ -172,6 +182,11 @@ public:
|
||||
/// Store a fragment to memory
|
||||
CUTLASS_DEVICE
|
||||
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
||||
store_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void store_with_byte_offset(Fragment const &frag, Index byte_offset) {
|
||||
address_iterator_.set_iteration_index(0);
|
||||
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
|
||||
|
||||
@ -180,7 +195,11 @@ public:
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
||||
int access_idx = c + s * ThreadMap::Iterations::kContiguous;
|
||||
*(address_iterator_.get() + pointer_offset) = frag_ptr[access_idx];
|
||||
|
||||
char *byte_ptr = reinterpret_cast<char *>(address_iterator_.get()) + byte_offset;
|
||||
AccessType *access_ptr = reinterpret_cast<AccessType *>(byte_ptr);
|
||||
|
||||
*access_ptr = frag_ptr[access_idx];
|
||||
++address_iterator_;
|
||||
}
|
||||
}
|
||||
@ -189,7 +208,7 @@ public:
|
||||
/// Store a fragment to memory
|
||||
CUTLASS_DEVICE
|
||||
void store(Fragment const &frag) {
|
||||
store_with_pointer_offset(frag, 0);
|
||||
store_with_byte_offset(frag, 0);
|
||||
}
|
||||
};
|
||||
|
||||
@ -567,6 +586,11 @@ class RegularTileIterator<Shape_, Element_,
|
||||
/// Store a fragment to memory
|
||||
CUTLASS_DEVICE
|
||||
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
||||
store_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void store_with_byte_offset(Fragment const &frag, Index byte_offset) {
|
||||
address_iterator_.set_iteration_index(0);
|
||||
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
|
||||
|
||||
@ -575,7 +599,11 @@ class RegularTileIterator<Shape_, Element_,
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
||||
int access_idx = c + s * ThreadMap::Iterations::kContiguous;
|
||||
*(address_iterator_.get() + pointer_offset) = frag_ptr[access_idx];
|
||||
|
||||
char *byte_ptr = reinterpret_cast<char *>(address_iterator_.get()) + byte_offset;
|
||||
AccessType *access_ptr = reinterpret_cast<AccessType *>(byte_ptr);
|
||||
|
||||
*access_ptr = frag_ptr[access_idx];
|
||||
++address_iterator_;
|
||||
}
|
||||
}
|
||||
@ -806,3 +834,5 @@ class RegularTileIterator<Shape_, Element_,
|
||||
} // namespace threadblock
|
||||
} // namespace transform
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -1146,7 +1146,7 @@ class RegularTileIterator<
|
||||
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
||||
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
||||
|
||||
Index vec_pointer_offset = pointer_offset / ThreadMap::kElementsPerAccess;
|
||||
Index vec_pointer_offset = pointer_offset / Layout::kElementsPerAccess;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
@ -1185,13 +1185,14 @@ class RegularTileIterator<
|
||||
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
||||
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
|
||||
|
||||
Index vec_pointer_offset = pointer_offset / ThreadMap::kElementsPerAccess;
|
||||
Index vec_pointer_offset = pointer_offset / Layout::kElementsPerAccess;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
|
||||
AccessType *access_ptr = pointer_[(s & 1) ^ ((s >> 1) & 1)];
|
||||
|
||||
access_ptr += 16 * (s / 2);
|
||||
access_ptr += 16 * (s / 2) + vec_pointer_offset;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
||||
@ -1199,8 +1200,7 @@ class RegularTileIterator<
|
||||
for(int i = 0; i < Detail::kIterarionsPerAccess; ++i) {
|
||||
|
||||
int access_offset =
|
||||
c * ThreadMap::Delta::kContiguous / Detail::kContiguousElementsPerLine * line_size +
|
||||
vec_pointer_offset + i * line_size;
|
||||
c * ThreadMap::Delta::kContiguous / Detail::kContiguousElementsPerLine * line_size + i * line_size;
|
||||
|
||||
int access_idx = (c + s * ThreadMap::Iterations::kContiguous) *
|
||||
Detail::kIterarionsPerAccess + i;
|
||||
|
||||
Reference in New Issue
Block a user