gemm_universal_with_broadcast, +2 sources.

This commit is contained in:
Ying Zhang
2022-04-20 15:35:58 -07:00
parent faab7536fc
commit fb063251f2
9 changed files with 1220 additions and 243 deletions

View File

@ -53,6 +53,7 @@ namespace thread {
template <typename T>
struct Identity {
static const bool kIsHeavy=false;
CUTLASS_HOST_DEVICE
T operator()(T value) const {
return value;
@ -160,7 +161,7 @@ struct LeakyReLU {
Params():
LinearCombinationGenericParams<T>(),
leaky_alpha(T(1)) {}
CUTLASS_HOST_DEVICE
Params(
T alpha,
@ -194,7 +195,7 @@ struct LeakyReLU<Array<T, N> > {
Params():
LinearCombinationGenericParams<T>(),
leaky_alpha(T(1)) {}
CUTLASS_HOST_DEVICE
Params(
T alpha,
@ -464,7 +465,7 @@ struct HardSwish<Array<half_t, N> > {
maximum<Array<T, N> > mx;
multiplies<Array<T, N> > mul;
plus<Array<T, N> > add;
return mul(mul(mn(mx(add(rhs, T(3)), T(0)), T(6)), rhs), T(0.16666667f));
}
@ -493,7 +494,7 @@ struct GELU {
return T(cutlass::constants::half<T>() * scalar *
(cutlass::constants::one<T>() + (T)erff((float)(scalar / cutlass::constants::root_two<T>()))));
}
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
@ -509,7 +510,7 @@ struct GELU<float> {
return cutlass::constants::half<float>() * scalar *
(cutlass::constants::one<float>() + erff( scalar / cutlass::constants::root_two<float>() ));
}
using Params = LinearCombinationGenericParams<float>;
CUTLASS_HOST_DEVICE
@ -525,7 +526,7 @@ struct GELU<double> {
return cutlass::constants::half<double>() * scalar *
(cutlass::constants::one<double>() + erf( scalar / cutlass::constants::root_two<double>() ));
}
using Params = LinearCombinationGenericParams<double>;
CUTLASS_HOST_DEVICE
@ -548,7 +549,7 @@ struct GELU<Array<T, N> > {
return y;
}
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
@ -572,7 +573,7 @@ struct GELU_taylor {
}
using Params = LinearCombinationGenericParams<T>;
};
template <int N>
@ -618,7 +619,7 @@ struct GELU_taylor<Array<T, N> > {
return y;
}
using Params = LinearCombinationGenericParams<T>;
};

View File

@ -46,11 +46,13 @@ namespace epilogue {
namespace thread {
// /// Models a residual block of the form: UnaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual))
// or form UnaryOp(BinaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual1), residual2))
template <typename ElementOutput_, typename ElementAccumulator_,
typename ElementCompute_, typename ElementC_, int ElementsPerAccess,
template <typename T> class ActivationOp_,
template <typename T> class BinaryOp_,
template <typename T> class UnaryOp_>
template <typename T> class BinaryOp1_,
template <typename T> class UnaryOp_,
template <typename T> class BinaryOp2_=BinaryOp1_>
class LinearCombinationResidualBlock {
public:
@ -62,7 +64,8 @@ public:
static int const kCount = kElementsPerAccess;
using UnaryOp = UnaryOp_<Array<ElementCompute, kCount>>;
using BinaryOp = BinaryOp_<Array<ElementCompute, kCount>>;
using BinaryOp1 = BinaryOp1_<Array<ElementCompute, kCount>>;
using BinaryOp2 = BinaryOp2_<Array<ElementCompute, kCount>>;
using ActivationOp = ActivationOp_<Array<ElementCompute, kCount>>;
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
@ -138,7 +141,7 @@ public:
FragmentC const &residual,
FragmentCompute const &bias) const {
UnaryOp unary_op;
BinaryOp binary_op;
BinaryOp1 binary_op;
ActivationOp activation;
FragmentCompute tmp_Accum =
@ -154,6 +157,31 @@ public:
frag_Z = convert_z(result_Z);
}
/// Applies the operation UnaryOp(BinaryOp(BinaryOp(ActivationOp(AB + bias), residual1), residual2))
CUTLASS_HOST_DEVICE
void operator()(FragmentOutput &frag_Z, FragmentOutput &, FragmentAccumulator const &AB,
FragmentC const &residual1, FragmentC const &residual2,
FragmentCompute const &bias) const {
UnaryOp unary_op;
BinaryOp1 binary_op1;
BinaryOp2 binary_op2;
ActivationOp activation;
FragmentCompute tmp_Accum =
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
FragmentCompute tmp_residual1 =
NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(residual1);
FragmentCompute tmp_residual2 =
NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(residual2);
FragmentCompute z =
binary_op2(binary_op1(activation(alpha_ * tmp_Accum + bias), beta_ * tmp_residual1), beta_ * tmp_residual2);
FragmentCompute result_Z = skip_elementwise_ ? z : unary_op(z);
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> convert_z;
frag_Z = convert_z(result_Z);
}
/// Should never be called
CUTLASS_HOST_DEVICE
void operator()(FragmentOutput &, FragmentOutput &, FragmentAccumulator const &,

View File

@ -89,7 +89,7 @@ template <
bool StoreT = true
>
struct EpilogueWithBroadcastOpBase {
using ElementOutput = ElementC_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
@ -119,7 +119,7 @@ struct EpilogueWithBroadcastOpBase {
/// Constructor from Params
EpilogueWithBroadcastOpBase(Params const &params_) { }
/// Determine if the source is needed. May return false if
/// Determine if the source is needed. May return false if
bool is_source_needed() const {
return true;
}
@ -130,19 +130,19 @@ struct EpilogueWithBroadcastOpBase {
/// Applies the operation when is_source_needed() is true
CUTLASS_HOST_DEVICE
void operator()(
FragmentZ &frag_Z,
FragmentT &frag_T,
FragmentZ &frag_Z,
FragmentT &frag_T,
FragmentAccumulator const &AB,
FragmentC const &frag_C,
FragmentC const &frag_C1,
FragmentC const &frag_C2,
FragmentCompute const &V) const {
}
/// Applies the operation when is_source_needed() is false
CUTLASS_HOST_DEVICE
void operator()(
FragmentZ &frag_Z,
FragmentT &frag_T,
FragmentZ &frag_Z,
FragmentT &frag_T,
FragmentAccumulator const &AB,
FragmentCompute const &V) const {
@ -160,11 +160,11 @@ struct EpilogueWithBroadcastOpBase {
///
/// if (ElementwiseOp::kStoreZ) {
/// store(converted_u);
/// }
/// }
///
/// if (ElementwiseOp::kStoreT) {
/// store(v);
/// }
/// }
///
template <
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
@ -182,24 +182,24 @@ template <
int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large
(!IsEpilogueFunctorHeavy<OutputOp_>::value)
>
class EpilogueWithBroadcast :
class EpilogueWithBroadcast :
public EpilogueBase<
Shape_,
typename WarpMmaOperator_::Shape,
PartitionsK,
AccumulatorFragmentIterator_,
WarpTileIterator_,
Shape_,
typename WarpMmaOperator_::Shape,
PartitionsK,
AccumulatorFragmentIterator_,
WarpTileIterator_,
Padding_,
FragmentsPerPartition> {
public:
using Base = EpilogueBase<
Shape_,
typename WarpMmaOperator_::Shape,
PartitionsK,
AccumulatorFragmentIterator_,
WarpTileIterator_,
Shape_,
typename WarpMmaOperator_::Shape,
PartitionsK,
AccumulatorFragmentIterator_,
WarpTileIterator_,
Padding_,
FragmentsPerPartition>;
@ -235,7 +235,7 @@ public:
/// Fragment object used to store the broadcast values
using BroadcastFragment = Array<
ElementCompute,
ElementCompute,
ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>;
/// Output element
@ -261,14 +261,14 @@ public:
typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
/// Array type used by output functor
using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
/// Array type used by output functor
using ComputeAccessType = Array<ElementCompute, OutputTileIterator::kElementsPerAccess>;
/// Tensor access type
using TensorAccessType = Array<ElementTensor, OutputTileIterator::kElementsPerAccess>;
/// Number of warps
using WarpCount = typename Base::WarpCount;
@ -304,7 +304,7 @@ public:
/// I'm not sure what I meant here.
static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount);
/// Shape of the shared memory allocation for the epilogue
/// Shape of the shared memory allocation for the epilogue
using StorageShape = MatrixShape<
kThreadRows,
Shape::kN
@ -351,7 +351,7 @@ public:
static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero.");
static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
"Divisibility");
private:
@ -367,7 +367,7 @@ public:
/// Constructor
CUTLASS_DEVICE
EpilogueWithBroadcast(
SharedStorage &shared_storage, ///< Shared storage object
SharedStorage &shared_storage, ///< Shared storage object
int thread_idx, ///< ID of a thread within the threadblock
int warp_idx, ///< ID of warp within threadblock
int lane_idx ///< Id of thread within warp
@ -386,32 +386,34 @@ public:
ElementVector const * broadcast_ptr, ///< Broadcast vector
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix
OutputTileIterator source_iterator1, ///< Tile iterator for source accumulator matrix
OutputTileIterator source_iterator2, ///< Tile iterator for source accumulator matrix
TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand
MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
MatrixCoord(Shape::kM, Shape::kN),
MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
MatrixCoord()) {
BroadcastFragment broadcast_fragment;
load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset);
if (!output_op.is_source_needed()) {
compute_source_not_needed_(
output_op,
broadcast_fragment,
destination_iterator,
output_op,
broadcast_fragment,
destination_iterator,
accumulators,
tensor_iterator);
}
else {
compute_source_needed_(
output_op,
broadcast_fragment,
destination_iterator,
accumulators,
source_iterator,
output_op,
broadcast_fragment,
destination_iterator,
accumulators,
source_iterator1,
source_iterator2,
tensor_iterator);
}
}
@ -427,7 +429,7 @@ private:
) {
broadcast_fragment.clear();
// If no pointer is supplied, set with all zeros and avoid memory accesses
if (!broadcast_ptr) {
return;
@ -513,9 +515,9 @@ private:
OutputOp const &output_op, ///< Output operator
BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand
) {
) {
//
// Iterator over warp-level accumulator fragment
@ -525,7 +527,7 @@ private:
//
// Iterate over accumulator tile
//
//
// CUTLASS_PRAGMA_UNROLL
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1)
@ -534,7 +536,7 @@ private:
//
// Convert and store fragment
//
__syncthreads();
@ -638,7 +640,7 @@ private:
}
};
/// Streams the result to global memory
CUTLASS_DEVICE
void compute_source_needed_(
@ -646,12 +648,15 @@ private:
BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
OutputTileIterator source_iterator1, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
OutputTileIterator source_iterator2, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand
) {
typename OutputTileIterator::Fragment source_fragment;
source_fragment.clear();
) {
typename OutputTileIterator::Fragment source_fragment1;
source_fragment1.clear();
typename OutputTileIterator::Fragment source_fragment2;
source_fragment2.clear();
//
// Iterator over warp-level accumulator fragment
@ -661,7 +666,7 @@ private:
//
// Iterate over accumulator tile
//
//
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
@ -670,13 +675,18 @@ private:
// Load the source
//
source_iterator.load(source_fragment);
++source_iterator;
source_iterator1.load(source_fragment1);
++source_iterator1;
if (source_iterator2.enabled()) {
source_iterator2.load(source_fragment2);
++source_iterator2;
}
//
// Convert and store fragment
//
__syncthreads();
acc2smem_source_needed<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
@ -720,8 +730,10 @@ private:
frag_T,
output_op,
aligned_accum_fragment[0],
source_fragment,
broadcast_fragment);
source_fragment1,
source_fragment2,
broadcast_fragment,
source_iterator2.enabled());
//
// Conditionally store fragments
@ -746,8 +758,10 @@ private:
typename TensorTileIterator::Fragment &frag_T,
OutputOp const &output_op,
typename SharedLoadIterator::Fragment const &frag_AB,
typename OutputTileIterator::Fragment const &frag_C,
BroadcastFragment const &frag_Broadcast) {
typename OutputTileIterator::Fragment const &frag_C1,
typename OutputTileIterator::Fragment const &frag_C2,
BroadcastFragment const &frag_Broadcast,
bool frag_C2_enabled) {
using AccessTypeZ = Array<typename OutputTileIterator::Element, kElementsPerAccess>;
using AccessTypeT = Array<typename TensorTileIterator::Element, kElementsPerAccess>;
@ -755,28 +769,39 @@ private:
AccessTypeZ *frag_Z_ptr = reinterpret_cast<AccessTypeZ *>(&frag_Z);
AccessTypeT *frag_T_ptr = reinterpret_cast<AccessTypeT *>(&frag_T);
AccumulatorAccessType const *frag_AB_ptr =
AccumulatorAccessType const *frag_AB_ptr =
reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
OutputAccessType const *frag_C_ptr =
reinterpret_cast<OutputAccessType const *>(&frag_C);
OutputAccessType const *frag_C1_ptr =
reinterpret_cast<OutputAccessType const *>(&frag_C1);
OutputAccessType const *frag_C2_ptr =
reinterpret_cast<OutputAccessType const *>(&frag_C2);
AccessTypeBroadcast const *frag_Broadcast_ptr =
reinterpret_cast<AccessTypeBroadcast const *>(&frag_Broadcast);
int const kOutputOpIterations =
int const kOutputOpIterations =
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i) {
output_op(
frag_Z_ptr[i],
frag_T_ptr[i],
frag_AB_ptr[i],
frag_C_ptr[i],
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
if (frag_C2_enabled) {
output_op(
frag_Z_ptr[i],
frag_T_ptr[i],
frag_AB_ptr[i],
frag_C1_ptr[i],
frag_C2_ptr[i],
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
} else {
output_op(
frag_Z_ptr[i],
frag_T_ptr[i],
frag_AB_ptr[i],
frag_C1_ptr[i],
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
}
}
}
@ -795,23 +820,23 @@ private:
AccessTypeZ *frag_Z_ptr = reinterpret_cast<AccessTypeZ *>(&frag_Z);
AccessTypeT *frag_T_ptr = reinterpret_cast<AccessTypeT *>(&frag_T);
AccumulatorAccessType const *frag_AB_ptr =
AccumulatorAccessType const *frag_AB_ptr =
reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
AccessTypeBroadcast const *frag_Broadcast_ptr =
reinterpret_cast<AccessTypeBroadcast const *>(&frag_Broadcast);
int const kOutputOpIterations =
int const kOutputOpIterations =
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i) {
output_op(
frag_Z_ptr[i],
frag_T_ptr[i],
frag_AB_ptr[i],
frag_Z_ptr[i],
frag_T_ptr[i],
frag_AB_ptr[i],
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
}
}

View File

@ -100,10 +100,10 @@ public:
/// Fragment object
using Fragment = Array<
Element,
ThreadMap::Iterations::kColumn *
ThreadMap::Iterations::kRow *
ThreadMap::Iterations::kGroup *
Element,
ThreadMap::Iterations::kColumn *
ThreadMap::Iterations::kRow *
ThreadMap::Iterations::kGroup *
ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>;
/// Memory access size
@ -121,15 +121,15 @@ public:
Params() { }
CUTLASS_HOST_DEVICE
Params(Layout const &layout):
Params(Layout const &layout):
PredicatedTileIteratorParams(
layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess,
make_OutputTileThreadMapDesc<ThreadMap>()
)
)
{ }
CUTLASS_HOST_DEVICE
Params(Base const &base) :
Params(Base const &base) :
Base(base) { }
};
@ -176,7 +176,7 @@ private:
PredicatedTileIteratorParams params_;
/// Byte-level pointer. This pointer is usually for both load() and store(), unless PermuteD is performed. When having PermuteD, byte_pointer_ is only for load().
uint8_t *byte_pointer_;
uint8_t *byte_pointer_{nullptr};
/// Byte-level pointer for store(). Due to PermuteD Op, store_byte_pointer_ may be with different address computation compared to byte_pointer_.
uint8_t *store_byte_pointer_;
@ -200,13 +200,13 @@ private:
int state_[3];
/// Scatter indices
int const *indices_;
int const *indices_;
/// Whether to perform Permute Op
bool PermuteD;
/// PermuteDLayout
mutable PermuteDLayout permute_layout_;
//
// Static asserts about internal strides
//
@ -236,7 +236,7 @@ public:
int thread_idx,
TensorCoord threadblock_offset = TensorCoord(),
int const *indices = nullptr
):
):
params_(params), indices_(indices)
{
@ -252,13 +252,18 @@ public:
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) {
mask_.predicates[c] = ((thread_offset.column()
mask_.predicates[c] = ((thread_offset.column()
+ ThreadMap::Delta::kColumn * c) < extent.column());
}
// Null pointer performs no accesses
if (!pointer) {
mask_.clear();
} else {
// Initialize pointer
byte_pointer_ = reinterpret_cast<uint8_t *>(pointer) +
LongIndex(thread_offset.row()) * LongIndex(params_.stride) +
LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess;
}
if (ScatterD && !indices) {
@ -266,8 +271,8 @@ public:
}
// Initialize byte_pointer_
byte_pointer_ = reinterpret_cast<uint8_t *>(pointer) +
LongIndex(thread_offset.row()) * LongIndex(params_.stride) +
byte_pointer_ = reinterpret_cast<uint8_t *>(pointer) +
LongIndex(thread_offset.row()) * LongIndex(params_.stride) +
LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess;
if (ScatterD) {
@ -284,7 +289,7 @@ public:
}else{
PermuteD = true;
store_byte_pointer_ = reinterpret_cast<uint8_t *>(pointer);
permute_layout_ = PermuteDLayout(extent,
permute_layout_ = PermuteDLayout(extent,
params_.stride * kElementsPerAccess / sizeof(AccessType));
}
@ -315,11 +320,11 @@ public:
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int frag_row_idx =
int frag_row_idx =
(row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
int row_offset = row * ThreadMap::Delta::kRow
+ group * ThreadMap::Delta::kGroup
int row_offset = row * ThreadMap::Delta::kRow
+ group * ThreadMap::Delta::kGroup
+ cluster * ThreadMap::Delta::kCluster;
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
@ -339,7 +344,7 @@ public:
bool guard = row_guard && mask_.predicates[column];
cutlass::arch::global_load<
AccessType,
AccessType,
sizeof(AccessType)
>(
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn +
@ -389,11 +394,11 @@ public:
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int frag_row_idx =
int frag_row_idx =
(row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
int row_offset = row * ThreadMap::Delta::kRow
+ group * ThreadMap::Delta::kGroup
int row_offset = row * ThreadMap::Delta::kRow
+ group * ThreadMap::Delta::kGroup
+ cluster * ThreadMap::Delta::kCluster;
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
@ -413,7 +418,7 @@ public:
bool guard = row_guard && mask_.predicates[column];
int col_offset = column * ThreadMap::Delta::kColumn;
if (PermuteD) {
int col = col_offset + thread_start_column_;
int row = row_offset + thread_start_row_;
@ -436,7 +441,7 @@ public:
(void *)&memory_pointer[0],
guard);
}
if (!PermuteD) {
memory_pointer += (ThreadMap::Delta::kColumn / kElementsPerAccess);
}
@ -483,11 +488,11 @@ public:
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int frag_row_idx =
int frag_row_idx =
(row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
int row_offset = row * ThreadMap::Delta::kRow
+ group * ThreadMap::Delta::kGroup
int row_offset = row * ThreadMap::Delta::kRow
+ group * ThreadMap::Delta::kGroup
+ cluster * ThreadMap::Delta::kCluster;
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
@ -511,7 +516,7 @@ public:
bool guard = row_guard && mask_.predicates[column];
cutlass::arch::global_load<
AccessType,
AccessType,
sizeof(AccessType)
>(
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn +
@ -553,11 +558,11 @@ public:
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int frag_row_idx =
int frag_row_idx =
(row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
int row_offset = row * ThreadMap::Delta::kRow
+ group * ThreadMap::Delta::kGroup
int row_offset = row * ThreadMap::Delta::kRow
+ group * ThreadMap::Delta::kGroup
+ cluster * ThreadMap::Delta::kCluster;
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
@ -585,7 +590,7 @@ public:
bool guard = row_guard && mask_.predicates[column];
cutlass::arch::global_load<
AccessType,
AccessType,
sizeof(AccessType)
>(
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn +
@ -655,7 +660,7 @@ public:
}
thread_start_row_ += ThreadMap::Shape::kRow;
if (state_[0] == ThreadMap::Count::kRow) {
state_[0] = 0;
@ -663,7 +668,7 @@ public:
byte_pointer_ += params_.advance_group;
store_byte_pointer_ += params_.advance_group;
thread_start_row_ += (ThreadMap::Shape::kGroup - 1) *
thread_start_row_ += (ThreadMap::Shape::kGroup - 1) *
ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
if (state_[1] == ThreadMap::Count::kGroup) {
@ -673,7 +678,7 @@ public:
byte_pointer_ += params_.advance_cluster;
store_byte_pointer_ += params_.advance_cluster;
thread_start_row_ += ThreadMap::Count::kGroup *
thread_start_row_ += ThreadMap::Count::kGroup *
ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;
if (state_[2] == ThreadMap::Count::kCluster) {
@ -706,6 +711,10 @@ public:
CUTLASS_DEVICE void set_mask(Mask const &mask) {
mask_ = mask;
}
CUTLASS_DEVICE bool enabled() {
return (byte_pointer_ != nullptr);
}
};
////////////////////////////////////////////////////////////////////////////////
@ -717,7 +726,7 @@ public:
template <
typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
typename Element_, ///< Element data type
int InterleavedN ///< Number of Interleaved N
int InterleavedN ///< Number of Interleaved N
>
class InterleavedPredicatedTileIterator {
public:
@ -751,14 +760,14 @@ public:
Params() { }
CUTLASS_HOST_DEVICE
Params(Layout const &layout):
Params(Layout const &layout):
Base(
layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess,
make_InterleavedPredicatedTileIteratorDesc<Element, ThreadMap>()
) { }
CUTLASS_HOST_DEVICE
Params(Base const &base) :
Params(Base const &base) :
Base(base) { }
};
@ -862,8 +871,8 @@ public:
}
// Initialize pointer
byte_pointer_ = reinterpret_cast<uint8_t *>(pointer) +
LongIndex(thread_offset.strided()) * LongIndex(params_.stride) +
byte_pointer_ = reinterpret_cast<uint8_t *>(pointer) +
LongIndex(thread_offset.strided()) * LongIndex(params_.stride) +
LongIndex(thread_offset.contiguous()) * sizeof(AccessType) / kElementsPerAccess;
// Initialize internal state counter
@ -891,7 +900,7 @@ public:
bool guard = col_guard && mask_.predicates[iteration_contiguous_];
cutlass::arch::global_load<
AccessType,
AccessType,
sizeof(AccessType)
>(
*frag_ptr,
@ -1091,7 +1100,7 @@ private:
/// Extent of the matrix tile in rows
Index extent_row_;
/// Extent of the matrix tile in pq
/// Extent of the matrix tile in pq
Index extent_pq_;
/// A thread's starting row position (assuming steady-state predicates have
@ -1133,7 +1142,7 @@ public:
):
params_(params) {
MatrixCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset;
extent_col_ = extent.c();
extent_pq_ = extent.h() * extent.w();
extent_row_ = extent.n() * extent_pq_;
@ -1188,7 +1197,7 @@ public:
reinterpret_cast<AccessType const *>(byte_pointer);
cutlass::arch::global_load<
AccessType,
AccessType,
sizeof(AccessType)
>(
*frag_ptr,

View File

@ -0,0 +1,354 @@
/*! \file
\gemm_universal which takes LinearCombinationResidualBlock as the epilogue output op.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h"
#include "cutlass/device_kernel.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/gemm/kernel/gemm_universal.h"
#include "cutlass/gemm/kernel/default_gemm_universal.h"
#include "cutlass/gemm/kernel/default_gemm_with_broadcast.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/device/gemm_universal_base.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace device {
/////////////////////////////////////////////////////////////////////////////////////////////////
/*!
The universal GEMM with a broadcast epilogue.
Supports
*/
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator_ = ElementC_,
/// Operator class tag
typename OperatorClass_ = arch::OpClassSimt,
/// Tag indicating architecture to tune for. This is the minimum SM that
/// supports the intended feature. The device kernel can be built
/// targeting any SM larger than this number.
typename ArchTag_ = arch::Sm70,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::InstructionShape,
/// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp'
typename EpilogueOutputOp_ = cutlass::epilogue::thread::LinearCombinationBiasElementwise<
ElementC_, ElementAccumulator_, ElementAccumulator_,
ElementC_, ElementC_, 16 / sizeof(ElementC_)>,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>,
/// Number of stages used in the pipelined mainloop
int Stages =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kStages,
/// Access granularity of A matrix in units of elements
int AlignmentA =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentA,
/// Access granularity of B matrix in units of elements
int AlignmentB =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentB,
/// Operation performed by GEMM
typename Operator_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::Operator,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA = ComplexTransform::kNone,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB = ComplexTransform::kNone
>
class GemmUniversalWithBroadcast :
public GemmUniversalBase<
typename kernel::DefaultGemmWithBroadcast<
ElementA_,
LayoutA_,
TransformA,
AlignmentA,
ElementB_,
LayoutB_,
TransformB,
AlignmentB,
ElementC_,
LayoutC_,
ElementAccumulator_,
OperatorClass_,
ArchTag_,
ThreadblockShape_,
WarpShape_,
InstructionShape_,
EpilogueOutputOp_,
ThreadblockSwizzle_,
Stages,
Operator_
>::GemmKernel
> {
public:
using ElementAccumulator = ElementAccumulator_;
using OperatorClass = OperatorClass_;
using ArchTag = ArchTag_;
using ThreadblockShape = ThreadblockShape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using EpilogueOutputOp = EpilogueOutputOp_;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using Operator = Operator_;
static int const kStages = Stages;
static int const kAlignmentA = AlignmentA;
static int const kAlignmentB = AlignmentB;
static int const kAlignmentC = EpilogueOutputOp::kCount;
static ComplexTransform const kTransformA = TransformA;
static ComplexTransform const kTransformB = TransformB;
using Base = GemmUniversalBase<
typename kernel::DefaultGemmWithBroadcast<
ElementA_,
LayoutA_,
TransformA,
AlignmentA,
ElementB_,
LayoutB_,
TransformB,
AlignmentB,
ElementC_,
LayoutC_,
ElementAccumulator_,
OperatorClass_,
ArchTag_,
ThreadblockShape_,
WarpShape_,
InstructionShape_,
EpilogueOutputOp_,
ThreadblockSwizzle_,
Stages,
Operator_
>::GemmKernel
>;
using Arguments = typename Base::Arguments;
using GemmKernel = typename Base::GemmKernel;
};
////////////////////////////////////////////////////////////////////////////////
/// Parital specialization for column-major output exchanges problem size and operand.
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Element type for C and D matrix operands
typename ElementC_,
/// Element type for internal accumulation
typename ElementAccumulator_,
/// Operator class tag
typename OperatorClass_,
/// Tag indicating architecture to tune for. This is the minimum SM that
/// supports the intended feature. The device kernel can be built
/// targeting any SM larger than this number.
typename ArchTag_,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_,
/// Epilogue output operator
typename EpilogueOutputOp_,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle_,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Access granularity of A matrix in units of elements
int AlignmentA,
/// Access granularity of B matrix in units of elements
int AlignmentB,
/// Operation performed by GEMM
typename Operator_,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB>
class GemmUniversalWithBroadcast<ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_,
layout::ColumnMajor, // partially specialized on LayoutC
ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_,
WarpShape_, InstructionShape_, EpilogueOutputOp_,
ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB,
Operator_, TransformA, TransformB> {
public:
using ElementA = ElementA_;
using LayoutA = LayoutA_;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
using ElementB = ElementB_;
using LayoutB = LayoutB_;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
using ElementC = ElementC_;
using LayoutC = layout::ColumnMajor;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
using ElementAccumulator = ElementAccumulator_;
using OperatorClass = OperatorClass_;
using ArchTag = ArchTag_;
using ThreadblockShape = ThreadblockShape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using EpilogueOutputOp = EpilogueOutputOp_;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using Operator = Operator_;
static int const kStages = Stages;
static int const kAlignmentA = AlignmentA;
static int const kAlignmentB = AlignmentB;
static ComplexTransform const kTransformA = TransformA;
static ComplexTransform const kTransformB = TransformB;
using UnderlyingOperator = typename GemmUniversalWithBroadcast<
ElementB,
typename layout::LayoutTranspose<LayoutB>::type,
ElementA,
typename layout::LayoutTranspose<LayoutA>::type,
ElementC,
layout::RowMajor,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
kAlignmentB,
kAlignmentA,
Operator,
kTransformB,
kTransformA
>::Base;
using GemmKernel = typename UnderlyingOperator::GemmKernel;
static int const kAlignmentC = EpilogueOutputOp::kCount;
/// Argument structure
using Arguments = typename UnderlyingOperator::Arguments;
private:
UnderlyingOperator underlying_operator_;
public:
/// Constructs the GEMM.
GemmUniversalWithBroadcast() { }
/// Helper to construct a transposed equivalent for the underying GEMM operator
static Arguments to_underlying_arguments(Arguments const &args) {
return args.transposed_problem();
}
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args) {
return UnderlyingOperator::can_implement(to_underlying_arguments(args));
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const &args) {
return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
}
/// Computes the grid shape
static dim3 get_grid_shape(Arguments const &args) {
return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args));
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int smem_capacity = -1) {
return UnderlyingOperator::maximum_active_blocks(smem_capacity);
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream);
}
/// Lightweight update given a subset of arguments
Status update(Arguments const &args, void *workspace = nullptr) {
return underlying_operator_.update(to_underlying_arguments(args), workspace);
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr) {
return underlying_operator_.run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr) {
return run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess) {
status = run(stream);
}
return status;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace device
} // namespace gemm
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -52,7 +52,7 @@ namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
>
@ -92,7 +92,7 @@ public:
/// Split-K preserves splits that are 128b aligned
static int const kSplitKAlignment = const_max(
128 / sizeof_bits<ElementA>::value,
128 / sizeof_bits<ElementA>::value,
128 / sizeof_bits<ElementB>::value
);
@ -115,7 +115,8 @@ public:
void const * ptr_A;
void const * ptr_B;
void const * ptr_C;
void const * ptr_C1;
void const * ptr_C2;
void * ptr_D;
void * ptr_Vector;
@ -123,14 +124,16 @@ public:
int64_t batch_stride_A;
int64_t batch_stride_B;
int64_t batch_stride_C;
int64_t batch_stride_C1;
int64_t batch_stride_C2;
int64_t batch_stride_D;
int64_t batch_stride_Vector;
int64_t batch_stride_Tensor;
typename LayoutA::Stride::Index lda;
typename LayoutB::Stride::Index ldb;
typename LayoutC::Stride::Index ldc;
typename LayoutC::Stride::Index ldc1;
typename LayoutC::Stride::Index ldc2;
typename LayoutC::Stride::Index ldd;
typename LayoutC::Stride::Index ldr;
typename LayoutC::Stride::Index ldt;
@ -138,11 +141,62 @@ public:
//
// Methods
//
Arguments():
mode(GemmUniversalMode::kGemm),
batch_count(1),
ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr) { }
Arguments():
mode(GemmUniversalMode::kGemm),
batch_count(1),
ptr_A(nullptr), ptr_B(nullptr), ptr_C1(nullptr), ptr_C2(nullptr), ptr_D(nullptr) { }
/// constructs an arguments structure
Arguments(
GemmUniversalMode mode,
GemmCoord problem_size,
int batch_count,
typename EpilogueOutputOp::Params epilogue,
void const * ptr_A,
void const * ptr_B,
void const * ptr_C1,
void const * ptr_C2,
void * ptr_D,
void * ptr_Vector,
void * ptr_Tensor,
int64_t batch_stride_A,
int64_t batch_stride_B,
int64_t batch_stride_C1,
int64_t batch_stride_C2,
int64_t batch_stride_D,
int64_t batch_stride_Vector,
int64_t batch_stride_Tensor,
typename LayoutA::Stride::Index lda,
typename LayoutB::Stride::Index ldb,
typename LayoutC::Stride::Index ldc1,
typename LayoutC::Stride::Index ldc2,
typename LayoutC::Stride::Index ldd,
typename LayoutC::Stride::Index ldr,
typename LayoutC::Stride::Index ldt
):
mode(mode),
problem_size(problem_size),
batch_count(batch_count),
epilogue(epilogue),
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C1(ptr_C1), ptr_C2(ptr_C2), ptr_D(ptr_D),
ptr_Vector(ptr_Vector),
ptr_Tensor(ptr_Tensor),
batch_stride_A(batch_stride_A),
batch_stride_B(batch_stride_B),
batch_stride_C1(batch_stride_C1),
batch_stride_C2(batch_stride_C2),
batch_stride_D(batch_stride_D),
batch_stride_Vector(batch_stride_Vector),
batch_stride_Tensor(batch_stride_Tensor),
lda(lda), ldb(ldb), ldc1(ldc1), ldc2(ldc2), ldd(ldd), ldr(ldr), ldt(ldt)
{
CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Arguments::Arguments() - problem_size: " << problem_size);
CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction);
CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor);
CUTLASS_TRACE_HOST(" ldr: " << this->ldr);
CUTLASS_TRACE_HOST(" ldt: " << this->ldt);
}
/// constructs an arguments structure
Arguments(
@ -168,33 +222,18 @@ public:
typename LayoutC::Stride::Index ldd,
typename LayoutC::Stride::Index ldr,
typename LayoutC::Stride::Index ldt
):
mode(mode),
problem_size(problem_size),
batch_count(batch_count),
epilogue(epilogue),
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
ptr_Vector(ptr_Vector),
ptr_Tensor(ptr_Tensor),
batch_stride_A(batch_stride_A),
batch_stride_B(batch_stride_B),
batch_stride_C(batch_stride_C),
batch_stride_D(batch_stride_D),
batch_stride_Vector(batch_stride_Vector),
batch_stride_Tensor(batch_stride_Tensor),
lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), ldr(ldr), ldt(ldt)
{
CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Arguments::Arguments() - problem_size: " << problem_size);
CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction);
CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor);
CUTLASS_TRACE_HOST(" ldr: " << this->ldr);
CUTLASS_TRACE_HOST(" ldt: " << this->ldt);
}
): Arguments(
mode, problem_size, batch_count, epilogue,
ptr_A, ptr_B, ptr_C1, nullptr, ptr_D, ptr_Vector, ptr_Tensor,
batch_stride_A, batch_stride_B, batch_stride_C, 0, batch_stride_D,
batch_stride_Vector, batch_stride_Tensor,
lda, ldb, ldc, 0, ldd, ldr, ldt) {}
/// Returns arguments for the transposed problem
Arguments transposed_problem() const {
Arguments args(*this);
std::swap(args.problem_size.m(), args.problem_size.n());
std::swap(args.ptr_A, args.ptr_B);
std::swap(args.lda, args.ldb);
@ -217,10 +256,11 @@ public:
typename Mma::IteratorA::Params params_A;
typename Mma::IteratorB::Params params_B;
typename Epilogue::OutputTileIterator::Params params_C;
typename Epilogue::OutputTileIterator::Params params_C1;
typename Epilogue::OutputTileIterator::Params params_C2;
typename Epilogue::OutputTileIterator::Params params_D;
typename Epilogue::TensorTileIterator::Params params_Tensor;
typename EpilogueOutputOp::Params output_op;
@ -230,9 +270,10 @@ public:
void * ptr_A;
void * ptr_B;
void * ptr_C;
void * ptr_C1;
void * ptr_C2;
void * ptr_D;
void * ptr_Vector;
typename LayoutC::Stride::Index ldr;
@ -240,7 +281,8 @@ public:
int64_t batch_stride_A;
int64_t batch_stride_B;
int64_t batch_stride_C;
int64_t batch_stride_C1;
int64_t batch_stride_C2;
int64_t batch_stride_D;
int64_t batch_stride_Vector;
int64_t batch_stride_Tensor;
@ -256,21 +298,24 @@ public:
swizzle_log_tile(0),
params_A(0),
params_B(0),
params_C(0),
params_C1(0),
params_C2(0),
params_D(0),
batch_count(0),
gemm_k_size(0),
mode(cutlass::gemm::GemmUniversalMode::kGemm),
ptr_A(nullptr),
ptr_B(nullptr),
ptr_C(nullptr),
ptr_C1(nullptr),
ptr_C2(nullptr),
ptr_D(nullptr),
ptr_Vector(nullptr),
ldr(0),
ptr_Tensor(nullptr),
batch_stride_A(0),
batch_stride_B(0),
batch_stride_C(0),
batch_stride_C1(0),
batch_stride_C2(0),
batch_stride_D(0),
batch_stride_Vector(0),
batch_stride_Tensor(0),
@ -288,7 +333,8 @@ public:
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
params_A(args.lda),
params_B(args.ldb),
params_C(args.ldc),
params_C1(args.ldc1),
params_C2(args.ldc2),
params_D(args.ldd),
params_Tensor(args.ldt),
output_op(args.epilogue),
@ -297,15 +343,17 @@ public:
gemm_k_size(gemm_k_size),
ptr_A(const_cast<void *>(args.ptr_A)),
ptr_B(const_cast<void *>(args.ptr_B)),
ptr_C(const_cast<void *>(args.ptr_C)),
ptr_C1(const_cast<void *>(args.ptr_C1)),
ptr_C2(const_cast<void *>(args.ptr_C2)),
ptr_D(args.ptr_D),
ptr_Vector(args.ptr_Vector),
ptr_Vector(args.ptr_Vector),
ldr(args.ldr),
ptr_Tensor(args.ptr_Tensor),
batch_stride_A(args.batch_stride_A),
batch_stride_B(args.batch_stride_B),
batch_stride_C(args.batch_stride_C),
batch_stride_C1(args.batch_stride_C1),
batch_stride_C2(args.batch_stride_C2),
batch_stride_D(args.batch_stride_D),
batch_stride_Vector(args.batch_stride_Vector),
batch_stride_Tensor(args.batch_stride_Tensor),
@ -326,7 +374,8 @@ public:
ptr_A = const_cast<void *>(args.ptr_A);
ptr_B = const_cast<void *>(args.ptr_B);
ptr_C = const_cast<void *>(args.ptr_C);
ptr_C1 = const_cast<void *>(args.ptr_C1);
ptr_C2 = const_cast<void *>(args.ptr_C2);
ptr_D = args.ptr_D;
ptr_Vector = args.ptr_Vector;
@ -335,7 +384,8 @@ public:
batch_stride_A = args.batch_stride_A;
batch_stride_B = args.batch_stride_B;
batch_stride_C = args.batch_stride_C;
batch_stride_C1 = args.batch_stride_C1;
batch_stride_C2 = args.batch_stride_C2;
batch_stride_D = args.batch_stride_D;
batch_stride_Vector = args.batch_stride_Vector;
batch_stride_Tensor = args.batch_stride_Tensor;
@ -364,7 +414,7 @@ public:
//
CUTLASS_DEVICE
GemmWithFusedEpilogue() { }
GemmWithFusedEpilogue() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(
@ -458,7 +508,7 @@ public:
int offset_k = 0;
int problem_size_k = params.problem_size.k();
ElementA *ptr_A = static_cast<ElementA *>(params.ptr_A);
ElementA *ptr_A = static_cast<ElementA *>(params.ptr_A);
ElementB *ptr_B = static_cast<ElementB *>(params.ptr_B);
@ -466,12 +516,12 @@ public:
//
// Fetch pointers based on mode.
//
if (params.mode == GemmUniversalMode::kGemm ||
if (params.mode == GemmUniversalMode::kGemm ||
params.mode == GemmUniversalMode::kGemmSplitKParallel) {
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) {
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
}
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
@ -537,10 +587,10 @@ public:
// Compute threadblock-scoped matrix multiply-add
mma(
gemm_k_iterations,
accumulators,
iterator_A,
iterator_B,
gemm_k_iterations,
accumulators,
iterator_A,
iterator_B,
accumulators);
//
@ -563,28 +613,36 @@ public:
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
ElementC *ptr_C = static_cast<ElementC *>(params.ptr_C);
ElementC *ptr_C1 = static_cast<ElementC *>(params.ptr_C1);
ElementC *ptr_C2 = static_cast<ElementC *>(params.ptr_C2);
ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);
typename Epilogue::ElementTensor *ptr_Tensor = static_cast<typename Epilogue::ElementTensor *>(params.ptr_Tensor);
// Define the reduction output pointer and move to the appropriate place
typename Epilogue::ElementVector *ptr_Vector =
typename Epilogue::ElementVector *ptr_Vector =
static_cast<typename Epilogue::ElementVector *>(params.ptr_Vector);
//
// Fetch pointers based on mode.
//
//
// Special path when split-K not enabled.
//
//
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() == 1) {
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(
params.params_C,
ptr_C,
typename Epilogue::OutputTileIterator iterator_C1(
params.params_C1,
ptr_C1,
params.problem_size.mn(),
thread_idx,
threadblock_offset
);
typename Epilogue::OutputTileIterator iterator_C2(
params.params_C2,
ptr_C2,
params.problem_size.mn(),
thread_idx,
threadblock_offset
@ -610,9 +668,9 @@ public:
// Construct the epilogue
Epilogue epilogue(
shared_storage.epilogue,
thread_idx,
warp_idx,
shared_storage.epilogue,
thread_idx,
warp_idx,
lane_idx);
// Move to appropriate location for this output tile
@ -625,7 +683,8 @@ public:
ptr_Vector,
iterator_D,
accumulators,
iterator_C,
iterator_C1,
iterator_C2,
tensor_iterator,
params.problem_size.mn(),
threadblock_offset);
@ -637,7 +696,7 @@ public:
// Slower path when split-K or batching is needed
//
#if SPLIT_K_ENABLED
// Construct the semaphore.
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
@ -646,7 +705,7 @@ public:
// If performing a reduction via split-K, fetch the initial synchronization
if (params.grid_tiled_shape.k() > 1) {
// Fetch the synchronization lock initially but do not block.
semaphore.fetch();
@ -658,7 +717,10 @@ public:
ptr_D += threadblock_tile_offset.k() * params.batch_stride_D;
}
else if (params.mode == GemmUniversalMode::kBatched) {
ptr_C += threadblock_tile_offset.k() * params.batch_stride_C;
ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C1;
if (ptr_C2) {
ptr_C2 += threadblock_tile_offset.k() * params.batch_stride_C2;
}
ptr_D += threadblock_tile_offset.k() * params.batch_stride_D;
if (ptr_Tensor) {
ptr_Tensor += threadblock_tile_offset.k() * params.batch_stride_Tensor;
@ -668,7 +730,10 @@ public:
}
}
else if (params.mode == GemmUniversalMode::kArray) {
ptr_C = static_cast<ElementC * const *>(params.ptr_C)[threadblock_tile_offset.k()];
ptr_C1 = static_cast<ElementC * const *>(params.ptr_C1)[threadblock_tile_offset.k()];
if (ptr_C2) {
ptr_C2 = static_cast<ElementC * const *>(params.ptr_C2)[threadblock_tile_offset.k()];
}
ptr_D = static_cast<ElementC * const *>(params.ptr_D)[threadblock_tile_offset.k()];
if (ptr_Tensor) {
ptr_Tensor = static_cast<typename Epilogue::ElementTensor * const *>(params.ptr_Tensor)[threadblock_tile_offset.k()];
@ -680,9 +745,16 @@ public:
#endif
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(
params.params_C,
ptr_C,
typename Epilogue::OutputTileIterator iterator_C1(
params.params_C1,
ptr_C1,
params.problem_size.mn(),
thread_idx,
threadblock_offset
);
typename Epilogue::OutputTileIterator iterator_C2(
params.params_C2,
ptr_C2,
params.problem_size.mn(),
thread_idx,
threadblock_offset
@ -711,18 +783,18 @@ public:
// Construct the epilogue
Epilogue epilogue(
shared_storage.epilogue,
thread_idx,
warp_idx,
shared_storage.epilogue,
thread_idx,
warp_idx,
lane_idx);
#if SPLIT_K_ENABLED
// Wait on the semaphore - this latency may have been covered by iterator construction
if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) {
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
if (threadblock_tile_offset.k()) {
iterator_C = iterator_D;
iterator_C1 = iterator_D;
}
semaphore.wait(threadblock_tile_offset.k());
@ -744,7 +816,8 @@ public:
: ptr_Vector,
iterator_D,
accumulators,
iterator_C,
iterator_C1,
iterator_C2,
tensor_iterator,
params.problem_size.mn(),
threadblock_offset);
@ -754,7 +827,7 @@ public:
//
#if SPLIT_K_ENABLED
if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) {
if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) {
int lock = 0;
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
@ -766,7 +839,7 @@ public:
// Otherwise, the semaphore is incremented
lock = threadblock_tile_offset.k() + 1;
}
semaphore.release(lock);
}
#endif

View File

@ -60,6 +60,13 @@ add_custom_target(
test_unit_gemv_device
)
add_custom_target(
cutlass_test_unit_gemm_device_new
DEPENDS
cutlass_test_unit_gemm_device_exp
cutlass_test_unit_gemm_device_new
)
cutlass_test_unit_add_executable(
cutlass_test_unit_gemm_device_simt
@ -68,7 +75,7 @@ cutlass_test_unit_add_executable(
simt_sgemm_nt_sm80.cu
simt_sgemm_tn_sm80.cu
simt_cgemm_nn_sm50.cu
simt_cgemm_nt_sm50.cu
simt_cgemm_tn_sm50.cu
@ -163,7 +170,7 @@ cutlass_test_unit_add_executable(
gemm_s4n_s4t_s4n_tensor_op_s32_sm75.cu
gemm_s4t_s4n_s32t_tensor_op_s32_sm75.cu
gemm_s4t_s4n_s32n_tensor_op_s32_sm75.cu
gemm_s4t_s4n_s4n_tensor_op_s32_sm75.cu
gemm_s4t_s4n_s4n_tensor_op_s32_sm75.cu
gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu
gemm_b1t_b1n_s32t_tensor_op_s32_sm75.cu
@ -213,7 +220,7 @@ cutlass_test_unit_add_executable(
gemm_tf32t_tf32n_f32t_tensor_op_f32_sm80.cu
gemm_tf32n_tf32t_f32t_tensor_op_f32_sm80.cu
gemm_tf32n_tf32n_f32t_tensor_op_f32_sm80.cu
gemm_tf32t_tf32t_f32t_tensor_op_f32_sm80.cu
gemm_tf32t_tf32t_f32t_tensor_op_f32_sm80.cu
gemm_universal_cf32n_cf32n_cf32n_tensor_op_f32_sm80.cu
gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32_sm80.cu
gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32_sm80.cu
@ -232,7 +239,7 @@ cutlass_test_unit_add_executable(
gemm_f64n_f64t_f64t_tensor_op_f64_sm80.cu
gemm_f64t_f64n_f64t_tensor_op_f64_sm80.cu
gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu
gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu
gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu
gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu
gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm80.cu
@ -253,7 +260,7 @@ cutlass_test_unit_add_executable(
gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu
gemm_s4t_s4n_s32n_tensor_op_s32_sm80.cu
gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu
gemm_s4t_s4n_s4n_tensor_op_s32_sm80.cu
gemm_s4t_s4n_s4n_tensor_op_s32_sm80.cu
gemm_s4t_s4n_s4t_tensor_op_s32_sm80.cu
gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu
gemm_b1t_b1n_s32t_tensor_op_s32_sm80.cu
@ -326,8 +333,8 @@ cutlass_test_unit_add_executable(
BATCH_SOURCES ON
BATCH_SIZE 4
gemm_planar_complex_f16_f16_f32_tensor_op_sm70.cu
gemm_planar_complex_f16_f16_f32_tensor_op_sm75.cu
gemm_planar_complex_f16_f16_f32_tensor_op_sm70.cu
gemm_planar_complex_f16_f16_f32_tensor_op_sm75.cu
gemm_planar_complex_f16_f16_f32_tensor_op_sm80.cu
)
@ -403,7 +410,7 @@ add_dependencies(
cutlass_test_unit_add_executable(
cutlass_test_unit_gemm_device_gemm_with_fused_epilogue_tensorop
gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu
gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu
@ -430,10 +437,10 @@ cutlass_test_unit_add_executable(
BATCH_SOURCES ON
BATCH_SIZE 4
## SYRK
## SYRK
# Syrk SM80 f64 tests
syrk_f64n_f64t_tensor_op_f64_sm80.cu
syrk_f64t_f64n_tensor_op_f64_sm80.cu
syrk_f64n_f64t_tensor_op_f64_sm80.cu
syrk_f64t_f64n_tensor_op_f64_sm80.cu
# Syrk SM80 f32 tests
syrk_tf32n_f32t_tensor_op_f32_sm80.cu
@ -452,7 +459,7 @@ cutlass_test_unit_add_executable(
syrk_cf32n_cf32t_tensor_op_fast_f32_sm80.cu
syrk_cf32n_cf32n_tensor_op_fast_f32_sm80.cu
## HERK
## HERK
# Herk SM80 complex f64 tests
herk_cf64h_cf64n_tensor_op_f64_sm80.cu
@ -550,7 +557,7 @@ cutlass_test_unit_add_executable(
hemm_cf64h_cf64n_cf64n_tensor_op_ls_f64_sm80.cu
hemm_cf64h_cf64n_cf64n_tensor_op_rs_f64_sm80.cu
hemm_cf64h_cf64n_cf64n_tensor_op_ls_f64_gaussian_sm80.cu
# Hemm SM80 complex f32 tests
hemm_cf32h_cf32n_tensor_op_f32_ls_sm80.cu
hemm_cf32h_cf32n_tensor_op_f32_rs_sm80.cu
@ -581,4 +588,10 @@ cutlass_test_unit_add_executable(
her2k_cf64h_cf64n_tensor_op_f64_grouped_sm80.cu
)
cutlass_test_unit_add_executable(
cutlass_test_unit_gemm_device_exp
gemm_broadcast_test.cu
)
endif()

View File

@ -0,0 +1,474 @@
#include <fstream>
#include "cutlass/cutlass.h"
#include "cutlass/functional.h"
#include "cutlass/gemm/kernel/default_gemm_with_broadcast.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal_with_broadcast.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/epilogue/thread/linear_combination_bias_relu.h"
#include "cutlass/epilogue/thread/linear_combination_residual_block.h"
#include "../../common/cutlass_unit_test.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_elementwise.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/gemm.h"
template<typename GemmElement, typename LayoutA, typename LayoutB, typename LayoutC>
struct TestbedUtils {
/// Initialization
cutlass::Distribution::Kind init_A;
cutlass::Distribution::Kind init_B;
cutlass::Distribution::Kind init_C;
cutlass::Distribution::Kind init_D;
uint64_t seed;
cutlass::HostTensor<GemmElement, LayoutA> tensor_A; // Input A
cutlass::HostTensor<GemmElement, LayoutB> tensor_B; // Input B
cutlass::HostTensor<GemmElement, LayoutC> tensor_C; // Input C
cutlass::HostTensor<GemmElement, LayoutC> tensor_D1; // Input D
cutlass::HostTensor<GemmElement, LayoutC> tensor_D2; // Input D
cutlass::HostTensor<GemmElement, LayoutC> tensor_Y1; // Input Y
cutlass::HostTensor<GemmElement, LayoutC> tensor_Y2; // Input Y
cutlass::HostTensor<GemmElement, LayoutC> tensor_Y_ref;
cutlass::HostTensor<GemmElement, LayoutC> tensor_Y_transpose_ref;
//
// Methods
//
TestbedUtils(
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_D_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080
):
init_A(init_A_), init_B(init_B_), init_C(init_C_), init_D(init_D_), seed(seed_) { }
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
double scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
int bits_output = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
} else if (bits_output == 16) {
scope_max = 5;
scope_min = -5;
} else {
scope_max = 8;
scope_min = -8;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(
view.data(), view.capacity());
}
else {
// TODO: Implement the rest
EXPECT_TRUE(false) << "Not implemented";
return false;
}
return true;
}
/// Initializes data structures
void initialize(cutlass::gemm::GemmCoord problem_size) {
//
// Allocate the GEMM workspace
//
tensor_A.resize(problem_size.mk());
tensor_B.resize(problem_size.kn());
tensor_C.resize({1, problem_size.n()});
tensor_D1.resize(problem_size.mn());
tensor_D2.resize(problem_size.mn());
tensor_Y1.resize(problem_size.mn());
tensor_Y2.resize(problem_size.mn());
tensor_Y_ref.resize(problem_size.mn());
tensor_Y_transpose_ref.resize(problem_size.mn());
EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019));
EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018));
EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017));
EXPECT_TRUE(initialize_tensor(tensor_D1.host_view(), init_D, seed + 2016));
EXPECT_TRUE(initialize_tensor(tensor_D2.host_view(), init_D, seed + 2015));
EXPECT_TRUE(initialize_tensor(tensor_Y1.host_view(), cutlass::Distribution::AllZeros, 0));
EXPECT_TRUE(initialize_tensor(tensor_Y2.host_view(), cutlass::Distribution::AllZeros, 0));
EXPECT_TRUE(initialize_tensor(tensor_Y_ref.host_view(), cutlass::Distribution::AllZeros, 0));
EXPECT_TRUE(initialize_tensor(tensor_Y_transpose_ref.host_view(), cutlass::Distribution::AllZeros, 0));
// It is possible to randomly initialize to all zeros, so override this with non-zeros
// in the upper left corner of each operand.
tensor_A.host_view().at({0, 0}) = GemmElement(1);
tensor_B.host_view().at({0, 0}) = GemmElement(1);
tensor_C.host_view().at({0, 0}) = GemmElement(1);
tensor_D1.host_view().at({0, 0}) = GemmElement(1);
tensor_D2.host_view().at({0, 0}) = GemmElement(1);
tensor_A.sync_device();
tensor_B.sync_device();
tensor_C.sync_device();
tensor_D1.sync_device();
tensor_D2.sync_device();
}
/// Compares computed reference with device reference and outputs to a file if incorrect
bool compare_reference(
cutlass::gemm::GemmCoord problem_size, cutlass::HostTensor<GemmElement, LayoutC>& tensor_Y_ref, cutlass::HostTensor<GemmElement, LayoutC>& tensor_Y) {
tensor_Y_ref.sync_host();
tensor_Y.sync_host();
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0);
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0);
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0);
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D2.host_view()), 0);
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Y_ref.host_view()), 0);
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Y.host_view()), 0);
bool passed = true;
float norm_diff = 0;
norm_diff = cutlass::reference::host::TensorNormDiff(tensor_Y_ref.host_view(), tensor_Y.host_view(), float());
passed = (norm_diff <= 0.1f);
EXPECT_LT(norm_diff, 0.1f) << " tensor_Y is incorrect";
std::ofstream file("errors_testbed_gemm_broadcast_new.txt");
file
<< "problem: " << problem_size << "\n\n";
file
<< "capacity: \n"
<< "A: " << tensor_A.capacity()
<< "\nB: " << tensor_B.capacity()
<< "\nC: " << tensor_C.capacity()
<< "\nD1: " << tensor_D1.capacity()
<< "\nD2: " << tensor_D2.capacity()
<< "\nY: " << tensor_Y.capacity()
<< "\n\n"
<< "\nY_ref: " << tensor_Y_ref.capacity()
<< "\n\n";
file
<< "A =\n" << tensor_A.host_view()
<< "\n\nB =\n" << tensor_B.host_view()
<< "\n\nC =\n" << tensor_C.host_view()
<< "\n\nD1 =\n" << tensor_D1.host_view()
<< "\n\nD2 =\n" << tensor_D2.host_view()
<< "\n\nY =\n" << tensor_Y.host_view()
<< "\n\nY_ref =\n" << tensor_Y_ref.host_view();
return passed;
}
};
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
TEST(SM80_Device_GemmWithBroadcast_CUSTOM_f16n_f16n_f16n_tensor_op_f32, 128x128_32x3_64x64x32_16x8x16) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
const int M = 1024;
const int K = 10240;
const int N = 512;
cutlass::gemm::GemmCoord problem_size{M, N, K};
TestbedUtils<cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor> utils;
utils.initialize(problem_size);
{
// Create reference Gemm.
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::ColumnMajor,
ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
2,
1,
1>;
//
// Initialize the GEMM operator
//
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size,
1 /* batch_count */,
{cutlass::half_t(1) /* alpha */, cutlass::half_t(1) /* beta */},
utils.tensor_A.device_data(),
utils.tensor_B.device_data(),
utils.tensor_C.device_data(),
utils.tensor_Y_ref.device_data(),
0,
0,
0,
0,
K,
K,
0,
N,
};
Gemm gemm_op;
size_t workspace_size = Gemm::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
cutlass::Status status = gemm_op.initialize(arguments, workspace.get());
EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status);
//
// Run the GEMM
//
status = gemm_op();
EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status);
}
{
cutlass::gemm::GemmCoord problem_size{N, M, K};
// Create another reference Gemm.
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::ColumnMajor,
ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
2,
1,
1>;
//
// Initialize the GEMM operator
//
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size,
1 /* batch_count */,
{cutlass::half_t(1) /* alpha */, cutlass::half_t(1) /* beta */},
utils.tensor_B.device_data(),
utils.tensor_A.device_data(),
utils.tensor_C.device_data(),
utils.tensor_Y_transpose_ref.device_data(),
0,
0,
0,
0,
K,
K,
0,
N
};
Gemm gemm_op;
size_t workspace_size = Gemm::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
cutlass::Status status = gemm_op.initialize(arguments, workspace.get());
EXPECT_TRUE(status == cutlass::Status::kSuccess) << int(status);
//
// Run the GEMM
//
status = gemm_op();
EXPECT_TRUE(status == cutlass::Status::kSuccess) << int(status);
}
utils.compare_reference(problem_size, utils.tensor_Y_ref, utils.tensor_Y_transpose_ref);
{
// Create GemmWithBroadcast.
using Gemm = cutlass::gemm::device::GemmUniversalWithBroadcast<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::ColumnMajor,
ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombinationResidualBlock<
ElementOutput, ElementAccumulator, ElementAccumulator,
ElementAccumulator, 128 / cutlass::sizeof_bits<ElementOutput>::value,
cutlass::epilogue::thread::Identity, cutlass::multiplies, cutlass::epilogue::thread::Identity>,
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
2,
1,
1>;
//
// Initialize the GEMM operator
//
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size,
1 /* batch_count */,
{cutlass::half_t(1) /* alpha */, cutlass::half_t(1) /* beta */},
utils.tensor_A.device_data(),
utils.tensor_B.device_data(),
utils.tensor_D1.device_data(),
utils.tensor_Y1.device_data(),
utils.tensor_C.device_data(),
nullptr,
0,
0,
0,
0,
0,
0,
K,
K,
N,
N,
0,
0
};
Gemm gemm_op;
size_t workspace_size = Gemm::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
cutlass::Status status = gemm_op.initialize(arguments, workspace.get());
EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status);
//
// Run the GEMM
//
status = gemm_op();
EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status);
}
utils.tensor_Y_ref.sync_host();
cutlass::reference::host::TensorMul(utils.tensor_Y_ref.host_view(), utils.tensor_D1.host_view());
utils.tensor_Y_ref.sync_device();
utils.compare_reference(problem_size, utils.tensor_Y_ref, utils.tensor_Y1);
{
// Create GemmWithBroadcast.
using Gemm = cutlass::gemm::device::GemmUniversalWithBroadcast<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::ColumnMajor,
ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombinationResidualBlock<
ElementOutput, ElementAccumulator, ElementAccumulator,
ElementAccumulator, 128 / cutlass::sizeof_bits<ElementOutput>::value,
cutlass::epilogue::thread::Identity, cutlass::multiplies, cutlass::epilogue::thread::Identity, cutlass::plus>,
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
2,
1,
1>;
//
// Initialize the GEMM operator
//
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size,
1 /* batch_count */,
{cutlass::half_t(1) /* alpha */, cutlass::half_t(1) /* beta */},
utils.tensor_A.device_data(),
utils.tensor_B.device_data(),
utils.tensor_D1.device_data(),
utils.tensor_D2.device_data(),
utils.tensor_Y2.device_data(),
utils.tensor_C.device_data(),
nullptr,
0,
0,
0,
0,
0,
0,
0,
K,
K,
N,
N,
N,
0,
0
};
Gemm gemm_op;
size_t workspace_size = Gemm::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
cutlass::Status status = gemm_op.initialize(arguments, workspace.get());
EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status);
//
// Run the GEMM
//
status = gemm_op();
EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status);
}
utils.tensor_Y_ref.sync_host();
cutlass::reference::host::TensorAdd(utils.tensor_Y_ref.host_view(), utils.tensor_D2.host_view());
utils.tensor_Y_ref.sync_device();
utils.compare_reference(problem_size, utils.tensor_Y_ref, utils.tensor_Y2);
}
#endif

View File

@ -53,7 +53,7 @@ namespace detail {
/// Helper to apply a binary operator in place
template <
typename ElementA,
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
@ -68,8 +68,8 @@ struct TensorFuncBinaryOp {
/// View of left-hand-side tensor
TensorView<ElementD, LayoutD> view_d;
TensorRef<ElementA, LayoutA> ref_a;
TensorRef<ElementB, LayoutB> ref_b;
TensorRef<ElementA, LayoutA> view_a;
TensorRef<ElementB, LayoutB> view_b;
BinaryFunc func;
//
@ -82,8 +82,8 @@ struct TensorFuncBinaryOp {
/// Constructor
TensorFuncBinaryOp(
TensorView<ElementD, LayoutD> const & view_d_,
TensorRef<ElementA, LayoutA> const & ref_a_,
TensorRef<ElementB, LayoutB> const & ref_b_,
TensorRef<ElementA, LayoutA> const & view_a_,
TensorRef<ElementB, LayoutB> const & view_b_,
BinaryFunc func = BinaryFunc()
):
view_d(view_d_), view_a(view_a_), view_b(view_b_), func(func) { }
@ -118,7 +118,7 @@ void TensorAdd(
) {
detail::TensorFuncBinaryOp<
ElementD,
ElementD,
LayoutD,
ElementA,
LayoutA,
@ -129,7 +129,7 @@ void TensorAdd(
TensorForEach(
d.extent(),
func);
func);
}
/// Adds a tensor in place: d = d .+ a
@ -164,7 +164,7 @@ void TensorSub(
) {
detail::TensorFuncBinaryOp<
ElementD,
ElementD,
LayoutD,
ElementA,
LayoutA,
@ -191,7 +191,7 @@ void TensorSub(
TensorView<ElementD, LayoutD> d, ///< destination tensor view
TensorRef<ElementA, LayoutA> a ///< A tensor reference
) {
TensorSub(d, d, a);
}
@ -211,9 +211,9 @@ void TensorMul(
TensorRef<ElementA, LayoutA> a, ///< A tensor reference
TensorRef<ElementB, LayoutB> b ///< B tensor reference
) {
detail::TensorFuncBinaryOp<
ElementD,
ElementD,
LayoutD,
ElementA,
LayoutA,
@ -257,9 +257,9 @@ void TensorDiv(
TensorRef<ElementA, LayoutA> a, ///< A tensor reference
TensorRef<ElementB, LayoutB> b ///< B tensor reference
) {
detail::TensorFuncBinaryOp<
ElementD,
ElementD,
LayoutD,
ElementA,
LayoutA,
@ -284,7 +284,7 @@ void TensorDiv(
TensorView<ElementD, LayoutD> d, ///< destination tensor view
TensorRef<ElementA, LayoutA> a ///< A tensor reference
) {
TensorMul(d, d, a);
TensorDiv(d, d, a);
}
@ -304,15 +304,15 @@ void TensorModulus(
TensorRef<ElementA, LayoutA> a, ///< A tensor reference
TensorRef<ElementB, LayoutB> b ///< B tensor reference
) {
detail::TensorFuncBinaryOp<
ElementD,
ElementD,
LayoutD,
ElementA,
LayoutA,
ElementB,
LayoutB,
cutlass::modulus<ElementD>
cutlass::divides<ElementD>
> func(d, a, b);
TensorForEach(
@ -331,7 +331,7 @@ void TensorModulus(
TensorView<ElementD, LayoutD> d, ///< destination tensor view
TensorRef<ElementA, LayoutA> a ///< A tensor reference
) {
TensorMul(d, d, a);
TensorModulus(d, d, a);
}
///////////////////////////////////////////////////////////////////////////////////////////////////