gemm_universal_with_broadcast, +2 sources.
This commit is contained in:
@ -53,6 +53,7 @@ namespace thread {
|
||||
|
||||
template <typename T>
|
||||
struct Identity {
|
||||
static const bool kIsHeavy=false;
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T value) const {
|
||||
return value;
|
||||
@ -160,7 +161,7 @@ struct LeakyReLU {
|
||||
Params():
|
||||
LinearCombinationGenericParams<T>(),
|
||||
leaky_alpha(T(1)) {}
|
||||
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
T alpha,
|
||||
@ -194,7 +195,7 @@ struct LeakyReLU<Array<T, N> > {
|
||||
Params():
|
||||
LinearCombinationGenericParams<T>(),
|
||||
leaky_alpha(T(1)) {}
|
||||
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
T alpha,
|
||||
@ -464,7 +465,7 @@ struct HardSwish<Array<half_t, N> > {
|
||||
maximum<Array<T, N> > mx;
|
||||
multiplies<Array<T, N> > mul;
|
||||
plus<Array<T, N> > add;
|
||||
|
||||
|
||||
return mul(mul(mn(mx(add(rhs, T(3)), T(0)), T(6)), rhs), T(0.16666667f));
|
||||
}
|
||||
|
||||
@ -493,7 +494,7 @@ struct GELU {
|
||||
return T(cutlass::constants::half<T>() * scalar *
|
||||
(cutlass::constants::one<T>() + (T)erff((float)(scalar / cutlass::constants::root_two<T>()))));
|
||||
}
|
||||
|
||||
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -509,7 +510,7 @@ struct GELU<float> {
|
||||
return cutlass::constants::half<float>() * scalar *
|
||||
(cutlass::constants::one<float>() + erff( scalar / cutlass::constants::root_two<float>() ));
|
||||
}
|
||||
|
||||
|
||||
using Params = LinearCombinationGenericParams<float>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -525,7 +526,7 @@ struct GELU<double> {
|
||||
return cutlass::constants::half<double>() * scalar *
|
||||
(cutlass::constants::one<double>() + erf( scalar / cutlass::constants::root_two<double>() ));
|
||||
}
|
||||
|
||||
|
||||
using Params = LinearCombinationGenericParams<double>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -548,7 +549,7 @@ struct GELU<Array<T, N> > {
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -572,7 +573,7 @@ struct GELU_taylor {
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
|
||||
};
|
||||
|
||||
template <int N>
|
||||
@ -618,7 +619,7 @@ struct GELU_taylor<Array<T, N> > {
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
};
|
||||
|
||||
|
||||
@ -46,11 +46,13 @@ namespace epilogue {
|
||||
namespace thread {
|
||||
|
||||
// /// Models a residual block of the form: UnaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual))
|
||||
// or form UnaryOp(BinaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual1), residual2))
|
||||
template <typename ElementOutput_, typename ElementAccumulator_,
|
||||
typename ElementCompute_, typename ElementC_, int ElementsPerAccess,
|
||||
template <typename T> class ActivationOp_,
|
||||
template <typename T> class BinaryOp_,
|
||||
template <typename T> class UnaryOp_>
|
||||
template <typename T> class BinaryOp1_,
|
||||
template <typename T> class UnaryOp_,
|
||||
template <typename T> class BinaryOp2_=BinaryOp1_>
|
||||
class LinearCombinationResidualBlock {
|
||||
public:
|
||||
|
||||
@ -62,7 +64,8 @@ public:
|
||||
static int const kCount = kElementsPerAccess;
|
||||
|
||||
using UnaryOp = UnaryOp_<Array<ElementCompute, kCount>>;
|
||||
using BinaryOp = BinaryOp_<Array<ElementCompute, kCount>>;
|
||||
using BinaryOp1 = BinaryOp1_<Array<ElementCompute, kCount>>;
|
||||
using BinaryOp2 = BinaryOp2_<Array<ElementCompute, kCount>>;
|
||||
using ActivationOp = ActivationOp_<Array<ElementCompute, kCount>>;
|
||||
|
||||
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
@ -138,7 +141,7 @@ public:
|
||||
FragmentC const &residual,
|
||||
FragmentCompute const &bias) const {
|
||||
UnaryOp unary_op;
|
||||
BinaryOp binary_op;
|
||||
BinaryOp1 binary_op;
|
||||
ActivationOp activation;
|
||||
|
||||
FragmentCompute tmp_Accum =
|
||||
@ -154,6 +157,31 @@ public:
|
||||
frag_Z = convert_z(result_Z);
|
||||
}
|
||||
|
||||
/// Applies the operation UnaryOp(BinaryOp(BinaryOp(ActivationOp(AB + bias), residual1), residual2))
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(FragmentOutput &frag_Z, FragmentOutput &, FragmentAccumulator const &AB,
|
||||
FragmentC const &residual1, FragmentC const &residual2,
|
||||
FragmentCompute const &bias) const {
|
||||
UnaryOp unary_op;
|
||||
BinaryOp1 binary_op1;
|
||||
BinaryOp2 binary_op2;
|
||||
ActivationOp activation;
|
||||
|
||||
FragmentCompute tmp_Accum =
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
|
||||
FragmentCompute tmp_residual1 =
|
||||
NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(residual1);
|
||||
FragmentCompute tmp_residual2 =
|
||||
NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(residual2);
|
||||
|
||||
FragmentCompute z =
|
||||
binary_op2(binary_op1(activation(alpha_ * tmp_Accum + bias), beta_ * tmp_residual1), beta_ * tmp_residual2);
|
||||
FragmentCompute result_Z = skip_elementwise_ ? z : unary_op(z);
|
||||
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> convert_z;
|
||||
frag_Z = convert_z(result_Z);
|
||||
}
|
||||
|
||||
/// Should never be called
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(FragmentOutput &, FragmentOutput &, FragmentAccumulator const &,
|
||||
|
||||
@ -89,7 +89,7 @@ template <
|
||||
bool StoreT = true
|
||||
>
|
||||
struct EpilogueWithBroadcastOpBase {
|
||||
|
||||
|
||||
using ElementOutput = ElementC_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
@ -119,7 +119,7 @@ struct EpilogueWithBroadcastOpBase {
|
||||
/// Constructor from Params
|
||||
EpilogueWithBroadcastOpBase(Params const ¶ms_) { }
|
||||
|
||||
/// Determine if the source is needed. May return false if
|
||||
/// Determine if the source is needed. May return false if
|
||||
bool is_source_needed() const {
|
||||
return true;
|
||||
}
|
||||
@ -130,19 +130,19 @@ struct EpilogueWithBroadcastOpBase {
|
||||
/// Applies the operation when is_source_needed() is true
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
FragmentZ &frag_Z,
|
||||
FragmentT &frag_T,
|
||||
FragmentZ &frag_Z,
|
||||
FragmentT &frag_T,
|
||||
FragmentAccumulator const &AB,
|
||||
FragmentC const &frag_C,
|
||||
FragmentC const &frag_C1,
|
||||
FragmentC const &frag_C2,
|
||||
FragmentCompute const &V) const {
|
||||
|
||||
}
|
||||
|
||||
/// Applies the operation when is_source_needed() is false
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
FragmentZ &frag_Z,
|
||||
FragmentT &frag_T,
|
||||
FragmentZ &frag_Z,
|
||||
FragmentT &frag_T,
|
||||
FragmentAccumulator const &AB,
|
||||
FragmentCompute const &V) const {
|
||||
|
||||
@ -160,11 +160,11 @@ struct EpilogueWithBroadcastOpBase {
|
||||
///
|
||||
/// if (ElementwiseOp::kStoreZ) {
|
||||
/// store(converted_u);
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// if (ElementwiseOp::kStoreT) {
|
||||
/// store(v);
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
template <
|
||||
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
|
||||
@ -182,24 +182,24 @@ template <
|
||||
int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large
|
||||
(!IsEpilogueFunctorHeavy<OutputOp_>::value)
|
||||
>
|
||||
class EpilogueWithBroadcast :
|
||||
class EpilogueWithBroadcast :
|
||||
public EpilogueBase<
|
||||
Shape_,
|
||||
typename WarpMmaOperator_::Shape,
|
||||
PartitionsK,
|
||||
AccumulatorFragmentIterator_,
|
||||
WarpTileIterator_,
|
||||
Shape_,
|
||||
typename WarpMmaOperator_::Shape,
|
||||
PartitionsK,
|
||||
AccumulatorFragmentIterator_,
|
||||
WarpTileIterator_,
|
||||
Padding_,
|
||||
FragmentsPerPartition> {
|
||||
|
||||
public:
|
||||
|
||||
using Base = EpilogueBase<
|
||||
Shape_,
|
||||
typename WarpMmaOperator_::Shape,
|
||||
PartitionsK,
|
||||
AccumulatorFragmentIterator_,
|
||||
WarpTileIterator_,
|
||||
Shape_,
|
||||
typename WarpMmaOperator_::Shape,
|
||||
PartitionsK,
|
||||
AccumulatorFragmentIterator_,
|
||||
WarpTileIterator_,
|
||||
Padding_,
|
||||
FragmentsPerPartition>;
|
||||
|
||||
@ -235,7 +235,7 @@ public:
|
||||
|
||||
/// Fragment object used to store the broadcast values
|
||||
using BroadcastFragment = Array<
|
||||
ElementCompute,
|
||||
ElementCompute,
|
||||
ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>;
|
||||
|
||||
/// Output element
|
||||
@ -261,14 +261,14 @@ public:
|
||||
typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
||||
|
||||
/// Array type used by output functor
|
||||
using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
||||
using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
||||
|
||||
/// Array type used by output functor
|
||||
using ComputeAccessType = Array<ElementCompute, OutputTileIterator::kElementsPerAccess>;
|
||||
|
||||
/// Tensor access type
|
||||
using TensorAccessType = Array<ElementTensor, OutputTileIterator::kElementsPerAccess>;
|
||||
|
||||
|
||||
/// Number of warps
|
||||
using WarpCount = typename Base::WarpCount;
|
||||
|
||||
@ -304,7 +304,7 @@ public:
|
||||
/// I'm not sure what I meant here.
|
||||
static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount);
|
||||
|
||||
/// Shape of the shared memory allocation for the epilogue
|
||||
/// Shape of the shared memory allocation for the epilogue
|
||||
using StorageShape = MatrixShape<
|
||||
kThreadRows,
|
||||
Shape::kN
|
||||
@ -351,7 +351,7 @@ public:
|
||||
|
||||
static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero.");
|
||||
|
||||
static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
|
||||
static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
|
||||
"Divisibility");
|
||||
|
||||
private:
|
||||
@ -367,7 +367,7 @@ public:
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
EpilogueWithBroadcast(
|
||||
SharedStorage &shared_storage, ///< Shared storage object
|
||||
SharedStorage &shared_storage, ///< Shared storage object
|
||||
int thread_idx, ///< ID of a thread within the threadblock
|
||||
int warp_idx, ///< ID of warp within threadblock
|
||||
int lane_idx ///< Id of thread within warp
|
||||
@ -386,32 +386,34 @@ public:
|
||||
ElementVector const * broadcast_ptr, ///< Broadcast vector
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix
|
||||
OutputTileIterator source_iterator1, ///< Tile iterator for source accumulator matrix
|
||||
OutputTileIterator source_iterator2, ///< Tile iterator for source accumulator matrix
|
||||
TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand
|
||||
MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
|
||||
MatrixCoord(Shape::kM, Shape::kN),
|
||||
MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
|
||||
MatrixCoord()) {
|
||||
|
||||
|
||||
BroadcastFragment broadcast_fragment;
|
||||
|
||||
load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset);
|
||||
|
||||
if (!output_op.is_source_needed()) {
|
||||
compute_source_not_needed_(
|
||||
output_op,
|
||||
broadcast_fragment,
|
||||
destination_iterator,
|
||||
output_op,
|
||||
broadcast_fragment,
|
||||
destination_iterator,
|
||||
accumulators,
|
||||
tensor_iterator);
|
||||
}
|
||||
else {
|
||||
compute_source_needed_(
|
||||
output_op,
|
||||
broadcast_fragment,
|
||||
destination_iterator,
|
||||
accumulators,
|
||||
source_iterator,
|
||||
output_op,
|
||||
broadcast_fragment,
|
||||
destination_iterator,
|
||||
accumulators,
|
||||
source_iterator1,
|
||||
source_iterator2,
|
||||
tensor_iterator);
|
||||
}
|
||||
}
|
||||
@ -427,7 +429,7 @@ private:
|
||||
) {
|
||||
|
||||
broadcast_fragment.clear();
|
||||
|
||||
|
||||
// If no pointer is supplied, set with all zeros and avoid memory accesses
|
||||
if (!broadcast_ptr) {
|
||||
return;
|
||||
@ -513,9 +515,9 @@ private:
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand
|
||||
) {
|
||||
) {
|
||||
|
||||
//
|
||||
// Iterator over warp-level accumulator fragment
|
||||
@ -525,7 +527,7 @@ private:
|
||||
|
||||
//
|
||||
// Iterate over accumulator tile
|
||||
//
|
||||
//
|
||||
|
||||
// CUTLASS_PRAGMA_UNROLL
|
||||
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1)
|
||||
@ -534,7 +536,7 @@ private:
|
||||
//
|
||||
// Convert and store fragment
|
||||
//
|
||||
|
||||
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@ -638,7 +640,7 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void compute_source_needed_(
|
||||
@ -646,12 +648,15 @@ private:
|
||||
BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
OutputTileIterator source_iterator1, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
OutputTileIterator source_iterator2, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand
|
||||
) {
|
||||
|
||||
typename OutputTileIterator::Fragment source_fragment;
|
||||
source_fragment.clear();
|
||||
) {
|
||||
|
||||
typename OutputTileIterator::Fragment source_fragment1;
|
||||
source_fragment1.clear();
|
||||
typename OutputTileIterator::Fragment source_fragment2;
|
||||
source_fragment2.clear();
|
||||
|
||||
//
|
||||
// Iterator over warp-level accumulator fragment
|
||||
@ -661,7 +666,7 @@ private:
|
||||
|
||||
//
|
||||
// Iterate over accumulator tile
|
||||
//
|
||||
//
|
||||
|
||||
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
|
||||
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
|
||||
@ -670,13 +675,18 @@ private:
|
||||
// Load the source
|
||||
//
|
||||
|
||||
source_iterator.load(source_fragment);
|
||||
++source_iterator;
|
||||
source_iterator1.load(source_fragment1);
|
||||
++source_iterator1;
|
||||
|
||||
if (source_iterator2.enabled()) {
|
||||
source_iterator2.load(source_fragment2);
|
||||
++source_iterator2;
|
||||
}
|
||||
|
||||
//
|
||||
// Convert and store fragment
|
||||
//
|
||||
|
||||
|
||||
__syncthreads();
|
||||
|
||||
acc2smem_source_needed<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
|
||||
@ -720,8 +730,10 @@ private:
|
||||
frag_T,
|
||||
output_op,
|
||||
aligned_accum_fragment[0],
|
||||
source_fragment,
|
||||
broadcast_fragment);
|
||||
source_fragment1,
|
||||
source_fragment2,
|
||||
broadcast_fragment,
|
||||
source_iterator2.enabled());
|
||||
|
||||
//
|
||||
// Conditionally store fragments
|
||||
@ -746,8 +758,10 @@ private:
|
||||
typename TensorTileIterator::Fragment &frag_T,
|
||||
OutputOp const &output_op,
|
||||
typename SharedLoadIterator::Fragment const &frag_AB,
|
||||
typename OutputTileIterator::Fragment const &frag_C,
|
||||
BroadcastFragment const &frag_Broadcast) {
|
||||
typename OutputTileIterator::Fragment const &frag_C1,
|
||||
typename OutputTileIterator::Fragment const &frag_C2,
|
||||
BroadcastFragment const &frag_Broadcast,
|
||||
bool frag_C2_enabled) {
|
||||
|
||||
using AccessTypeZ = Array<typename OutputTileIterator::Element, kElementsPerAccess>;
|
||||
using AccessTypeT = Array<typename TensorTileIterator::Element, kElementsPerAccess>;
|
||||
@ -755,28 +769,39 @@ private:
|
||||
|
||||
AccessTypeZ *frag_Z_ptr = reinterpret_cast<AccessTypeZ *>(&frag_Z);
|
||||
AccessTypeT *frag_T_ptr = reinterpret_cast<AccessTypeT *>(&frag_T);
|
||||
|
||||
AccumulatorAccessType const *frag_AB_ptr =
|
||||
|
||||
AccumulatorAccessType const *frag_AB_ptr =
|
||||
reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
|
||||
|
||||
OutputAccessType const *frag_C_ptr =
|
||||
reinterpret_cast<OutputAccessType const *>(&frag_C);
|
||||
OutputAccessType const *frag_C1_ptr =
|
||||
reinterpret_cast<OutputAccessType const *>(&frag_C1);
|
||||
OutputAccessType const *frag_C2_ptr =
|
||||
reinterpret_cast<OutputAccessType const *>(&frag_C2);
|
||||
|
||||
AccessTypeBroadcast const *frag_Broadcast_ptr =
|
||||
reinterpret_cast<AccessTypeBroadcast const *>(&frag_Broadcast);
|
||||
|
||||
int const kOutputOpIterations =
|
||||
int const kOutputOpIterations =
|
||||
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kOutputOpIterations; ++i) {
|
||||
|
||||
output_op(
|
||||
frag_Z_ptr[i],
|
||||
frag_T_ptr[i],
|
||||
frag_AB_ptr[i],
|
||||
frag_C_ptr[i],
|
||||
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
|
||||
if (frag_C2_enabled) {
|
||||
output_op(
|
||||
frag_Z_ptr[i],
|
||||
frag_T_ptr[i],
|
||||
frag_AB_ptr[i],
|
||||
frag_C1_ptr[i],
|
||||
frag_C2_ptr[i],
|
||||
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
|
||||
} else {
|
||||
output_op(
|
||||
frag_Z_ptr[i],
|
||||
frag_T_ptr[i],
|
||||
frag_AB_ptr[i],
|
||||
frag_C1_ptr[i],
|
||||
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -795,23 +820,23 @@ private:
|
||||
|
||||
AccessTypeZ *frag_Z_ptr = reinterpret_cast<AccessTypeZ *>(&frag_Z);
|
||||
AccessTypeT *frag_T_ptr = reinterpret_cast<AccessTypeT *>(&frag_T);
|
||||
|
||||
AccumulatorAccessType const *frag_AB_ptr =
|
||||
|
||||
AccumulatorAccessType const *frag_AB_ptr =
|
||||
reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
|
||||
|
||||
AccessTypeBroadcast const *frag_Broadcast_ptr =
|
||||
reinterpret_cast<AccessTypeBroadcast const *>(&frag_Broadcast);
|
||||
|
||||
int const kOutputOpIterations =
|
||||
int const kOutputOpIterations =
|
||||
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kOutputOpIterations; ++i) {
|
||||
|
||||
output_op(
|
||||
frag_Z_ptr[i],
|
||||
frag_T_ptr[i],
|
||||
frag_AB_ptr[i],
|
||||
frag_Z_ptr[i],
|
||||
frag_T_ptr[i],
|
||||
frag_AB_ptr[i],
|
||||
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -100,10 +100,10 @@ public:
|
||||
|
||||
/// Fragment object
|
||||
using Fragment = Array<
|
||||
Element,
|
||||
ThreadMap::Iterations::kColumn *
|
||||
ThreadMap::Iterations::kRow *
|
||||
ThreadMap::Iterations::kGroup *
|
||||
Element,
|
||||
ThreadMap::Iterations::kColumn *
|
||||
ThreadMap::Iterations::kRow *
|
||||
ThreadMap::Iterations::kGroup *
|
||||
ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>;
|
||||
|
||||
/// Memory access size
|
||||
@ -121,15 +121,15 @@ public:
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Layout const &layout):
|
||||
Params(Layout const &layout):
|
||||
PredicatedTileIteratorParams(
|
||||
layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess,
|
||||
make_OutputTileThreadMapDesc<ThreadMap>()
|
||||
)
|
||||
)
|
||||
{ }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Base const &base) :
|
||||
Params(Base const &base) :
|
||||
Base(base) { }
|
||||
};
|
||||
|
||||
@ -176,7 +176,7 @@ private:
|
||||
PredicatedTileIteratorParams params_;
|
||||
|
||||
/// Byte-level pointer. This pointer is usually for both load() and store(), unless PermuteD is performed. When having PermuteD, byte_pointer_ is only for load().
|
||||
uint8_t *byte_pointer_;
|
||||
uint8_t *byte_pointer_{nullptr};
|
||||
|
||||
/// Byte-level pointer for store(). Due to PermuteD Op, store_byte_pointer_ may be with different address computation compared to byte_pointer_.
|
||||
uint8_t *store_byte_pointer_;
|
||||
@ -200,13 +200,13 @@ private:
|
||||
int state_[3];
|
||||
|
||||
/// Scatter indices
|
||||
int const *indices_;
|
||||
int const *indices_;
|
||||
|
||||
/// Whether to perform Permute Op
|
||||
bool PermuteD;
|
||||
/// PermuteDLayout
|
||||
mutable PermuteDLayout permute_layout_;
|
||||
|
||||
|
||||
//
|
||||
// Static asserts about internal strides
|
||||
//
|
||||
@ -236,7 +236,7 @@ public:
|
||||
int thread_idx,
|
||||
TensorCoord threadblock_offset = TensorCoord(),
|
||||
int const *indices = nullptr
|
||||
):
|
||||
):
|
||||
params_(params), indices_(indices)
|
||||
{
|
||||
|
||||
@ -252,13 +252,18 @@ public:
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) {
|
||||
|
||||
mask_.predicates[c] = ((thread_offset.column()
|
||||
mask_.predicates[c] = ((thread_offset.column()
|
||||
+ ThreadMap::Delta::kColumn * c) < extent.column());
|
||||
}
|
||||
|
||||
// Null pointer performs no accesses
|
||||
if (!pointer) {
|
||||
mask_.clear();
|
||||
} else {
|
||||
// Initialize pointer
|
||||
byte_pointer_ = reinterpret_cast<uint8_t *>(pointer) +
|
||||
LongIndex(thread_offset.row()) * LongIndex(params_.stride) +
|
||||
LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess;
|
||||
}
|
||||
|
||||
if (ScatterD && !indices) {
|
||||
@ -266,8 +271,8 @@ public:
|
||||
}
|
||||
|
||||
// Initialize byte_pointer_
|
||||
byte_pointer_ = reinterpret_cast<uint8_t *>(pointer) +
|
||||
LongIndex(thread_offset.row()) * LongIndex(params_.stride) +
|
||||
byte_pointer_ = reinterpret_cast<uint8_t *>(pointer) +
|
||||
LongIndex(thread_offset.row()) * LongIndex(params_.stride) +
|
||||
LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess;
|
||||
|
||||
if (ScatterD) {
|
||||
@ -284,7 +289,7 @@ public:
|
||||
}else{
|
||||
PermuteD = true;
|
||||
store_byte_pointer_ = reinterpret_cast<uint8_t *>(pointer);
|
||||
permute_layout_ = PermuteDLayout(extent,
|
||||
permute_layout_ = PermuteDLayout(extent,
|
||||
params_.stride * kElementsPerAccess / sizeof(AccessType));
|
||||
}
|
||||
|
||||
@ -315,11 +320,11 @@ public:
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
|
||||
|
||||
int frag_row_idx =
|
||||
int frag_row_idx =
|
||||
(row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
|
||||
|
||||
int row_offset = row * ThreadMap::Delta::kRow
|
||||
+ group * ThreadMap::Delta::kGroup
|
||||
int row_offset = row * ThreadMap::Delta::kRow
|
||||
+ group * ThreadMap::Delta::kGroup
|
||||
+ cluster * ThreadMap::Delta::kCluster;
|
||||
|
||||
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
|
||||
@ -339,7 +344,7 @@ public:
|
||||
bool guard = row_guard && mask_.predicates[column];
|
||||
|
||||
cutlass::arch::global_load<
|
||||
AccessType,
|
||||
AccessType,
|
||||
sizeof(AccessType)
|
||||
>(
|
||||
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn +
|
||||
@ -389,11 +394,11 @@ public:
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
|
||||
|
||||
int frag_row_idx =
|
||||
int frag_row_idx =
|
||||
(row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
|
||||
|
||||
int row_offset = row * ThreadMap::Delta::kRow
|
||||
+ group * ThreadMap::Delta::kGroup
|
||||
int row_offset = row * ThreadMap::Delta::kRow
|
||||
+ group * ThreadMap::Delta::kGroup
|
||||
+ cluster * ThreadMap::Delta::kCluster;
|
||||
|
||||
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
|
||||
@ -413,7 +418,7 @@ public:
|
||||
bool guard = row_guard && mask_.predicates[column];
|
||||
|
||||
int col_offset = column * ThreadMap::Delta::kColumn;
|
||||
|
||||
|
||||
if (PermuteD) {
|
||||
int col = col_offset + thread_start_column_;
|
||||
int row = row_offset + thread_start_row_;
|
||||
@ -436,7 +441,7 @@ public:
|
||||
(void *)&memory_pointer[0],
|
||||
guard);
|
||||
}
|
||||
|
||||
|
||||
if (!PermuteD) {
|
||||
memory_pointer += (ThreadMap::Delta::kColumn / kElementsPerAccess);
|
||||
}
|
||||
@ -483,11 +488,11 @@ public:
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
|
||||
|
||||
int frag_row_idx =
|
||||
int frag_row_idx =
|
||||
(row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
|
||||
|
||||
int row_offset = row * ThreadMap::Delta::kRow
|
||||
+ group * ThreadMap::Delta::kGroup
|
||||
int row_offset = row * ThreadMap::Delta::kRow
|
||||
+ group * ThreadMap::Delta::kGroup
|
||||
+ cluster * ThreadMap::Delta::kCluster;
|
||||
|
||||
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
|
||||
@ -511,7 +516,7 @@ public:
|
||||
bool guard = row_guard && mask_.predicates[column];
|
||||
|
||||
cutlass::arch::global_load<
|
||||
AccessType,
|
||||
AccessType,
|
||||
sizeof(AccessType)
|
||||
>(
|
||||
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn +
|
||||
@ -553,11 +558,11 @@ public:
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
|
||||
|
||||
int frag_row_idx =
|
||||
int frag_row_idx =
|
||||
(row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
|
||||
|
||||
int row_offset = row * ThreadMap::Delta::kRow
|
||||
+ group * ThreadMap::Delta::kGroup
|
||||
int row_offset = row * ThreadMap::Delta::kRow
|
||||
+ group * ThreadMap::Delta::kGroup
|
||||
+ cluster * ThreadMap::Delta::kCluster;
|
||||
|
||||
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
|
||||
@ -585,7 +590,7 @@ public:
|
||||
bool guard = row_guard && mask_.predicates[column];
|
||||
|
||||
cutlass::arch::global_load<
|
||||
AccessType,
|
||||
AccessType,
|
||||
sizeof(AccessType)
|
||||
>(
|
||||
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn +
|
||||
@ -655,7 +660,7 @@ public:
|
||||
}
|
||||
|
||||
thread_start_row_ += ThreadMap::Shape::kRow;
|
||||
|
||||
|
||||
if (state_[0] == ThreadMap::Count::kRow) {
|
||||
|
||||
state_[0] = 0;
|
||||
@ -663,7 +668,7 @@ public:
|
||||
byte_pointer_ += params_.advance_group;
|
||||
store_byte_pointer_ += params_.advance_group;
|
||||
|
||||
thread_start_row_ += (ThreadMap::Shape::kGroup - 1) *
|
||||
thread_start_row_ += (ThreadMap::Shape::kGroup - 1) *
|
||||
ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
|
||||
|
||||
if (state_[1] == ThreadMap::Count::kGroup) {
|
||||
@ -673,7 +678,7 @@ public:
|
||||
byte_pointer_ += params_.advance_cluster;
|
||||
store_byte_pointer_ += params_.advance_cluster;
|
||||
|
||||
thread_start_row_ += ThreadMap::Count::kGroup *
|
||||
thread_start_row_ += ThreadMap::Count::kGroup *
|
||||
ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;
|
||||
|
||||
if (state_[2] == ThreadMap::Count::kCluster) {
|
||||
@ -706,6 +711,10 @@ public:
|
||||
CUTLASS_DEVICE void set_mask(Mask const &mask) {
|
||||
mask_ = mask;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool enabled() {
|
||||
return (byte_pointer_ != nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -717,7 +726,7 @@ public:
|
||||
template <
|
||||
typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
|
||||
typename Element_, ///< Element data type
|
||||
int InterleavedN ///< Number of Interleaved N
|
||||
int InterleavedN ///< Number of Interleaved N
|
||||
>
|
||||
class InterleavedPredicatedTileIterator {
|
||||
public:
|
||||
@ -751,14 +760,14 @@ public:
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Layout const &layout):
|
||||
Params(Layout const &layout):
|
||||
Base(
|
||||
layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess,
|
||||
make_InterleavedPredicatedTileIteratorDesc<Element, ThreadMap>()
|
||||
) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Base const &base) :
|
||||
Params(Base const &base) :
|
||||
Base(base) { }
|
||||
};
|
||||
|
||||
@ -862,8 +871,8 @@ public:
|
||||
}
|
||||
|
||||
// Initialize pointer
|
||||
byte_pointer_ = reinterpret_cast<uint8_t *>(pointer) +
|
||||
LongIndex(thread_offset.strided()) * LongIndex(params_.stride) +
|
||||
byte_pointer_ = reinterpret_cast<uint8_t *>(pointer) +
|
||||
LongIndex(thread_offset.strided()) * LongIndex(params_.stride) +
|
||||
LongIndex(thread_offset.contiguous()) * sizeof(AccessType) / kElementsPerAccess;
|
||||
|
||||
// Initialize internal state counter
|
||||
@ -891,7 +900,7 @@ public:
|
||||
bool guard = col_guard && mask_.predicates[iteration_contiguous_];
|
||||
|
||||
cutlass::arch::global_load<
|
||||
AccessType,
|
||||
AccessType,
|
||||
sizeof(AccessType)
|
||||
>(
|
||||
*frag_ptr,
|
||||
@ -1091,7 +1100,7 @@ private:
|
||||
/// Extent of the matrix tile in rows
|
||||
Index extent_row_;
|
||||
|
||||
/// Extent of the matrix tile in pq
|
||||
/// Extent of the matrix tile in pq
|
||||
Index extent_pq_;
|
||||
|
||||
/// A thread's starting row position (assuming steady-state predicates have
|
||||
@ -1133,7 +1142,7 @@ public:
|
||||
):
|
||||
params_(params) {
|
||||
MatrixCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset;
|
||||
|
||||
|
||||
extent_col_ = extent.c();
|
||||
extent_pq_ = extent.h() * extent.w();
|
||||
extent_row_ = extent.n() * extent_pq_;
|
||||
@ -1188,7 +1197,7 @@ public:
|
||||
reinterpret_cast<AccessType const *>(byte_pointer);
|
||||
|
||||
cutlass::arch::global_load<
|
||||
AccessType,
|
||||
AccessType,
|
||||
sizeof(AccessType)
|
||||
>(
|
||||
*frag_ptr,
|
||||
|
||||
354
include/cutlass/gemm/device/gemm_universal_with_broadcast.h
Normal file
354
include/cutlass/gemm/device/gemm_universal_with_broadcast.h
Normal file
@ -0,0 +1,354 @@
|
||||
/*! \file
|
||||
\gemm_universal which takes LinearCombinationResidualBlock as the epilogue output op.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.h"
|
||||
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_with_broadcast.h"
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_base.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace device {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*!
|
||||
The universal GEMM with a broadcast epilogue.
|
||||
Supports
|
||||
*/
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator_ = ElementC_,
|
||||
/// Operator class tag
|
||||
typename OperatorClass_ = arch::OpClassSimt,
|
||||
/// Tag indicating architecture to tune for. This is the minimum SM that
|
||||
/// supports the intended feature. The device kernel can be built
|
||||
/// targeting any SM larger than this number.
|
||||
typename ArchTag_ = arch::Sm70,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::InstructionShape,
|
||||
/// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp'
|
||||
typename EpilogueOutputOp_ = cutlass::epilogue::thread::LinearCombinationBiasElementwise<
|
||||
ElementC_, ElementAccumulator_, ElementAccumulator_,
|
||||
ElementC_, ElementC_, 16 / sizeof(ElementC_)>,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kStages,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kAlignmentA,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kAlignmentB,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::Operator,
|
||||
/// Complex elementwise transformation on A operand
|
||||
ComplexTransform TransformA = ComplexTransform::kNone,
|
||||
/// Complex elementwise transformation on B operand
|
||||
ComplexTransform TransformB = ComplexTransform::kNone
|
||||
>
|
||||
class GemmUniversalWithBroadcast :
|
||||
public GemmUniversalBase<
|
||||
typename kernel::DefaultGemmWithBroadcast<
|
||||
ElementA_,
|
||||
LayoutA_,
|
||||
TransformA,
|
||||
AlignmentA,
|
||||
ElementB_,
|
||||
LayoutB_,
|
||||
TransformB,
|
||||
AlignmentB,
|
||||
ElementC_,
|
||||
LayoutC_,
|
||||
ElementAccumulator_,
|
||||
OperatorClass_,
|
||||
ArchTag_,
|
||||
ThreadblockShape_,
|
||||
WarpShape_,
|
||||
InstructionShape_,
|
||||
EpilogueOutputOp_,
|
||||
ThreadblockSwizzle_,
|
||||
Stages,
|
||||
Operator_
|
||||
>::GemmKernel
|
||||
> {
|
||||
|
||||
public:
|
||||
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using OperatorClass = OperatorClass_;
|
||||
using ArchTag = ArchTag_;
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = InstructionShape_;
|
||||
using EpilogueOutputOp = EpilogueOutputOp_;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
using Operator = Operator_;
|
||||
static int const kStages = Stages;
|
||||
static int const kAlignmentA = AlignmentA;
|
||||
static int const kAlignmentB = AlignmentB;
|
||||
static int const kAlignmentC = EpilogueOutputOp::kCount;
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
|
||||
using Base = GemmUniversalBase<
|
||||
typename kernel::DefaultGemmWithBroadcast<
|
||||
ElementA_,
|
||||
LayoutA_,
|
||||
TransformA,
|
||||
AlignmentA,
|
||||
ElementB_,
|
||||
LayoutB_,
|
||||
TransformB,
|
||||
AlignmentB,
|
||||
ElementC_,
|
||||
LayoutC_,
|
||||
ElementAccumulator_,
|
||||
OperatorClass_,
|
||||
ArchTag_,
|
||||
ThreadblockShape_,
|
||||
WarpShape_,
|
||||
InstructionShape_,
|
||||
EpilogueOutputOp_,
|
||||
ThreadblockSwizzle_,
|
||||
Stages,
|
||||
Operator_
|
||||
>::GemmKernel
|
||||
>;
|
||||
|
||||
using Arguments = typename Base::Arguments;
|
||||
using GemmKernel = typename Base::GemmKernel;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Parital specialization for column-major output exchanges problem size and operand.
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC_,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator_,
|
||||
/// Operator class tag
|
||||
typename OperatorClass_,
|
||||
/// Tag indicating architecture to tune for. This is the minimum SM that
|
||||
/// supports the intended feature. The device kernel can be built
|
||||
/// targeting any SM larger than this number.
|
||||
typename ArchTag_,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape_,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp_,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle_,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_,
|
||||
/// Complex elementwise transformation on A operand
|
||||
ComplexTransform TransformA,
|
||||
/// Complex elementwise transformation on B operand
|
||||
ComplexTransform TransformB>
|
||||
class GemmUniversalWithBroadcast<ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_,
|
||||
layout::ColumnMajor, // partially specialized on LayoutC
|
||||
ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_,
|
||||
WarpShape_, InstructionShape_, EpilogueOutputOp_,
|
||||
ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB,
|
||||
Operator_, TransformA, TransformB> {
|
||||
public:
|
||||
|
||||
using ElementA = ElementA_;
|
||||
using LayoutA = LayoutA_;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
using ElementB = ElementB_;
|
||||
using LayoutB = LayoutB_;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
using ElementC = ElementC_;
|
||||
using LayoutC = layout::ColumnMajor;
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using OperatorClass = OperatorClass_;
|
||||
using ArchTag = ArchTag_;
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = InstructionShape_;
|
||||
using EpilogueOutputOp = EpilogueOutputOp_;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
using Operator = Operator_;
|
||||
static int const kStages = Stages;
|
||||
static int const kAlignmentA = AlignmentA;
|
||||
static int const kAlignmentB = AlignmentB;
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
|
||||
using UnderlyingOperator = typename GemmUniversalWithBroadcast<
|
||||
ElementB,
|
||||
typename layout::LayoutTranspose<LayoutB>::type,
|
||||
ElementA,
|
||||
typename layout::LayoutTranspose<LayoutA>::type,
|
||||
ElementC,
|
||||
layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
kAlignmentB,
|
||||
kAlignmentA,
|
||||
Operator,
|
||||
kTransformB,
|
||||
kTransformA
|
||||
>::Base;
|
||||
|
||||
using GemmKernel = typename UnderlyingOperator::GemmKernel;
|
||||
static int const kAlignmentC = EpilogueOutputOp::kCount;
|
||||
|
||||
/// Argument structure
|
||||
using Arguments = typename UnderlyingOperator::Arguments;
|
||||
|
||||
private:
|
||||
|
||||
UnderlyingOperator underlying_operator_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the GEMM.
|
||||
GemmUniversalWithBroadcast() { }
|
||||
|
||||
/// Helper to construct a transposed equivalent for the underying GEMM operator
|
||||
static Arguments to_underlying_arguments(Arguments const &args) {
|
||||
return args.transposed_problem();
|
||||
}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const &args) {
|
||||
|
||||
return UnderlyingOperator::can_implement(to_underlying_arguments(args));
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
|
||||
return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
|
||||
}
|
||||
|
||||
/// Computes the grid shape
|
||||
static dim3 get_grid_shape(Arguments const &args) {
|
||||
return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args));
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int smem_capacity = -1) {
|
||||
return UnderlyingOperator::maximum_active_blocks(smem_capacity);
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
|
||||
return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream);
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const &args, void *workspace = nullptr) {
|
||||
|
||||
return underlying_operator_.update(to_underlying_arguments(args), workspace);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
|
||||
return underlying_operator_.run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr) {
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace device
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -52,7 +52,7 @@ namespace kernel {
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
|
||||
>
|
||||
@ -92,7 +92,7 @@ public:
|
||||
|
||||
/// Split-K preserves splits that are 128b aligned
|
||||
static int const kSplitKAlignment = const_max(
|
||||
128 / sizeof_bits<ElementA>::value,
|
||||
128 / sizeof_bits<ElementA>::value,
|
||||
128 / sizeof_bits<ElementB>::value
|
||||
);
|
||||
|
||||
@ -115,7 +115,8 @@ public:
|
||||
|
||||
void const * ptr_A;
|
||||
void const * ptr_B;
|
||||
void const * ptr_C;
|
||||
void const * ptr_C1;
|
||||
void const * ptr_C2;
|
||||
void * ptr_D;
|
||||
|
||||
void * ptr_Vector;
|
||||
@ -123,14 +124,16 @@ public:
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_C1;
|
||||
int64_t batch_stride_C2;
|
||||
int64_t batch_stride_D;
|
||||
int64_t batch_stride_Vector;
|
||||
int64_t batch_stride_Tensor;
|
||||
|
||||
typename LayoutA::Stride::Index lda;
|
||||
typename LayoutB::Stride::Index ldb;
|
||||
typename LayoutC::Stride::Index ldc;
|
||||
typename LayoutC::Stride::Index ldc1;
|
||||
typename LayoutC::Stride::Index ldc2;
|
||||
typename LayoutC::Stride::Index ldd;
|
||||
typename LayoutC::Stride::Index ldr;
|
||||
typename LayoutC::Stride::Index ldt;
|
||||
@ -138,11 +141,62 @@ public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Arguments():
|
||||
mode(GemmUniversalMode::kGemm),
|
||||
batch_count(1),
|
||||
ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr) { }
|
||||
|
||||
Arguments():
|
||||
mode(GemmUniversalMode::kGemm),
|
||||
batch_count(1),
|
||||
ptr_A(nullptr), ptr_B(nullptr), ptr_C1(nullptr), ptr_C2(nullptr), ptr_D(nullptr) { }
|
||||
|
||||
/// constructs an arguments structure
|
||||
Arguments(
|
||||
GemmUniversalMode mode,
|
||||
GemmCoord problem_size,
|
||||
int batch_count,
|
||||
typename EpilogueOutputOp::Params epilogue,
|
||||
void const * ptr_A,
|
||||
void const * ptr_B,
|
||||
void const * ptr_C1,
|
||||
void const * ptr_C2,
|
||||
void * ptr_D,
|
||||
void * ptr_Vector,
|
||||
void * ptr_Tensor,
|
||||
int64_t batch_stride_A,
|
||||
int64_t batch_stride_B,
|
||||
int64_t batch_stride_C1,
|
||||
int64_t batch_stride_C2,
|
||||
int64_t batch_stride_D,
|
||||
int64_t batch_stride_Vector,
|
||||
int64_t batch_stride_Tensor,
|
||||
typename LayoutA::Stride::Index lda,
|
||||
typename LayoutB::Stride::Index ldb,
|
||||
typename LayoutC::Stride::Index ldc1,
|
||||
typename LayoutC::Stride::Index ldc2,
|
||||
typename LayoutC::Stride::Index ldd,
|
||||
typename LayoutC::Stride::Index ldr,
|
||||
typename LayoutC::Stride::Index ldt
|
||||
):
|
||||
mode(mode),
|
||||
problem_size(problem_size),
|
||||
batch_count(batch_count),
|
||||
epilogue(epilogue),
|
||||
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C1(ptr_C1), ptr_C2(ptr_C2), ptr_D(ptr_D),
|
||||
ptr_Vector(ptr_Vector),
|
||||
ptr_Tensor(ptr_Tensor),
|
||||
batch_stride_A(batch_stride_A),
|
||||
batch_stride_B(batch_stride_B),
|
||||
batch_stride_C1(batch_stride_C1),
|
||||
batch_stride_C2(batch_stride_C2),
|
||||
batch_stride_D(batch_stride_D),
|
||||
batch_stride_Vector(batch_stride_Vector),
|
||||
batch_stride_Tensor(batch_stride_Tensor),
|
||||
lda(lda), ldb(ldb), ldc1(ldc1), ldc2(ldc2), ldd(ldd), ldr(ldr), ldt(ldt)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Arguments::Arguments() - problem_size: " << problem_size);
|
||||
CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction);
|
||||
CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor);
|
||||
CUTLASS_TRACE_HOST(" ldr: " << this->ldr);
|
||||
CUTLASS_TRACE_HOST(" ldt: " << this->ldt);
|
||||
}
|
||||
|
||||
/// constructs an arguments structure
|
||||
Arguments(
|
||||
@ -168,33 +222,18 @@ public:
|
||||
typename LayoutC::Stride::Index ldd,
|
||||
typename LayoutC::Stride::Index ldr,
|
||||
typename LayoutC::Stride::Index ldt
|
||||
):
|
||||
mode(mode),
|
||||
problem_size(problem_size),
|
||||
batch_count(batch_count),
|
||||
epilogue(epilogue),
|
||||
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
|
||||
ptr_Vector(ptr_Vector),
|
||||
ptr_Tensor(ptr_Tensor),
|
||||
batch_stride_A(batch_stride_A),
|
||||
batch_stride_B(batch_stride_B),
|
||||
batch_stride_C(batch_stride_C),
|
||||
batch_stride_D(batch_stride_D),
|
||||
batch_stride_Vector(batch_stride_Vector),
|
||||
batch_stride_Tensor(batch_stride_Tensor),
|
||||
lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), ldr(ldr), ldt(ldt)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Arguments::Arguments() - problem_size: " << problem_size);
|
||||
CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction);
|
||||
CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor);
|
||||
CUTLASS_TRACE_HOST(" ldr: " << this->ldr);
|
||||
CUTLASS_TRACE_HOST(" ldt: " << this->ldt);
|
||||
}
|
||||
): Arguments(
|
||||
mode, problem_size, batch_count, epilogue,
|
||||
ptr_A, ptr_B, ptr_C1, nullptr, ptr_D, ptr_Vector, ptr_Tensor,
|
||||
batch_stride_A, batch_stride_B, batch_stride_C, 0, batch_stride_D,
|
||||
batch_stride_Vector, batch_stride_Tensor,
|
||||
lda, ldb, ldc, 0, ldd, ldr, ldt) {}
|
||||
|
||||
|
||||
/// Returns arguments for the transposed problem
|
||||
Arguments transposed_problem() const {
|
||||
Arguments args(*this);
|
||||
|
||||
|
||||
std::swap(args.problem_size.m(), args.problem_size.n());
|
||||
std::swap(args.ptr_A, args.ptr_B);
|
||||
std::swap(args.lda, args.ldb);
|
||||
@ -217,10 +256,11 @@ public:
|
||||
|
||||
typename Mma::IteratorA::Params params_A;
|
||||
typename Mma::IteratorB::Params params_B;
|
||||
typename Epilogue::OutputTileIterator::Params params_C;
|
||||
typename Epilogue::OutputTileIterator::Params params_C1;
|
||||
typename Epilogue::OutputTileIterator::Params params_C2;
|
||||
typename Epilogue::OutputTileIterator::Params params_D;
|
||||
typename Epilogue::TensorTileIterator::Params params_Tensor;
|
||||
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
|
||||
@ -230,9 +270,10 @@ public:
|
||||
|
||||
void * ptr_A;
|
||||
void * ptr_B;
|
||||
void * ptr_C;
|
||||
void * ptr_C1;
|
||||
void * ptr_C2;
|
||||
void * ptr_D;
|
||||
|
||||
|
||||
void * ptr_Vector;
|
||||
typename LayoutC::Stride::Index ldr;
|
||||
|
||||
@ -240,7 +281,8 @@ public:
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_C1;
|
||||
int64_t batch_stride_C2;
|
||||
int64_t batch_stride_D;
|
||||
int64_t batch_stride_Vector;
|
||||
int64_t batch_stride_Tensor;
|
||||
@ -256,21 +298,24 @@ public:
|
||||
swizzle_log_tile(0),
|
||||
params_A(0),
|
||||
params_B(0),
|
||||
params_C(0),
|
||||
params_C1(0),
|
||||
params_C2(0),
|
||||
params_D(0),
|
||||
batch_count(0),
|
||||
gemm_k_size(0),
|
||||
mode(cutlass::gemm::GemmUniversalMode::kGemm),
|
||||
ptr_A(nullptr),
|
||||
ptr_B(nullptr),
|
||||
ptr_C(nullptr),
|
||||
ptr_C1(nullptr),
|
||||
ptr_C2(nullptr),
|
||||
ptr_D(nullptr),
|
||||
ptr_Vector(nullptr),
|
||||
ldr(0),
|
||||
ptr_Tensor(nullptr),
|
||||
batch_stride_A(0),
|
||||
batch_stride_B(0),
|
||||
batch_stride_C(0),
|
||||
batch_stride_C1(0),
|
||||
batch_stride_C2(0),
|
||||
batch_stride_D(0),
|
||||
batch_stride_Vector(0),
|
||||
batch_stride_Tensor(0),
|
||||
@ -288,7 +333,8 @@ public:
|
||||
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
|
||||
params_A(args.lda),
|
||||
params_B(args.ldb),
|
||||
params_C(args.ldc),
|
||||
params_C1(args.ldc1),
|
||||
params_C2(args.ldc2),
|
||||
params_D(args.ldd),
|
||||
params_Tensor(args.ldt),
|
||||
output_op(args.epilogue),
|
||||
@ -297,15 +343,17 @@ public:
|
||||
gemm_k_size(gemm_k_size),
|
||||
ptr_A(const_cast<void *>(args.ptr_A)),
|
||||
ptr_B(const_cast<void *>(args.ptr_B)),
|
||||
ptr_C(const_cast<void *>(args.ptr_C)),
|
||||
ptr_C1(const_cast<void *>(args.ptr_C1)),
|
||||
ptr_C2(const_cast<void *>(args.ptr_C2)),
|
||||
ptr_D(args.ptr_D),
|
||||
ptr_Vector(args.ptr_Vector),
|
||||
ptr_Vector(args.ptr_Vector),
|
||||
ldr(args.ldr),
|
||||
ptr_Tensor(args.ptr_Tensor),
|
||||
|
||||
batch_stride_A(args.batch_stride_A),
|
||||
batch_stride_B(args.batch_stride_B),
|
||||
batch_stride_C(args.batch_stride_C),
|
||||
batch_stride_C1(args.batch_stride_C1),
|
||||
batch_stride_C2(args.batch_stride_C2),
|
||||
batch_stride_D(args.batch_stride_D),
|
||||
batch_stride_Vector(args.batch_stride_Vector),
|
||||
batch_stride_Tensor(args.batch_stride_Tensor),
|
||||
@ -326,7 +374,8 @@ public:
|
||||
|
||||
ptr_A = const_cast<void *>(args.ptr_A);
|
||||
ptr_B = const_cast<void *>(args.ptr_B);
|
||||
ptr_C = const_cast<void *>(args.ptr_C);
|
||||
ptr_C1 = const_cast<void *>(args.ptr_C1);
|
||||
ptr_C2 = const_cast<void *>(args.ptr_C2);
|
||||
ptr_D = args.ptr_D;
|
||||
|
||||
ptr_Vector = args.ptr_Vector;
|
||||
@ -335,7 +384,8 @@ public:
|
||||
|
||||
batch_stride_A = args.batch_stride_A;
|
||||
batch_stride_B = args.batch_stride_B;
|
||||
batch_stride_C = args.batch_stride_C;
|
||||
batch_stride_C1 = args.batch_stride_C1;
|
||||
batch_stride_C2 = args.batch_stride_C2;
|
||||
batch_stride_D = args.batch_stride_D;
|
||||
batch_stride_Vector = args.batch_stride_Vector;
|
||||
batch_stride_Tensor = args.batch_stride_Tensor;
|
||||
@ -364,7 +414,7 @@ public:
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GemmWithFusedEpilogue() { }
|
||||
GemmWithFusedEpilogue() { }
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(
|
||||
@ -458,7 +508,7 @@ public:
|
||||
int offset_k = 0;
|
||||
int problem_size_k = params.problem_size.k();
|
||||
|
||||
ElementA *ptr_A = static_cast<ElementA *>(params.ptr_A);
|
||||
ElementA *ptr_A = static_cast<ElementA *>(params.ptr_A);
|
||||
ElementB *ptr_B = static_cast<ElementB *>(params.ptr_B);
|
||||
|
||||
|
||||
@ -466,12 +516,12 @@ public:
|
||||
//
|
||||
// Fetch pointers based on mode.
|
||||
//
|
||||
if (params.mode == GemmUniversalMode::kGemm ||
|
||||
if (params.mode == GemmUniversalMode::kGemm ||
|
||||
params.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
|
||||
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) {
|
||||
|
||||
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
|
||||
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
|
||||
}
|
||||
|
||||
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
|
||||
@ -537,10 +587,10 @@ public:
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(
|
||||
gemm_k_iterations,
|
||||
accumulators,
|
||||
iterator_A,
|
||||
iterator_B,
|
||||
gemm_k_iterations,
|
||||
accumulators,
|
||||
iterator_A,
|
||||
iterator_B,
|
||||
accumulators);
|
||||
|
||||
//
|
||||
@ -563,28 +613,36 @@ public:
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
ElementC *ptr_C = static_cast<ElementC *>(params.ptr_C);
|
||||
ElementC *ptr_C1 = static_cast<ElementC *>(params.ptr_C1);
|
||||
ElementC *ptr_C2 = static_cast<ElementC *>(params.ptr_C2);
|
||||
ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);
|
||||
typename Epilogue::ElementTensor *ptr_Tensor = static_cast<typename Epilogue::ElementTensor *>(params.ptr_Tensor);
|
||||
|
||||
// Define the reduction output pointer and move to the appropriate place
|
||||
typename Epilogue::ElementVector *ptr_Vector =
|
||||
typename Epilogue::ElementVector *ptr_Vector =
|
||||
static_cast<typename Epilogue::ElementVector *>(params.ptr_Vector);
|
||||
|
||||
//
|
||||
// Fetch pointers based on mode.
|
||||
//
|
||||
|
||||
|
||||
//
|
||||
// Special path when split-K not enabled.
|
||||
//
|
||||
//
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() == 1) {
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C(
|
||||
params.params_C,
|
||||
ptr_C,
|
||||
typename Epilogue::OutputTileIterator iterator_C1(
|
||||
params.params_C1,
|
||||
ptr_C1,
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
typename Epilogue::OutputTileIterator iterator_C2(
|
||||
params.params_C2,
|
||||
ptr_C2,
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
@ -610,9 +668,9 @@ public:
|
||||
|
||||
// Construct the epilogue
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
// Move to appropriate location for this output tile
|
||||
@ -625,7 +683,8 @@ public:
|
||||
ptr_Vector,
|
||||
iterator_D,
|
||||
accumulators,
|
||||
iterator_C,
|
||||
iterator_C1,
|
||||
iterator_C2,
|
||||
tensor_iterator,
|
||||
params.problem_size.mn(),
|
||||
threadblock_offset);
|
||||
@ -637,7 +696,7 @@ public:
|
||||
// Slower path when split-K or batching is needed
|
||||
//
|
||||
|
||||
|
||||
|
||||
#if SPLIT_K_ENABLED
|
||||
// Construct the semaphore.
|
||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||
@ -646,7 +705,7 @@ public:
|
||||
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
if (params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
|
||||
@ -658,7 +717,10 @@ public:
|
||||
ptr_D += threadblock_tile_offset.k() * params.batch_stride_D;
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kBatched) {
|
||||
ptr_C += threadblock_tile_offset.k() * params.batch_stride_C;
|
||||
ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C1;
|
||||
if (ptr_C2) {
|
||||
ptr_C2 += threadblock_tile_offset.k() * params.batch_stride_C2;
|
||||
}
|
||||
ptr_D += threadblock_tile_offset.k() * params.batch_stride_D;
|
||||
if (ptr_Tensor) {
|
||||
ptr_Tensor += threadblock_tile_offset.k() * params.batch_stride_Tensor;
|
||||
@ -668,7 +730,10 @@ public:
|
||||
}
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kArray) {
|
||||
ptr_C = static_cast<ElementC * const *>(params.ptr_C)[threadblock_tile_offset.k()];
|
||||
ptr_C1 = static_cast<ElementC * const *>(params.ptr_C1)[threadblock_tile_offset.k()];
|
||||
if (ptr_C2) {
|
||||
ptr_C2 = static_cast<ElementC * const *>(params.ptr_C2)[threadblock_tile_offset.k()];
|
||||
}
|
||||
ptr_D = static_cast<ElementC * const *>(params.ptr_D)[threadblock_tile_offset.k()];
|
||||
if (ptr_Tensor) {
|
||||
ptr_Tensor = static_cast<typename Epilogue::ElementTensor * const *>(params.ptr_Tensor)[threadblock_tile_offset.k()];
|
||||
@ -680,9 +745,16 @@ public:
|
||||
#endif
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C(
|
||||
params.params_C,
|
||||
ptr_C,
|
||||
typename Epilogue::OutputTileIterator iterator_C1(
|
||||
params.params_C1,
|
||||
ptr_C1,
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
typename Epilogue::OutputTileIterator iterator_C2(
|
||||
params.params_C2,
|
||||
ptr_C2,
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
@ -711,18 +783,18 @@ public:
|
||||
|
||||
// Construct the epilogue
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
#if SPLIT_K_ENABLED
|
||||
// Wait on the semaphore - this latency may have been covered by iterator construction
|
||||
if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
|
||||
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
||||
if (threadblock_tile_offset.k()) {
|
||||
iterator_C = iterator_D;
|
||||
iterator_C1 = iterator_D;
|
||||
}
|
||||
|
||||
semaphore.wait(threadblock_tile_offset.k());
|
||||
@ -744,7 +816,8 @@ public:
|
||||
: ptr_Vector,
|
||||
iterator_D,
|
||||
accumulators,
|
||||
iterator_C,
|
||||
iterator_C1,
|
||||
iterator_C2,
|
||||
tensor_iterator,
|
||||
params.problem_size.mn(),
|
||||
threadblock_offset);
|
||||
@ -754,7 +827,7 @@ public:
|
||||
//
|
||||
|
||||
#if SPLIT_K_ENABLED
|
||||
if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) {
|
||||
if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
int lock = 0;
|
||||
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
|
||||
@ -766,7 +839,7 @@ public:
|
||||
// Otherwise, the semaphore is incremented
|
||||
lock = threadblock_tile_offset.k() + 1;
|
||||
}
|
||||
|
||||
|
||||
semaphore.release(lock);
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -60,6 +60,13 @@ add_custom_target(
|
||||
test_unit_gemv_device
|
||||
)
|
||||
|
||||
add_custom_target(
|
||||
cutlass_test_unit_gemm_device_new
|
||||
DEPENDS
|
||||
cutlass_test_unit_gemm_device_exp
|
||||
cutlass_test_unit_gemm_device_new
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_simt
|
||||
|
||||
@ -68,7 +75,7 @@ cutlass_test_unit_add_executable(
|
||||
|
||||
simt_sgemm_nt_sm80.cu
|
||||
simt_sgemm_tn_sm80.cu
|
||||
|
||||
|
||||
simt_cgemm_nn_sm50.cu
|
||||
simt_cgemm_nt_sm50.cu
|
||||
simt_cgemm_tn_sm50.cu
|
||||
@ -163,7 +170,7 @@ cutlass_test_unit_add_executable(
|
||||
gemm_s4n_s4t_s4n_tensor_op_s32_sm75.cu
|
||||
gemm_s4t_s4n_s32t_tensor_op_s32_sm75.cu
|
||||
gemm_s4t_s4n_s32n_tensor_op_s32_sm75.cu
|
||||
gemm_s4t_s4n_s4n_tensor_op_s32_sm75.cu
|
||||
gemm_s4t_s4n_s4n_tensor_op_s32_sm75.cu
|
||||
gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu
|
||||
|
||||
gemm_b1t_b1n_s32t_tensor_op_s32_sm75.cu
|
||||
@ -213,7 +220,7 @@ cutlass_test_unit_add_executable(
|
||||
gemm_tf32t_tf32n_f32t_tensor_op_f32_sm80.cu
|
||||
gemm_tf32n_tf32t_f32t_tensor_op_f32_sm80.cu
|
||||
gemm_tf32n_tf32n_f32t_tensor_op_f32_sm80.cu
|
||||
gemm_tf32t_tf32t_f32t_tensor_op_f32_sm80.cu
|
||||
gemm_tf32t_tf32t_f32t_tensor_op_f32_sm80.cu
|
||||
gemm_universal_cf32n_cf32n_cf32n_tensor_op_f32_sm80.cu
|
||||
gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32_sm80.cu
|
||||
gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32_sm80.cu
|
||||
@ -232,7 +239,7 @@ cutlass_test_unit_add_executable(
|
||||
gemm_f64n_f64t_f64t_tensor_op_f64_sm80.cu
|
||||
gemm_f64t_f64n_f64t_tensor_op_f64_sm80.cu
|
||||
|
||||
gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu
|
||||
gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu
|
||||
gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu
|
||||
gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu
|
||||
gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm80.cu
|
||||
@ -253,7 +260,7 @@ cutlass_test_unit_add_executable(
|
||||
gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu
|
||||
gemm_s4t_s4n_s32n_tensor_op_s32_sm80.cu
|
||||
gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu
|
||||
gemm_s4t_s4n_s4n_tensor_op_s32_sm80.cu
|
||||
gemm_s4t_s4n_s4n_tensor_op_s32_sm80.cu
|
||||
gemm_s4t_s4n_s4t_tensor_op_s32_sm80.cu
|
||||
gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu
|
||||
gemm_b1t_b1n_s32t_tensor_op_s32_sm80.cu
|
||||
@ -326,8 +333,8 @@ cutlass_test_unit_add_executable(
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 4
|
||||
|
||||
gemm_planar_complex_f16_f16_f32_tensor_op_sm70.cu
|
||||
gemm_planar_complex_f16_f16_f32_tensor_op_sm75.cu
|
||||
gemm_planar_complex_f16_f16_f32_tensor_op_sm70.cu
|
||||
gemm_planar_complex_f16_f16_f32_tensor_op_sm75.cu
|
||||
gemm_planar_complex_f16_f16_f32_tensor_op_sm80.cu
|
||||
)
|
||||
|
||||
@ -403,7 +410,7 @@ add_dependencies(
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_gemm_with_fused_epilogue_tensorop
|
||||
|
||||
|
||||
gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu
|
||||
gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu
|
||||
|
||||
@ -430,10 +437,10 @@ cutlass_test_unit_add_executable(
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 4
|
||||
|
||||
## SYRK
|
||||
## SYRK
|
||||
# Syrk SM80 f64 tests
|
||||
syrk_f64n_f64t_tensor_op_f64_sm80.cu
|
||||
syrk_f64t_f64n_tensor_op_f64_sm80.cu
|
||||
syrk_f64n_f64t_tensor_op_f64_sm80.cu
|
||||
syrk_f64t_f64n_tensor_op_f64_sm80.cu
|
||||
|
||||
# Syrk SM80 f32 tests
|
||||
syrk_tf32n_f32t_tensor_op_f32_sm80.cu
|
||||
@ -452,7 +459,7 @@ cutlass_test_unit_add_executable(
|
||||
syrk_cf32n_cf32t_tensor_op_fast_f32_sm80.cu
|
||||
syrk_cf32n_cf32n_tensor_op_fast_f32_sm80.cu
|
||||
|
||||
## HERK
|
||||
## HERK
|
||||
# Herk SM80 complex f64 tests
|
||||
herk_cf64h_cf64n_tensor_op_f64_sm80.cu
|
||||
|
||||
@ -550,7 +557,7 @@ cutlass_test_unit_add_executable(
|
||||
hemm_cf64h_cf64n_cf64n_tensor_op_ls_f64_sm80.cu
|
||||
hemm_cf64h_cf64n_cf64n_tensor_op_rs_f64_sm80.cu
|
||||
hemm_cf64h_cf64n_cf64n_tensor_op_ls_f64_gaussian_sm80.cu
|
||||
|
||||
|
||||
# Hemm SM80 complex f32 tests
|
||||
hemm_cf32h_cf32n_tensor_op_f32_ls_sm80.cu
|
||||
hemm_cf32h_cf32n_tensor_op_f32_rs_sm80.cu
|
||||
@ -581,4 +588,10 @@ cutlass_test_unit_add_executable(
|
||||
her2k_cf64h_cf64n_tensor_op_f64_grouped_sm80.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_exp
|
||||
|
||||
gemm_broadcast_test.cu
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
474
test/unit/gemm/device/gemm_broadcast_test.cu
Normal file
474
test/unit/gemm/device/gemm_broadcast_test.cu
Normal file
@ -0,0 +1,474 @@
|
||||
#include <fstream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/functional.h"
|
||||
|
||||
#include "cutlass/gemm/kernel/default_gemm_with_broadcast.h"
|
||||
#include "cutlass/gemm/device/gemm_universal.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_with_broadcast.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_bias_relu.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_residual_block.h"
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_elementwise.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
|
||||
template<typename GemmElement, typename LayoutA, typename LayoutB, typename LayoutC>
|
||||
struct TestbedUtils {
|
||||
/// Initialization
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
cutlass::Distribution::Kind init_C;
|
||||
cutlass::Distribution::Kind init_D;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::HostTensor<GemmElement, LayoutA> tensor_A; // Input A
|
||||
cutlass::HostTensor<GemmElement, LayoutB> tensor_B; // Input B
|
||||
cutlass::HostTensor<GemmElement, LayoutC> tensor_C; // Input C
|
||||
cutlass::HostTensor<GemmElement, LayoutC> tensor_D1; // Input D
|
||||
cutlass::HostTensor<GemmElement, LayoutC> tensor_D2; // Input D
|
||||
cutlass::HostTensor<GemmElement, LayoutC> tensor_Y1; // Input Y
|
||||
cutlass::HostTensor<GemmElement, LayoutC> tensor_Y2; // Input Y
|
||||
cutlass::HostTensor<GemmElement, LayoutC> tensor_Y_ref;
|
||||
cutlass::HostTensor<GemmElement, LayoutC> tensor_Y_transpose_ref;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
TestbedUtils(
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_D_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), init_D(init_D_), seed(seed_) { }
|
||||
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
int bits_output = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
} else if (bits_output == 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
|
||||
cutlass::reference::host::BlockFillSequential(
|
||||
view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
// TODO: Implement the rest
|
||||
EXPECT_TRUE(false) << "Not implemented";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initializes data structures
|
||||
void initialize(cutlass::gemm::GemmCoord problem_size) {
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
//
|
||||
|
||||
tensor_A.resize(problem_size.mk());
|
||||
tensor_B.resize(problem_size.kn());
|
||||
tensor_C.resize({1, problem_size.n()});
|
||||
tensor_D1.resize(problem_size.mn());
|
||||
tensor_D2.resize(problem_size.mn());
|
||||
tensor_Y1.resize(problem_size.mn());
|
||||
tensor_Y2.resize(problem_size.mn());
|
||||
tensor_Y_ref.resize(problem_size.mn());
|
||||
tensor_Y_transpose_ref.resize(problem_size.mn());
|
||||
|
||||
EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019));
|
||||
EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018));
|
||||
EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017));
|
||||
EXPECT_TRUE(initialize_tensor(tensor_D1.host_view(), init_D, seed + 2016));
|
||||
EXPECT_TRUE(initialize_tensor(tensor_D2.host_view(), init_D, seed + 2015));
|
||||
EXPECT_TRUE(initialize_tensor(tensor_Y1.host_view(), cutlass::Distribution::AllZeros, 0));
|
||||
EXPECT_TRUE(initialize_tensor(tensor_Y2.host_view(), cutlass::Distribution::AllZeros, 0));
|
||||
EXPECT_TRUE(initialize_tensor(tensor_Y_ref.host_view(), cutlass::Distribution::AllZeros, 0));
|
||||
EXPECT_TRUE(initialize_tensor(tensor_Y_transpose_ref.host_view(), cutlass::Distribution::AllZeros, 0));
|
||||
|
||||
// It is possible to randomly initialize to all zeros, so override this with non-zeros
|
||||
// in the upper left corner of each operand.
|
||||
tensor_A.host_view().at({0, 0}) = GemmElement(1);
|
||||
tensor_B.host_view().at({0, 0}) = GemmElement(1);
|
||||
tensor_C.host_view().at({0, 0}) = GemmElement(1);
|
||||
tensor_D1.host_view().at({0, 0}) = GemmElement(1);
|
||||
tensor_D2.host_view().at({0, 0}) = GemmElement(1);
|
||||
|
||||
tensor_A.sync_device();
|
||||
tensor_B.sync_device();
|
||||
tensor_C.sync_device();
|
||||
tensor_D1.sync_device();
|
||||
tensor_D2.sync_device();
|
||||
}
|
||||
|
||||
/// Compares computed reference with device reference and outputs to a file if incorrect
|
||||
bool compare_reference(
|
||||
cutlass::gemm::GemmCoord problem_size, cutlass::HostTensor<GemmElement, LayoutC>& tensor_Y_ref, cutlass::HostTensor<GemmElement, LayoutC>& tensor_Y) {
|
||||
|
||||
tensor_Y_ref.sync_host();
|
||||
tensor_Y.sync_host();
|
||||
|
||||
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0);
|
||||
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0);
|
||||
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0);
|
||||
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
|
||||
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D2.host_view()), 0);
|
||||
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Y_ref.host_view()), 0);
|
||||
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Y.host_view()), 0);
|
||||
|
||||
bool passed = true;
|
||||
float norm_diff = 0;
|
||||
|
||||
norm_diff = cutlass::reference::host::TensorNormDiff(tensor_Y_ref.host_view(), tensor_Y.host_view(), float());
|
||||
passed = (norm_diff <= 0.1f);
|
||||
EXPECT_LT(norm_diff, 0.1f) << " tensor_Y is incorrect";
|
||||
|
||||
std::ofstream file("errors_testbed_gemm_broadcast_new.txt");
|
||||
|
||||
|
||||
file
|
||||
<< "problem: " << problem_size << "\n\n";
|
||||
|
||||
file
|
||||
<< "capacity: \n"
|
||||
<< "A: " << tensor_A.capacity()
|
||||
<< "\nB: " << tensor_B.capacity()
|
||||
<< "\nC: " << tensor_C.capacity()
|
||||
<< "\nD1: " << tensor_D1.capacity()
|
||||
<< "\nD2: " << tensor_D2.capacity()
|
||||
<< "\nY: " << tensor_Y.capacity()
|
||||
<< "\n\n"
|
||||
<< "\nY_ref: " << tensor_Y_ref.capacity()
|
||||
<< "\n\n";
|
||||
file
|
||||
<< "A =\n" << tensor_A.host_view()
|
||||
<< "\n\nB =\n" << tensor_B.host_view()
|
||||
<< "\n\nC =\n" << tensor_C.host_view()
|
||||
<< "\n\nD1 =\n" << tensor_D1.host_view()
|
||||
<< "\n\nD2 =\n" << tensor_D2.host_view()
|
||||
<< "\n\nY =\n" << tensor_Y.host_view()
|
||||
<< "\n\nY_ref =\n" << tensor_Y_ref.host_view();
|
||||
|
||||
return passed;
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
TEST(SM80_Device_GemmWithBroadcast_CUSTOM_f16n_f16n_f16n_tensor_op_f32, 128x128_32x3_64x64x32_16x8x16) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
const int M = 1024;
|
||||
const int K = 10240;
|
||||
const int N = 512;
|
||||
cutlass::gemm::GemmCoord problem_size{M, N, K};
|
||||
TestbedUtils<cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor> utils;
|
||||
utils.initialize(problem_size);
|
||||
|
||||
{
|
||||
// Create reference Gemm.
|
||||
using Gemm = cutlass::gemm::device::GemmUniversal<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<32, 32, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
|
||||
2,
|
||||
1,
|
||||
1>;
|
||||
|
||||
//
|
||||
// Initialize the GEMM operator
|
||||
//
|
||||
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_size,
|
||||
1 /* batch_count */,
|
||||
{cutlass::half_t(1) /* alpha */, cutlass::half_t(1) /* beta */},
|
||||
utils.tensor_A.device_data(),
|
||||
utils.tensor_B.device_data(),
|
||||
utils.tensor_C.device_data(),
|
||||
utils.tensor_Y_ref.device_data(),
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
K,
|
||||
K,
|
||||
0,
|
||||
N,
|
||||
};
|
||||
|
||||
Gemm gemm_op;
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
cutlass::Status status = gemm_op.initialize(arguments, workspace.get());
|
||||
EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status);
|
||||
|
||||
//
|
||||
// Run the GEMM
|
||||
//
|
||||
|
||||
status = gemm_op();
|
||||
EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status);
|
||||
}
|
||||
|
||||
{
|
||||
cutlass::gemm::GemmCoord problem_size{N, M, K};
|
||||
// Create another reference Gemm.
|
||||
using Gemm = cutlass::gemm::device::GemmUniversal<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<32, 32, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
|
||||
2,
|
||||
1,
|
||||
1>;
|
||||
|
||||
//
|
||||
// Initialize the GEMM operator
|
||||
//
|
||||
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_size,
|
||||
1 /* batch_count */,
|
||||
{cutlass::half_t(1) /* alpha */, cutlass::half_t(1) /* beta */},
|
||||
utils.tensor_B.device_data(),
|
||||
utils.tensor_A.device_data(),
|
||||
utils.tensor_C.device_data(),
|
||||
utils.tensor_Y_transpose_ref.device_data(),
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
K,
|
||||
K,
|
||||
0,
|
||||
N
|
||||
};
|
||||
|
||||
Gemm gemm_op;
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
cutlass::Status status = gemm_op.initialize(arguments, workspace.get());
|
||||
EXPECT_TRUE(status == cutlass::Status::kSuccess) << int(status);
|
||||
|
||||
//
|
||||
// Run the GEMM
|
||||
//
|
||||
status = gemm_op();
|
||||
EXPECT_TRUE(status == cutlass::Status::kSuccess) << int(status);
|
||||
}
|
||||
|
||||
utils.compare_reference(problem_size, utils.tensor_Y_ref, utils.tensor_Y_transpose_ref);
|
||||
|
||||
{
|
||||
// Create GemmWithBroadcast.
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalWithBroadcast<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<32, 32, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombinationResidualBlock<
|
||||
ElementOutput, ElementAccumulator, ElementAccumulator,
|
||||
ElementAccumulator, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
cutlass::epilogue::thread::Identity, cutlass::multiplies, cutlass::epilogue::thread::Identity>,
|
||||
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
|
||||
2,
|
||||
1,
|
||||
1>;
|
||||
|
||||
//
|
||||
// Initialize the GEMM operator
|
||||
//
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_size,
|
||||
1 /* batch_count */,
|
||||
{cutlass::half_t(1) /* alpha */, cutlass::half_t(1) /* beta */},
|
||||
utils.tensor_A.device_data(),
|
||||
utils.tensor_B.device_data(),
|
||||
utils.tensor_D1.device_data(),
|
||||
utils.tensor_Y1.device_data(),
|
||||
utils.tensor_C.device_data(),
|
||||
nullptr,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
K,
|
||||
K,
|
||||
N,
|
||||
N,
|
||||
0,
|
||||
0
|
||||
};
|
||||
|
||||
Gemm gemm_op;
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
cutlass::Status status = gemm_op.initialize(arguments, workspace.get());
|
||||
EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status);
|
||||
|
||||
//
|
||||
// Run the GEMM
|
||||
//
|
||||
|
||||
status = gemm_op();
|
||||
EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status);
|
||||
}
|
||||
|
||||
utils.tensor_Y_ref.sync_host();
|
||||
cutlass::reference::host::TensorMul(utils.tensor_Y_ref.host_view(), utils.tensor_D1.host_view());
|
||||
utils.tensor_Y_ref.sync_device();
|
||||
utils.compare_reference(problem_size, utils.tensor_Y_ref, utils.tensor_Y1);
|
||||
|
||||
{
|
||||
// Create GemmWithBroadcast.
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalWithBroadcast<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<32, 32, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombinationResidualBlock<
|
||||
ElementOutput, ElementAccumulator, ElementAccumulator,
|
||||
ElementAccumulator, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
cutlass::epilogue::thread::Identity, cutlass::multiplies, cutlass::epilogue::thread::Identity, cutlass::plus>,
|
||||
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
|
||||
2,
|
||||
1,
|
||||
1>;
|
||||
|
||||
//
|
||||
// Initialize the GEMM operator
|
||||
//
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_size,
|
||||
1 /* batch_count */,
|
||||
{cutlass::half_t(1) /* alpha */, cutlass::half_t(1) /* beta */},
|
||||
utils.tensor_A.device_data(),
|
||||
utils.tensor_B.device_data(),
|
||||
utils.tensor_D1.device_data(),
|
||||
utils.tensor_D2.device_data(),
|
||||
utils.tensor_Y2.device_data(),
|
||||
utils.tensor_C.device_data(),
|
||||
nullptr,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
K,
|
||||
K,
|
||||
N,
|
||||
N,
|
||||
N,
|
||||
0,
|
||||
0
|
||||
};
|
||||
|
||||
Gemm gemm_op;
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
cutlass::Status status = gemm_op.initialize(arguments, workspace.get());
|
||||
EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status);
|
||||
|
||||
//
|
||||
// Run the GEMM
|
||||
//
|
||||
|
||||
status = gemm_op();
|
||||
EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status);
|
||||
}
|
||||
|
||||
utils.tensor_Y_ref.sync_host();
|
||||
cutlass::reference::host::TensorAdd(utils.tensor_Y_ref.host_view(), utils.tensor_D2.host_view());
|
||||
utils.tensor_Y_ref.sync_device();
|
||||
utils.compare_reference(problem_size, utils.tensor_Y_ref, utils.tensor_Y2);
|
||||
}
|
||||
|
||||
#endif
|
||||
@ -53,7 +53,7 @@ namespace detail {
|
||||
|
||||
/// Helper to apply a binary operator in place
|
||||
template <
|
||||
typename ElementA,
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
@ -68,8 +68,8 @@ struct TensorFuncBinaryOp {
|
||||
|
||||
/// View of left-hand-side tensor
|
||||
TensorView<ElementD, LayoutD> view_d;
|
||||
TensorRef<ElementA, LayoutA> ref_a;
|
||||
TensorRef<ElementB, LayoutB> ref_b;
|
||||
TensorRef<ElementA, LayoutA> view_a;
|
||||
TensorRef<ElementB, LayoutB> view_b;
|
||||
BinaryFunc func;
|
||||
|
||||
//
|
||||
@ -82,8 +82,8 @@ struct TensorFuncBinaryOp {
|
||||
/// Constructor
|
||||
TensorFuncBinaryOp(
|
||||
TensorView<ElementD, LayoutD> const & view_d_,
|
||||
TensorRef<ElementA, LayoutA> const & ref_a_,
|
||||
TensorRef<ElementB, LayoutB> const & ref_b_,
|
||||
TensorRef<ElementA, LayoutA> const & view_a_,
|
||||
TensorRef<ElementB, LayoutB> const & view_b_,
|
||||
BinaryFunc func = BinaryFunc()
|
||||
):
|
||||
view_d(view_d_), view_a(view_a_), view_b(view_b_), func(func) { }
|
||||
@ -118,7 +118,7 @@ void TensorAdd(
|
||||
) {
|
||||
|
||||
detail::TensorFuncBinaryOp<
|
||||
ElementD,
|
||||
ElementD,
|
||||
LayoutD,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
@ -129,7 +129,7 @@ void TensorAdd(
|
||||
|
||||
TensorForEach(
|
||||
d.extent(),
|
||||
func);
|
||||
func);
|
||||
}
|
||||
|
||||
/// Adds a tensor in place: d = d .+ a
|
||||
@ -164,7 +164,7 @@ void TensorSub(
|
||||
) {
|
||||
|
||||
detail::TensorFuncBinaryOp<
|
||||
ElementD,
|
||||
ElementD,
|
||||
LayoutD,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
@ -191,7 +191,7 @@ void TensorSub(
|
||||
TensorView<ElementD, LayoutD> d, ///< destination tensor view
|
||||
TensorRef<ElementA, LayoutA> a ///< A tensor reference
|
||||
) {
|
||||
|
||||
|
||||
TensorSub(d, d, a);
|
||||
}
|
||||
|
||||
@ -211,9 +211,9 @@ void TensorMul(
|
||||
TensorRef<ElementA, LayoutA> a, ///< A tensor reference
|
||||
TensorRef<ElementB, LayoutB> b ///< B tensor reference
|
||||
) {
|
||||
|
||||
|
||||
detail::TensorFuncBinaryOp<
|
||||
ElementD,
|
||||
ElementD,
|
||||
LayoutD,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
@ -257,9 +257,9 @@ void TensorDiv(
|
||||
TensorRef<ElementA, LayoutA> a, ///< A tensor reference
|
||||
TensorRef<ElementB, LayoutB> b ///< B tensor reference
|
||||
) {
|
||||
|
||||
|
||||
detail::TensorFuncBinaryOp<
|
||||
ElementD,
|
||||
ElementD,
|
||||
LayoutD,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
@ -284,7 +284,7 @@ void TensorDiv(
|
||||
TensorView<ElementD, LayoutD> d, ///< destination tensor view
|
||||
TensorRef<ElementA, LayoutA> a ///< A tensor reference
|
||||
) {
|
||||
TensorMul(d, d, a);
|
||||
TensorDiv(d, d, a);
|
||||
}
|
||||
|
||||
|
||||
@ -304,15 +304,15 @@ void TensorModulus(
|
||||
TensorRef<ElementA, LayoutA> a, ///< A tensor reference
|
||||
TensorRef<ElementB, LayoutB> b ///< B tensor reference
|
||||
) {
|
||||
|
||||
|
||||
detail::TensorFuncBinaryOp<
|
||||
ElementD,
|
||||
ElementD,
|
||||
LayoutD,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
cutlass::modulus<ElementD>
|
||||
cutlass::divides<ElementD>
|
||||
> func(d, a, b);
|
||||
|
||||
TensorForEach(
|
||||
@ -331,7 +331,7 @@ void TensorModulus(
|
||||
TensorView<ElementD, LayoutD> d, ///< destination tensor view
|
||||
TensorRef<ElementA, LayoutA> a ///< A tensor reference
|
||||
) {
|
||||
TensorMul(d, d, a);
|
||||
TensorModulus(d, d, a);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
Reference in New Issue
Block a user