Replaced GoogleTest copy with submodule. Added updates to support intra-threadblock reductions. Added tests for same.
This commit is contained in:
@ -33,7 +33,7 @@
|
||||
|
||||
#define CUTLASS_MAJOR 1
|
||||
#define CUTLASS_MINOR 0
|
||||
#define CUTLASS_PATCH 0
|
||||
#define CUTLASS_PATCH 1
|
||||
#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH)
|
||||
|
||||
#ifdef __NVCC__
|
||||
|
||||
@ -184,7 +184,7 @@ struct FragmentIterator {
|
||||
/// The shape of the the fragment.
|
||||
typedef typename ShapeMul<Iterations, Shape<1, 1, 1, kElementsPerAccess> >::Shape FragmentShape;
|
||||
/// The linear strides for iterations.
|
||||
typedef typename ShapeStrides<FragmentShape>::Shape Strides;
|
||||
typedef typename ShapeStrides<FragmentShape, kElementsPerAccess>::Shape Strides;
|
||||
|
||||
/// Ctor.
|
||||
template <typename OtherFragment_>
|
||||
@ -242,7 +242,7 @@ struct FragmentConstIterator {
|
||||
/// The shape of the the fragment.
|
||||
typedef typename ShapeMul<Iterations, Shape<1, 1, 1, kElementsPerAccess> >::Shape FragmentShape;
|
||||
/// The linear strides for iterations.
|
||||
typedef typename ShapeStrides<FragmentShape>::Shape IterationsStrides;
|
||||
typedef typename ShapeStrides<FragmentShape, kElementsPerAccess>::Shape IterationsStrides;
|
||||
|
||||
/// Ctor.
|
||||
template <typename OtherFragment_>
|
||||
|
||||
@ -49,21 +49,29 @@ struct FragmentMultiplyAdd {
|
||||
CUTLASS_DEVICE FragmentMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b.
|
||||
template <typename Fragment_>
|
||||
CUTLASS_DEVICE void multiply(Scalar_ a, Fragment_ const& b, Fragment_& d) {
|
||||
for (int j = 0; j < Fragment_::kElements; ++j) {
|
||||
d[j] = a * b[j];
|
||||
template <typename FragmentB_, typename FragmentCd_>
|
||||
CUTLASS_DEVICE void multiply(Scalar_ a, FragmentB_ const& b, FragmentCd_& d) {
|
||||
int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
|
||||
for (int j = 0; j < FragmentCd_::kElements; ++j) {
|
||||
d[j] = a * b[j * kReduction + 0];
|
||||
for (int k = 1; k < kReduction; ++k) {
|
||||
d[j] += a * b[j * kReduction + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Multiply : d = a*b + c.
|
||||
template <typename Fragment_>
|
||||
template <typename FragmentB_, typename FragmentCd_>
|
||||
CUTLASS_DEVICE void multiply_add(Scalar_ a,
|
||||
Fragment_ const& b,
|
||||
Fragment_ const& c,
|
||||
Fragment_& d) {
|
||||
for (int j = 0; j < Fragment_::kElements; ++j) {
|
||||
d[j] = a * b[j] + c[j];
|
||||
FragmentB_ const& b,
|
||||
FragmentCd_ const& c,
|
||||
FragmentCd_& d) {
|
||||
int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
|
||||
for (int j = 0; j < FragmentCd_::kElements; ++j) {
|
||||
d[j] = a * b[j * kReduction + 0] + c[j];
|
||||
for (int k = 1; k < kReduction; ++k) {
|
||||
d[j] += a * b[j * kReduction + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -74,7 +82,7 @@ struct FragmentMultiplyAdd {
|
||||
template <>
|
||||
struct FragmentMultiplyAdd<half> {
|
||||
/// The shape of the instruction.
|
||||
typedef Shape<1, 1, 1, 1> InstructionShape;
|
||||
typedef Shape<1, 1, 2, 1> InstructionShape;
|
||||
/// The type for A.
|
||||
typedef half ScalarA;
|
||||
/// The type for B.
|
||||
@ -86,38 +94,48 @@ struct FragmentMultiplyAdd<half> {
|
||||
CUTLASS_DEVICE FragmentMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b.
|
||||
template <typename Fragment_>
|
||||
CUTLASS_DEVICE void multiply(half a, Fragment_ const& b, Fragment_& d) {
|
||||
template <typename FragmentB_, typename FragmentCd_>
|
||||
CUTLASS_DEVICE void multiply(half a, FragmentB_ const& b, FragmentCd_& d) {
|
||||
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
|
||||
|
||||
// Assemble a half2 from a.
|
||||
__half2 const a_half2 = __half2half2(a);
|
||||
// The input.
|
||||
__half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]);
|
||||
// The output.
|
||||
__half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
|
||||
|
||||
// Assemble a half2 from a.
|
||||
__half2 const a_half2 = __half2half2(a);
|
||||
|
||||
for (int i = 0; i < Fragment_::kElements / 2; ++i) {
|
||||
d_half2[i] = __hmul2(a_half2, b_half2[i]);
|
||||
int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
|
||||
for (int j = 0; j < FragmentCd_::kElements / 2; ++j) {
|
||||
d_half2[j] = __hmul2(a_half2, b_half2[j * kReduction + 0]);
|
||||
for (int k = 1; k < kReduction; ++k) {
|
||||
d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + k], d_half2[j]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Multiply : d = a*b + c.
|
||||
template <typename Fragment_>
|
||||
CUTLASS_DEVICE void multiply_add(half a, Fragment_ const& b, Fragment_ const& c, Fragment_& d) {
|
||||
template <typename FragmentB_, typename FragmentCd_>
|
||||
CUTLASS_DEVICE void multiply_add(half a,
|
||||
FragmentB_ const& b,
|
||||
FragmentCd_ const& c,
|
||||
FragmentCd_& d) {
|
||||
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
|
||||
// Assemble a half2 from a.
|
||||
__half2 const a_half2 = __half2half2(a);
|
||||
// The inputs.
|
||||
__half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]);
|
||||
__half2 const* c_half2 = reinterpret_cast<__half2 const*>(&c[0]);
|
||||
// The output.
|
||||
__half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
|
||||
|
||||
// Assemble a half2 from a.
|
||||
__half2 const a_half2 = __half2half2(a);
|
||||
|
||||
for (int i = 0; i < Fragment_::kElements / 2; ++i) {
|
||||
d_half2[i] = __hfma2(a_half2, b_half2[i], c_half2[i]);
|
||||
int const kReduction = (FragmentB_::kElements / FragmentCd_::kElements);
|
||||
for (int j = 0; j < FragmentCd_::kElements / 2; ++j) {
|
||||
d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + 0], c_half2[j]);
|
||||
for (int k = 1; k < kReduction; ++k) {
|
||||
d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + k], d_half2[j]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -39,6 +39,8 @@ struct ClearAccumulators {
|
||||
/// The shared storage.
|
||||
struct SharedStorage {};
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE ClearAccumulators() {}
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE ClearAccumulators(SharedStorage& shared_storage) {}
|
||||
|
||||
|
||||
@ -40,7 +40,7 @@ namespace gemm {
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Gemm_>
|
||||
__global__ void gemm_kernel(typename Gemm_::Params params) {
|
||||
__global__ /*__launch_bounds__(Gemm_::kThreads)*/ void gemm_kernel(typename Gemm_::Params params) {
|
||||
// Declare shared memory.
|
||||
__shared__ typename Gemm_::SharedStorage shared_storage;
|
||||
|
||||
@ -193,6 +193,71 @@ struct Gemm {
|
||||
CUTLASS_DEVICE Gemm(Params const& params_, SharedStorage& shared_storage_)
|
||||
: params(params_), shared_storage(shared_storage_) {}
|
||||
|
||||
/// Consume a single iteration of the loop.
|
||||
template <bool kIsLastIteration>
|
||||
CUTLASS_DEVICE void consume_tile(typename Traits::GlobalLoadStream& global_stream,
|
||||
typename Traits::SharedLoadStream& shared_load_stream,
|
||||
typename Traits::MultiplyAdd::Accumulators& accumulators,
|
||||
Index outer_k) {
|
||||
// If that's the last "load iteration" update the predicates.
|
||||
if (!kIsLastIteration) {
|
||||
global_stream.move_to_residue<false>(outer_k);
|
||||
}
|
||||
|
||||
// Load data for the next iteration of the main loop.
|
||||
if (!kIsLastIteration) {
|
||||
global_stream.copy();
|
||||
}
|
||||
|
||||
// The unrolling steps for the main loop.
|
||||
int const kUnrollingSteps =
|
||||
Traits::MultiplyAdd::AccumulatorsPerWarp::kD / Traits::MultiplyAdd::InstructionShape::kD;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int step = 0; step < kUnrollingSteps - 1; ++step) {
|
||||
// Trigger the copy from shared memory for the next A/B values.
|
||||
shared_load_stream.copy(step + 1);
|
||||
// Make sure the values are available for the current iteration to do the multiply-add.
|
||||
shared_load_stream.commit(step);
|
||||
|
||||
// Do the math on the fragments of the current iteration.
|
||||
typename Traits::MultiplyAdd multiply_add;
|
||||
multiply_add.multiply_add(shared_load_stream.fragment_a(step),
|
||||
shared_load_stream.fragment_b(step),
|
||||
accumulators,
|
||||
accumulators);
|
||||
}
|
||||
|
||||
// Make sure the data from shared memory has been entirely consumed.
|
||||
Traits::shared_load_fence(true);
|
||||
|
||||
// Commit the data in shared memory for A/B.
|
||||
if (!kIsLastIteration) {
|
||||
global_stream.commit();
|
||||
}
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
Traits::shared_store_fence(true);
|
||||
|
||||
// Trigger the loads for the next iteration (if needed).
|
||||
if (!kIsLastIteration) {
|
||||
// Move to the next stage for the load (if it makes sense).
|
||||
shared_load_stream.inc_stage();
|
||||
// Trigger the copy from shared memory for the next loop iteration.
|
||||
shared_load_stream.copy(0);
|
||||
}
|
||||
|
||||
// Make sure the values are available for the current iteration to do the multiply-add.
|
||||
shared_load_stream.commit(kUnrollingSteps - 1);
|
||||
|
||||
// Do the math on the fragments of the current iteration.
|
||||
typename Traits::MultiplyAdd multiply_add;
|
||||
multiply_add.multiply_add(shared_load_stream.fragment_a(kUnrollingSteps - 1),
|
||||
shared_load_stream.fragment_b(kUnrollingSteps - 1),
|
||||
accumulators,
|
||||
accumulators);
|
||||
}
|
||||
|
||||
/// Do the GEMM.
|
||||
CUTLASS_DEVICE void multiply_add() {
|
||||
// Swizzle the IDs of the block (to enable better cache behavior).
|
||||
@ -212,16 +277,11 @@ struct Gemm {
|
||||
// Create the accumulator clear.
|
||||
ClearAccumulators clear(shared_storage.main_loop.clear);
|
||||
|
||||
/// Define the mainloop iteration size
|
||||
typedef typename Traits::MultiplyAdd MultiplyAdd;
|
||||
|
||||
// By how much we unroll the main loop.
|
||||
Index const kUnroll = static_cast<Index>(MultiplyAdd::AccumulatorsPerWarp::kD);
|
||||
Index const kUnroll = static_cast<Index>(Traits::OutputTile::kD);
|
||||
|
||||
// If we do not have enough steps in the main loop, trigger the residue code.
|
||||
if (params.k < kUnroll) {
|
||||
global_stream.residue(params.k, true);
|
||||
}
|
||||
global_stream.move_to_residue<true>(params.k);
|
||||
|
||||
// Fetch the fragments for A and B from global memory.
|
||||
global_stream.copy();
|
||||
@ -232,9 +292,12 @@ struct Gemm {
|
||||
// Make sure the data is in shared memory.
|
||||
Traits::shared_store_fence(false);
|
||||
|
||||
// Rollback to the beginning of the GEMM-K dimension. It may have no impact.
|
||||
global_stream.rollback();
|
||||
|
||||
// The unrolling steps for the main loop.
|
||||
int const kUnrollingSteps =
|
||||
MultiplyAdd::AccumulatorsPerWarp::kD / MultiplyAdd::InstructionShape::kD;
|
||||
Traits::MultiplyAdd::AccumulatorsPerWarp::kD / Traits::MultiplyAdd::InstructionShape::kD;
|
||||
|
||||
// Make sure we have at least 2 unrolling steps or our pipeling is not going to work.
|
||||
static_assert(kUnrollingSteps >= 2, "The pipelining assumes at least two steps");
|
||||
@ -246,59 +309,21 @@ struct Gemm {
|
||||
shared_load_stream.copy(0);
|
||||
|
||||
// Allocate the accumulators.
|
||||
typename MultiplyAdd::Accumulators accumulators;
|
||||
typename Traits::MultiplyAdd::Accumulators accumulators;
|
||||
// Clear the accumulators.
|
||||
clear.clear(accumulators);
|
||||
|
||||
// The loop index.
|
||||
Index outer_k = params.k - kUnroll;
|
||||
|
||||
// Enter the main loop and iterate.
|
||||
typedef typename Traits::Index Index;
|
||||
for (Index outer_k = params.k - kUnroll; outer_k > -kUnroll; outer_k -= kUnroll) {
|
||||
// If that's the last "load iteration" update the predicates.
|
||||
int const is_residue = outer_k <= kUnroll;
|
||||
if (is_residue) {
|
||||
global_stream.residue(outer_k);
|
||||
}
|
||||
for (; outer_k > 0; outer_k -= kUnroll) {
|
||||
consume_tile<false>(global_stream, shared_load_stream, accumulators, outer_k);
|
||||
}
|
||||
|
||||
// Load data for the next iteration of the main loop.
|
||||
global_stream.copy();
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int step = 0; step < kUnrollingSteps - 1; ++step) {
|
||||
// Trigger the copy from shared memory for the next A/B values.
|
||||
shared_load_stream.copy(step + 1);
|
||||
// Make sure the values are available for the current iteration to do the multiply-add.
|
||||
shared_load_stream.commit(step);
|
||||
|
||||
// Do the math on the fragments of the current iteration.
|
||||
MultiplyAdd multiply_add;
|
||||
multiply_add.multiply_add(shared_load_stream.fragment_a(step),
|
||||
shared_load_stream.fragment_b(step),
|
||||
accumulators,
|
||||
accumulators);
|
||||
}
|
||||
|
||||
// Make sure the data from shared memory has been entirely consumed.
|
||||
Traits::shared_load_fence(true);
|
||||
|
||||
// Commit the data in shared memory for A/B.
|
||||
global_stream.commit();
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
Traits::shared_store_fence(true);
|
||||
|
||||
// Move to the next stage for the load (if it makes sense).
|
||||
shared_load_stream.inc_stage();
|
||||
// Trigger the copy from shared memory for the next loop iteration.
|
||||
shared_load_stream.copy(0);
|
||||
// Make sure the values are available for the current iteration to do the multiply-add.
|
||||
shared_load_stream.commit(kUnrollingSteps - 1);
|
||||
|
||||
// Do the math on the fragments of the current iteration.
|
||||
MultiplyAdd multiply_add;
|
||||
multiply_add.multiply_add(shared_load_stream.fragment_a(kUnrollingSteps - 1),
|
||||
shared_load_stream.fragment_b(kUnrollingSteps - 1),
|
||||
accumulators,
|
||||
accumulators);
|
||||
// Residual loop.
|
||||
for (; outer_k > -kUnroll; outer_k -= kUnroll) {
|
||||
consume_tile<true>(global_stream, shared_load_stream, accumulators, outer_k);
|
||||
}
|
||||
|
||||
// Epilogue.
|
||||
|
||||
@ -117,6 +117,7 @@ struct GemmEpilogue {
|
||||
CUTLASS_DEVICE void epilogue_with_or_without_beta(Coord<3> const& block,
|
||||
Accumulators& accumulators) {
|
||||
|
||||
// The problem size.
|
||||
Coord<3> const bounds = cutlass::make_Coord(0, n, m);
|
||||
|
||||
// The functor.
|
||||
@ -153,6 +154,18 @@ struct GemmEpilogue {
|
||||
GlobalStoreIteratorD global_store_iterator(
|
||||
params.iterator_d, bounds, block, pointer_offset, predicate_offset);
|
||||
|
||||
// The transformer to transform before storing to shared memory.
|
||||
SharedStoreTransformerD shared_store_transformer;
|
||||
typename SharedStoreTransformerD::OutputFragment shared_store_transformed_d;
|
||||
|
||||
// The iterator to store to shared memory.
|
||||
SharedStoreIteratorD shared_store_iterator(params.shared_store_iterator_d,
|
||||
shared_storage.shared_stream.store);
|
||||
|
||||
// The iterator to load from shared memory. TODO: Use a stream.
|
||||
SharedLoadIteratorD shared_load_iterator(params.shared_load_iterator_d,
|
||||
shared_storage.shared_stream.load);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
// Load the C matrix into fragment.
|
||||
@ -166,20 +179,13 @@ struct GemmEpilogue {
|
||||
// Copy the accumulators to shared memory.
|
||||
int const offset = (h * Iterations::kW + w) * SharedStoreIteratorD::Fragment::kElements;
|
||||
|
||||
SharedStoreTransformerD shared_store_transformer;
|
||||
typename SharedStoreTransformerD::OutputFragment shared_store_transformed_d;
|
||||
shared_store_transformer.transform(accumulators, offset, shared_store_transformed_d);
|
||||
|
||||
SharedStoreIteratorD shared_store_iterator(params.shared_store_iterator_d,
|
||||
shared_storage.shared_stream.store);
|
||||
shared_iterator_store(shared_store_iterator, shared_store_transformed_d);
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
shared_store_fence();
|
||||
|
||||
// Copy the accumulators back to registers from shared memory.
|
||||
SharedLoadIteratorD shared_load_iterator(params.shared_load_iterator_d,
|
||||
shared_storage.shared_stream.load);
|
||||
typename SharedLoadIteratorD::Fragment fetched_d;
|
||||
shared_iterator_load(shared_load_iterator, fetched_d);
|
||||
|
||||
|
||||
@ -84,8 +84,9 @@ struct GlobalLoadStreamBase {
|
||||
typename StoreIterator::Params store_iterator;
|
||||
|
||||
/// Setup the params.
|
||||
CUTLASS_HOST_DEVICE int initialize(Pointer pointer, Index ld) {
|
||||
int error_code = load_iterator.initialize(pointer, ld);
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc, Pointer pointer, Index ld) {
|
||||
int error_code = load_iterator.initialize(desc, pointer, ld);
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
@ -128,6 +129,9 @@ struct GlobalLoadStreamBase {
|
||||
store_iterator.inc_stage();
|
||||
}
|
||||
|
||||
/// Move to the beginning of the residue code. That's a new code path in CUTLASS 1.0.1.
|
||||
CUTLASS_DEVICE void move_to_residue(Index k) { load_iterator.move_to_residue(k); }
|
||||
|
||||
/// Execute the residue code.
|
||||
CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) {
|
||||
load_iterator.residue(k);
|
||||
@ -136,6 +140,9 @@ struct GlobalLoadStreamBase {
|
||||
}
|
||||
}
|
||||
|
||||
/// Rollback to the beginning of the GEMM-k dimension.
|
||||
CUTLASS_DEVICE void rollback() { load_iterator.rollback(); }
|
||||
|
||||
/// The iterator.
|
||||
LoadIterator load_iterator;
|
||||
/// The fragment to fetch from shared memory.
|
||||
|
||||
@ -195,7 +195,8 @@ struct GemmGlobalIteratorAb
|
||||
|
||||
struct Params : public BaseParams {
|
||||
/// Initializes params to load a strip-mined tile, given pointer and stride_h.
|
||||
CUTLASS_HOST_DEVICE int initialize(Scalar const* ptr, Index stride_h) {
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc, Scalar const* ptr, Index stride_h) {
|
||||
Index inc_d = 0;
|
||||
Index inc_advance = 0;
|
||||
// Move by some columns for each iteration in the H dimension.
|
||||
@ -220,16 +221,75 @@ struct GemmGlobalIteratorAb
|
||||
(Base::Iterations::kH - 1) * inc_h;
|
||||
}
|
||||
|
||||
Base::Params::initialize(ptr, 0, stride_h, 0, inc_d, inc_h, 0, inc_advance);
|
||||
// The dimensions of the tile.
|
||||
int const kH = TileTraits_::Tile::kH;
|
||||
int const kW = TileTraits_::Tile::kW * TileTraits_::kAccessSize;
|
||||
|
||||
// Move to the residue.
|
||||
Index const kBlock = kAdvance == IteratorAdvance::kH ? kH : kW;
|
||||
// The jump in the gemm-k dimension.
|
||||
Index const stride = kAdvance == IteratorAdvance::kH ? stride_h : 1;
|
||||
|
||||
// Compute the offset to the residue and how to "come" back.
|
||||
Index const kResidue = desc.k % kBlock;
|
||||
if (kResidue > 0) {
|
||||
move_to_residue_offset = (desc.k - kResidue) * stride;
|
||||
} else {
|
||||
move_to_residue_offset = (desc.k - kBlock) * stride;
|
||||
}
|
||||
|
||||
Base::Params::initialize(ptr, 0, stride_h, 1, inc_d, inc_h, 0, inc_advance);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// The extra offset to control moving to the residue.
|
||||
Index move_to_residue_offset;
|
||||
};
|
||||
|
||||
/// Offset of an individual lane from the start of the tile
|
||||
Coord<4> thread_offset;
|
||||
/// The parameters
|
||||
Params params;
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE GemmGlobalIteratorAb(Params const& _params,
|
||||
const Coord<3>& bounds,
|
||||
const Coord<3>& block,
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: params(_params) {
|
||||
thread_offset = thread_offset_func();
|
||||
// The column.
|
||||
Index block_h = thread_offset[1];
|
||||
// The contiguous dimension.
|
||||
Index block_w = thread_offset[2];
|
||||
|
||||
// Add the blocks indices.
|
||||
if (kAdvance == IteratorAdvance::kH) {
|
||||
block_h += block[1];
|
||||
block_w += block[2];
|
||||
|
||||
} else {
|
||||
block_h += block[2];
|
||||
block_w += block[1];
|
||||
}
|
||||
|
||||
// Setup the pointer.
|
||||
params.pointer += (block_h * params.stride_h + block_w);
|
||||
|
||||
// Initialize predicates
|
||||
initialize_predicates(bounds, make_Coord(0, block_h, block_w));
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void get(typename Base::AccessType& value, int d, int h, int w, int c) const {
|
||||
int const imm =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(0, 0, w, c);
|
||||
Load<Scalar, TileTraits_::kAccessSize, MemorySpace::kGlobal>::load(value, params.pointer, imm);
|
||||
}
|
||||
|
||||
/// Increment the pointer in the H dimension.
|
||||
CUTLASS_DEVICE void inc_h() { params.pointer += params.inc_h; }
|
||||
/// Increment the pointer in the D dimension.
|
||||
CUTLASS_DEVICE void inc_d() { params.pointer += params.inc_d; }
|
||||
/// Increment the pointer to move to the next iteration.
|
||||
CUTLASS_DEVICE void inc_advance() { params.pointer += params.inc_advance; }
|
||||
|
||||
/// Initialize the predicates.
|
||||
CUTLASS_DEVICE void initialize_predicates(const Coord<3>& bounds, const Coord<3>& block) {
|
||||
// Setup the masks to control loads.
|
||||
predicates.fill(0);
|
||||
@ -263,46 +323,29 @@ struct GemmGlobalIteratorAb
|
||||
}
|
||||
}
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE GemmGlobalIteratorAb(Params const& _params,
|
||||
const Coord<3>& bounds,
|
||||
const Coord<3>& block,
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: params(_params) {
|
||||
thread_offset = thread_offset_func();
|
||||
// The column.
|
||||
Index block_h = thread_offset[1];
|
||||
// The contiguous dimension.
|
||||
Index block_w = thread_offset[2];
|
||||
/// Move to residue portion.
|
||||
CUTLASS_DEVICE void move_to_residue(Index k) {
|
||||
// Store the pointer and the predicates.
|
||||
stored_pointer = params.pointer;
|
||||
stored_predicates = predicates;
|
||||
|
||||
// Add the blocks indices.
|
||||
if (kAdvance == IteratorAdvance::kH) {
|
||||
block_h += block[1];
|
||||
block_w += block[2];
|
||||
// Move the pointer to the residue.
|
||||
params.pointer += params.move_to_residue_offset;
|
||||
|
||||
} else {
|
||||
block_h += block[2];
|
||||
block_w += block[1];
|
||||
// The dimensions of the tile.
|
||||
int const kH = TileTraits_::Tile::kH;
|
||||
int const kW = TileTraits_::Tile::kW * TileTraits_::kAccessSize;
|
||||
|
||||
// The unrolling factor.
|
||||
int const kUnroll = kAdvance == IteratorAdvance::kH ? kH : kW;
|
||||
|
||||
// Clear the predicates for the residue. TODO: We can do something smarter.
|
||||
int const kResidue = (int)(k % (Index)kUnroll);
|
||||
if (kResidue > 0) {
|
||||
residue(kResidue);
|
||||
}
|
||||
|
||||
// Setup the pointer.
|
||||
params.pointer += (block_h * params.stride_h + block_w);
|
||||
|
||||
// Initialize predicates
|
||||
initialize_predicates(bounds, make_Coord(0, block_h, block_w));
|
||||
}
|
||||
|
||||
/// Increment the pointer in the H dimension.
|
||||
CUTLASS_DEVICE void inc_h() { params.pointer += params.inc_h; }
|
||||
/// Increment the pointer in the D dimension.
|
||||
CUTLASS_DEVICE void inc_d() { params.pointer += params.inc_d; }
|
||||
/// Increment the pointer to move to the next iteration.
|
||||
CUTLASS_DEVICE void inc_advance() { params.pointer += params.inc_advance; }
|
||||
|
||||
/// Returns the current pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
Scalar const* data() const { return params.pointer; }
|
||||
|
||||
/// That's the residue! Update the predicates.
|
||||
CUTLASS_DEVICE void residue(Index k) {
|
||||
// The coordinates of the thread.
|
||||
@ -332,14 +375,26 @@ struct GemmGlobalIteratorAb
|
||||
}
|
||||
}
|
||||
|
||||
/// Rollback to beginning of first tile and initialize predicates.
|
||||
CUTLASS_DEVICE void rollback() {
|
||||
params.pointer = stored_pointer;
|
||||
predicates = stored_predicates;
|
||||
}
|
||||
|
||||
/// Is the iterator valid?
|
||||
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const {
|
||||
int const bit = ComputeOffsetFromShape<typename Base::Iterations>::get(d, h, w, c);
|
||||
return predicates[bit];
|
||||
}
|
||||
|
||||
/// Offset of an individual lane from the start of the tile
|
||||
Coord<4> thread_offset;
|
||||
/// The parameters
|
||||
Params params;
|
||||
/// The pointer.
|
||||
typename Base::Scalar const* stored_pointer;
|
||||
/// The predicates.
|
||||
PredicateVector predicates;
|
||||
PredicateVector predicates, stored_predicates;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -439,6 +494,13 @@ struct GemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
|
||||
this->params.predicate_offset -= (h + pred_offset);
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void get(typename Base::AccessType& value, int d, int h, int w, int c) const {
|
||||
int const imm =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(0, 0, w, c);
|
||||
Load<Scalar, TileTraits_::kAccessSize, MemorySpace::kGlobal>::load(value, params.pointer, imm);
|
||||
}
|
||||
|
||||
/// Increment the pointer in the C dimension.
|
||||
CUTLASS_DEVICE void inc_c() {}
|
||||
/// Increment the pointer in the W dimension.
|
||||
@ -456,18 +518,19 @@ struct GemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
|
||||
this->params.predicate_offset -= params.predicate_inc_advance;
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void set(typename Base::AccessType const& value, int d, int h, int w, int c) {
|
||||
int const imm =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(0, 0, w, c);
|
||||
Store<Scalar, TileTraits_::kAccessSize, MemorySpace::kGlobal>::store(
|
||||
value, params.pointer, imm);
|
||||
}
|
||||
|
||||
/// Test the validity of the iterator.
|
||||
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const {
|
||||
return predicates.at(w) && params.predicate_offset > 0;
|
||||
}
|
||||
|
||||
/// Returns the raw pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
Pointer data() { return params.pointer; }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Pointer const data() const { return params.pointer; }
|
||||
|
||||
/// The predicates for the row.
|
||||
cutlass::PredicateVector<Base::Iterations::kW> predicates;
|
||||
};
|
||||
|
||||
@ -104,8 +104,7 @@ struct GemmSharedStoreWithSkewTileAbTraits {
|
||||
typedef Shape<0, ShapeCount<Tile>::kWc, Threads::kH * kAccessSize> ImmediateOffsetStrides;
|
||||
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
CUTLASS_HOST_DEVICE Coord<4> operator()() const {
|
||||
int offset = ComputeThreadOffsetFromStrides<Threads, ThreadsStrides>::get();
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
}
|
||||
@ -164,22 +163,25 @@ struct GemmSharedLoadTileATraits {
|
||||
typedef Shape<1, 1, TileWithoutSkew::kW / kWarps / kThreadsPerWarp /* / kScalarsPerLds*/>
|
||||
Iterations;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<TileWithSkew::kW, 0, kWarps * kThreadsPerWarp * kAccessSize, 0> Delta;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<TileWithSkew::kW, 0, kWarps * kThreadsPerWarp * kAccessSize, 0>
|
||||
typedef Shape<TileWithSkew::kW * Warps::kD, 0, kWarps * kThreadsPerWarp * kAccessSize, 0>
|
||||
ImmediateOffsetStrides;
|
||||
typedef Shape<TileWithSkew::kW * Warps::kD, 0, kWarps * kThreadsPerWarp * kAccessSize, 0> Delta;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
CUTLASS_HOST_DEVICE Coord<4> operator()() const {
|
||||
// Extract the warp.
|
||||
int const warp = threadIdx.x / kWarpSize % Warps::kW;
|
||||
// Compute the row offset for each thread
|
||||
int const lane = (threadIdx.x & 0x0e) / 2;
|
||||
int const warp = threadIdx.x / kWarpSize;
|
||||
// Extract the slice.
|
||||
int const slice = warp / (Warps::kH * Warps::kW);
|
||||
// Compute the row offset for each warp.
|
||||
int const warp_row = warp % Warps::kW;
|
||||
// Compute the row offset for each thread.
|
||||
int const lane_row = (threadIdx.x & 0x0e) / 2;
|
||||
// The offset.
|
||||
int const offset = (warp * ThreadsPerWarp::kW + lane) * kAccessSize;
|
||||
|
||||
int const offset =
|
||||
slice * Tile::kW * Tile::kC + (warp_row * ThreadsPerWarp::kW + lane_row) * kAccessSize;
|
||||
// Embed the offset in a 4D coordinate vector.
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
}
|
||||
};
|
||||
@ -231,23 +233,27 @@ struct GemmSharedLoadTileBTraits {
|
||||
/// The number of iterations needed to load/store the tile.
|
||||
typedef Shape<1, 1, TileWithoutSkew::kW / kWarps / kThreadsPerWarp /* / kAccessSize*/> Iterations;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<TileWithSkew::kW, 0, kWarps * kThreadsPerWarp * kAccessSize, 0> Delta;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<TileWithSkew::kW, 0, kWarps * kThreadsPerWarp * kAccessSize, 0>
|
||||
typedef Shape<TileWithSkew::kW * Warps::kD, 0, kWarps * kThreadsPerWarp * kAccessSize, 0>
|
||||
ImmediateOffsetStrides;
|
||||
typedef Shape<TileWithSkew::kW * Warps::kD, 0, kWarps * kThreadsPerWarp * kAccessSize, 0> Delta;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
// The position of the warp.
|
||||
int const warp = threadIdx.x / (Warps::kW * kWarpSize);
|
||||
|
||||
// Compute the column offset for each thread
|
||||
int const lane = (threadIdx.x & 0x10) / 8 + (threadIdx.x & 0x01);
|
||||
CUTLASS_HOST_DEVICE Coord<4> operator()() const {
|
||||
// Extract the warp.
|
||||
int const warp = threadIdx.x / kWarpSize;
|
||||
// Extract the slice.
|
||||
int const slice = warp / (Warps::kH * Warps::kW);
|
||||
// The warp in the slice.
|
||||
int const warp_in_slice = warp % (Warps::kH * Warps::kW);
|
||||
// Compute the row offset for each warp.
|
||||
int const warp_col = warp_in_slice / Warps::kW;
|
||||
// Compute the row offset for each thread.
|
||||
int const lane_col = (threadIdx.x & 0x10) / 8 + (threadIdx.x & 0x01);
|
||||
// The offset.
|
||||
int const offset = (warp * ThreadsPerWarp::kH + lane) * kAccessSize;
|
||||
|
||||
int const offset =
|
||||
slice * Tile::kW * Tile::kC + (warp_col * ThreadsPerWarp::kH + lane_col) * kAccessSize;
|
||||
// Embed the offset in a 4D coordinate.
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
}
|
||||
};
|
||||
@ -297,28 +303,26 @@ struct GemmSharedStoreTileDTraits {
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
// We issue STS.128 in the epilogue to store the accumulators to shared memory. When we use
|
||||
// STS.128, we have to guarantee that threads in groups of 8 do not have bank conflicts (i.e
|
||||
// they write to different banks).
|
||||
CUTLASS_HOST_DEVICE Coord<4> operator()() const {
|
||||
// The warp.
|
||||
int const warp = threadIdx.x / kWarpSize;
|
||||
|
||||
// The position of the warp in the 2D tile.
|
||||
int const warp_row = warp % Warps::kW;
|
||||
int const warp_col = warp / Warps::kW;
|
||||
|
||||
// We assume that the elements are distributed in a warps as 4 columns of 8 elements. The
|
||||
// columns are stored in threads col0=[0, 2, 4, 6, 8, 10, 12, 14], col1=[1, 3, 5, 7, .., 15],
|
||||
// col2=[16, 18, 20, ..., 30] and col3=[17, 19, ..., 31].
|
||||
int hi_halfwarp_offset = ((threadIdx.x >> 4) & 0x1) * OutputTile::kW;
|
||||
int lo_halfwarp_offset = ((threadIdx.x >> 1) & 0x7) + ThreadsPerWarp::kW * warp_row;
|
||||
|
||||
// Odd threads go to the second half of shared memory.
|
||||
int const row = threadIdx.x & 0x01;
|
||||
|
||||
int const warp_id = (threadIdx.x >> 5);
|
||||
|
||||
int const warp_row = (warp_id % Warps::kW);
|
||||
int const warp_col = (warp_id / Warps::kW);
|
||||
|
||||
int hi_halfwarp_offset = OutputTile::kW * ((threadIdx.x >> 4) & 1);
|
||||
int lo_halfwarp_offset = (((threadIdx.x >> 1) & 0x7) + warp_row * ThreadsPerWarp::kW);
|
||||
|
||||
int col = kAccessSize * lo_halfwarp_offset +
|
||||
warp_col * (ThreadsPerWarp::kH / 2) * OutputTile::kW + hi_halfwarp_offset;
|
||||
|
||||
int offset = row * kScalarsPerRow + col;
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
int col = warp_col * (ThreadsPerWarp::kH / 2) * OutputTile::kW +
|
||||
lo_halfwarp_offset * kAccessSize + hi_halfwarp_offset;
|
||||
// Embed the offset in a 4D coords.
|
||||
return make_Coord(0, 0, row * kScalarsPerRow + col, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
@ -357,32 +361,39 @@ struct GemmSharedLoadTileDTraits {
|
||||
/// The number of scalars per row. We build a tile with 2 rows (to avoid bank conflicts).
|
||||
static int const kScalarsPerRow = kThreads / 2 * kScalarsPerThread + kSkew;
|
||||
|
||||
/// The tile.
|
||||
/// The tile. We have 2 rows of scalars. We use those two rows to make sure we do not have bank
|
||||
/// conflicts in the epilogue.
|
||||
typedef Shape<1, 2, kScalarsPerRow / kAccessSize, kAccessSize> Tile;
|
||||
|
||||
// Compute the number of iterations per warp in the Tile::kH dimension.
|
||||
static int const kIterationsInHPerWarp = kTileH_ / ShapeCount<Warps>::kCount;
|
||||
|
||||
// As shown above, the shared memory tile is composed of 2 rows and each rows is made of
|
||||
// As explained above, the shared memory tile is composed of 2 rows and each rows is made of
|
||||
// kScalarsPerRow. A warp is expected to read from the 1st row, then move to the 2nd row and go
|
||||
// back to the 1st row. To model that scheme we define the Iterations shape as Shape<X, 2, ...>.
|
||||
// However, in some cases, we have only 1 iteration per warp. In that case, we must define the
|
||||
// shape as Shape<1, 1, ...>. The following code does that.
|
||||
// shape as Shape<1, 1, ...>. The following code does that except that we hijack the kH dimension
|
||||
// to keep the number of elements to reduce for split-K.
|
||||
static int const kIterationsH = kIterationsInHPerWarp == 1 ? 1 : 2;
|
||||
// As soon as we know kIterationsH, it is trivial to compute kIterationsD:
|
||||
static int const kIterationsD = kIterationsInHPerWarp / kIterationsH;
|
||||
|
||||
// If we have split-K enabled, we have to jump over the elements from the "odd/even" column of
|
||||
// threads to grab the other elements.
|
||||
static int const kSplitK = OutputTile::kW * ThreadsPerWarp::kH / 2 * Warps::kH;
|
||||
|
||||
/// The number of iterations needed to store the tile.
|
||||
typedef Shape<kIterationsD, kIterationsH, OutputTile::kW / kWarpSize / kAccessSize> Iterations;
|
||||
typedef Shape<kIterationsD, kIterationsH, OutputTile::kW / kWarpSize / kAccessSize, Warps::kD>
|
||||
Iterations;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<OutputTile::kW, kScalarsPerRow, kWarpSize * kAccessSize> Delta;
|
||||
typedef Shape<OutputTile::kW, kScalarsPerRow, kWarpSize * kAccessSize, kSplitK>
|
||||
ImmediateOffsetStrides;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<OutputTile::kW, kScalarsPerRow, kWarpSize * kAccessSize> ImmediateOffsetStrides;
|
||||
typedef Shape<OutputTile::kW, kScalarsPerRow, kWarpSize * kAccessSize, kSplitK> Delta;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
CUTLASS_HOST_DEVICE Coord<4> operator()() const {
|
||||
// Each warp works on a different column.
|
||||
int const h = threadIdx.x / kWarpSize;
|
||||
// Compute the row.
|
||||
|
||||
@ -74,7 +74,9 @@ template <
|
||||
/// The number of scalars per LDS for D.
|
||||
int kScalarsPerLdsD_,
|
||||
/// The number of stages in shared memory to do single/double/triple-buffering.
|
||||
int kStages_>
|
||||
int kStages_,
|
||||
/// Do we do the residue in the prologue?
|
||||
bool kResidueInPrologue_ = false>
|
||||
|
||||
struct GemmConfig {
|
||||
//
|
||||
@ -129,6 +131,9 @@ struct GemmConfig {
|
||||
|
||||
/// The number of stages in shared memory to implement double, triple, more-buffering.
|
||||
static int const kStages = kStages_;
|
||||
|
||||
/// Do we do the residue in the prologue?
|
||||
static bool const kResidueInPrologue = kResidueInPrologue_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -229,8 +234,12 @@ struct GemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_> {
|
||||
|
||||
/// The number of scalars in 4B.
|
||||
static int const kScalarsIn4B = sizeof(MultiplyAddScalar) > 4 ? 1 : 4 / sizeof(MultiplyAddScalar);
|
||||
/// The skew for A.
|
||||
static int const kSkewA = 128 / sizeof(MultiplyAddScalar) / GemmConfig_::kScalarsPerStsA /
|
||||
GlobalTileTraits::Threads::kW * kScalarsIn4B;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for A^T.
|
||||
typedef GemmSharedStoreWithSkewTileAbTraits<
|
||||
typedef GemmSharedStoreWithSkewTileAbTraits <
|
||||
// The pointer is float.
|
||||
MultiplyAddScalar,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
@ -242,9 +251,8 @@ struct GemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_> {
|
||||
// The number of scalars per STS.
|
||||
GemmConfig_::kScalarsPerStsA,
|
||||
// The skew to avoid bank conflicts added in the tile W dimension.
|
||||
128 / sizeof(MultiplyAddScalar) / GemmConfig_::kScalarsPerStsA /
|
||||
GlobalTileTraits::Threads::kW * kScalarsIn4B>
|
||||
SharedStoreTileTraits;
|
||||
kSkewA<GemmConfig_::kScalarsPerLdsA ? GemmConfig_::kScalarsPerLdsA : kSkewA>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for A^T.
|
||||
typedef GemmSharedLoadTileATraits<
|
||||
@ -302,8 +310,12 @@ struct GemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_> {
|
||||
|
||||
/// The number of scalars in 4B.
|
||||
static int const kScalarsIn4B = sizeof(MultiplyAddScalar) > 4 ? 1 : 4 / sizeof(MultiplyAddScalar);
|
||||
/// The skew for B.
|
||||
static int const kSkewB = 128 / sizeof(MultiplyAddScalar) / GemmConfig_::kScalarsPerStsB /
|
||||
GlobalTileTraits::Threads::kW * kScalarsIn4B;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for B^N.
|
||||
typedef GemmSharedStoreWithSkewTileAbTraits<
|
||||
typedef GemmSharedStoreWithSkewTileAbTraits <
|
||||
// The pointer is float.
|
||||
MultiplyAddScalar,
|
||||
// The tile has size KxN in GEMM's terminology.
|
||||
@ -315,9 +327,8 @@ struct GemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_> {
|
||||
// The number of scalars per STS.
|
||||
GemmConfig_::kScalarsPerStsB,
|
||||
// The skew to avoid bank conflicts added in the tile W dimension.
|
||||
128 / sizeof(MultiplyAddScalar) / GemmConfig_::kScalarsPerStsB /
|
||||
GlobalTileTraits::Threads::kW * kScalarsIn4B>
|
||||
SharedStoreTileTraits;
|
||||
kSkewB<GemmConfig_::kScalarsPerLdsB ? GemmConfig_::kScalarsPerLdsB : kSkewB>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for B^N.
|
||||
typedef GemmSharedLoadTileBTraits<
|
||||
@ -405,6 +416,60 @@ struct GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmTraits_, bool kResidueInPrologue_ = GemmTraits_::kResidueInPrologue>
|
||||
struct GemmResidue {
|
||||
/// Move to residue portion.
|
||||
template <bool kIsPrologue>
|
||||
static CUTLASS_DEVICE void move_to_residue(typename GemmTraits_::GlobalLoadStreamA& stream_a,
|
||||
typename GemmTraits_::GlobalLoadStreamB& stream_b,
|
||||
typename GemmTraits_::Index k) {
|
||||
// The new code path in CUTLASS 1.0.1: We treat the residue in the prologue so we can have
|
||||
// complete main loops after that. It helps simplify the logic in the main loop.
|
||||
if (kIsPrologue) {
|
||||
stream_a.move_to_residue(k);
|
||||
stream_b.move_to_residue(k);
|
||||
}
|
||||
}
|
||||
|
||||
/// Rollback to beginning of first tile and initialize predicates.
|
||||
static CUTLASS_DEVICE void rollback(typename GemmTraits_::GlobalLoadStreamA& stream_a,
|
||||
typename GemmTraits_::GlobalLoadStreamB& stream_b) {
|
||||
stream_a.rollback();
|
||||
stream_b.rollback();
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmTraits_>
|
||||
struct GemmResidue<GemmTraits_, false> {
|
||||
/// Move to residue portion.
|
||||
template <bool kIsPrologue>
|
||||
static CUTLASS_DEVICE void move_to_residue(typename GemmTraits_::GlobalLoadStreamA& stream_a,
|
||||
typename GemmTraits_::GlobalLoadStreamB& stream_b,
|
||||
typename GemmTraits_::Index k) {
|
||||
// The index.
|
||||
typedef typename GemmTraits_::Index Index;
|
||||
// By how much we unroll the main loop.
|
||||
Index const kUnroll = static_cast<Index>(GemmTraits_::OutputTile::kD);
|
||||
|
||||
// Call the residue code. That's the same path as CUTLASS 1.0.0.
|
||||
if (kIsPrologue && k < kUnroll) {
|
||||
stream_a.residue(k, true);
|
||||
stream_b.residue(k, true);
|
||||
} else if (k <= kUnroll) {
|
||||
stream_a.residue(k, false);
|
||||
stream_b.residue(k, false);
|
||||
}
|
||||
}
|
||||
|
||||
/// Rollback to beginning of first tile and initialize predicates.
|
||||
static CUTLASS_DEVICE void rollback(typename GemmTraits_::GlobalLoadStreamA& stream_a,
|
||||
typename GemmTraits_::GlobalLoadStreamB& stream_b) {}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The GEMM configuration.
|
||||
typename GemmConfig_,
|
||||
@ -426,10 +491,24 @@ template <
|
||||
typename ClearAccumulators_ = ClearAccumulators<typename GemmConfig_::Accumulators::Scalar> >
|
||||
|
||||
struct GemmTraits {
|
||||
/// This class.
|
||||
typedef GemmTraits<GemmConfig_,
|
||||
GlobalLoadStreamA_,
|
||||
GlobalLoadStreamB_,
|
||||
SharedLoadStreamA_,
|
||||
SharedLoadStreamB_,
|
||||
Epilogue_,
|
||||
BlockSwizzle_,
|
||||
Index_,
|
||||
ClearAccumulators_>
|
||||
This_;
|
||||
|
||||
/// The configuration.
|
||||
typedef GemmConfig_ GemmConfig;
|
||||
/// The output tile.
|
||||
typedef typename GemmConfig::OutputTile OutputTile;
|
||||
/// Is the residue treated in the prologue?
|
||||
static bool const kResidueInPrologue = GemmConfig::kResidueInPrologue;
|
||||
|
||||
/// The stream to load A from global memory to shared memory.
|
||||
typedef GlobalLoadStreamA_ GlobalLoadStreamA;
|
||||
@ -450,18 +529,6 @@ struct GemmTraits {
|
||||
/// The iterator for B to load from shared memory.
|
||||
typedef SharedLoadStreamB_ SharedLoadStreamB;
|
||||
|
||||
/// The shared storage for A.
|
||||
typedef typename GlobalLoadStreamA::SharedStoreStorage SharedStoreStorageA;
|
||||
// Btw, make sure we did not messed up with the size of the storage.
|
||||
static_assert(sizeof(SharedStoreStorageA) == sizeof(typename SharedLoadStreamA::SharedStorage),
|
||||
"");
|
||||
|
||||
/// The shared storage for B.
|
||||
typedef typename GlobalLoadStreamB::SharedStoreStorage SharedStoreStorageB;
|
||||
// Btw, make sure we did not messed up with the size of the storage.
|
||||
static_assert(sizeof(SharedStoreStorageB) == sizeof(typename SharedLoadStreamB::SharedStorage),
|
||||
"");
|
||||
|
||||
/// The multiply-add functor.
|
||||
typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
|
||||
/// The epilogue.
|
||||
@ -502,14 +569,15 @@ struct GemmTraits {
|
||||
|
||||
// Initialize the iterator for A.
|
||||
int error_code =
|
||||
global_stream_a.initialize(reinterpret_cast<ScalarA const*>(desc.d_a), desc.lda);
|
||||
global_stream_a.initialize(desc, reinterpret_cast<ScalarA const*>(desc.d_a), desc.lda);
|
||||
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
// Initialize the iterator for B.
|
||||
error_code = global_stream_b.initialize(reinterpret_cast<ScalarB const*>(desc.d_b), desc.ldb);
|
||||
error_code =
|
||||
global_stream_b.initialize(desc, reinterpret_cast<ScalarB const*>(desc.d_b), desc.ldb);
|
||||
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
@ -574,12 +642,15 @@ struct GemmTraits {
|
||||
stream_b.commit();
|
||||
}
|
||||
|
||||
/// Execute the residue code.
|
||||
CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) {
|
||||
stream_a.residue(k, skip_clear);
|
||||
stream_b.residue(k, skip_clear);
|
||||
/// Move to residue portion.
|
||||
template <bool kIsPrologue>
|
||||
CUTLASS_DEVICE void move_to_residue(Index k) {
|
||||
GemmResidue<This_>::move_to_residue<kIsPrologue>(stream_a, stream_b, k);
|
||||
}
|
||||
|
||||
/// Rollback to beginning of first tile and initialize predicates.
|
||||
CUTLASS_DEVICE void rollback() { GemmResidue<This_>::rollback(stream_a, stream_b); }
|
||||
|
||||
/// The stream for A.
|
||||
GlobalLoadStreamA stream_a;
|
||||
/// The stream for B.
|
||||
|
||||
@ -147,8 +147,11 @@ struct HgemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_>
|
||||
GemmConfig_::kScalarsPerLdgA>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The skew.
|
||||
static int const kSkewA = 128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for A^T.
|
||||
typedef GemmSharedStoreWithSkewTileAbTraits<
|
||||
typedef GemmSharedStoreWithSkewTileAbTraits <
|
||||
// The pointer.
|
||||
half,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
@ -160,8 +163,8 @@ struct HgemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_>
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
2,
|
||||
// The skew to avoid bank conflicts added in the tile W dimension.
|
||||
128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2>
|
||||
SharedStoreTileTraits;
|
||||
kSkewA<GemmConfig_::kScalarsPerLdsA ? GemmConfig_::kScalarsPerLdsA : kSkewA>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for A^T.
|
||||
typedef GemmSharedLoadTileATraits<
|
||||
@ -212,8 +215,11 @@ struct HgemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_>
|
||||
GemmConfig_::kScalarsPerLdgB>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The skew for B.
|
||||
static int const kSkewB = 128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for B^N.
|
||||
typedef GemmSharedStoreWithSkewTileAbTraits<
|
||||
typedef GemmSharedStoreWithSkewTileAbTraits <
|
||||
// The pointer.
|
||||
half,
|
||||
// The tile has size KxN in GEMM's terminology.
|
||||
@ -225,8 +231,8 @@ struct HgemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_>
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
2,
|
||||
// The skew to avoid bank conflicts added in the tile W dimension.
|
||||
128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2>
|
||||
SharedStoreTileTraits;
|
||||
kSkewB<GemmConfig_::kScalarsPerLdsB ? GemmConfig_::kScalarsPerLdsB : kSkewB>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for B^N.
|
||||
typedef GemmSharedLoadTileBTraits<
|
||||
@ -261,7 +267,7 @@ template <
|
||||
/// The functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_ = Shape<32, 8, 8>,
|
||||
typename AccumulatorsPerThread_ = Shape<8, 8, 16>,
|
||||
/// The number of halfs loaded in one LDG for A.
|
||||
int kScalarsPerLdgA_ = 2,
|
||||
/// The number of halfs loaded in one LDG for B.
|
||||
|
||||
@ -47,19 +47,19 @@ template <GemmOperand::Kind kOperand_,
|
||||
typename Tile_,
|
||||
typename Threads_,
|
||||
int kAccessSize_>
|
||||
struct IgemmContiguousGlobalTileTraits : public GemmGlobalTileTraits<
|
||||
// Which GEMM operand?
|
||||
kOperand_,
|
||||
// The layout.
|
||||
kLayout_,
|
||||
// The scalar.
|
||||
Scalar_,
|
||||
// The tile.
|
||||
Tile_,
|
||||
// The threads.
|
||||
Threads_,
|
||||
// The number of scalars per LDG/STG.
|
||||
kAccessSize_> {
|
||||
struct IgemmGlobalTileTraits : public GemmGlobalTileTraits<
|
||||
// Which GEMM operand?
|
||||
kOperand_,
|
||||
// The layout.
|
||||
kLayout_,
|
||||
// The scalar.
|
||||
Scalar_,
|
||||
// The tile.
|
||||
Tile_,
|
||||
// The threads.
|
||||
Threads_,
|
||||
// The number of scalars per LDG/STG.
|
||||
kAccessSize_> {
|
||||
/// The base class.
|
||||
typedef GemmGlobalTileTraits<kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_> Base;
|
||||
/// The threads.
|
||||
@ -91,5 +91,71 @@ struct IgemmContiguousGlobalTileTraits : public GemmGlobalTileTraits<
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Deprecated. Please use IgemmGlobalTileTraits instead.
|
||||
|
||||
template <GemmOperand::Kind kOperand_,
|
||||
MatrixLayout::Kind kLayout_,
|
||||
typename Scalar_,
|
||||
typename Tile_,
|
||||
typename Threads_,
|
||||
int kAccessSize_>
|
||||
struct IgemmContiguousGlobalTileTraits
|
||||
: public IgemmGlobalTileTraits<kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename TileTraits_, typename Index_ = int>
|
||||
struct IgemmGlobalIteratorAb : public GemmGlobalIteratorAb<TileTraits_, Index_> {
|
||||
/// The base class.
|
||||
typedef GemmGlobalIteratorAb<TileTraits_, Index_> Base;
|
||||
/// The functor to compute the thread offset.
|
||||
typedef typename TileTraits_::ThreadOffset ThreadOffset;
|
||||
|
||||
/// Constructor.
|
||||
CUTLASS_DEVICE IgemmGlobalIteratorAb(typename Base::Params const& _params,
|
||||
const Coord<3>& bounds,
|
||||
const Coord<3>& block,
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: Base(_params, bounds, block, thread_offset_func), in_residue_(false), mask_(0xffffffff) {
|
||||
// The number of elements read in a single iteration.
|
||||
int const kBlock = TileTraits_::Tile::kW * TileTraits_::kAccessSize;
|
||||
// The residue.
|
||||
int const kResidue = (int)(bounds[1] % kBlock);
|
||||
|
||||
// Compute the number of elements that are valid.
|
||||
int const left = kResidue - Base::thread_offset[2];
|
||||
if (left > 0 && left < 4) {
|
||||
mask_ = (1u << (8 * left)) - 1u;
|
||||
}
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void get(typename Base::AccessType& value, int d, int h, int w, int c) const {
|
||||
Base::get(value, d, h, w, c);
|
||||
if (in_residue_) {
|
||||
reinterpret_cast<uint32_t&>(value) &= mask_;
|
||||
}
|
||||
}
|
||||
|
||||
/// Move to residue portion.
|
||||
CUTLASS_DEVICE void move_to_residue(typename Base::Index k) {
|
||||
Base::move_to_residue(k);
|
||||
in_residue_ = true;
|
||||
}
|
||||
|
||||
/// Move back to the beginning of the first tile.
|
||||
CUTLASS_DEVICE void rollback() {
|
||||
Base::rollback();
|
||||
in_residue_ = false;
|
||||
}
|
||||
|
||||
/// Are we in the residue?
|
||||
bool in_residue_;
|
||||
/// The mask to clean up the values.
|
||||
uint32_t mask_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
@ -87,7 +87,9 @@ struct IgemmConfig
|
||||
/// The number of scalars per LDS for D.
|
||||
1,
|
||||
/// The number of stages in shared memory.
|
||||
2> {};
|
||||
2,
|
||||
/// Enable the code path that deals with the residue in epilogue.
|
||||
true> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -125,17 +127,19 @@ struct IgemmConfig<OutputTile_, int8_t, AccumulatorsPerThread_>
|
||||
/// The number of scalars per LDS for D.
|
||||
4,
|
||||
/// The number of stages in shared memory.
|
||||
2> {};
|
||||
2,
|
||||
/// Enable the code path that deals with the residue in epilogue.
|
||||
true> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
|
||||
template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_, typename Index_>
|
||||
struct IgemmTileTraitsHelperA : public GemmTileTraitsHelperA<kLayout_, GemmConfig_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct IgemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_>
|
||||
template <typename GemmConfig_, typename Index_>
|
||||
struct IgemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_, Index_>
|
||||
: public GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> {
|
||||
/// The base config.
|
||||
typedef GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> Base;
|
||||
@ -144,7 +148,7 @@ struct IgemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_>
|
||||
static int const kScalarsPerStsA = 16;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for A^N.
|
||||
typedef IgemmContiguousGlobalTileTraits<
|
||||
typedef IgemmGlobalTileTraits<
|
||||
GemmOperand::kA,
|
||||
// The layout.
|
||||
MatrixLayout::kColumnMajor,
|
||||
@ -155,9 +159,12 @@ struct IgemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_>
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
4>
|
||||
GemmConfig_::kScalarsPerLdgA>
|
||||
GlobalTileTraits;
|
||||
|
||||
// The iterator.
|
||||
typedef GemmGlobalIteratorAb<GlobalTileTraits, Index_> GlobalLoadIterator;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for A^N.
|
||||
typedef GemmSharedStoreTileAbTraits<
|
||||
// The pointer is float.
|
||||
@ -173,13 +180,149 @@ struct IgemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
|
||||
template <typename GemmConfig_, typename Index_>
|
||||
struct IgemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_, Index_> {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
|
||||
|
||||
/// The input scalar.
|
||||
typedef int8_t Scalar;
|
||||
/// The scalar stored in shared memory.
|
||||
typedef int8_t MultiplyAddScalar;
|
||||
|
||||
/// The number of scalars per LDG/STS/LDS for A.
|
||||
static int const kScalarsPerStsA = 16;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for A^T.
|
||||
typedef IgemmGlobalTileTraits<
|
||||
GemmOperand::kA,
|
||||
// The layout.
|
||||
MatrixLayout::kRowMajor,
|
||||
// The pointer is float const.
|
||||
int8_t const,
|
||||
// The tile has size NxK in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgA>
|
||||
GlobalTileTraits;
|
||||
|
||||
// The iterator.
|
||||
typedef IgemmGlobalIteratorAb<GlobalTileTraits, Index_> GlobalLoadIterator;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for A^N.
|
||||
typedef GemmSharedStoreWithSkewTileAbTraits<
|
||||
// The pointer is int8.
|
||||
int8_t,
|
||||
// The tile has size KxN in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages, GemmConfig_::OutputTile::kD / 4, GemmConfig_::OutputTile::kW * 4>,
|
||||
// The threads are distributed as (threads / K) x K (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS.
|
||||
kScalarsPerStsA,
|
||||
// The skew to avoid bank conflicts added in the tile W dimension.
|
||||
16>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for A^N.
|
||||
typedef GemmSharedLoadTileATraits<
|
||||
// The pointer is float const.
|
||||
int8_t const,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The shape of the FMA instruction.
|
||||
typename GemmConfig_::InstructionShape,
|
||||
// The number of stages.
|
||||
GemmConfig_::kStages,
|
||||
// The number of scalars per LDS.
|
||||
16,
|
||||
// The skew.
|
||||
SharedStoreTileTraits::kSkew>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_, typename Index_>
|
||||
struct IgemmTileTraitsHelperB : public GemmTileTraitsHelperB<kLayout_, GemmConfig_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct IgemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_>
|
||||
template <typename GemmConfig_, typename Index_>
|
||||
struct IgemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_, Index_> {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
|
||||
|
||||
/// The input scalar.
|
||||
typedef int8_t Scalar;
|
||||
/// The scalar stored in shared memory.
|
||||
typedef int8_t MultiplyAddScalar;
|
||||
|
||||
/// The number of scalars per LDG/STS/LDS for B.
|
||||
static int const kScalarsPerStsB = 16;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for B^T.
|
||||
typedef IgemmGlobalTileTraits<
|
||||
GemmOperand::kB,
|
||||
// The layout.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer is float const.
|
||||
int8_t const,
|
||||
// The tile has size NxK in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgB>
|
||||
GlobalTileTraits;
|
||||
|
||||
// The iterator.
|
||||
typedef IgemmGlobalIteratorAb<GlobalTileTraits, Index_> GlobalLoadIterator;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for B^N.
|
||||
typedef GemmSharedStoreWithSkewTileAbTraits<
|
||||
// The pointer is int8.
|
||||
int8_t,
|
||||
// The tile has size KxN in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages, GemmConfig_::OutputTile::kD / 4, GemmConfig_::OutputTile::kH * 4>,
|
||||
// The threads are distributed as (threads / K) x K (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS.
|
||||
kScalarsPerStsB,
|
||||
// The skew to avoid bank conflicts added in the tile W dimension.
|
||||
16>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for B^N.
|
||||
typedef GemmSharedLoadTileBTraits<
|
||||
// The pointer is float const.
|
||||
int8_t const,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The shape of the FMA instruction.
|
||||
typename GemmConfig_::InstructionShape,
|
||||
// The number of stages.
|
||||
GemmConfig_::kStages,
|
||||
// The number of scalars per LDS.
|
||||
16,
|
||||
// The skew.
|
||||
SharedStoreTileTraits::kSkew>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_, typename Index_>
|
||||
struct IgemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_, Index_>
|
||||
: public GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> {
|
||||
/// The base config.
|
||||
typedef GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> Base;
|
||||
@ -188,7 +331,7 @@ struct IgemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_>
|
||||
static int const kScalarsPerStsB = 16;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for B^T.
|
||||
typedef IgemmContiguousGlobalTileTraits<
|
||||
typedef IgemmGlobalTileTraits<
|
||||
GemmOperand::kB,
|
||||
// The layout.
|
||||
MatrixLayout::kRowMajor,
|
||||
@ -199,9 +342,12 @@ struct IgemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_>
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
4>
|
||||
GemmConfig_::kScalarsPerLdgB>
|
||||
GlobalTileTraits;
|
||||
|
||||
// The iterator.
|
||||
typedef GemmGlobalIteratorAb<GlobalTileTraits, Index_> GlobalLoadIterator;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for B^N.
|
||||
typedef GemmSharedStoreTileAbTraits<
|
||||
// The pointer is float.
|
||||
@ -266,13 +412,13 @@ struct IgemmTraitsHelper {
|
||||
/// The IGEMM config.
|
||||
typedef IgemmConfig<OutputTile_, ScalarD_, AccumulatorsPerThread_> GemmConfig;
|
||||
/// The GEMM config for A.
|
||||
typedef IgemmTileTraitsHelperA<kLayoutA_, GemmConfig> GemmTileTraitsHelperA;
|
||||
typedef IgemmTileTraitsHelperA<kLayoutA_, GemmConfig, Index_> GemmTileTraitsHelperA;
|
||||
/// The GEMM config for B.
|
||||
typedef IgemmTileTraitsHelperB<kLayoutB_, GemmConfig> GemmTileTraitsHelperB;
|
||||
typedef IgemmTileTraitsHelperB<kLayoutB_, GemmConfig, Index_> GemmTileTraitsHelperB;
|
||||
|
||||
/// The iterator to load A from global memory.
|
||||
typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperA::GlobalTileTraits, Index_>
|
||||
GlobalLoadIteratorA;
|
||||
typedef typename GemmTileTraitsHelperA::GlobalLoadIterator GlobalLoadIteratorA;
|
||||
|
||||
/// The default transformer for A.
|
||||
typedef typename IgemmTransformerA<GemmTileTraitsHelperA::kLayout,
|
||||
GlobalLoadIteratorA>::Transformer GlobalTransformerA;
|
||||
@ -287,8 +433,8 @@ struct IgemmTraitsHelper {
|
||||
GlobalLoadStreamA;
|
||||
|
||||
/// The iterator to load B from global memory.
|
||||
typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperB::GlobalTileTraits, Index_>
|
||||
GlobalLoadIteratorB;
|
||||
typedef typename GemmTileTraitsHelperB::GlobalLoadIterator GlobalLoadIteratorB;
|
||||
|
||||
// The default transformer for B.
|
||||
typedef typename IgemmTransformerB<GemmTileTraitsHelperB::kLayout,
|
||||
GlobalLoadIteratorB>::Transformer GlobalTransformerB;
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
@ -61,17 +60,17 @@ struct LinearScaling {
|
||||
CUTLASS_DEVICE LinearScaling(Params const& params) : alpha(params.alpha), beta(params.beta) {}
|
||||
|
||||
/// Evaluate the functor.
|
||||
template <typename Fragment_>
|
||||
CUTLASS_DEVICE void evaluate(Fragment_ const& accum, Fragment_& output) {
|
||||
template <typename FragmentA_, typename FragmentB_>
|
||||
CUTLASS_DEVICE void evaluate(FragmentA_ const& accum, FragmentB_& output) {
|
||||
FragmentMultiplyAdd mad;
|
||||
mad.multiply(alpha, accum, output);
|
||||
}
|
||||
|
||||
/// Evaluate the functor.
|
||||
template <typename Fragment_>
|
||||
CUTLASS_DEVICE void evaluate(Fragment_ const& accum, Fragment_ const& old, Fragment_& output) {
|
||||
template <typename FragmentA_, typename FragmentB_>
|
||||
CUTLASS_DEVICE void evaluate(FragmentA_ const& accum, FragmentB_ const& old, FragmentB_& output) {
|
||||
FragmentMultiplyAdd mad;
|
||||
Fragment_ tmp;
|
||||
FragmentB_ tmp;
|
||||
mad.multiply(beta, old, tmp);
|
||||
mad.multiply_add(alpha, accum, tmp, output);
|
||||
}
|
||||
|
||||
@ -164,6 +164,13 @@ struct WmmaGemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
|
||||
this->params.predicate_offset -= (h + pred_offset);
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void get(typename Base::AccessType& value, int d, int h, int w, int c) const {
|
||||
int const imm =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(0, 0, w, c);
|
||||
Load<Scalar, TileTraits_::kAccessSize, MemorySpace::kGlobal>::load(value, params.pointer, imm);
|
||||
}
|
||||
|
||||
/// Increment the pointer in the C dimension.
|
||||
CUTLASS_DEVICE void inc_c() {}
|
||||
/// Increment the pointer in the W dimension.
|
||||
@ -181,18 +188,19 @@ struct WmmaGemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
|
||||
params.predicate_offset -= params.predicate_inc_advance;
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void set(typename Base::AccessType const& value, int d, int h, int w, int c) {
|
||||
int const imm =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(d, h, w, 0);
|
||||
Store<Scalar, TileTraits_::kAccessSize, MemorySpace::kGlobal>::store(
|
||||
value, params.pointer, imm);
|
||||
}
|
||||
|
||||
/// Test the predicate.
|
||||
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const {
|
||||
return predicates.at(w) && params.predicate_offset > 0;
|
||||
}
|
||||
|
||||
/// Returns the raw pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
Pointer data() { return params.pointer; }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Pointer const data() const { return params.pointer; }
|
||||
|
||||
/// The predicates for the row.
|
||||
cutlass::PredicateVector<Base::Iterations::kW> predicates;
|
||||
};
|
||||
|
||||
@ -45,14 +45,12 @@ CUTLASS_HOST_DEVICE void iterator_load(InputIterator &iterator, Fragment &fragme
|
||||
for (int w = 0; w < InputIterator::Iterations::kW; ++w) {
|
||||
for (int c = 0; c < InputIterator::Iterations::kC; ++c) {
|
||||
if (iterator.valid(d, h, w, c)) {
|
||||
int const offset =
|
||||
ComputeOffsetFromStrides<typename InputIterator::ImmediateOffsetStrides>::get(
|
||||
0, 0, w, c);
|
||||
Load<typename Fragment::Element, InputIterator::Tile::kC, InputIterator::kMemorySpace>::
|
||||
load(reinterpret_cast<typename InputIterator::AccessType &>(
|
||||
frag_iterator.at(d, h, w, c)),
|
||||
iterator.data(),
|
||||
offset);
|
||||
iterator.get(reinterpret_cast<typename InputIterator::AccessType &>(
|
||||
frag_iterator.at(d, h, w, c)),
|
||||
d,
|
||||
h,
|
||||
w,
|
||||
c);
|
||||
}
|
||||
}
|
||||
if (w < InputIterator::Iterations::kW - 1) {
|
||||
@ -196,17 +194,12 @@ CUTLASS_HOST_DEVICE void iterator_store(OutputIterator &iterator, Fragment &frag
|
||||
for (int h = 0; h < OutputIterator::Iterations::kH; ++h) {
|
||||
for (int w = 0; w < OutputIterator::Iterations::kW; ++w) {
|
||||
if (iterator.valid(d, h, w, 0)) {
|
||||
int const offset =
|
||||
ComputeOffsetFromStrides<typename OutputIterator::ImmediateOffsetStrides>::get(
|
||||
d, h, w, 0);
|
||||
|
||||
Store<typename Fragment::Element,
|
||||
OutputIterator::Tile::kC,
|
||||
OutputIterator::kMemorySpace>::
|
||||
store(reinterpret_cast<typename OutputIterator::AccessType &>(
|
||||
frag_iterator.at(d, h, w, 0)),
|
||||
iterator.data(),
|
||||
offset);
|
||||
iterator.set(reinterpret_cast<typename OutputIterator::AccessType const &>(
|
||||
frag_iterator.at(d, h, w, 0)),
|
||||
d,
|
||||
h,
|
||||
w,
|
||||
0);
|
||||
}
|
||||
if (w < OutputIterator::Iterations::kW - 1) {
|
||||
iterator.inc_w();
|
||||
|
||||
@ -106,6 +106,29 @@ struct Load<double, 2, Memory_, true, 16> {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(__CUDACC_VERSION_MAJOR) && __CUDACC_VERSION_MAJOR < 10
|
||||
// WAR bug in NVCC where the upper and lower half of the register end up being the same
|
||||
template <MemorySpace::Kind Memory_>
|
||||
struct Load<half, 8, Memory_, true, 16> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<half, 8>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void load(AccessType& dst, half const* pointer, int offset) {
|
||||
int2 tmp = reinterpret_cast<int2 const*>(&pointer[offset])[0];
|
||||
dst.registers[0] = tmp.x;
|
||||
dst.registers[1] = tmp.y;
|
||||
|
||||
tmp = reinterpret_cast<int2 const*>(&pointer[offset + 4])[0];
|
||||
dst.registers[2] = tmp.x;
|
||||
dst.registers[3] = tmp.y;
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, int Lanes_, MemorySpace::Kind Memory_>
|
||||
struct Load<Scalar_, Lanes_, Memory_, true, 16> {
|
||||
/// The output type.
|
||||
|
||||
@ -150,9 +150,13 @@ struct ShapeMin {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Shape_>
|
||||
template <typename Shape_, int kElementsPerAccess>
|
||||
struct ShapeStrides {
|
||||
typedef Shape<Shape_::kH * Shape_::kW * Shape_::kC, Shape_::kW * Shape_::kC, Shape_::kC, 1> Shape;
|
||||
typedef Shape<Shape_::kH * Shape_::kW * Shape_::kC,
|
||||
Shape_::kW * Shape_::kC,
|
||||
Shape_::kC,
|
||||
kElementsPerAccess>
|
||||
Shape;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -73,7 +73,11 @@ struct IteratorFragment {
|
||||
* @brief A template defining \ref tile_traits_concept
|
||||
* @concept{tile_traits_concept}
|
||||
*/
|
||||
template <typename Tile_, typename Delta_, typename Iterations_, typename ThreadOffset_>
|
||||
template <typename Tile_,
|
||||
typename Delta_,
|
||||
typename Iterations_,
|
||||
typename ThreadOffset_,
|
||||
int kAccessSize>
|
||||
struct TileTraits {
|
||||
/// Shape of the tile
|
||||
typedef Tile_ Tile;
|
||||
@ -501,6 +505,13 @@ struct TileLoadIterator : public TileIteratorBase<Traits_,
|
||||
CUTLASS_HOST_DEVICE
|
||||
Scalar const *data() const { return params.pointer; }
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void get(AccessType &value, int d, int h, int w, int c) const {
|
||||
int const imm =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(d, h, w, c);
|
||||
Load<Scalar, Base::kAccessSize, kMemorySpace>::load(value, params.pointer, imm);
|
||||
}
|
||||
|
||||
/// Increment in the D dimension
|
||||
CUTLASS_HOST_DEVICE void inc_d() { params.pointer += params.inc_d; }
|
||||
|
||||
@ -829,6 +840,13 @@ struct TileStoreIterator : public TileIteratorBase<Traits_,
|
||||
}
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void set(AccessType const &value, int d, int h, int w, int c) {
|
||||
int const imm =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(d, h, w, c);
|
||||
Store<Scalar, Base::kAccessSize, kMemorySpace>::store(value, params.pointer, imm);
|
||||
}
|
||||
|
||||
public:
|
||||
/// Stores a fragment and advances to the next tile.
|
||||
template <typename Fragment, typename PredicateIterator>
|
||||
|
||||
@ -299,7 +299,7 @@ typedef integral_constant<bool, true> true_type;
|
||||
/// The type used as a compile-time boolean with false value.
|
||||
typedef integral_constant<bool, false> false_type;
|
||||
|
||||
#if (!defined(_MSC_VER) && (__cplusplus < 201402L)) || (defined(_MSC_VER) && (_MSC_VER < 1900))
|
||||
#if (!defined(_MSC_VER) && (__cplusplus <= 201402L)) || (defined(_MSC_VER) && (_MSC_VER < 1900))
|
||||
|
||||
/// std::bool_constant
|
||||
template <bool V>
|
||||
|
||||
@ -47,6 +47,7 @@ set(CUTLASS_UNIT_TEST_SOURCES
|
||||
core/tile_iterator.cu
|
||||
gemm/dgemm.cu
|
||||
gemm/hgemm_128x128x8.cu
|
||||
gemm/hgemm_128x128x16.cu
|
||||
gemm/hgemm_128x32x8.cu
|
||||
gemm/hgemm_128x64x8.cu
|
||||
gemm/igemm_128x128x32.cu
|
||||
@ -54,12 +55,19 @@ set(CUTLASS_UNIT_TEST_SOURCES
|
||||
gemm/igemm_128x32x32.cu
|
||||
gemm/igemm_128x128x32_float.cu
|
||||
gemm/igemm_128x128x32_int8.cu
|
||||
gemm/igemm_32x32x128.cu
|
||||
gemm/sgemm_128x128x8.cu
|
||||
gemm/sgemm_128x128x16.cu
|
||||
gemm/sgemm_128x64x8.cu
|
||||
gemm/sgemm_128x64x16.cu
|
||||
gemm/sgemm_128x32x8.cu
|
||||
gemm/sgemm_128x32x16.cu
|
||||
gemm/sgemm_64x128x8.cu
|
||||
gemm/sgemm_64x128x16.cu
|
||||
gemm/sgemm_64x64x8.cu
|
||||
gemm/sgemm_64x64x16.cu
|
||||
gemm/sgemm_64x32x8.cu
|
||||
gemm/sgemm_64x32x16.cu
|
||||
gemm/wmma_gemm.cu
|
||||
)
|
||||
|
||||
|
||||
@ -26,9 +26,26 @@
|
||||
\brief CUTLASS Unit Tests
|
||||
*/
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
void set_gtest_flag() {
|
||||
// Default flags can be overwritten by --gtest_filter from commandline
|
||||
cudaDeviceProp deviceProperties;
|
||||
cudaGetDeviceProperties(&deviceProperties, 0);
|
||||
|
||||
int deviceMajorMinor = deviceProperties.major * 10 + deviceProperties.minor;
|
||||
|
||||
if (deviceMajorMinor < 53)
|
||||
::testing::GTEST_FLAG(filter) = "-*Igemm*:*Hgemm*:*mma*";
|
||||
else if (deviceMajorMinor < 61)
|
||||
::testing::GTEST_FLAG(filter) = "-*Igemm*:*mma*";
|
||||
else if (deviceMajorMinor < 70)
|
||||
::testing::GTEST_FLAG(filter) = "-*mma*";
|
||||
}
|
||||
|
||||
int main(int argc, char* arg[]) {
|
||||
set_gtest_flag();
|
||||
::testing::InitGoogleTest(&argc, arg);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
|
||||
@ -104,6 +104,64 @@ TEST(Dgemm_128x128x8, dgemm_512x256x64_nt) {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
//Sliced-K configuration
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Dgemm_64x32x16, dgemm_64x32x16_nt) {
|
||||
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<16, 32, 64> > GemmTraits;
|
||||
run_gemm<GemmTraits>(64, 32, 16);
|
||||
}
|
||||
|
||||
TEST(Dgemm_64x32x16, dgemm_256x128x64_nt) {
|
||||
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<16, 32, 64> > GemmTraits;
|
||||
run_gemm<GemmTraits>(256, 128, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Dgemm_64x64x16, dgemm_64x64x16_nt) {
|
||||
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<16, 64, 64> > GemmTraits;
|
||||
run_gemm<GemmTraits>(64, 64, 16);
|
||||
}
|
||||
|
||||
TEST(Dgemm_64x64x16, dgemm_256x128x64_nt) {
|
||||
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<16, 64, 64> > GemmTraits;
|
||||
run_gemm<GemmTraits>(256, 128, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Dgemm_128x32x16, dgemm_128x32x8_nt) {
|
||||
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<16, 32, 128> > GemmTraits;
|
||||
run_gemm<GemmTraits>(128, 32, 16);
|
||||
}
|
||||
|
||||
TEST(Dgemm_128x32x16, dgemm_256x64x64_nt) {
|
||||
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<16, 32, 128> > GemmTraits;
|
||||
run_gemm<GemmTraits>(256, 64, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
//
|
||||
// DGEMM Column-Column
|
||||
//
|
||||
@ -182,6 +240,64 @@ TEST(Dgemm_128x128x8, dgemm_512x256x64_nn) {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Sliced-K configuration
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Dgemm_64x32x16, dgemm_64x32x16_nn) {
|
||||
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 64> > GemmTraits;
|
||||
run_gemm<GemmTraits>(64, 32, 16);
|
||||
}
|
||||
|
||||
TEST(Dgemm_64x32x16, dgemm_256x128x64_nn) {
|
||||
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 64> > GemmTraits;
|
||||
run_gemm<GemmTraits>(256, 128, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Dgemm_64x64x16, dgemm_64x64x16_nn) {
|
||||
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 64, 64> > GemmTraits;
|
||||
run_gemm<GemmTraits>(64, 64, 16);
|
||||
}
|
||||
|
||||
TEST(Dgemm_64x64x16, dgemm_256x128x64_nn) {
|
||||
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 64, 64> > GemmTraits;
|
||||
run_gemm<GemmTraits>(256, 128, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Dgemm_128x32x16, dgemm_128x32x16_nn) {
|
||||
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 128> > GemmTraits;
|
||||
run_gemm<GemmTraits>(128, 32, 16);
|
||||
}
|
||||
|
||||
TEST(Dgemm_128x32x16, dgemm_256x64x64_nn) {
|
||||
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 128> > GemmTraits;
|
||||
run_gemm<GemmTraits>(256, 64, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
//
|
||||
// DGEMM Row-Column
|
||||
//
|
||||
@ -260,6 +376,64 @@ TEST(Dgemm_128x128x8, dgemm_512x256x64_tn) {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Sliced-K configuration
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Dgemm_64x32x16, dgemm_64x32x16_tn) {
|
||||
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 64> > GemmTraits;
|
||||
run_gemm<GemmTraits>(64, 32, 16);
|
||||
}
|
||||
|
||||
TEST(Dgemm_64x32x16, dgemm_256x128x64_tn) {
|
||||
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 64> > GemmTraits;
|
||||
run_gemm<GemmTraits>(256, 128, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Dgemm_64x64x16, dgemm_64x64x16_tn) {
|
||||
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 64, 64> > GemmTraits;
|
||||
run_gemm<GemmTraits>(64, 64, 16);
|
||||
}
|
||||
|
||||
TEST(Dgemm_64x64x16, dgemm_256x128x64_tn) {
|
||||
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 64, 64> > GemmTraits;
|
||||
run_gemm<GemmTraits>(256, 128, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Dgemm_128x32x16, dgemm_128x32x8_tn) {
|
||||
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 128> > GemmTraits;
|
||||
run_gemm<GemmTraits>(128, 32, 16);
|
||||
}
|
||||
|
||||
TEST(Dgemm_128x32x16, dgemm_256x64x64_tn) {
|
||||
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 128> > GemmTraits;
|
||||
run_gemm<GemmTraits>(256, 64, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
//
|
||||
// DGEMM Row-Row
|
||||
//
|
||||
@ -338,3 +512,62 @@ TEST(Dgemm_128x128x8, dgemm_512x256x64_tt) {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Sliced-K configuration
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Dgemm_64x32x16, dgemm_64x32x16_tt) {
|
||||
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<16, 32, 64> > GemmTraits;
|
||||
run_gemm<GemmTraits>(64, 32, 16);
|
||||
}
|
||||
|
||||
TEST(Dgemm_64x32x16, dgemm_256x128x64_tt) {
|
||||
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<16, 32, 64> > GemmTraits;
|
||||
run_gemm<GemmTraits>(256, 128, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Dgemm_64x64x16, dgemm_64x64x16_tt) {
|
||||
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<16, 64, 64> > GemmTraits;
|
||||
run_gemm<GemmTraits>(64, 64, 16);
|
||||
}
|
||||
|
||||
TEST(Dgemm_64x64x16, dgemm_256x128x64_tt) {
|
||||
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<16, 64, 64> > GemmTraits;
|
||||
run_gemm<GemmTraits>(256, 128, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Dgemm_128x32x16, dgemm_128x32x8_tt) {
|
||||
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<16, 32, 128> > GemmTraits;
|
||||
run_gemm<GemmTraits>(128, 32, 16);
|
||||
}
|
||||
|
||||
TEST(Dgemm_128x32x16, dgemm_256x64x64_tt) {
|
||||
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<16, 32, 128> > GemmTraits;
|
||||
run_gemm<GemmTraits>(256, 64, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
@ -24,12 +24,18 @@
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <tools/test/unit/gemm/gemm_testbed.h>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmTraits_>
|
||||
static void run_gemm(
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc,
|
||||
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type alpha =
|
||||
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type(1),
|
||||
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type beta =
|
||||
@ -51,6 +57,9 @@ static void run_gemm(
|
||||
testbed(m,
|
||||
n,
|
||||
k,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
cutlass::convert(GemmTraits_::kLayoutA),
|
||||
cutlass::convert(GemmTraits_::kLayoutB),
|
||||
alpha,
|
||||
@ -88,3 +97,22 @@ static void run_gemm(
|
||||
ASSERT_TRUE(testbed.verify_with_host());
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmTraits_>
|
||||
static void run_gemm(
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type alpha =
|
||||
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type(1),
|
||||
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type beta =
|
||||
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type(0)) {
|
||||
int lda = GemmTraits_::kLayoutA == cutlass::MatrixLayout::kColumnMajor ? m : k;
|
||||
int ldb = GemmTraits_::kLayoutB == cutlass::MatrixLayout::kColumnMajor ? k : n;
|
||||
|
||||
run_gemm<GemmTraits_>(m, n, k, lda, ldb, m, alpha, beta);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -44,6 +44,8 @@
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <cutlass::GemmOperand::Kind kOperand_,
|
||||
cutlass::MatrixLayout::Kind kLayout_,
|
||||
typename Scalar_,
|
||||
@ -53,6 +55,8 @@ struct WmmaMatrix;
|
||||
|
||||
namespace test {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
struct GemmTestbedTraits : public cutlass::TypeTraits<T> {};
|
||||
|
||||
@ -68,6 +72,8 @@ struct GemmTestbedTraits<cutlass::WmmaMatrix<kOperand_, kLayout_, Scalar_, WmmaS
|
||||
static inline double to_print(double x) { return x; }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename AType, typename BType, typename CType, typename Accumulator, typename Scalar>
|
||||
struct GemmTestbed {
|
||||
//
|
||||
@ -219,11 +225,11 @@ struct GemmTestbed {
|
||||
|
||||
typedef cutlass::Coord<cutlass::HostTensor<T>::Rank> Coord_t;
|
||||
|
||||
size_t matrix_stride = layout == CUBLAS_OP_N ? columns * ldm : rows * ldm;
|
||||
// TODO: Remove that (int) cast.
|
||||
Coord_t stride = cutlass::make_Coord(
|
||||
rows * columns, layout == CUBLAS_OP_N ? 1 : ldm, layout == CUBLAS_OP_N ? ldm : 1, 1);
|
||||
|
||||
(int)matrix_stride, layout == CUBLAS_OP_N ? 1 : ldm, layout == CUBLAS_OP_N ? ldm : 1, 1);
|
||||
Coord_t size = cutlass::make_Coord(1, rows, columns, 1);
|
||||
|
||||
tensor.reset(stride, size);
|
||||
}
|
||||
|
||||
@ -231,11 +237,13 @@ struct GemmTestbed {
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructs a workspace for verifying GEMM, assumes
|
||||
/// dense packing.
|
||||
/// Constructs a workspace for verifying GEMM.
|
||||
GemmTestbed(int M_,
|
||||
int N_,
|
||||
int K_,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc,
|
||||
cublasOperation_t layout_a,
|
||||
cublasOperation_t layout_b,
|
||||
Scalar alpha_ = Scalar(1),
|
||||
@ -248,33 +256,6 @@ struct GemmTestbed {
|
||||
throw cutlass::cuda_exception("Failed to create CUBLAS handle");
|
||||
}
|
||||
|
||||
resize(A, M_, K_, layout_a);
|
||||
resize(B, K_, N_, layout_b);
|
||||
resize(C_initial, M_, N_, layout_c);
|
||||
resize(ref_host, M_, N_, layout_c);
|
||||
resize(ref_cublas, M_, N_, layout_c);
|
||||
resize(computed, M_, N_, layout_c);
|
||||
}
|
||||
|
||||
/// Constructs a workspace for verifying GEMM with arbitrary strides
|
||||
GemmTestbed(int M_,
|
||||
int N_,
|
||||
int K_,
|
||||
int ldc,
|
||||
cublasOperation_t layout_a,
|
||||
int lda,
|
||||
cublasOperation_t layout_b,
|
||||
int ldb,
|
||||
Scalar alpha_ = Scalar(1),
|
||||
Scalar beta_ = Scalar(0),
|
||||
cublasGemmAlgo_t algorithm_ = CUBLAS_GEMM_DEFAULT,
|
||||
cublasOperation_t layout_c = CUBLAS_OP_N)
|
||||
: alpha(alpha_), beta(beta_), algorithm(algorithm_) {
|
||||
status = cublasCreate(&handle);
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
throw cutlass::cuda_exception("Failed to create CUBLAS handle");
|
||||
}
|
||||
|
||||
resize(A, M_, K_, layout_a, lda);
|
||||
resize(B, K_, N_, layout_b, ldb);
|
||||
resize(C_initial, M_, N_, layout_c, ldc);
|
||||
@ -515,6 +496,8 @@ struct GemmTestbed {
|
||||
|
||||
} // namespace test
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
inline cublasOperation_t convert(cutlass::MatrixLayout::Kind layout) {
|
||||
switch (layout) {
|
||||
@ -527,4 +510,6 @@ inline cublasOperation_t convert(cutlass::MatrixLayout::Kind layout) {
|
||||
}
|
||||
return CUBLAS_OP_N;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
}
|
||||
|
||||
347
tools/test/unit/gemm/hgemm_128x128x16.cu
Normal file
347
tools/test/unit/gemm/hgemm_128x128x16.cu
Normal file
@ -0,0 +1,347 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <cutlass_unit_test.h>
|
||||
#include <tools/util/half.h>
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
#include <cutlass/gemm/hgemm_traits.h>
|
||||
#include <tools/test/unit/gemm/gemm_testbed.h>
|
||||
#include <tools/test/unit/gemm/gemm.h>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_2x2x2_nt) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(2, 2, 2);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_128x128x8_nt) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(128, 128, 8);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_128x128x16_nt) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(128, 128, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_128x128x17_nt) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(128, 128, 17);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_128x128x64_nt) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(128, 128, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_256x128x16_nt) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(256, 128, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_128x256x16_nt) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(128, 256, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_256x256x16_nt) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(256, 256, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_128x128x16_nn) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(128, 128, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_128x128x18_nn) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(128, 128, 18);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_128x128x64_nn) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(128, 128, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_256x128x16_nn) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(256, 128, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_128x256x16_nn) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(128, 256, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_256x256x16_nn) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(256, 256, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_128x128x16_tn) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(128, 128, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_128x128x18_tn) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(128, 128, 18);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_128x128x64_tn) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(128, 128, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_256x128x16_tn) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(256, 128, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_128x256x16_tn) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(128, 256, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_256x256x16_tn) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(256, 256, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_128x128x16_tt) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(128, 128, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_128x128x18_tt) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(128, 128, 18);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_128x128x64_tt) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(128, 128, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_256x128x16_tt) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(256, 128, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_128x256x16_tt) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(128, 256, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_256x256x16_tt) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(256, 256, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_128x128x16_alpha2_nt) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(128, 128, 16, cutlass::half_t(2), cutlass::half_t(0));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
TEST(Hgemm_128x128x16, hgemm_128x128x16_beta1_nt) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(128, 128, 16, cutlass::half_t(1), cutlass::half_t(1));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_128x128x16_alpha2_beta1_nt) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(128, 128, 16, cutlass::half_t(2), cutlass::half_t(1));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_120x112x64_ldg8_nt) {
|
||||
// Load 8 halfs per LDG for A/B.
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<16, 128, 128>,
|
||||
cutlass::gemm::LinearScaling<half>,
|
||||
cutlass::Shape<8, 8, 16>,
|
||||
8, 8>
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(120, 112, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_508x252x120_ragged_nt) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(508, 252, 120);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_124x126x32_ragged_nt) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(124, 126, 32);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_124x126x32_ragged_alpha2_beta1_nt) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<16, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(124, 126, 32, cutlass::half_t(2), cutlass::half_t(1));
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -345,7 +345,7 @@ TEST(Hgemm_128x128x8, hgemm_128x128x16_alpha2_beta1_nt) {
|
||||
TEST(Hgemm_128x128x8, hgemm_120x112x64_ldg8_nt) {
|
||||
// Load 8 halfs per LDG for A/B.
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<8, 128, 128>,
|
||||
cutlass::gemm::LinearScaling<half>,
|
||||
cutlass::Shape<8, 8, 16>,
|
||||
@ -367,7 +367,7 @@ TEST(Hgemm_128x128x8, hgemm_508x252x120_ragged_nt) {
|
||||
|
||||
TEST(Hgemm_128x128x8, hgemm_124x126x32_ragged_nt) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<8, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(124, 126, 32);
|
||||
@ -377,7 +377,7 @@ TEST(Hgemm_128x128x8, hgemm_124x126x32_ragged_nt) {
|
||||
|
||||
TEST(Hgemm_128x128x8, hgemm_124x126x32_ragged_alpha2_beta1_nt) {
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<8, 128, 128> >
|
||||
HgemmTraits;
|
||||
run_gemm<HgemmTraits>(124, 126, 32, cutlass::half_t(2), cutlass::half_t(1));
|
||||
|
||||
@ -25,7 +25,6 @@
|
||||
#include <cutlass_unit_test.h>
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
#include <cutlass/gemm/igemm_traits.h>
|
||||
#include <tools/test/unit/gemm/gemm_testbed.h>
|
||||
#include <tools/test/unit/gemm/gemm.h>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
238
tools/test/unit/gemm/igemm_32x32x128.cu
Normal file
238
tools/test/unit/gemm/igemm_32x32x128.cu
Normal file
@ -0,0 +1,238 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <cutlass_unit_test.h>
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
#include <cutlass/gemm/igemm_traits.h>
|
||||
#include <tools/test/unit/gemm/gemm.h>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Igemm_32x32x128, igemm_32x32x4_nt) {
|
||||
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<128, 32, 32>,
|
||||
int,
|
||||
cutlass::gemm::LinearScaling<int>,
|
||||
cutlass::Shape<32, 8, 4> >
|
||||
IgemmTraits;
|
||||
run_gemm<IgemmTraits>(32, 32, 4);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Igemm_32x32x128, igemm_32x32x8_nt) {
|
||||
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<128, 32, 32>,
|
||||
int,
|
||||
cutlass::gemm::LinearScaling<int>,
|
||||
cutlass::Shape<32, 8, 4> >
|
||||
IgemmTraits;
|
||||
run_gemm<IgemmTraits>(32, 32, 8);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Igemm_32x32x128, igemm_32x32x32_nt) {
|
||||
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<128, 32, 32>,
|
||||
int,
|
||||
cutlass::gemm::LinearScaling<int>,
|
||||
cutlass::Shape<32, 8, 4> >
|
||||
IgemmTraits;
|
||||
run_gemm<IgemmTraits>(32, 32, 32);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Igemm_32x32x128, igemm_32x32x128_nt) {
|
||||
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<128, 32, 32>,
|
||||
int,
|
||||
cutlass::gemm::LinearScaling<int>,
|
||||
cutlass::Shape<32, 8, 4> >
|
||||
IgemmTraits;
|
||||
run_gemm<IgemmTraits>(32, 32, 128);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Igemm_32x32x128, igemm_32x32x4_nn) {
|
||||
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<128, 32, 32>,
|
||||
int,
|
||||
cutlass::gemm::LinearScaling<int>,
|
||||
cutlass::Shape<32, 8, 4> >
|
||||
IgemmTraits;
|
||||
run_gemm<IgemmTraits>(32, 32, 4);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Igemm_32x32x128, igemm_32x32x8_nn) {
|
||||
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<128, 32, 32>,
|
||||
int,
|
||||
cutlass::gemm::LinearScaling<int>,
|
||||
cutlass::Shape<32, 8, 4> >
|
||||
IgemmTraits;
|
||||
run_gemm<IgemmTraits>(32, 32, 8);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Igemm_32x32x128, igemm_32x32x32_nn) {
|
||||
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<128, 32, 32>,
|
||||
int,
|
||||
cutlass::gemm::LinearScaling<int>,
|
||||
cutlass::Shape<32, 8, 4> >
|
||||
IgemmTraits;
|
||||
run_gemm<IgemmTraits>(32, 32, 32);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Igemm_32x32x128, igemm_32x32x128_nn) {
|
||||
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<128, 32, 32>,
|
||||
int,
|
||||
cutlass::gemm::LinearScaling<int>,
|
||||
cutlass::Shape<32, 8, 4> >
|
||||
IgemmTraits;
|
||||
run_gemm<IgemmTraits>(32, 32, 128);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Igemm_32x32x128, igemm_32x32x4_tn) {
|
||||
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<128, 32, 32>,
|
||||
int,
|
||||
cutlass::gemm::LinearScaling<int>,
|
||||
cutlass::Shape<32, 8, 4> >
|
||||
IgemmTraits;
|
||||
run_gemm<IgemmTraits>(32, 32, 4);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Igemm_32x32x128, igemm_32x32x8_tn) {
|
||||
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<128, 32, 32>,
|
||||
int,
|
||||
cutlass::gemm::LinearScaling<int>,
|
||||
cutlass::Shape<32, 8, 4> >
|
||||
IgemmTraits;
|
||||
run_gemm<IgemmTraits>(32, 32, 8);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Igemm_32x32x128, igemm_32x32x15_tn) {
|
||||
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<128, 32, 32>,
|
||||
int,
|
||||
cutlass::gemm::LinearScaling<int>,
|
||||
cutlass::Shape<32, 8, 4> >
|
||||
IgemmTraits;
|
||||
run_gemm<IgemmTraits>(32, 32, 15, 16, 16, 32);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Igemm_32x32x128, igemm_32x32x32_tn) {
|
||||
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<128, 32, 32>,
|
||||
int,
|
||||
cutlass::gemm::LinearScaling<int>,
|
||||
cutlass::Shape<32, 8, 4> >
|
||||
IgemmTraits;
|
||||
run_gemm<IgemmTraits>(32, 32, 32);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Igemm_32x32x128, igemm_32x32x128_tn) {
|
||||
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<128, 32, 32>,
|
||||
int,
|
||||
cutlass::gemm::LinearScaling<int>,
|
||||
cutlass::Shape<32, 8, 4> >
|
||||
IgemmTraits;
|
||||
run_gemm<IgemmTraits>(32, 32, 128);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Igemm_32x32x128, igemm_32x32x8_tt) {
|
||||
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<128, 32, 32>,
|
||||
int,
|
||||
cutlass::gemm::LinearScaling<int>,
|
||||
cutlass::Shape<32, 8, 4> >
|
||||
IgemmTraits;
|
||||
run_gemm<IgemmTraits>(32, 32, 8);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Igemm_32x32x128, igemm_32x32x32_tt) {
|
||||
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<128, 32, 32>,
|
||||
int,
|
||||
cutlass::gemm::LinearScaling<int>,
|
||||
cutlass::Shape<32, 8, 4> >
|
||||
IgemmTraits;
|
||||
run_gemm<IgemmTraits>(32, 32, 32);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Igemm_32x32x128, igemm_32x32x128_tt) {
|
||||
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<128, 32, 32>,
|
||||
int,
|
||||
cutlass::gemm::LinearScaling<int>,
|
||||
cutlass::Shape<32, 8, 4> >
|
||||
IgemmTraits;
|
||||
run_gemm<IgemmTraits>(32, 32, 128);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
410
tools/test/unit/gemm/sgemm_128x128x16.cu
Normal file
410
tools/test/unit/gemm/sgemm_128x128x16.cu
Normal file
@ -0,0 +1,410 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <cutlass_unit_test.h>
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
#include <cutlass/gemm/sgemm_traits.h>
|
||||
#include <tools/test/unit/gemm/gemm_testbed.h>
|
||||
#include <tools/test/unit/gemm/gemm.h>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_128x128x16_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 128, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_128x81x1_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 81, 1);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_128x112x16_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 112, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_128x112x17_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 112, 17);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_128x73x16_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 73, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_97x112x64_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(97, 112, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_256x112x16_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(256, 112, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_128x240x16_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 240, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_256x240x16_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(256, 240, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_128x128x16_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 128, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_128x112x1_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 112, 1);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_79x112x16_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(79, 112, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_128x81x17_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 81, 17);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_128x112x16_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 112, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_128x73x64_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 73, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_256x112x16_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(256, 112, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_128x256x16_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 256, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_256x256x16_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(256, 256, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_128x128x16_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 128, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_128x128x1_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<16, 128, 128> > SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 128, 1);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_127x112x16_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(127, 112, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_21x112x17_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(21, 112, 17);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_128x73x16_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 73, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_128x81x64_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 81, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_256x112x16_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(256, 112, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_47x256x16_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(47, 256, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_211x256x16_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(211, 256, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_128x128x16_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 128, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_128x128x1_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> > SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 128, 1);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_109x112x16_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(109, 112, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_128x112x17_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 112, 17);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_128x112x16_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 112, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_123x112x64_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(123, 112, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_256x112x16_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(256, 112, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_128x256x16_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 256, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_256x256x16_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(256, 256, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_120x112x64_ldg4_nt) {
|
||||
// Load 4 floats per LDG for A/B.
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<16, 128, 128>,
|
||||
cutlass::gemm::LinearScaling<float>,
|
||||
cutlass::Shape<8, 8, 8>,
|
||||
4,
|
||||
4>
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(120, 112, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_128x128x16_alpha2_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 128, 16, 2.f, 0.f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_128x112x16_beta1_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 112, 16, 1.f, 1.f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x128x16, sgemm_128x112x16_alpha2_beta1_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 112, 16, 2.f, 1.f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -334,7 +334,7 @@ TEST(Sgemm_128x128x8, sgemm_256x256x16_tt) {
|
||||
TEST(Sgemm_128x128x8, sgemm_120x112x64_ldg4_nt) {
|
||||
// Load 4 floats per LDG for A/B.
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<8, 128, 128>,
|
||||
cutlass::gemm::LinearScaling<float>,
|
||||
cutlass::Shape<8, 8, 8>,
|
||||
|
||||
294
tools/test/unit/gemm/sgemm_128x32x16.cu
Normal file
294
tools/test/unit/gemm/sgemm_128x32x16.cu
Normal file
@ -0,0 +1,294 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <cutlass_unit_test.h>
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
#include <cutlass/gemm/sgemm_traits.h>
|
||||
#include <tools/test/unit/gemm/gemm_testbed.h>
|
||||
#include <tools/test/unit/gemm/gemm.h>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x32x16, sgemm_128x32x1_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 32, 1);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x32x16, sgemm_128x32x16_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 32, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x32x16, sgemm_128x32x17_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 32, 17);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x32x16, sgemm_128x32x32_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 32, 32);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x32x16, sgemm_256x32x16_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(256, 32, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x32x16, sgemm_128x64x16_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x32x16, sgemm_256x64x16_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(256, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x32x16, sgemm_128x32x1_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 32, 1);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x32x16, sgemm_128x32x16_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 32, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x32x16, sgemm_128x32x17_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 32, 17);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x32x16, sgemm_128x32x32_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 32, 32);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x32x16, sgemm_256x32x16_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(256, 32, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x32x16, sgemm_128x64x16_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x32x16, sgemm_256x64x16_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(256, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x32x16, sgemm_128x32x1_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<16, 128, 128> > SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 128, 1);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x32x16, sgemm_128x32x16_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 32, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x32x16, sgemm_128x32x17_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 32, 17);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x32x16, sgemm_128x32x32_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 32, 32);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x32x16, sgemm_256x32x16_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(256, 32, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x32x16, sgemm_128x64x16_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x32x16, sgemm_256x64x16_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(256, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x32x16, sgemm_128x32x1_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> > SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 128, 1);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x32x16, sgemm_128x32x16_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 32, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x32x16, sgemm_128x32x17_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 32, 17);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x32x16, sgemm_128x32x32_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 32, 32);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x32x16, sgemm_256x32x16_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(256, 32, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x32x16, sgemm_128x64x16_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x32x16, sgemm_256x64x16_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(256, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
285
tools/test/unit/gemm/sgemm_128x64x16.cu
Normal file
285
tools/test/unit/gemm/sgemm_128x64x16.cu
Normal file
@ -0,0 +1,285 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <cutlass_unit_test.h>
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
#include <cutlass/gemm/sgemm_traits.h>
|
||||
#include <tools/test/unit/gemm/gemm_testbed.h>
|
||||
#include <tools/test/unit/gemm/gemm.h>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x64x16, sgemm_128x64x1_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 64, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 64, 1);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x64x16, sgemm_128x64x16_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 64, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x64x16, sgemm_128x64x17_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 64, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 64, 17);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x64x16, sgemm_128x64x64_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 64, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 64, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x64x16, sgemm_256x64x16_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 64, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(256, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x64x16, sgemm_128x128x16_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 64, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 128, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x64x16, sgemm_256x128x16_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 64, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(256, 128, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x64x16, sgemm_128x64x1_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 64, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 64, 1);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x64x16, sgemm_128x64x8_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 64, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x64x16, sgemm_128x64x17_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 64, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 64, 17);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x64x16, sgemm_128x64x64_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 64, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 64, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x64x16, sgemm_256x64x16_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 64, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(256, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x64x16, sgemm_128x128x16_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 64, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 128, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x64x16, sgemm_256x128x16_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 64, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(256, 128, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x64x16, sgemm_128x64x1_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<16, 128, 128> > SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 128, 1);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x64x16, sgemm_128x64x16_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 64, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x64x16, sgemm_128x64x17_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 64, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 64, 17);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x64x16, sgemm_128x64x64_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 64, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 64, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x64x16, sgemm_256x64x16_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 64, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(256, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x64x16, sgemm_128x128x16_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 64, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 128, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x64x16, sgemm_256x128x16_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 64, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(256, 128, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x64x16, sgemm_128x64x1_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> > SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 128, 1);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x64x16, sgemm_128x64x16_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 64, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x64x16, sgemm_128x64x17_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 64, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 64, 17);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x64x16, sgemm_128x64x64_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 64, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 64, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x64x16, sgemm_128x128x16_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 64, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 128, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_128x64x16, sgemm_256x128x16_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 64, 128> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(256, 128, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -333,10 +333,10 @@ TEST(Sgemm_128x64x8, sgemm_256x128x16_tt) {
|
||||
|
||||
TEST(Sgemm_128x64x8, sgemm_128x64x64_8x4_accumulators_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<8, 64, 128>,
|
||||
cutlass::gemm::LinearScaling<float>,
|
||||
cutlass::Shape<8, 4, 8> >
|
||||
cutlass::Shape<8, 8, 8> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 64, 64);
|
||||
}
|
||||
@ -345,7 +345,7 @@ TEST(Sgemm_128x64x8, sgemm_128x64x64_8x4_accumulators_nt) {
|
||||
|
||||
TEST(Sgemm_128x64x8, sgemm_128x64x64_4x8_accumulators_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<8, 64, 128>,
|
||||
cutlass::gemm::LinearScaling<float>,
|
||||
cutlass::Shape<8, 8, 4> >
|
||||
|
||||
43
tools/test/unit/gemm/sgemm_64x128x16.cu
Normal file
43
tools/test/unit/gemm/sgemm_64x128x16.cu
Normal file
@ -0,0 +1,43 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <cutlass_unit_test.h>
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
#include <cutlass/gemm/sgemm_traits.h>
|
||||
#include <tools/test/unit/gemm/gemm_testbed.h>
|
||||
#include <tools/test/unit/gemm/gemm.h>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x128x16, sgemm_64x128x64_4x8_accumulators_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<16, 128, 64>,
|
||||
cutlass::gemm::LinearScaling<float>,
|
||||
cutlass::Shape<8, 8, 4> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 128, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -32,7 +32,7 @@
|
||||
|
||||
TEST(Sgemm_64x128x8, sgemm_64x128x64_4x8_accumulators_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<8, 128, 64>,
|
||||
cutlass::gemm::LinearScaling<float>,
|
||||
cutlass::Shape<8, 8, 4> >
|
||||
|
||||
277
tools/test/unit/gemm/sgemm_64x32x16.cu
Normal file
277
tools/test/unit/gemm/sgemm_64x32x16.cu
Normal file
@ -0,0 +1,277 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <cutlass_unit_test.h>
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
#include <cutlass/gemm/sgemm_traits.h>
|
||||
#include <tools/test/unit/gemm/gemm_testbed.h>
|
||||
#include <tools/test/unit/gemm/gemm.h>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x32x16, sgemm_64x32x1_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 32, 1);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x32x16, sgemm_64x32x16_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 32, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x32x16, sgemm_64x32x17_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 32, 17);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x32x16, sgemm_64x32x64_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 32, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x32x16, sgemm_128x32x16_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 32, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x32x16, sgemm_64x64x16_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x32x16, sgemm_128x64x16_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x32x16, sgemm_64x32x1_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 32, 1);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x32x16, sgemm_64x32x16_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 32, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x32x16, sgemm_64x32x17_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 32, 17);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x32x16, sgemm_64x32x64_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 32, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x32x16, sgemm_128x32x16_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 32, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x32x16, sgemm_64x64x16_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x32x16, sgemm_128x64x16_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x32x16, sgemm_64x32x16_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 32, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x32x16, sgemm_64x32x17_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 32, 17);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x32x16, sgemm_64x32x64_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 32, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x32x16, sgemm_128x32x16_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 32, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x32x16, sgemm_64x64x16_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x32x16, sgemm_128x64x16_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 32, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x32x16, sgemm_64x64x1_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 64, 64> > SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 64, 1);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x32x16, sgemm_64x32x16_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 32, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x32x16, sgemm_64x32x17_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 32, 17);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x32x16, sgemm_128x32x16_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 32, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x32x16, sgemm_64x64x16_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x32x16, sgemm_128x64x16_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
294
tools/test/unit/gemm/sgemm_64x64x16.cu
Normal file
294
tools/test/unit/gemm/sgemm_64x64x16.cu
Normal file
@ -0,0 +1,294 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <cutlass_unit_test.h>
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
#include <cutlass/gemm/sgemm_traits.h>
|
||||
#include <tools/test/unit/gemm/gemm_testbed.h>
|
||||
#include <tools/test/unit/gemm/gemm.h>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x64x16, sgemm_64x64x1_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 64, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 64, 1);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x64x16, sgemm_64x64x16_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 64, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x64x16, sgemm_64x64x17_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 64, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 64, 17);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x64x16, sgemm_64x64x64_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 64, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 64, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x64x16, sgemm_128x64x16_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 64, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x64x16, sgemm_64x128x16_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 64, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 128, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x64x16, sgemm_128x128x16_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 64, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 128, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x64x16, sgemm_64x64x1_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<8, 64, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 64, 1);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x64x16, sgemm_64x64x16_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<8, 64, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x64x16, sgemm_64x64x17_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<8, 64, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 64, 17);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x64x16, sgemm_64x64x64_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<8, 64, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 64, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x64x16, sgemm_128x64x16_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<8, 64, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x64x16, sgemm_64x128x16_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<8, 64, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 128, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x64x16, sgemm_128x128x16_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<8, 64, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 128, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x64x16, sgemm_64x64x1_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 64, 64> > SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 64, 1);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x64x16, sgemm_64x64x16_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<8, 64, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x64x16, sgemm_64x64x17_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<8, 64, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 64, 17);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x64x16, sgemm_64x64x64_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<8, 64, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 64, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x64x16, sgemm_128x64x16_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<8, 64, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x64x16, sgemm_64x128x16_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<8, 64, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 128, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x64x16, sgemm_128x128x16_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<8, 64, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 128, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x64x16, sgemm_64x64x1_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 64, 64> > SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 64, 1);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x64x16, sgemm_64x64x16_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 64, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x64x16, sgemm_64x64x17_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 64, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 64, 17);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x64x16, sgemm_64x64x64_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 64, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 64, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x64x16, sgemm_128x64x16_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 64, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 64, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x64x16, sgemm_64x128x16_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 64, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(64, 128, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_64x64x16, sgemm_128x128x16_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 64, 64> >
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 128, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -63,6 +63,7 @@ TEST(WmmaGemm_128x128x32, wmma_16x16x16_gemm_256x256x128_nt) {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 9100
|
||||
TEST(WmmaGemm_128x128x32, wmma_8x32x16_gemm_256x256x128_nt) {
|
||||
typedef cutlass::gemm::WmmaGemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
@ -76,9 +77,11 @@ TEST(WmmaGemm_128x128x32, wmma_8x32x16_gemm_256x256x128_nt) {
|
||||
WmmaGemmTraits;
|
||||
run_gemm<WmmaGemmTraits>(256, 256, 128);
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 9100
|
||||
TEST(WmmaGemm_128x128x32, wmma_32x8x16_gemm_256x256x128_nt) {
|
||||
typedef cutlass::gemm::WmmaGemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
@ -91,6 +94,7 @@ TEST(WmmaGemm_128x128x32, wmma_32x8x16_gemm_256x256x128_nt) {
|
||||
WmmaGemmTraits;
|
||||
run_gemm<WmmaGemmTraits>(256, 256, 128);
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -124,6 +128,7 @@ TEST(WmmaGemm_128x128x32, wmma_16x16x16_gemm_256x256x128_nn) {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 9100
|
||||
TEST(WmmaGemm_128x128x32, wmma_8x32x16_gemm_256x256x128_nn) {
|
||||
typedef cutlass::gemm::WmmaGemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
@ -137,9 +142,11 @@ TEST(WmmaGemm_128x128x32, wmma_8x32x16_gemm_256x256x128_nn) {
|
||||
WmmaGemmTraits;
|
||||
run_gemm<WmmaGemmTraits>(256, 256, 128);
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 9100
|
||||
TEST(WmmaGemm_128x128x32, wmma_32x8x16_gemm_256x256x128_nn) {
|
||||
typedef cutlass::gemm::WmmaGemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
@ -152,6 +159,7 @@ TEST(WmmaGemm_128x128x32, wmma_32x8x16_gemm_256x256x128_nn) {
|
||||
WmmaGemmTraits;
|
||||
run_gemm<WmmaGemmTraits>(256, 256, 128);
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -185,6 +193,7 @@ TEST(WmmaGemm_128x128x32, wmma_16x16x16_gemm_256x256x128_tt) {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 9100
|
||||
TEST(WmmaGemm_128x128x32, wmma_8x32x16_gemm_256x256x128_tt) {
|
||||
typedef cutlass::gemm::WmmaGemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
@ -198,9 +207,11 @@ TEST(WmmaGemm_128x128x32, wmma_8x32x16_gemm_256x256x128_tt) {
|
||||
WmmaGemmTraits;
|
||||
run_gemm<WmmaGemmTraits>(256, 256, 128);
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 9100
|
||||
TEST(WmmaGemm_128x128x32, wmma_32x8x16_gemm_256x256x128_tt) {
|
||||
typedef cutlass::gemm::WmmaGemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
@ -213,6 +224,7 @@ TEST(WmmaGemm_128x128x32, wmma_32x8x16_gemm_256x256x128_tt) {
|
||||
WmmaGemmTraits;
|
||||
run_gemm<WmmaGemmTraits>(256, 256, 128);
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -246,6 +258,7 @@ TEST(WmmaGemm_128x128x32, wmma_16x16x16_gemm_256x256x128_tn) {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 9100
|
||||
TEST(WmmaGemm_128x128x32, wmma_8x32x16_gemm_256x256x128_tn) {
|
||||
typedef cutlass::gemm::WmmaGemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
@ -259,9 +272,11 @@ TEST(WmmaGemm_128x128x32, wmma_8x32x16_gemm_256x256x128_tn) {
|
||||
WmmaGemmTraits;
|
||||
run_gemm<WmmaGemmTraits>(256, 256, 128);
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 9100
|
||||
TEST(WmmaGemm_128x128x32, wmma_32x8x16_gemm_256x256x128_tn) {
|
||||
typedef cutlass::gemm::WmmaGemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
@ -274,6 +289,7 @@ TEST(WmmaGemm_128x128x32, wmma_32x8x16_gemm_256x256x128_tn) {
|
||||
WmmaGemmTraits;
|
||||
run_gemm<WmmaGemmTraits>(256, 256, 128);
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -109,6 +109,9 @@ class HostTensor : public HostTensorView<T> {
|
||||
|
||||
host_.clear();
|
||||
host_.resize(_capacity);
|
||||
for (size_t i = 0; i < _capacity; ++i) {
|
||||
host_[i] = T((int)0xdeadbeef);
|
||||
}
|
||||
device_.reset(_device_memory, _capacity);
|
||||
|
||||
Base::reset(TensorRef_t(host_.data(), _stride), _size);
|
||||
|
||||
Reference in New Issue
Block a user