Replaced GoogleTest copy with submodule. Added updates to support intra-threadblock reductions. Added tests for same.

This commit is contained in:
akerr
2018-06-11 11:47:15 -07:00
parent 2c496c3e9e
commit 374882be53
40 changed files with 3279 additions and 336 deletions

View File

@ -33,7 +33,7 @@
#define CUTLASS_MAJOR 1
#define CUTLASS_MINOR 0
#define CUTLASS_PATCH 0
#define CUTLASS_PATCH 1
#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH)
#ifdef __NVCC__

View File

@ -184,7 +184,7 @@ struct FragmentIterator {
/// The shape of the the fragment.
typedef typename ShapeMul<Iterations, Shape<1, 1, 1, kElementsPerAccess> >::Shape FragmentShape;
/// The linear strides for iterations.
typedef typename ShapeStrides<FragmentShape>::Shape Strides;
typedef typename ShapeStrides<FragmentShape, kElementsPerAccess>::Shape Strides;
/// Ctor.
template <typename OtherFragment_>
@ -242,7 +242,7 @@ struct FragmentConstIterator {
/// The shape of the the fragment.
typedef typename ShapeMul<Iterations, Shape<1, 1, 1, kElementsPerAccess> >::Shape FragmentShape;
/// The linear strides for iterations.
typedef typename ShapeStrides<FragmentShape>::Shape IterationsStrides;
typedef typename ShapeStrides<FragmentShape, kElementsPerAccess>::Shape IterationsStrides;
/// Ctor.
template <typename OtherFragment_>

View File

@ -49,21 +49,29 @@ struct FragmentMultiplyAdd {
CUTLASS_DEVICE FragmentMultiplyAdd() {}
/// Multiply : d = a*b.
template <typename Fragment_>
CUTLASS_DEVICE void multiply(Scalar_ a, Fragment_ const& b, Fragment_& d) {
for (int j = 0; j < Fragment_::kElements; ++j) {
d[j] = a * b[j];
template <typename FragmentB_, typename FragmentCd_>
CUTLASS_DEVICE void multiply(Scalar_ a, FragmentB_ const& b, FragmentCd_& d) {
int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
for (int j = 0; j < FragmentCd_::kElements; ++j) {
d[j] = a * b[j * kReduction + 0];
for (int k = 1; k < kReduction; ++k) {
d[j] += a * b[j * kReduction + k];
}
}
}
/// Multiply : d = a*b + c.
template <typename Fragment_>
template <typename FragmentB_, typename FragmentCd_>
CUTLASS_DEVICE void multiply_add(Scalar_ a,
Fragment_ const& b,
Fragment_ const& c,
Fragment_& d) {
for (int j = 0; j < Fragment_::kElements; ++j) {
d[j] = a * b[j] + c[j];
FragmentB_ const& b,
FragmentCd_ const& c,
FragmentCd_& d) {
int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
for (int j = 0; j < FragmentCd_::kElements; ++j) {
d[j] = a * b[j * kReduction + 0] + c[j];
for (int k = 1; k < kReduction; ++k) {
d[j] += a * b[j * kReduction + k];
}
}
}
};
@ -74,7 +82,7 @@ struct FragmentMultiplyAdd {
template <>
struct FragmentMultiplyAdd<half> {
/// The shape of the instruction.
typedef Shape<1, 1, 1, 1> InstructionShape;
typedef Shape<1, 1, 2, 1> InstructionShape;
/// The type for A.
typedef half ScalarA;
/// The type for B.
@ -86,38 +94,48 @@ struct FragmentMultiplyAdd<half> {
CUTLASS_DEVICE FragmentMultiplyAdd() {}
/// Multiply : d = a*b.
template <typename Fragment_>
CUTLASS_DEVICE void multiply(half a, Fragment_ const& b, Fragment_& d) {
template <typename FragmentB_, typename FragmentCd_>
CUTLASS_DEVICE void multiply(half a, FragmentB_ const& b, FragmentCd_& d) {
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
// Assemble a half2 from a.
__half2 const a_half2 = __half2half2(a);
// The input.
__half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]);
// The output.
__half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
// Assemble a half2 from a.
__half2 const a_half2 = __half2half2(a);
for (int i = 0; i < Fragment_::kElements / 2; ++i) {
d_half2[i] = __hmul2(a_half2, b_half2[i]);
int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
for (int j = 0; j < FragmentCd_::kElements / 2; ++j) {
d_half2[j] = __hmul2(a_half2, b_half2[j * kReduction + 0]);
for (int k = 1; k < kReduction; ++k) {
d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + k], d_half2[j]);
}
}
#endif
}
/// Multiply : d = a*b + c.
template <typename Fragment_>
CUTLASS_DEVICE void multiply_add(half a, Fragment_ const& b, Fragment_ const& c, Fragment_& d) {
template <typename FragmentB_, typename FragmentCd_>
CUTLASS_DEVICE void multiply_add(half a,
FragmentB_ const& b,
FragmentCd_ const& c,
FragmentCd_& d) {
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
// Assemble a half2 from a.
__half2 const a_half2 = __half2half2(a);
// The inputs.
__half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]);
__half2 const* c_half2 = reinterpret_cast<__half2 const*>(&c[0]);
// The output.
__half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
// Assemble a half2 from a.
__half2 const a_half2 = __half2half2(a);
for (int i = 0; i < Fragment_::kElements / 2; ++i) {
d_half2[i] = __hfma2(a_half2, b_half2[i], c_half2[i]);
int const kReduction = (FragmentB_::kElements / FragmentCd_::kElements);
for (int j = 0; j < FragmentCd_::kElements / 2; ++j) {
d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + 0], c_half2[j]);
for (int k = 1; k < kReduction; ++k) {
d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + k], d_half2[j]);
}
}
#endif
}

View File

@ -39,6 +39,8 @@ struct ClearAccumulators {
/// The shared storage.
struct SharedStorage {};
/// Ctor.
CUTLASS_DEVICE ClearAccumulators() {}
/// Ctor.
CUTLASS_DEVICE ClearAccumulators(SharedStorage& shared_storage) {}

View File

@ -40,7 +40,7 @@ namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Gemm_>
__global__ void gemm_kernel(typename Gemm_::Params params) {
__global__ /*__launch_bounds__(Gemm_::kThreads)*/ void gemm_kernel(typename Gemm_::Params params) {
// Declare shared memory.
__shared__ typename Gemm_::SharedStorage shared_storage;
@ -193,6 +193,71 @@ struct Gemm {
CUTLASS_DEVICE Gemm(Params const& params_, SharedStorage& shared_storage_)
: params(params_), shared_storage(shared_storage_) {}
/// Consume a single iteration of the loop.
template <bool kIsLastIteration>
CUTLASS_DEVICE void consume_tile(typename Traits::GlobalLoadStream& global_stream,
typename Traits::SharedLoadStream& shared_load_stream,
typename Traits::MultiplyAdd::Accumulators& accumulators,
Index outer_k) {
// If that's the last "load iteration" update the predicates.
if (!kIsLastIteration) {
global_stream.move_to_residue<false>(outer_k);
}
// Load data for the next iteration of the main loop.
if (!kIsLastIteration) {
global_stream.copy();
}
// The unrolling steps for the main loop.
int const kUnrollingSteps =
Traits::MultiplyAdd::AccumulatorsPerWarp::kD / Traits::MultiplyAdd::InstructionShape::kD;
CUTLASS_PRAGMA_UNROLL
for (int step = 0; step < kUnrollingSteps - 1; ++step) {
// Trigger the copy from shared memory for the next A/B values.
shared_load_stream.copy(step + 1);
// Make sure the values are available for the current iteration to do the multiply-add.
shared_load_stream.commit(step);
// Do the math on the fragments of the current iteration.
typename Traits::MultiplyAdd multiply_add;
multiply_add.multiply_add(shared_load_stream.fragment_a(step),
shared_load_stream.fragment_b(step),
accumulators,
accumulators);
}
// Make sure the data from shared memory has been entirely consumed.
Traits::shared_load_fence(true);
// Commit the data in shared memory for A/B.
if (!kIsLastIteration) {
global_stream.commit();
}
// Make sure the data is in shared memory.
Traits::shared_store_fence(true);
// Trigger the loads for the next iteration (if needed).
if (!kIsLastIteration) {
// Move to the next stage for the load (if it makes sense).
shared_load_stream.inc_stage();
// Trigger the copy from shared memory for the next loop iteration.
shared_load_stream.copy(0);
}
// Make sure the values are available for the current iteration to do the multiply-add.
shared_load_stream.commit(kUnrollingSteps - 1);
// Do the math on the fragments of the current iteration.
typename Traits::MultiplyAdd multiply_add;
multiply_add.multiply_add(shared_load_stream.fragment_a(kUnrollingSteps - 1),
shared_load_stream.fragment_b(kUnrollingSteps - 1),
accumulators,
accumulators);
}
/// Do the GEMM.
CUTLASS_DEVICE void multiply_add() {
// Swizzle the IDs of the block (to enable better cache behavior).
@ -212,16 +277,11 @@ struct Gemm {
// Create the accumulator clear.
ClearAccumulators clear(shared_storage.main_loop.clear);
/// Define the mainloop iteration size
typedef typename Traits::MultiplyAdd MultiplyAdd;
// By how much we unroll the main loop.
Index const kUnroll = static_cast<Index>(MultiplyAdd::AccumulatorsPerWarp::kD);
Index const kUnroll = static_cast<Index>(Traits::OutputTile::kD);
// If we do not have enough steps in the main loop, trigger the residue code.
if (params.k < kUnroll) {
global_stream.residue(params.k, true);
}
global_stream.move_to_residue<true>(params.k);
// Fetch the fragments for A and B from global memory.
global_stream.copy();
@ -232,9 +292,12 @@ struct Gemm {
// Make sure the data is in shared memory.
Traits::shared_store_fence(false);
// Rollback to the beginning of the GEMM-K dimension. It may have no impact.
global_stream.rollback();
// The unrolling steps for the main loop.
int const kUnrollingSteps =
MultiplyAdd::AccumulatorsPerWarp::kD / MultiplyAdd::InstructionShape::kD;
Traits::MultiplyAdd::AccumulatorsPerWarp::kD / Traits::MultiplyAdd::InstructionShape::kD;
// Make sure we have at least 2 unrolling steps or our pipeling is not going to work.
static_assert(kUnrollingSteps >= 2, "The pipelining assumes at least two steps");
@ -246,59 +309,21 @@ struct Gemm {
shared_load_stream.copy(0);
// Allocate the accumulators.
typename MultiplyAdd::Accumulators accumulators;
typename Traits::MultiplyAdd::Accumulators accumulators;
// Clear the accumulators.
clear.clear(accumulators);
// The loop index.
Index outer_k = params.k - kUnroll;
// Enter the main loop and iterate.
typedef typename Traits::Index Index;
for (Index outer_k = params.k - kUnroll; outer_k > -kUnroll; outer_k -= kUnroll) {
// If that's the last "load iteration" update the predicates.
int const is_residue = outer_k <= kUnroll;
if (is_residue) {
global_stream.residue(outer_k);
}
for (; outer_k > 0; outer_k -= kUnroll) {
consume_tile<false>(global_stream, shared_load_stream, accumulators, outer_k);
}
// Load data for the next iteration of the main loop.
global_stream.copy();
CUTLASS_PRAGMA_UNROLL
for (int step = 0; step < kUnrollingSteps - 1; ++step) {
// Trigger the copy from shared memory for the next A/B values.
shared_load_stream.copy(step + 1);
// Make sure the values are available for the current iteration to do the multiply-add.
shared_load_stream.commit(step);
// Do the math on the fragments of the current iteration.
MultiplyAdd multiply_add;
multiply_add.multiply_add(shared_load_stream.fragment_a(step),
shared_load_stream.fragment_b(step),
accumulators,
accumulators);
}
// Make sure the data from shared memory has been entirely consumed.
Traits::shared_load_fence(true);
// Commit the data in shared memory for A/B.
global_stream.commit();
// Make sure the data is in shared memory.
Traits::shared_store_fence(true);
// Move to the next stage for the load (if it makes sense).
shared_load_stream.inc_stage();
// Trigger the copy from shared memory for the next loop iteration.
shared_load_stream.copy(0);
// Make sure the values are available for the current iteration to do the multiply-add.
shared_load_stream.commit(kUnrollingSteps - 1);
// Do the math on the fragments of the current iteration.
MultiplyAdd multiply_add;
multiply_add.multiply_add(shared_load_stream.fragment_a(kUnrollingSteps - 1),
shared_load_stream.fragment_b(kUnrollingSteps - 1),
accumulators,
accumulators);
// Residual loop.
for (; outer_k > -kUnroll; outer_k -= kUnroll) {
consume_tile<true>(global_stream, shared_load_stream, accumulators, outer_k);
}
// Epilogue.

View File

@ -117,6 +117,7 @@ struct GemmEpilogue {
CUTLASS_DEVICE void epilogue_with_or_without_beta(Coord<3> const& block,
Accumulators& accumulators) {
// The problem size.
Coord<3> const bounds = cutlass::make_Coord(0, n, m);
// The functor.
@ -153,6 +154,18 @@ struct GemmEpilogue {
GlobalStoreIteratorD global_store_iterator(
params.iterator_d, bounds, block, pointer_offset, predicate_offset);
// The transformer to transform before storing to shared memory.
SharedStoreTransformerD shared_store_transformer;
typename SharedStoreTransformerD::OutputFragment shared_store_transformed_d;
// The iterator to store to shared memory.
SharedStoreIteratorD shared_store_iterator(params.shared_store_iterator_d,
shared_storage.shared_stream.store);
// The iterator to load from shared memory. TODO: Use a stream.
SharedLoadIteratorD shared_load_iterator(params.shared_load_iterator_d,
shared_storage.shared_stream.load);
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < Iterations::kW; ++w) {
// Load the C matrix into fragment.
@ -166,20 +179,13 @@ struct GemmEpilogue {
// Copy the accumulators to shared memory.
int const offset = (h * Iterations::kW + w) * SharedStoreIteratorD::Fragment::kElements;
SharedStoreTransformerD shared_store_transformer;
typename SharedStoreTransformerD::OutputFragment shared_store_transformed_d;
shared_store_transformer.transform(accumulators, offset, shared_store_transformed_d);
SharedStoreIteratorD shared_store_iterator(params.shared_store_iterator_d,
shared_storage.shared_stream.store);
shared_iterator_store(shared_store_iterator, shared_store_transformed_d);
// Make sure the data is in shared memory.
shared_store_fence();
// Copy the accumulators back to registers from shared memory.
SharedLoadIteratorD shared_load_iterator(params.shared_load_iterator_d,
shared_storage.shared_stream.load);
typename SharedLoadIteratorD::Fragment fetched_d;
shared_iterator_load(shared_load_iterator, fetched_d);

View File

@ -84,8 +84,9 @@ struct GlobalLoadStreamBase {
typename StoreIterator::Params store_iterator;
/// Setup the params.
CUTLASS_HOST_DEVICE int initialize(Pointer pointer, Index ld) {
int error_code = load_iterator.initialize(pointer, ld);
template <typename GemmDesc_>
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc, Pointer pointer, Index ld) {
int error_code = load_iterator.initialize(desc, pointer, ld);
if (error_code) {
return error_code;
}
@ -128,6 +129,9 @@ struct GlobalLoadStreamBase {
store_iterator.inc_stage();
}
/// Move to the beginning of the residue code. That's a new code path in CUTLASS 1.0.1.
CUTLASS_DEVICE void move_to_residue(Index k) { load_iterator.move_to_residue(k); }
/// Execute the residue code.
CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) {
load_iterator.residue(k);
@ -136,6 +140,9 @@ struct GlobalLoadStreamBase {
}
}
/// Rollback to the beginning of the GEMM-k dimension.
CUTLASS_DEVICE void rollback() { load_iterator.rollback(); }
/// The iterator.
LoadIterator load_iterator;
/// The fragment to fetch from shared memory.

View File

@ -195,7 +195,8 @@ struct GemmGlobalIteratorAb
struct Params : public BaseParams {
/// Initializes params to load a strip-mined tile, given pointer and stride_h.
CUTLASS_HOST_DEVICE int initialize(Scalar const* ptr, Index stride_h) {
template <typename GemmDesc_>
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc, Scalar const* ptr, Index stride_h) {
Index inc_d = 0;
Index inc_advance = 0;
// Move by some columns for each iteration in the H dimension.
@ -220,16 +221,75 @@ struct GemmGlobalIteratorAb
(Base::Iterations::kH - 1) * inc_h;
}
Base::Params::initialize(ptr, 0, stride_h, 0, inc_d, inc_h, 0, inc_advance);
// The dimensions of the tile.
int const kH = TileTraits_::Tile::kH;
int const kW = TileTraits_::Tile::kW * TileTraits_::kAccessSize;
// Move to the residue.
Index const kBlock = kAdvance == IteratorAdvance::kH ? kH : kW;
// The jump in the gemm-k dimension.
Index const stride = kAdvance == IteratorAdvance::kH ? stride_h : 1;
// Compute the offset to the residue and how to "come" back.
Index const kResidue = desc.k % kBlock;
if (kResidue > 0) {
move_to_residue_offset = (desc.k - kResidue) * stride;
} else {
move_to_residue_offset = (desc.k - kBlock) * stride;
}
Base::Params::initialize(ptr, 0, stride_h, 1, inc_d, inc_h, 0, inc_advance);
return 0;
}
// The extra offset to control moving to the residue.
Index move_to_residue_offset;
};
/// Offset of an individual lane from the start of the tile
Coord<4> thread_offset;
/// The parameters
Params params;
/// Ctor.
CUTLASS_DEVICE GemmGlobalIteratorAb(Params const& _params,
const Coord<3>& bounds,
const Coord<3>& block,
ThreadOffset thread_offset_func = ThreadOffset())
: params(_params) {
thread_offset = thread_offset_func();
// The column.
Index block_h = thread_offset[1];
// The contiguous dimension.
Index block_w = thread_offset[2];
// Add the blocks indices.
if (kAdvance == IteratorAdvance::kH) {
block_h += block[1];
block_w += block[2];
} else {
block_h += block[2];
block_w += block[1];
}
// Setup the pointer.
params.pointer += (block_h * params.stride_h + block_w);
// Initialize predicates
initialize_predicates(bounds, make_Coord(0, block_h, block_w));
}
/// The accessor.
CUTLASS_DEVICE void get(typename Base::AccessType& value, int d, int h, int w, int c) const {
int const imm =
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(0, 0, w, c);
Load<Scalar, TileTraits_::kAccessSize, MemorySpace::kGlobal>::load(value, params.pointer, imm);
}
/// Increment the pointer in the H dimension.
CUTLASS_DEVICE void inc_h() { params.pointer += params.inc_h; }
/// Increment the pointer in the D dimension.
CUTLASS_DEVICE void inc_d() { params.pointer += params.inc_d; }
/// Increment the pointer to move to the next iteration.
CUTLASS_DEVICE void inc_advance() { params.pointer += params.inc_advance; }
/// Initialize the predicates.
CUTLASS_DEVICE void initialize_predicates(const Coord<3>& bounds, const Coord<3>& block) {
// Setup the masks to control loads.
predicates.fill(0);
@ -263,46 +323,29 @@ struct GemmGlobalIteratorAb
}
}
/// Ctor.
CUTLASS_DEVICE GemmGlobalIteratorAb(Params const& _params,
const Coord<3>& bounds,
const Coord<3>& block,
ThreadOffset thread_offset_func = ThreadOffset())
: params(_params) {
thread_offset = thread_offset_func();
// The column.
Index block_h = thread_offset[1];
// The contiguous dimension.
Index block_w = thread_offset[2];
/// Move to residue portion.
CUTLASS_DEVICE void move_to_residue(Index k) {
// Store the pointer and the predicates.
stored_pointer = params.pointer;
stored_predicates = predicates;
// Add the blocks indices.
if (kAdvance == IteratorAdvance::kH) {
block_h += block[1];
block_w += block[2];
// Move the pointer to the residue.
params.pointer += params.move_to_residue_offset;
} else {
block_h += block[2];
block_w += block[1];
// The dimensions of the tile.
int const kH = TileTraits_::Tile::kH;
int const kW = TileTraits_::Tile::kW * TileTraits_::kAccessSize;
// The unrolling factor.
int const kUnroll = kAdvance == IteratorAdvance::kH ? kH : kW;
// Clear the predicates for the residue. TODO: We can do something smarter.
int const kResidue = (int)(k % (Index)kUnroll);
if (kResidue > 0) {
residue(kResidue);
}
// Setup the pointer.
params.pointer += (block_h * params.stride_h + block_w);
// Initialize predicates
initialize_predicates(bounds, make_Coord(0, block_h, block_w));
}
/// Increment the pointer in the H dimension.
CUTLASS_DEVICE void inc_h() { params.pointer += params.inc_h; }
/// Increment the pointer in the D dimension.
CUTLASS_DEVICE void inc_d() { params.pointer += params.inc_d; }
/// Increment the pointer to move to the next iteration.
CUTLASS_DEVICE void inc_advance() { params.pointer += params.inc_advance; }
/// Returns the current pointer
CUTLASS_HOST_DEVICE
Scalar const* data() const { return params.pointer; }
/// That's the residue! Update the predicates.
CUTLASS_DEVICE void residue(Index k) {
// The coordinates of the thread.
@ -332,14 +375,26 @@ struct GemmGlobalIteratorAb
}
}
/// Rollback to beginning of first tile and initialize predicates.
CUTLASS_DEVICE void rollback() {
params.pointer = stored_pointer;
predicates = stored_predicates;
}
/// Is the iterator valid?
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const {
int const bit = ComputeOffsetFromShape<typename Base::Iterations>::get(d, h, w, c);
return predicates[bit];
}
/// Offset of an individual lane from the start of the tile
Coord<4> thread_offset;
/// The parameters
Params params;
/// The pointer.
typename Base::Scalar const* stored_pointer;
/// The predicates.
PredicateVector predicates;
PredicateVector predicates, stored_predicates;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -439,6 +494,13 @@ struct GemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
this->params.predicate_offset -= (h + pred_offset);
}
/// The accessor.
CUTLASS_DEVICE void get(typename Base::AccessType& value, int d, int h, int w, int c) const {
int const imm =
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(0, 0, w, c);
Load<Scalar, TileTraits_::kAccessSize, MemorySpace::kGlobal>::load(value, params.pointer, imm);
}
/// Increment the pointer in the C dimension.
CUTLASS_DEVICE void inc_c() {}
/// Increment the pointer in the W dimension.
@ -456,18 +518,19 @@ struct GemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
this->params.predicate_offset -= params.predicate_inc_advance;
}
/// The accessor.
CUTLASS_DEVICE void set(typename Base::AccessType const& value, int d, int h, int w, int c) {
int const imm =
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(0, 0, w, c);
Store<Scalar, TileTraits_::kAccessSize, MemorySpace::kGlobal>::store(
value, params.pointer, imm);
}
/// Test the validity of the iterator.
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const {
return predicates.at(w) && params.predicate_offset > 0;
}
/// Returns the raw pointer
CUTLASS_HOST_DEVICE
Pointer data() { return params.pointer; }
CUTLASS_HOST_DEVICE
Pointer const data() const { return params.pointer; }
/// The predicates for the row.
cutlass::PredicateVector<Base::Iterations::kW> predicates;
};

View File

@ -104,8 +104,7 @@ struct GemmSharedStoreWithSkewTileAbTraits {
typedef Shape<0, ShapeCount<Tile>::kWc, Threads::kH * kAccessSize> ImmediateOffsetStrides;
struct ThreadOffset {
CUTLASS_HOST_DEVICE
Coord<4> operator()() const {
CUTLASS_HOST_DEVICE Coord<4> operator()() const {
int offset = ComputeThreadOffsetFromStrides<Threads, ThreadsStrides>::get();
return make_Coord(0, 0, offset, 0);
}
@ -164,22 +163,25 @@ struct GemmSharedLoadTileATraits {
typedef Shape<1, 1, TileWithoutSkew::kW / kWarps / kThreadsPerWarp /* / kScalarsPerLds*/>
Iterations;
/// The strides in each dimension between different loads/stores.
typedef Shape<TileWithSkew::kW, 0, kWarps * kThreadsPerWarp * kAccessSize, 0> Delta;
/// The strides in each dimension between different loads/stores.
typedef Shape<TileWithSkew::kW, 0, kWarps * kThreadsPerWarp * kAccessSize, 0>
typedef Shape<TileWithSkew::kW * Warps::kD, 0, kWarps * kThreadsPerWarp * kAccessSize, 0>
ImmediateOffsetStrides;
typedef Shape<TileWithSkew::kW * Warps::kD, 0, kWarps * kThreadsPerWarp * kAccessSize, 0> Delta;
/// Computes the thread offset in (H, W) based on thread ID
struct ThreadOffset {
CUTLASS_HOST_DEVICE
Coord<4> operator()() const {
CUTLASS_HOST_DEVICE Coord<4> operator()() const {
// Extract the warp.
int const warp = threadIdx.x / kWarpSize % Warps::kW;
// Compute the row offset for each thread
int const lane = (threadIdx.x & 0x0e) / 2;
int const warp = threadIdx.x / kWarpSize;
// Extract the slice.
int const slice = warp / (Warps::kH * Warps::kW);
// Compute the row offset for each warp.
int const warp_row = warp % Warps::kW;
// Compute the row offset for each thread.
int const lane_row = (threadIdx.x & 0x0e) / 2;
// The offset.
int const offset = (warp * ThreadsPerWarp::kW + lane) * kAccessSize;
int const offset =
slice * Tile::kW * Tile::kC + (warp_row * ThreadsPerWarp::kW + lane_row) * kAccessSize;
// Embed the offset in a 4D coordinate vector.
return make_Coord(0, 0, offset, 0);
}
};
@ -231,23 +233,27 @@ struct GemmSharedLoadTileBTraits {
/// The number of iterations needed to load/store the tile.
typedef Shape<1, 1, TileWithoutSkew::kW / kWarps / kThreadsPerWarp /* / kAccessSize*/> Iterations;
/// The strides in each dimension between different loads/stores.
typedef Shape<TileWithSkew::kW, 0, kWarps * kThreadsPerWarp * kAccessSize, 0> Delta;
/// The strides in each dimension between different loads/stores.
typedef Shape<TileWithSkew::kW, 0, kWarps * kThreadsPerWarp * kAccessSize, 0>
typedef Shape<TileWithSkew::kW * Warps::kD, 0, kWarps * kThreadsPerWarp * kAccessSize, 0>
ImmediateOffsetStrides;
typedef Shape<TileWithSkew::kW * Warps::kD, 0, kWarps * kThreadsPerWarp * kAccessSize, 0> Delta;
/// Computes the thread offset in (H, W) based on thread ID
struct ThreadOffset {
CUTLASS_HOST_DEVICE
Coord<4> operator()() const {
// The position of the warp.
int const warp = threadIdx.x / (Warps::kW * kWarpSize);
// Compute the column offset for each thread
int const lane = (threadIdx.x & 0x10) / 8 + (threadIdx.x & 0x01);
CUTLASS_HOST_DEVICE Coord<4> operator()() const {
// Extract the warp.
int const warp = threadIdx.x / kWarpSize;
// Extract the slice.
int const slice = warp / (Warps::kH * Warps::kW);
// The warp in the slice.
int const warp_in_slice = warp % (Warps::kH * Warps::kW);
// Compute the row offset for each warp.
int const warp_col = warp_in_slice / Warps::kW;
// Compute the row offset for each thread.
int const lane_col = (threadIdx.x & 0x10) / 8 + (threadIdx.x & 0x01);
// The offset.
int const offset = (warp * ThreadsPerWarp::kH + lane) * kAccessSize;
int const offset =
slice * Tile::kW * Tile::kC + (warp_col * ThreadsPerWarp::kH + lane_col) * kAccessSize;
// Embed the offset in a 4D coordinate.
return make_Coord(0, 0, offset, 0);
}
};
@ -297,28 +303,26 @@ struct GemmSharedStoreTileDTraits {
/// Computes the thread offset in (H, W) based on thread ID
struct ThreadOffset {
CUTLASS_HOST_DEVICE
Coord<4> operator()() const {
// We issue STS.128 in the epilogue to store the accumulators to shared memory. When we use
// STS.128, we have to guarantee that threads in groups of 8 do not have bank conflicts (i.e
// they write to different banks).
CUTLASS_HOST_DEVICE Coord<4> operator()() const {
// The warp.
int const warp = threadIdx.x / kWarpSize;
// The position of the warp in the 2D tile.
int const warp_row = warp % Warps::kW;
int const warp_col = warp / Warps::kW;
// We assume that the elements are distributed in a warps as 4 columns of 8 elements. The
// columns are stored in threads col0=[0, 2, 4, 6, 8, 10, 12, 14], col1=[1, 3, 5, 7, .., 15],
// col2=[16, 18, 20, ..., 30] and col3=[17, 19, ..., 31].
int hi_halfwarp_offset = ((threadIdx.x >> 4) & 0x1) * OutputTile::kW;
int lo_halfwarp_offset = ((threadIdx.x >> 1) & 0x7) + ThreadsPerWarp::kW * warp_row;
// Odd threads go to the second half of shared memory.
int const row = threadIdx.x & 0x01;
int const warp_id = (threadIdx.x >> 5);
int const warp_row = (warp_id % Warps::kW);
int const warp_col = (warp_id / Warps::kW);
int hi_halfwarp_offset = OutputTile::kW * ((threadIdx.x >> 4) & 1);
int lo_halfwarp_offset = (((threadIdx.x >> 1) & 0x7) + warp_row * ThreadsPerWarp::kW);
int col = kAccessSize * lo_halfwarp_offset +
warp_col * (ThreadsPerWarp::kH / 2) * OutputTile::kW + hi_halfwarp_offset;
int offset = row * kScalarsPerRow + col;
return make_Coord(0, 0, offset, 0);
int col = warp_col * (ThreadsPerWarp::kH / 2) * OutputTile::kW +
lo_halfwarp_offset * kAccessSize + hi_halfwarp_offset;
// Embed the offset in a 4D coords.
return make_Coord(0, 0, row * kScalarsPerRow + col, 0);
}
};
};
@ -357,32 +361,39 @@ struct GemmSharedLoadTileDTraits {
/// The number of scalars per row. We build a tile with 2 rows (to avoid bank conflicts).
static int const kScalarsPerRow = kThreads / 2 * kScalarsPerThread + kSkew;
/// The tile.
/// The tile. We have 2 rows of scalars. We use those two rows to make sure we do not have bank
/// conflicts in the epilogue.
typedef Shape<1, 2, kScalarsPerRow / kAccessSize, kAccessSize> Tile;
// Compute the number of iterations per warp in the Tile::kH dimension.
static int const kIterationsInHPerWarp = kTileH_ / ShapeCount<Warps>::kCount;
// As shown above, the shared memory tile is composed of 2 rows and each rows is made of
// As explained above, the shared memory tile is composed of 2 rows and each rows is made of
// kScalarsPerRow. A warp is expected to read from the 1st row, then move to the 2nd row and go
// back to the 1st row. To model that scheme we define the Iterations shape as Shape<X, 2, ...>.
// However, in some cases, we have only 1 iteration per warp. In that case, we must define the
// shape as Shape<1, 1, ...>. The following code does that.
// shape as Shape<1, 1, ...>. The following code does that except that we hijack the kH dimension
// to keep the number of elements to reduce for split-K.
static int const kIterationsH = kIterationsInHPerWarp == 1 ? 1 : 2;
// As soon as we know kIterationsH, it is trivial to compute kIterationsD:
static int const kIterationsD = kIterationsInHPerWarp / kIterationsH;
// If we have split-K enabled, we have to jump over the elements from the "odd/even" column of
// threads to grab the other elements.
static int const kSplitK = OutputTile::kW * ThreadsPerWarp::kH / 2 * Warps::kH;
/// The number of iterations needed to store the tile.
typedef Shape<kIterationsD, kIterationsH, OutputTile::kW / kWarpSize / kAccessSize> Iterations;
typedef Shape<kIterationsD, kIterationsH, OutputTile::kW / kWarpSize / kAccessSize, Warps::kD>
Iterations;
/// The strides in each dimension between different loads/stores.
typedef Shape<OutputTile::kW, kScalarsPerRow, kWarpSize * kAccessSize> Delta;
typedef Shape<OutputTile::kW, kScalarsPerRow, kWarpSize * kAccessSize, kSplitK>
ImmediateOffsetStrides;
/// The strides in each dimension between different loads/stores.
typedef Shape<OutputTile::kW, kScalarsPerRow, kWarpSize * kAccessSize> ImmediateOffsetStrides;
typedef Shape<OutputTile::kW, kScalarsPerRow, kWarpSize * kAccessSize, kSplitK> Delta;
/// Computes the thread offset in (H, W) based on thread ID
struct ThreadOffset {
CUTLASS_HOST_DEVICE
Coord<4> operator()() const {
CUTLASS_HOST_DEVICE Coord<4> operator()() const {
// Each warp works on a different column.
int const h = threadIdx.x / kWarpSize;
// Compute the row.

View File

@ -74,7 +74,9 @@ template <
/// The number of scalars per LDS for D.
int kScalarsPerLdsD_,
/// The number of stages in shared memory to do single/double/triple-buffering.
int kStages_>
int kStages_,
/// Do we do the residue in the prologue?
bool kResidueInPrologue_ = false>
struct GemmConfig {
//
@ -129,6 +131,9 @@ struct GemmConfig {
/// The number of stages in shared memory to implement double, triple, more-buffering.
static int const kStages = kStages_;
/// Do we do the residue in the prologue?
static bool const kResidueInPrologue = kResidueInPrologue_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -229,8 +234,12 @@ struct GemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_> {
/// The number of scalars in 4B.
static int const kScalarsIn4B = sizeof(MultiplyAddScalar) > 4 ? 1 : 4 / sizeof(MultiplyAddScalar);
/// The skew for A.
static int const kSkewA = 128 / sizeof(MultiplyAddScalar) / GemmConfig_::kScalarsPerStsA /
GlobalTileTraits::Threads::kW * kScalarsIn4B;
/// The traits class to build the iterator to store data to shared memory for A^T.
typedef GemmSharedStoreWithSkewTileAbTraits<
typedef GemmSharedStoreWithSkewTileAbTraits <
// The pointer is float.
MultiplyAddScalar,
// The tile has size KxM in GEMM's terminology.
@ -242,9 +251,8 @@ struct GemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_> {
// The number of scalars per STS.
GemmConfig_::kScalarsPerStsA,
// The skew to avoid bank conflicts added in the tile W dimension.
128 / sizeof(MultiplyAddScalar) / GemmConfig_::kScalarsPerStsA /
GlobalTileTraits::Threads::kW * kScalarsIn4B>
SharedStoreTileTraits;
kSkewA<GemmConfig_::kScalarsPerLdsA ? GemmConfig_::kScalarsPerLdsA : kSkewA>
SharedStoreTileTraits;
/// The traits class to build the iterator to load from shared memory for A^T.
typedef GemmSharedLoadTileATraits<
@ -302,8 +310,12 @@ struct GemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_> {
/// The number of scalars in 4B.
static int const kScalarsIn4B = sizeof(MultiplyAddScalar) > 4 ? 1 : 4 / sizeof(MultiplyAddScalar);
/// The skew for B.
static int const kSkewB = 128 / sizeof(MultiplyAddScalar) / GemmConfig_::kScalarsPerStsB /
GlobalTileTraits::Threads::kW * kScalarsIn4B;
/// The traits class to build the iterator to store data to shared memory for B^N.
typedef GemmSharedStoreWithSkewTileAbTraits<
typedef GemmSharedStoreWithSkewTileAbTraits <
// The pointer is float.
MultiplyAddScalar,
// The tile has size KxN in GEMM's terminology.
@ -315,9 +327,8 @@ struct GemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_> {
// The number of scalars per STS.
GemmConfig_::kScalarsPerStsB,
// The skew to avoid bank conflicts added in the tile W dimension.
128 / sizeof(MultiplyAddScalar) / GemmConfig_::kScalarsPerStsB /
GlobalTileTraits::Threads::kW * kScalarsIn4B>
SharedStoreTileTraits;
kSkewB<GemmConfig_::kScalarsPerLdsB ? GemmConfig_::kScalarsPerLdsB : kSkewB>
SharedStoreTileTraits;
/// The traits class to build the iterator to load from shared memory for B^N.
typedef GemmSharedLoadTileBTraits<
@ -405,6 +416,60 @@ struct GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemmTraits_, bool kResidueInPrologue_ = GemmTraits_::kResidueInPrologue>
struct GemmResidue {
/// Move to residue portion.
template <bool kIsPrologue>
static CUTLASS_DEVICE void move_to_residue(typename GemmTraits_::GlobalLoadStreamA& stream_a,
typename GemmTraits_::GlobalLoadStreamB& stream_b,
typename GemmTraits_::Index k) {
// The new code path in CUTLASS 1.0.1: We treat the residue in the prologue so we can have
// complete main loops after that. It helps simplify the logic in the main loop.
if (kIsPrologue) {
stream_a.move_to_residue(k);
stream_b.move_to_residue(k);
}
}
/// Rollback to beginning of first tile and initialize predicates.
static CUTLASS_DEVICE void rollback(typename GemmTraits_::GlobalLoadStreamA& stream_a,
typename GemmTraits_::GlobalLoadStreamB& stream_b) {
stream_a.rollback();
stream_b.rollback();
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemmTraits_>
struct GemmResidue<GemmTraits_, false> {
/// Move to residue portion.
template <bool kIsPrologue>
static CUTLASS_DEVICE void move_to_residue(typename GemmTraits_::GlobalLoadStreamA& stream_a,
typename GemmTraits_::GlobalLoadStreamB& stream_b,
typename GemmTraits_::Index k) {
// The index.
typedef typename GemmTraits_::Index Index;
// By how much we unroll the main loop.
Index const kUnroll = static_cast<Index>(GemmTraits_::OutputTile::kD);
// Call the residue code. That's the same path as CUTLASS 1.0.0.
if (kIsPrologue && k < kUnroll) {
stream_a.residue(k, true);
stream_b.residue(k, true);
} else if (k <= kUnroll) {
stream_a.residue(k, false);
stream_b.residue(k, false);
}
}
/// Rollback to beginning of first tile and initialize predicates.
static CUTLASS_DEVICE void rollback(typename GemmTraits_::GlobalLoadStreamA& stream_a,
typename GemmTraits_::GlobalLoadStreamB& stream_b) {}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// The GEMM configuration.
typename GemmConfig_,
@ -426,10 +491,24 @@ template <
typename ClearAccumulators_ = ClearAccumulators<typename GemmConfig_::Accumulators::Scalar> >
struct GemmTraits {
/// This class.
typedef GemmTraits<GemmConfig_,
GlobalLoadStreamA_,
GlobalLoadStreamB_,
SharedLoadStreamA_,
SharedLoadStreamB_,
Epilogue_,
BlockSwizzle_,
Index_,
ClearAccumulators_>
This_;
/// The configuration.
typedef GemmConfig_ GemmConfig;
/// The output tile.
typedef typename GemmConfig::OutputTile OutputTile;
/// Is the residue treated in the prologue?
static bool const kResidueInPrologue = GemmConfig::kResidueInPrologue;
/// The stream to load A from global memory to shared memory.
typedef GlobalLoadStreamA_ GlobalLoadStreamA;
@ -450,18 +529,6 @@ struct GemmTraits {
/// The iterator for B to load from shared memory.
typedef SharedLoadStreamB_ SharedLoadStreamB;
/// The shared storage for A.
typedef typename GlobalLoadStreamA::SharedStoreStorage SharedStoreStorageA;
// Btw, make sure we did not messed up with the size of the storage.
static_assert(sizeof(SharedStoreStorageA) == sizeof(typename SharedLoadStreamA::SharedStorage),
"");
/// The shared storage for B.
typedef typename GlobalLoadStreamB::SharedStoreStorage SharedStoreStorageB;
// Btw, make sure we did not messed up with the size of the storage.
static_assert(sizeof(SharedStoreStorageB) == sizeof(typename SharedLoadStreamB::SharedStorage),
"");
/// The multiply-add functor.
typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
/// The epilogue.
@ -502,14 +569,15 @@ struct GemmTraits {
// Initialize the iterator for A.
int error_code =
global_stream_a.initialize(reinterpret_cast<ScalarA const*>(desc.d_a), desc.lda);
global_stream_a.initialize(desc, reinterpret_cast<ScalarA const*>(desc.d_a), desc.lda);
if (error_code) {
return error_code;
}
// Initialize the iterator for B.
error_code = global_stream_b.initialize(reinterpret_cast<ScalarB const*>(desc.d_b), desc.ldb);
error_code =
global_stream_b.initialize(desc, reinterpret_cast<ScalarB const*>(desc.d_b), desc.ldb);
if (error_code) {
return error_code;
@ -574,12 +642,15 @@ struct GemmTraits {
stream_b.commit();
}
/// Execute the residue code.
CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) {
stream_a.residue(k, skip_clear);
stream_b.residue(k, skip_clear);
/// Move to residue portion.
template <bool kIsPrologue>
CUTLASS_DEVICE void move_to_residue(Index k) {
GemmResidue<This_>::move_to_residue<kIsPrologue>(stream_a, stream_b, k);
}
/// Rollback to beginning of first tile and initialize predicates.
CUTLASS_DEVICE void rollback() { GemmResidue<This_>::rollback(stream_a, stream_b); }
/// The stream for A.
GlobalLoadStreamA stream_a;
/// The stream for B.

View File

@ -147,8 +147,11 @@ struct HgemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_>
GemmConfig_::kScalarsPerLdgA>
GlobalTileTraits;
/// The skew.
static int const kSkewA = 128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2;
/// The traits class to build the iterator to store data to shared memory for A^T.
typedef GemmSharedStoreWithSkewTileAbTraits<
typedef GemmSharedStoreWithSkewTileAbTraits <
// The pointer.
half,
// The tile has size KxM in GEMM's terminology.
@ -160,8 +163,8 @@ struct HgemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_>
// The number of scalars per STS (STS.32 or STS.128, etc).
2,
// The skew to avoid bank conflicts added in the tile W dimension.
128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2>
SharedStoreTileTraits;
kSkewA<GemmConfig_::kScalarsPerLdsA ? GemmConfig_::kScalarsPerLdsA : kSkewA>
SharedStoreTileTraits;
/// The traits class to build the iterator to load from shared memory for A^T.
typedef GemmSharedLoadTileATraits<
@ -212,8 +215,11 @@ struct HgemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_>
GemmConfig_::kScalarsPerLdgB>
GlobalTileTraits;
/// The skew for B.
static int const kSkewB = 128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2;
/// The traits class to build the iterator to store data to shared memory for B^N.
typedef GemmSharedStoreWithSkewTileAbTraits<
typedef GemmSharedStoreWithSkewTileAbTraits <
// The pointer.
half,
// The tile has size KxN in GEMM's terminology.
@ -225,8 +231,8 @@ struct HgemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_>
// The number of scalars per STS (STS.32 or STS.128, etc).
2,
// The skew to avoid bank conflicts added in the tile W dimension.
128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2>
SharedStoreTileTraits;
kSkewB<GemmConfig_::kScalarsPerLdsB ? GemmConfig_::kScalarsPerLdsB : kSkewB>
SharedStoreTileTraits;
/// The traits class to build the iterator to load from shared memory for B^N.
typedef GemmSharedLoadTileBTraits<
@ -261,7 +267,7 @@ template <
/// The functor to do the math in the epilogue.
typename EpilogueFunctor_,
/// The number of accumulators per thread.
typename AccumulatorsPerThread_ = Shape<32, 8, 8>,
typename AccumulatorsPerThread_ = Shape<8, 8, 16>,
/// The number of halfs loaded in one LDG for A.
int kScalarsPerLdgA_ = 2,
/// The number of halfs loaded in one LDG for B.

View File

@ -47,19 +47,19 @@ template <GemmOperand::Kind kOperand_,
typename Tile_,
typename Threads_,
int kAccessSize_>
struct IgemmContiguousGlobalTileTraits : public GemmGlobalTileTraits<
// Which GEMM operand?
kOperand_,
// The layout.
kLayout_,
// The scalar.
Scalar_,
// The tile.
Tile_,
// The threads.
Threads_,
// The number of scalars per LDG/STG.
kAccessSize_> {
struct IgemmGlobalTileTraits : public GemmGlobalTileTraits<
// Which GEMM operand?
kOperand_,
// The layout.
kLayout_,
// The scalar.
Scalar_,
// The tile.
Tile_,
// The threads.
Threads_,
// The number of scalars per LDG/STG.
kAccessSize_> {
/// The base class.
typedef GemmGlobalTileTraits<kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_> Base;
/// The threads.
@ -91,5 +91,71 @@ struct IgemmContiguousGlobalTileTraits : public GemmGlobalTileTraits<
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Deprecated. Please use IgemmGlobalTileTraits instead.
template <GemmOperand::Kind kOperand_,
MatrixLayout::Kind kLayout_,
typename Scalar_,
typename Tile_,
typename Threads_,
int kAccessSize_>
struct IgemmContiguousGlobalTileTraits
: public IgemmGlobalTileTraits<kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_> {};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename TileTraits_, typename Index_ = int>
struct IgemmGlobalIteratorAb : public GemmGlobalIteratorAb<TileTraits_, Index_> {
/// The base class.
typedef GemmGlobalIteratorAb<TileTraits_, Index_> Base;
/// The functor to compute the thread offset.
typedef typename TileTraits_::ThreadOffset ThreadOffset;
/// Constructor.
CUTLASS_DEVICE IgemmGlobalIteratorAb(typename Base::Params const& _params,
const Coord<3>& bounds,
const Coord<3>& block,
ThreadOffset thread_offset_func = ThreadOffset())
: Base(_params, bounds, block, thread_offset_func), in_residue_(false), mask_(0xffffffff) {
// The number of elements read in a single iteration.
int const kBlock = TileTraits_::Tile::kW * TileTraits_::kAccessSize;
// The residue.
int const kResidue = (int)(bounds[1] % kBlock);
// Compute the number of elements that are valid.
int const left = kResidue - Base::thread_offset[2];
if (left > 0 && left < 4) {
mask_ = (1u << (8 * left)) - 1u;
}
}
/// The accessor.
CUTLASS_DEVICE void get(typename Base::AccessType& value, int d, int h, int w, int c) const {
Base::get(value, d, h, w, c);
if (in_residue_) {
reinterpret_cast<uint32_t&>(value) &= mask_;
}
}
/// Move to residue portion.
CUTLASS_DEVICE void move_to_residue(typename Base::Index k) {
Base::move_to_residue(k);
in_residue_ = true;
}
/// Move back to the beginning of the first tile.
CUTLASS_DEVICE void rollback() {
Base::rollback();
in_residue_ = false;
}
/// Are we in the residue?
bool in_residue_;
/// The mask to clean up the values.
uint32_t mask_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -87,7 +87,9 @@ struct IgemmConfig
/// The number of scalars per LDS for D.
1,
/// The number of stages in shared memory.
2> {};
2,
/// Enable the code path that deals with the residue in epilogue.
true> {};
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -125,17 +127,19 @@ struct IgemmConfig<OutputTile_, int8_t, AccumulatorsPerThread_>
/// The number of scalars per LDS for D.
4,
/// The number of stages in shared memory.
2> {};
2,
/// Enable the code path that deals with the residue in epilogue.
true> {};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_, typename Index_>
struct IgemmTileTraitsHelperA : public GemmTileTraitsHelperA<kLayout_, GemmConfig_> {};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemmConfig_>
struct IgemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_>
template <typename GemmConfig_, typename Index_>
struct IgemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_, Index_>
: public GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> {
/// The base config.
typedef GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> Base;
@ -144,7 +148,7 @@ struct IgemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_>
static int const kScalarsPerStsA = 16;
/// The traits class to build the iterator to load data from global memory for A^N.
typedef IgemmContiguousGlobalTileTraits<
typedef IgemmGlobalTileTraits<
GemmOperand::kA,
// The layout.
MatrixLayout::kColumnMajor,
@ -155,9 +159,12 @@ struct IgemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_>
// The threads are distributed as warps x 32 (the traits may reorganize).
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
4>
GemmConfig_::kScalarsPerLdgA>
GlobalTileTraits;
// The iterator.
typedef GemmGlobalIteratorAb<GlobalTileTraits, Index_> GlobalLoadIterator;
/// The traits class to build the iterator to store data to shared memory for A^N.
typedef GemmSharedStoreTileAbTraits<
// The pointer is float.
@ -173,13 +180,149 @@ struct IgemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_>
////////////////////////////////////////////////////////////////////////////////////////////////////
template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
template <typename GemmConfig_, typename Index_>
struct IgemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_, Index_> {
/// The layout.
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
/// The input scalar.
typedef int8_t Scalar;
/// The scalar stored in shared memory.
typedef int8_t MultiplyAddScalar;
/// The number of scalars per LDG/STS/LDS for A.
static int const kScalarsPerStsA = 16;
/// The traits class to build the iterator to load data from global memory for A^T.
typedef IgemmGlobalTileTraits<
GemmOperand::kA,
// The layout.
MatrixLayout::kRowMajor,
// The pointer is float const.
int8_t const,
// The tile has size NxK in GEMM's terminology.
Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD>,
// The threads are distributed as warps x 32 (the traits may reorganize).
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
GemmConfig_::kScalarsPerLdgA>
GlobalTileTraits;
// The iterator.
typedef IgemmGlobalIteratorAb<GlobalTileTraits, Index_> GlobalLoadIterator;
/// The traits class to build the iterator to store data to shared memory for A^N.
typedef GemmSharedStoreWithSkewTileAbTraits<
// The pointer is int8.
int8_t,
// The tile has size KxN in GEMM's terminology.
Shape<GemmConfig_::kStages, GemmConfig_::OutputTile::kD / 4, GemmConfig_::OutputTile::kW * 4>,
// The threads are distributed as (threads / K) x K (the traits may reorganize).
typename GlobalTileTraits::Threads,
// The number of scalars per STS.
kScalarsPerStsA,
// The skew to avoid bank conflicts added in the tile W dimension.
16>
SharedStoreTileTraits;
/// The traits class to build the iterator to load from shared memory for A^N.
typedef GemmSharedLoadTileATraits<
// The pointer is float const.
int8_t const,
// The output tile size.
typename GemmConfig_::OutputTile,
// The number of warps.
typename GemmConfig_::Warps,
// The number of threads per warp.
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
// The shape of the FMA instruction.
typename GemmConfig_::InstructionShape,
// The number of stages.
GemmConfig_::kStages,
// The number of scalars per LDS.
16,
// The skew.
SharedStoreTileTraits::kSkew>
SharedLoadTileTraits;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_, typename Index_>
struct IgemmTileTraitsHelperB : public GemmTileTraitsHelperB<kLayout_, GemmConfig_> {};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemmConfig_>
struct IgemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_>
template <typename GemmConfig_, typename Index_>
struct IgemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_, Index_> {
/// The layout.
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
/// The input scalar.
typedef int8_t Scalar;
/// The scalar stored in shared memory.
typedef int8_t MultiplyAddScalar;
/// The number of scalars per LDG/STS/LDS for B.
static int const kScalarsPerStsB = 16;
/// The traits class to build the iterator to load data from global memory for B^T.
typedef IgemmGlobalTileTraits<
GemmOperand::kB,
// The layout.
MatrixLayout::kColumnMajor,
// The pointer is float const.
int8_t const,
// The tile has size NxK in GEMM's terminology.
Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD>,
// The threads are distributed as warps x 32 (the traits may reorganize).
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
GemmConfig_::kScalarsPerLdgB>
GlobalTileTraits;
// The iterator.
typedef IgemmGlobalIteratorAb<GlobalTileTraits, Index_> GlobalLoadIterator;
/// The traits class to build the iterator to store data to shared memory for B^N.
typedef GemmSharedStoreWithSkewTileAbTraits<
// The pointer is int8.
int8_t,
// The tile has size KxN in GEMM's terminology.
Shape<GemmConfig_::kStages, GemmConfig_::OutputTile::kD / 4, GemmConfig_::OutputTile::kH * 4>,
// The threads are distributed as (threads / K) x K (the traits may reorganize).
typename GlobalTileTraits::Threads,
// The number of scalars per STS.
kScalarsPerStsB,
// The skew to avoid bank conflicts added in the tile W dimension.
16>
SharedStoreTileTraits;
/// The traits class to build the iterator to load from shared memory for B^N.
typedef GemmSharedLoadTileBTraits<
// The pointer is float const.
int8_t const,
// The output tile size.
typename GemmConfig_::OutputTile,
// The number of warps.
typename GemmConfig_::Warps,
// The number of threads per warp.
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
// The shape of the FMA instruction.
typename GemmConfig_::InstructionShape,
// The number of stages.
GemmConfig_::kStages,
// The number of scalars per LDS.
16,
// The skew.
SharedStoreTileTraits::kSkew>
SharedLoadTileTraits;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemmConfig_, typename Index_>
struct IgemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_, Index_>
: public GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> {
/// The base config.
typedef GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> Base;
@ -188,7 +331,7 @@ struct IgemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_>
static int const kScalarsPerStsB = 16;
/// The traits class to build the iterator to load data from global memory for B^T.
typedef IgemmContiguousGlobalTileTraits<
typedef IgemmGlobalTileTraits<
GemmOperand::kB,
// The layout.
MatrixLayout::kRowMajor,
@ -199,9 +342,12 @@ struct IgemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_>
// The threads are distributed as warps x 32 (the traits may reorganize).
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
4>
GemmConfig_::kScalarsPerLdgB>
GlobalTileTraits;
// The iterator.
typedef GemmGlobalIteratorAb<GlobalTileTraits, Index_> GlobalLoadIterator;
/// The traits class to build the iterator to store data to shared memory for B^N.
typedef GemmSharedStoreTileAbTraits<
// The pointer is float.
@ -266,13 +412,13 @@ struct IgemmTraitsHelper {
/// The IGEMM config.
typedef IgemmConfig<OutputTile_, ScalarD_, AccumulatorsPerThread_> GemmConfig;
/// The GEMM config for A.
typedef IgemmTileTraitsHelperA<kLayoutA_, GemmConfig> GemmTileTraitsHelperA;
typedef IgemmTileTraitsHelperA<kLayoutA_, GemmConfig, Index_> GemmTileTraitsHelperA;
/// The GEMM config for B.
typedef IgemmTileTraitsHelperB<kLayoutB_, GemmConfig> GemmTileTraitsHelperB;
typedef IgemmTileTraitsHelperB<kLayoutB_, GemmConfig, Index_> GemmTileTraitsHelperB;
/// The iterator to load A from global memory.
typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperA::GlobalTileTraits, Index_>
GlobalLoadIteratorA;
typedef typename GemmTileTraitsHelperA::GlobalLoadIterator GlobalLoadIteratorA;
/// The default transformer for A.
typedef typename IgemmTransformerA<GemmTileTraitsHelperA::kLayout,
GlobalLoadIteratorA>::Transformer GlobalTransformerA;
@ -287,8 +433,8 @@ struct IgemmTraitsHelper {
GlobalLoadStreamA;
/// The iterator to load B from global memory.
typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperB::GlobalTileTraits, Index_>
GlobalLoadIteratorB;
typedef typename GemmTileTraitsHelperB::GlobalLoadIterator GlobalLoadIteratorB;
// The default transformer for B.
typedef typename IgemmTransformerB<GemmTileTraitsHelperB::kLayout,
GlobalLoadIteratorB>::Transformer GlobalTransformerB;

View File

@ -1,4 +1,3 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
@ -61,17 +60,17 @@ struct LinearScaling {
CUTLASS_DEVICE LinearScaling(Params const& params) : alpha(params.alpha), beta(params.beta) {}
/// Evaluate the functor.
template <typename Fragment_>
CUTLASS_DEVICE void evaluate(Fragment_ const& accum, Fragment_& output) {
template <typename FragmentA_, typename FragmentB_>
CUTLASS_DEVICE void evaluate(FragmentA_ const& accum, FragmentB_& output) {
FragmentMultiplyAdd mad;
mad.multiply(alpha, accum, output);
}
/// Evaluate the functor.
template <typename Fragment_>
CUTLASS_DEVICE void evaluate(Fragment_ const& accum, Fragment_ const& old, Fragment_& output) {
template <typename FragmentA_, typename FragmentB_>
CUTLASS_DEVICE void evaluate(FragmentA_ const& accum, FragmentB_ const& old, FragmentB_& output) {
FragmentMultiplyAdd mad;
Fragment_ tmp;
FragmentB_ tmp;
mad.multiply(beta, old, tmp);
mad.multiply_add(alpha, accum, tmp, output);
}

View File

@ -164,6 +164,13 @@ struct WmmaGemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
this->params.predicate_offset -= (h + pred_offset);
}
/// The accessor.
CUTLASS_DEVICE void get(typename Base::AccessType& value, int d, int h, int w, int c) const {
int const imm =
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(0, 0, w, c);
Load<Scalar, TileTraits_::kAccessSize, MemorySpace::kGlobal>::load(value, params.pointer, imm);
}
/// Increment the pointer in the C dimension.
CUTLASS_DEVICE void inc_c() {}
/// Increment the pointer in the W dimension.
@ -181,18 +188,19 @@ struct WmmaGemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
params.predicate_offset -= params.predicate_inc_advance;
}
/// The accessor.
CUTLASS_DEVICE void set(typename Base::AccessType const& value, int d, int h, int w, int c) {
int const imm =
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(d, h, w, 0);
Store<Scalar, TileTraits_::kAccessSize, MemorySpace::kGlobal>::store(
value, params.pointer, imm);
}
/// Test the predicate.
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const {
return predicates.at(w) && params.predicate_offset > 0;
}
/// Returns the raw pointer
CUTLASS_HOST_DEVICE
Pointer data() { return params.pointer; }
CUTLASS_HOST_DEVICE
Pointer const data() const { return params.pointer; }
/// The predicates for the row.
cutlass::PredicateVector<Base::Iterations::kW> predicates;
};

View File

@ -45,14 +45,12 @@ CUTLASS_HOST_DEVICE void iterator_load(InputIterator &iterator, Fragment &fragme
for (int w = 0; w < InputIterator::Iterations::kW; ++w) {
for (int c = 0; c < InputIterator::Iterations::kC; ++c) {
if (iterator.valid(d, h, w, c)) {
int const offset =
ComputeOffsetFromStrides<typename InputIterator::ImmediateOffsetStrides>::get(
0, 0, w, c);
Load<typename Fragment::Element, InputIterator::Tile::kC, InputIterator::kMemorySpace>::
load(reinterpret_cast<typename InputIterator::AccessType &>(
frag_iterator.at(d, h, w, c)),
iterator.data(),
offset);
iterator.get(reinterpret_cast<typename InputIterator::AccessType &>(
frag_iterator.at(d, h, w, c)),
d,
h,
w,
c);
}
}
if (w < InputIterator::Iterations::kW - 1) {
@ -196,17 +194,12 @@ CUTLASS_HOST_DEVICE void iterator_store(OutputIterator &iterator, Fragment &frag
for (int h = 0; h < OutputIterator::Iterations::kH; ++h) {
for (int w = 0; w < OutputIterator::Iterations::kW; ++w) {
if (iterator.valid(d, h, w, 0)) {
int const offset =
ComputeOffsetFromStrides<typename OutputIterator::ImmediateOffsetStrides>::get(
d, h, w, 0);
Store<typename Fragment::Element,
OutputIterator::Tile::kC,
OutputIterator::kMemorySpace>::
store(reinterpret_cast<typename OutputIterator::AccessType &>(
frag_iterator.at(d, h, w, 0)),
iterator.data(),
offset);
iterator.set(reinterpret_cast<typename OutputIterator::AccessType const &>(
frag_iterator.at(d, h, w, 0)),
d,
h,
w,
0);
}
if (w < OutputIterator::Iterations::kW - 1) {
iterator.inc_w();

View File

@ -106,6 +106,29 @@ struct Load<double, 2, Memory_, true, 16> {
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(__CUDACC_VERSION_MAJOR) && __CUDACC_VERSION_MAJOR < 10
// WAR bug in NVCC where the upper and lower half of the register end up being the same
template <MemorySpace::Kind Memory_>
struct Load<half, 8, Memory_, true, 16> {
/// The output type.
typedef typename Vectorize<half, 8>::Type AccessType;
/// The store function.
static CUTLASS_DEVICE void load(AccessType& dst, half const* pointer, int offset) {
int2 tmp = reinterpret_cast<int2 const*>(&pointer[offset])[0];
dst.registers[0] = tmp.x;
dst.registers[1] = tmp.y;
tmp = reinterpret_cast<int2 const*>(&pointer[offset + 4])[0];
dst.registers[2] = tmp.x;
dst.registers[3] = tmp.y;
}
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Scalar_, int Lanes_, MemorySpace::Kind Memory_>
struct Load<Scalar_, Lanes_, Memory_, true, 16> {
/// The output type.

View File

@ -150,9 +150,13 @@ struct ShapeMin {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Shape_>
template <typename Shape_, int kElementsPerAccess>
struct ShapeStrides {
typedef Shape<Shape_::kH * Shape_::kW * Shape_::kC, Shape_::kW * Shape_::kC, Shape_::kC, 1> Shape;
typedef Shape<Shape_::kH * Shape_::kW * Shape_::kC,
Shape_::kW * Shape_::kC,
Shape_::kC,
kElementsPerAccess>
Shape;
};
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -73,7 +73,11 @@ struct IteratorFragment {
* @brief A template defining \ref tile_traits_concept
* @concept{tile_traits_concept}
*/
template <typename Tile_, typename Delta_, typename Iterations_, typename ThreadOffset_>
template <typename Tile_,
typename Delta_,
typename Iterations_,
typename ThreadOffset_,
int kAccessSize>
struct TileTraits {
/// Shape of the tile
typedef Tile_ Tile;
@ -501,6 +505,13 @@ struct TileLoadIterator : public TileIteratorBase<Traits_,
CUTLASS_HOST_DEVICE
Scalar const *data() const { return params.pointer; }
/// The accessor.
CUTLASS_DEVICE void get(AccessType &value, int d, int h, int w, int c) const {
int const imm =
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(d, h, w, c);
Load<Scalar, Base::kAccessSize, kMemorySpace>::load(value, params.pointer, imm);
}
/// Increment in the D dimension
CUTLASS_HOST_DEVICE void inc_d() { params.pointer += params.inc_d; }
@ -829,6 +840,13 @@ struct TileStoreIterator : public TileIteratorBase<Traits_,
}
}
/// The accessor.
CUTLASS_DEVICE void set(AccessType const &value, int d, int h, int w, int c) {
int const imm =
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(d, h, w, c);
Store<Scalar, Base::kAccessSize, kMemorySpace>::store(value, params.pointer, imm);
}
public:
/// Stores a fragment and advances to the next tile.
template <typename Fragment, typename PredicateIterator>

View File

@ -299,7 +299,7 @@ typedef integral_constant<bool, true> true_type;
/// The type used as a compile-time boolean with false value.
typedef integral_constant<bool, false> false_type;
#if (!defined(_MSC_VER) && (__cplusplus < 201402L)) || (defined(_MSC_VER) && (_MSC_VER < 1900))
#if (!defined(_MSC_VER) && (__cplusplus <= 201402L)) || (defined(_MSC_VER) && (_MSC_VER < 1900))
/// std::bool_constant
template <bool V>

View File

@ -47,6 +47,7 @@ set(CUTLASS_UNIT_TEST_SOURCES
core/tile_iterator.cu
gemm/dgemm.cu
gemm/hgemm_128x128x8.cu
gemm/hgemm_128x128x16.cu
gemm/hgemm_128x32x8.cu
gemm/hgemm_128x64x8.cu
gemm/igemm_128x128x32.cu
@ -54,12 +55,19 @@ set(CUTLASS_UNIT_TEST_SOURCES
gemm/igemm_128x32x32.cu
gemm/igemm_128x128x32_float.cu
gemm/igemm_128x128x32_int8.cu
gemm/igemm_32x32x128.cu
gemm/sgemm_128x128x8.cu
gemm/sgemm_128x128x16.cu
gemm/sgemm_128x64x8.cu
gemm/sgemm_128x64x16.cu
gemm/sgemm_128x32x8.cu
gemm/sgemm_128x32x16.cu
gemm/sgemm_64x128x8.cu
gemm/sgemm_64x128x16.cu
gemm/sgemm_64x64x8.cu
gemm/sgemm_64x64x16.cu
gemm/sgemm_64x32x8.cu
gemm/sgemm_64x32x16.cu
gemm/wmma_gemm.cu
)

View File

@ -26,9 +26,26 @@
\brief CUTLASS Unit Tests
*/
#include <cuda_runtime_api.h>
#include <gtest/gtest.h>
void set_gtest_flag() {
// Default flags can be overwritten by --gtest_filter from commandline
cudaDeviceProp deviceProperties;
cudaGetDeviceProperties(&deviceProperties, 0);
int deviceMajorMinor = deviceProperties.major * 10 + deviceProperties.minor;
if (deviceMajorMinor < 53)
::testing::GTEST_FLAG(filter) = "-*Igemm*:*Hgemm*:*mma*";
else if (deviceMajorMinor < 61)
::testing::GTEST_FLAG(filter) = "-*Igemm*:*mma*";
else if (deviceMajorMinor < 70)
::testing::GTEST_FLAG(filter) = "-*mma*";
}
int main(int argc, char* arg[]) {
set_gtest_flag();
::testing::InitGoogleTest(&argc, arg);
return RUN_ALL_TESTS();
}

View File

@ -104,6 +104,64 @@ TEST(Dgemm_128x128x8, dgemm_512x256x64_nt) {
////////////////////////////////////////////////////////////////////////////////////////////////////
//Sliced-K configuration
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Dgemm_64x32x16, dgemm_64x32x16_nt) {
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<16, 32, 64> > GemmTraits;
run_gemm<GemmTraits>(64, 32, 16);
}
TEST(Dgemm_64x32x16, dgemm_256x128x64_nt) {
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<16, 32, 64> > GemmTraits;
run_gemm<GemmTraits>(256, 128, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Dgemm_64x64x16, dgemm_64x64x16_nt) {
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<16, 64, 64> > GemmTraits;
run_gemm<GemmTraits>(64, 64, 16);
}
TEST(Dgemm_64x64x16, dgemm_256x128x64_nt) {
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<16, 64, 64> > GemmTraits;
run_gemm<GemmTraits>(256, 128, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Dgemm_128x32x16, dgemm_128x32x8_nt) {
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<16, 32, 128> > GemmTraits;
run_gemm<GemmTraits>(128, 32, 16);
}
TEST(Dgemm_128x32x16, dgemm_256x64x64_nt) {
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<16, 32, 128> > GemmTraits;
run_gemm<GemmTraits>(256, 64, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// DGEMM Column-Column
//
@ -182,6 +240,64 @@ TEST(Dgemm_128x128x8, dgemm_512x256x64_nn) {
////////////////////////////////////////////////////////////////////////////////////////////////////
// Sliced-K configuration
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Dgemm_64x32x16, dgemm_64x32x16_nn) {
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 64> > GemmTraits;
run_gemm<GemmTraits>(64, 32, 16);
}
TEST(Dgemm_64x32x16, dgemm_256x128x64_nn) {
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 64> > GemmTraits;
run_gemm<GemmTraits>(256, 128, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Dgemm_64x64x16, dgemm_64x64x16_nn) {
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 64, 64> > GemmTraits;
run_gemm<GemmTraits>(64, 64, 16);
}
TEST(Dgemm_64x64x16, dgemm_256x128x64_nn) {
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 64, 64> > GemmTraits;
run_gemm<GemmTraits>(256, 128, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Dgemm_128x32x16, dgemm_128x32x16_nn) {
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 128> > GemmTraits;
run_gemm<GemmTraits>(128, 32, 16);
}
TEST(Dgemm_128x32x16, dgemm_256x64x64_nn) {
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 128> > GemmTraits;
run_gemm<GemmTraits>(256, 64, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// DGEMM Row-Column
//
@ -260,6 +376,64 @@ TEST(Dgemm_128x128x8, dgemm_512x256x64_tn) {
////////////////////////////////////////////////////////////////////////////////////////////////////
// Sliced-K configuration
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Dgemm_64x32x16, dgemm_64x32x16_tn) {
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 64> > GemmTraits;
run_gemm<GemmTraits>(64, 32, 16);
}
TEST(Dgemm_64x32x16, dgemm_256x128x64_tn) {
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 64> > GemmTraits;
run_gemm<GemmTraits>(256, 128, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Dgemm_64x64x16, dgemm_64x64x16_tn) {
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 64, 64> > GemmTraits;
run_gemm<GemmTraits>(64, 64, 16);
}
TEST(Dgemm_64x64x16, dgemm_256x128x64_tn) {
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 64, 64> > GemmTraits;
run_gemm<GemmTraits>(256, 128, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Dgemm_128x32x16, dgemm_128x32x8_tn) {
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 128> > GemmTraits;
run_gemm<GemmTraits>(128, 32, 16);
}
TEST(Dgemm_128x32x16, dgemm_256x64x64_tn) {
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 128> > GemmTraits;
run_gemm<GemmTraits>(256, 64, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// DGEMM Row-Row
//
@ -338,3 +512,62 @@ TEST(Dgemm_128x128x8, dgemm_512x256x64_tt) {
////////////////////////////////////////////////////////////////////////////////////////////////////
// Sliced-K configuration
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Dgemm_64x32x16, dgemm_64x32x16_tt) {
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<16, 32, 64> > GemmTraits;
run_gemm<GemmTraits>(64, 32, 16);
}
TEST(Dgemm_64x32x16, dgemm_256x128x64_tt) {
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<16, 32, 64> > GemmTraits;
run_gemm<GemmTraits>(256, 128, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Dgemm_64x64x16, dgemm_64x64x16_tt) {
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<16, 64, 64> > GemmTraits;
run_gemm<GemmTraits>(64, 64, 16);
}
TEST(Dgemm_64x64x16, dgemm_256x128x64_tt) {
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<16, 64, 64> > GemmTraits;
run_gemm<GemmTraits>(256, 128, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Dgemm_128x32x16, dgemm_128x32x8_tt) {
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<16, 32, 128> > GemmTraits;
run_gemm<GemmTraits>(128, 32, 16);
}
TEST(Dgemm_128x32x16, dgemm_256x64x64_tt) {
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<16, 32, 128> > GemmTraits;
run_gemm<GemmTraits>(256, 64, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -24,12 +24,18 @@
**************************************************************************************************/
#include <cutlass/cutlass.h>
#include <tools/test/unit/gemm/gemm_testbed.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemmTraits_>
static void run_gemm(
int m,
int n,
int k,
int lda,
int ldb,
int ldc,
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type alpha =
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type(1),
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type beta =
@ -51,6 +57,9 @@ static void run_gemm(
testbed(m,
n,
k,
lda,
ldb,
ldc,
cutlass::convert(GemmTraits_::kLayoutA),
cutlass::convert(GemmTraits_::kLayoutB),
alpha,
@ -88,3 +97,22 @@ static void run_gemm(
ASSERT_TRUE(testbed.verify_with_host());
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemmTraits_>
static void run_gemm(
int m,
int n,
int k,
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type alpha =
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type(1),
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type beta =
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type(0)) {
int lda = GemmTraits_::kLayoutA == cutlass::MatrixLayout::kColumnMajor ? m : k;
int ldb = GemmTraits_::kLayoutB == cutlass::MatrixLayout::kColumnMajor ? k : n;
run_gemm<GemmTraits_>(m, n, k, lda, ldb, m, alpha, beta);
}
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -44,6 +44,8 @@
namespace cutlass {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <cutlass::GemmOperand::Kind kOperand_,
cutlass::MatrixLayout::Kind kLayout_,
typename Scalar_,
@ -53,6 +55,8 @@ struct WmmaMatrix;
namespace test {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
struct GemmTestbedTraits : public cutlass::TypeTraits<T> {};
@ -68,6 +72,8 @@ struct GemmTestbedTraits<cutlass::WmmaMatrix<kOperand_, kLayout_, Scalar_, WmmaS
static inline double to_print(double x) { return x; }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename AType, typename BType, typename CType, typename Accumulator, typename Scalar>
struct GemmTestbed {
//
@ -219,11 +225,11 @@ struct GemmTestbed {
typedef cutlass::Coord<cutlass::HostTensor<T>::Rank> Coord_t;
size_t matrix_stride = layout == CUBLAS_OP_N ? columns * ldm : rows * ldm;
// TODO: Remove that (int) cast.
Coord_t stride = cutlass::make_Coord(
rows * columns, layout == CUBLAS_OP_N ? 1 : ldm, layout == CUBLAS_OP_N ? ldm : 1, 1);
(int)matrix_stride, layout == CUBLAS_OP_N ? 1 : ldm, layout == CUBLAS_OP_N ? ldm : 1, 1);
Coord_t size = cutlass::make_Coord(1, rows, columns, 1);
tensor.reset(stride, size);
}
@ -231,11 +237,13 @@ struct GemmTestbed {
// Methods
//
/// Constructs a workspace for verifying GEMM, assumes
/// dense packing.
/// Constructs a workspace for verifying GEMM.
GemmTestbed(int M_,
int N_,
int K_,
int lda,
int ldb,
int ldc,
cublasOperation_t layout_a,
cublasOperation_t layout_b,
Scalar alpha_ = Scalar(1),
@ -248,33 +256,6 @@ struct GemmTestbed {
throw cutlass::cuda_exception("Failed to create CUBLAS handle");
}
resize(A, M_, K_, layout_a);
resize(B, K_, N_, layout_b);
resize(C_initial, M_, N_, layout_c);
resize(ref_host, M_, N_, layout_c);
resize(ref_cublas, M_, N_, layout_c);
resize(computed, M_, N_, layout_c);
}
/// Constructs a workspace for verifying GEMM with arbitrary strides
GemmTestbed(int M_,
int N_,
int K_,
int ldc,
cublasOperation_t layout_a,
int lda,
cublasOperation_t layout_b,
int ldb,
Scalar alpha_ = Scalar(1),
Scalar beta_ = Scalar(0),
cublasGemmAlgo_t algorithm_ = CUBLAS_GEMM_DEFAULT,
cublasOperation_t layout_c = CUBLAS_OP_N)
: alpha(alpha_), beta(beta_), algorithm(algorithm_) {
status = cublasCreate(&handle);
if (status != CUBLAS_STATUS_SUCCESS) {
throw cutlass::cuda_exception("Failed to create CUBLAS handle");
}
resize(A, M_, K_, layout_a, lda);
resize(B, K_, N_, layout_b, ldb);
resize(C_initial, M_, N_, layout_c, ldc);
@ -515,6 +496,8 @@ struct GemmTestbed {
} // namespace test
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
inline cublasOperation_t convert(cutlass::MatrixLayout::Kind layout) {
switch (layout) {
@ -527,4 +510,6 @@ inline cublasOperation_t convert(cutlass::MatrixLayout::Kind layout) {
}
return CUBLAS_OP_N;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}

View File

@ -0,0 +1,347 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <cutlass_unit_test.h>
#include <tools/util/half.h>
#include <cutlass/gemm/gemm.h>
#include <cutlass/gemm/hgemm_traits.h>
#include <tools/test/unit/gemm/gemm_testbed.h>
#include <tools/test/unit/gemm/gemm.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_2x2x2_nt) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(2, 2, 2);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_128x128x8_nt) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(128, 128, 8);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_128x128x16_nt) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(128, 128, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_128x128x17_nt) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(128, 128, 17);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_128x128x64_nt) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(128, 128, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_256x128x16_nt) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(256, 128, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_128x256x16_nt) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(128, 256, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_256x256x16_nt) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(256, 256, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_128x128x16_nn) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(128, 128, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_128x128x18_nn) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(128, 128, 18);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_128x128x64_nn) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(128, 128, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_256x128x16_nn) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(256, 128, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_128x256x16_nn) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(128, 256, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_256x256x16_nn) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(256, 256, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_128x128x16_tn) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(128, 128, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_128x128x18_tn) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(128, 128, 18);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_128x128x64_tn) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(128, 128, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_256x128x16_tn) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(256, 128, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_128x256x16_tn) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(128, 256, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_256x256x16_tn) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(256, 256, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_128x128x16_tt) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(128, 128, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_128x128x18_tt) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(128, 128, 18);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_128x128x64_tt) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(128, 128, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_256x128x16_tt) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(256, 128, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_128x256x16_tt) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(128, 256, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_256x256x16_tt) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(256, 256, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_128x128x16_alpha2_nt) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(128, 128, 16, cutlass::half_t(2), cutlass::half_t(0));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_128x128x16_beta1_nt) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(128, 128, 16, cutlass::half_t(1), cutlass::half_t(1));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_128x128x16_alpha2_beta1_nt) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(128, 128, 16, cutlass::half_t(2), cutlass::half_t(1));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_120x112x64_ldg8_nt) {
// Load 8 halfs per LDG for A/B.
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<16, 128, 128>,
cutlass::gemm::LinearScaling<half>,
cutlass::Shape<8, 8, 16>,
8, 8>
HgemmTraits;
run_gemm<HgemmTraits>(120, 112, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_508x252x120_ragged_nt) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(508, 252, 120);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_124x126x32_ragged_nt) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(124, 126, 32);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Hgemm_128x128x16, hgemm_124x126x32_ragged_alpha2_beta1_nt) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<16, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(124, 126, 32, cutlass::half_t(2), cutlass::half_t(1));
}
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -345,7 +345,7 @@ TEST(Hgemm_128x128x8, hgemm_128x128x16_alpha2_beta1_nt) {
TEST(Hgemm_128x128x8, hgemm_120x112x64_ldg8_nt) {
// Load 8 halfs per LDG for A/B.
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<8, 128, 128>,
cutlass::gemm::LinearScaling<half>,
cutlass::Shape<8, 8, 16>,
@ -367,7 +367,7 @@ TEST(Hgemm_128x128x8, hgemm_508x252x120_ragged_nt) {
TEST(Hgemm_128x128x8, hgemm_124x126x32_ragged_nt) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<8, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(124, 126, 32);
@ -377,7 +377,7 @@ TEST(Hgemm_128x128x8, hgemm_124x126x32_ragged_nt) {
TEST(Hgemm_128x128x8, hgemm_124x126x32_ragged_alpha2_beta1_nt) {
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<8, 128, 128> >
HgemmTraits;
run_gemm<HgemmTraits>(124, 126, 32, cutlass::half_t(2), cutlass::half_t(1));

View File

@ -25,7 +25,6 @@
#include <cutlass_unit_test.h>
#include <cutlass/gemm/gemm.h>
#include <cutlass/gemm/igemm_traits.h>
#include <tools/test/unit/gemm/gemm_testbed.h>
#include <tools/test/unit/gemm/gemm.h>
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,238 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <cutlass_unit_test.h>
#include <cutlass/gemm/gemm.h>
#include <cutlass/gemm/igemm_traits.h>
#include <tools/test/unit/gemm/gemm.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Igemm_32x32x128, igemm_32x32x4_nt) {
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<128, 32, 32>,
int,
cutlass::gemm::LinearScaling<int>,
cutlass::Shape<32, 8, 4> >
IgemmTraits;
run_gemm<IgemmTraits>(32, 32, 4);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Igemm_32x32x128, igemm_32x32x8_nt) {
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<128, 32, 32>,
int,
cutlass::gemm::LinearScaling<int>,
cutlass::Shape<32, 8, 4> >
IgemmTraits;
run_gemm<IgemmTraits>(32, 32, 8);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Igemm_32x32x128, igemm_32x32x32_nt) {
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<128, 32, 32>,
int,
cutlass::gemm::LinearScaling<int>,
cutlass::Shape<32, 8, 4> >
IgemmTraits;
run_gemm<IgemmTraits>(32, 32, 32);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Igemm_32x32x128, igemm_32x32x128_nt) {
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<128, 32, 32>,
int,
cutlass::gemm::LinearScaling<int>,
cutlass::Shape<32, 8, 4> >
IgemmTraits;
run_gemm<IgemmTraits>(32, 32, 128);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Igemm_32x32x128, igemm_32x32x4_nn) {
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<128, 32, 32>,
int,
cutlass::gemm::LinearScaling<int>,
cutlass::Shape<32, 8, 4> >
IgemmTraits;
run_gemm<IgemmTraits>(32, 32, 4);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Igemm_32x32x128, igemm_32x32x8_nn) {
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<128, 32, 32>,
int,
cutlass::gemm::LinearScaling<int>,
cutlass::Shape<32, 8, 4> >
IgemmTraits;
run_gemm<IgemmTraits>(32, 32, 8);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Igemm_32x32x128, igemm_32x32x32_nn) {
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<128, 32, 32>,
int,
cutlass::gemm::LinearScaling<int>,
cutlass::Shape<32, 8, 4> >
IgemmTraits;
run_gemm<IgemmTraits>(32, 32, 32);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Igemm_32x32x128, igemm_32x32x128_nn) {
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<128, 32, 32>,
int,
cutlass::gemm::LinearScaling<int>,
cutlass::Shape<32, 8, 4> >
IgemmTraits;
run_gemm<IgemmTraits>(32, 32, 128);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Igemm_32x32x128, igemm_32x32x4_tn) {
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<128, 32, 32>,
int,
cutlass::gemm::LinearScaling<int>,
cutlass::Shape<32, 8, 4> >
IgemmTraits;
run_gemm<IgemmTraits>(32, 32, 4);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Igemm_32x32x128, igemm_32x32x8_tn) {
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<128, 32, 32>,
int,
cutlass::gemm::LinearScaling<int>,
cutlass::Shape<32, 8, 4> >
IgemmTraits;
run_gemm<IgemmTraits>(32, 32, 8);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Igemm_32x32x128, igemm_32x32x15_tn) {
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<128, 32, 32>,
int,
cutlass::gemm::LinearScaling<int>,
cutlass::Shape<32, 8, 4> >
IgemmTraits;
run_gemm<IgemmTraits>(32, 32, 15, 16, 16, 32);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Igemm_32x32x128, igemm_32x32x32_tn) {
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<128, 32, 32>,
int,
cutlass::gemm::LinearScaling<int>,
cutlass::Shape<32, 8, 4> >
IgemmTraits;
run_gemm<IgemmTraits>(32, 32, 32);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Igemm_32x32x128, igemm_32x32x128_tn) {
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<128, 32, 32>,
int,
cutlass::gemm::LinearScaling<int>,
cutlass::Shape<32, 8, 4> >
IgemmTraits;
run_gemm<IgemmTraits>(32, 32, 128);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Igemm_32x32x128, igemm_32x32x8_tt) {
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<128, 32, 32>,
int,
cutlass::gemm::LinearScaling<int>,
cutlass::Shape<32, 8, 4> >
IgemmTraits;
run_gemm<IgemmTraits>(32, 32, 8);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Igemm_32x32x128, igemm_32x32x32_tt) {
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<128, 32, 32>,
int,
cutlass::gemm::LinearScaling<int>,
cutlass::Shape<32, 8, 4> >
IgemmTraits;
run_gemm<IgemmTraits>(32, 32, 32);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Igemm_32x32x128, igemm_32x32x128_tt) {
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<128, 32, 32>,
int,
cutlass::gemm::LinearScaling<int>,
cutlass::Shape<32, 8, 4> >
IgemmTraits;
run_gemm<IgemmTraits>(32, 32, 128);
}
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,410 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <cutlass_unit_test.h>
#include <cutlass/gemm/gemm.h>
#include <cutlass/gemm/sgemm_traits.h>
#include <tools/test/unit/gemm/gemm_testbed.h>
#include <tools/test/unit/gemm/gemm.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_128x128x16_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 128, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_128x81x1_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 81, 1);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_128x112x16_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 112, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_128x112x17_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 112, 17);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_128x73x16_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 73, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_97x112x64_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(97, 112, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_256x112x16_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(256, 112, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_128x240x16_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 240, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_256x240x16_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(256, 240, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_128x128x16_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 128, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_128x112x1_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 112, 1);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_79x112x16_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(79, 112, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_128x81x17_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 81, 17);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_128x112x16_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 112, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_128x73x64_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 73, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_256x112x16_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(256, 112, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_128x256x16_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 256, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_256x256x16_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(256, 256, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_128x128x16_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 128, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_128x128x1_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<16, 128, 128> > SgemmTraits;
run_gemm<SgemmTraits>(128, 128, 1);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_127x112x16_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(127, 112, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_21x112x17_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(21, 112, 17);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_128x73x16_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 73, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_128x81x64_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 81, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_256x112x16_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(256, 112, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_47x256x16_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(47, 256, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_211x256x16_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(211, 256, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_128x128x16_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 128, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_128x128x1_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> > SgemmTraits;
run_gemm<SgemmTraits>(128, 128, 1);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_109x112x16_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(109, 112, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_128x112x17_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 112, 17);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_128x112x16_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 112, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_123x112x64_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(123, 112, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_256x112x16_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(256, 112, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_128x256x16_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 256, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_256x256x16_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(256, 256, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_120x112x64_ldg4_nt) {
// Load 4 floats per LDG for A/B.
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<16, 128, 128>,
cutlass::gemm::LinearScaling<float>,
cutlass::Shape<8, 8, 8>,
4,
4>
SgemmTraits;
run_gemm<SgemmTraits>(120, 112, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_128x128x16_alpha2_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 128, 16, 2.f, 0.f);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_128x112x16_beta1_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 112, 16, 1.f, 1.f);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x128x16, sgemm_128x112x16_alpha2_beta1_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 112, 16, 2.f, 1.f);
}
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -334,7 +334,7 @@ TEST(Sgemm_128x128x8, sgemm_256x256x16_tt) {
TEST(Sgemm_128x128x8, sgemm_120x112x64_ldg4_nt) {
// Load 4 floats per LDG for A/B.
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<8, 128, 128>,
cutlass::gemm::LinearScaling<float>,
cutlass::Shape<8, 8, 8>,

View File

@ -0,0 +1,294 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <cutlass_unit_test.h>
#include <cutlass/gemm/gemm.h>
#include <cutlass/gemm/sgemm_traits.h>
#include <tools/test/unit/gemm/gemm_testbed.h>
#include <tools/test/unit/gemm/gemm.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x32x16, sgemm_128x32x1_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 32, 1);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x32x16, sgemm_128x32x16_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 32, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x32x16, sgemm_128x32x17_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 32, 17);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x32x16, sgemm_128x32x32_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 32, 32);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x32x16, sgemm_256x32x16_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(256, 32, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x32x16, sgemm_128x64x16_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x32x16, sgemm_256x64x16_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(256, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x32x16, sgemm_128x32x1_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 32, 1);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x32x16, sgemm_128x32x16_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 32, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x32x16, sgemm_128x32x17_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 32, 17);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x32x16, sgemm_128x32x32_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 32, 32);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x32x16, sgemm_256x32x16_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(256, 32, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x32x16, sgemm_128x64x16_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x32x16, sgemm_256x64x16_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(256, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x32x16, sgemm_128x32x1_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<16, 128, 128> > SgemmTraits;
run_gemm<SgemmTraits>(128, 128, 1);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x32x16, sgemm_128x32x16_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 32, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x32x16, sgemm_128x32x17_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 32, 17);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x32x16, sgemm_128x32x32_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 32, 32);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x32x16, sgemm_256x32x16_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(256, 32, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x32x16, sgemm_128x64x16_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x32x16, sgemm_256x64x16_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(256, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x32x16, sgemm_128x32x1_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> > SgemmTraits;
run_gemm<SgemmTraits>(128, 128, 1);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x32x16, sgemm_128x32x16_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 32, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x32x16, sgemm_128x32x17_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 32, 17);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x32x16, sgemm_128x32x32_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 32, 32);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x32x16, sgemm_256x32x16_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(256, 32, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x32x16, sgemm_128x64x16_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x32x16, sgemm_256x64x16_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(256, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,285 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <cutlass_unit_test.h>
#include <cutlass/gemm/gemm.h>
#include <cutlass/gemm/sgemm_traits.h>
#include <tools/test/unit/gemm/gemm_testbed.h>
#include <tools/test/unit/gemm/gemm.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x64x16, sgemm_128x64x1_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 64, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 64, 1);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x64x16, sgemm_128x64x16_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 64, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x64x16, sgemm_128x64x17_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 64, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 64, 17);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x64x16, sgemm_128x64x64_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 64, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 64, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x64x16, sgemm_256x64x16_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 64, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(256, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x64x16, sgemm_128x128x16_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 64, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 128, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x64x16, sgemm_256x128x16_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 64, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(256, 128, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x64x16, sgemm_128x64x1_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 64, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 64, 1);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x64x16, sgemm_128x64x8_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 64, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x64x16, sgemm_128x64x17_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 64, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 64, 17);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x64x16, sgemm_128x64x64_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 64, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 64, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x64x16, sgemm_256x64x16_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 64, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(256, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x64x16, sgemm_128x128x16_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 64, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 128, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x64x16, sgemm_256x128x16_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 64, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(256, 128, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x64x16, sgemm_128x64x1_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<16, 128, 128> > SgemmTraits;
run_gemm<SgemmTraits>(128, 128, 1);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x64x16, sgemm_128x64x16_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 64, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x64x16, sgemm_128x64x17_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 64, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 64, 17);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x64x16, sgemm_128x64x64_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 64, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 64, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x64x16, sgemm_256x64x16_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 64, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(256, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x64x16, sgemm_128x128x16_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 64, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 128, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x64x16, sgemm_256x128x16_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 64, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(256, 128, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x64x16, sgemm_128x64x1_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 128, 128> > SgemmTraits;
run_gemm<SgemmTraits>(128, 128, 1);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x64x16, sgemm_128x64x16_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 64, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x64x16, sgemm_128x64x17_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 64, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 64, 17);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x64x16, sgemm_128x64x64_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 64, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 64, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x64x16, sgemm_128x128x16_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 64, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 128, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_128x64x16, sgemm_256x128x16_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 64, 128> >
SgemmTraits;
run_gemm<SgemmTraits>(256, 128, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -333,10 +333,10 @@ TEST(Sgemm_128x64x8, sgemm_256x128x16_tt) {
TEST(Sgemm_128x64x8, sgemm_128x64x64_8x4_accumulators_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<8, 64, 128>,
cutlass::gemm::LinearScaling<float>,
cutlass::Shape<8, 4, 8> >
cutlass::Shape<8, 8, 8> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 64, 64);
}
@ -345,7 +345,7 @@ TEST(Sgemm_128x64x8, sgemm_128x64x64_8x4_accumulators_nt) {
TEST(Sgemm_128x64x8, sgemm_128x64x64_4x8_accumulators_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<8, 64, 128>,
cutlass::gemm::LinearScaling<float>,
cutlass::Shape<8, 8, 4> >

View File

@ -0,0 +1,43 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <cutlass_unit_test.h>
#include <cutlass/gemm/gemm.h>
#include <cutlass/gemm/sgemm_traits.h>
#include <tools/test/unit/gemm/gemm_testbed.h>
#include <tools/test/unit/gemm/gemm.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x128x16, sgemm_64x128x64_4x8_accumulators_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<16, 128, 64>,
cutlass::gemm::LinearScaling<float>,
cutlass::Shape<8, 8, 4> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 128, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -32,7 +32,7 @@
TEST(Sgemm_64x128x8, sgemm_64x128x64_4x8_accumulators_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<8, 128, 64>,
cutlass::gemm::LinearScaling<float>,
cutlass::Shape<8, 8, 4> >

View File

@ -0,0 +1,277 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <cutlass_unit_test.h>
#include <cutlass/gemm/gemm.h>
#include <cutlass/gemm/sgemm_traits.h>
#include <tools/test/unit/gemm/gemm_testbed.h>
#include <tools/test/unit/gemm/gemm.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x32x16, sgemm_64x32x1_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 32, 1);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x32x16, sgemm_64x32x16_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 32, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x32x16, sgemm_64x32x17_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 32, 17);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x32x16, sgemm_64x32x64_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 32, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x32x16, sgemm_128x32x16_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 32, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x32x16, sgemm_64x64x16_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x32x16, sgemm_128x64x16_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x32x16, sgemm_64x32x1_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 32, 1);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x32x16, sgemm_64x32x16_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 32, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x32x16, sgemm_64x32x17_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 32, 17);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x32x16, sgemm_64x32x64_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 32, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x32x16, sgemm_128x32x16_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 32, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x32x16, sgemm_64x64x16_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x32x16, sgemm_128x64x16_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x32x16, sgemm_64x32x16_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 32, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x32x16, sgemm_64x32x17_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 32, 17);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x32x16, sgemm_64x32x64_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 32, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x32x16, sgemm_128x32x16_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 32, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x32x16, sgemm_64x64x16_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x32x16, sgemm_128x64x16_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<16, 32, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x32x16, sgemm_64x64x1_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 64, 64> > SgemmTraits;
run_gemm<SgemmTraits>(64, 64, 1);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x32x16, sgemm_64x32x16_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 32, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x32x16, sgemm_64x32x17_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 32, 17);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x32x16, sgemm_128x32x16_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 32, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x32x16, sgemm_64x64x16_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x32x16, sgemm_128x64x16_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<16, 32, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,294 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <cutlass_unit_test.h>
#include <cutlass/gemm/gemm.h>
#include <cutlass/gemm/sgemm_traits.h>
#include <tools/test/unit/gemm/gemm_testbed.h>
#include <tools/test/unit/gemm/gemm.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x64x16, sgemm_64x64x1_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 64, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 64, 1);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x64x16, sgemm_64x64x16_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 64, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x64x16, sgemm_64x64x17_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 64, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 64, 17);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x64x16, sgemm_64x64x64_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 64, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 64, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x64x16, sgemm_128x64x16_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 64, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x64x16, sgemm_64x128x16_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 64, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 128, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x64x16, sgemm_128x128x16_nt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 64, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 128, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x64x16, sgemm_64x64x1_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<8, 64, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 64, 1);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x64x16, sgemm_64x64x16_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<8, 64, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x64x16, sgemm_64x64x17_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<8, 64, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 64, 17);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x64x16, sgemm_64x64x64_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<8, 64, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 64, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x64x16, sgemm_128x64x16_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<8, 64, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x64x16, sgemm_64x128x16_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<8, 64, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 128, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x64x16, sgemm_128x128x16_nn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<8, 64, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 128, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x64x16, sgemm_64x64x1_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 64, 64> > SgemmTraits;
run_gemm<SgemmTraits>(64, 64, 1);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x64x16, sgemm_64x64x16_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<8, 64, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x64x16, sgemm_64x64x17_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<8, 64, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 64, 17);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x64x16, sgemm_64x64x64_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<8, 64, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 64, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x64x16, sgemm_128x64x16_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<8, 64, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x64x16, sgemm_64x128x16_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<8, 64, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 128, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x64x16, sgemm_128x128x16_tn) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
cutlass::Shape<8, 64, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 128, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x64x16, sgemm_64x64x1_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 64, 64> > SgemmTraits;
run_gemm<SgemmTraits>(64, 64, 1);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x64x16, sgemm_64x64x16_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 64, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x64x16, sgemm_64x64x17_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 64, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 64, 17);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x64x16, sgemm_64x64x64_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 64, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 64, 64);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x64x16, sgemm_128x64x16_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 64, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 64, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x64x16, sgemm_64x128x16_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 64, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(64, 128, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Sgemm_64x64x16, sgemm_128x128x16_tt) {
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 64, 64> >
SgemmTraits;
run_gemm<SgemmTraits>(128, 128, 16);
}
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -63,6 +63,7 @@ TEST(WmmaGemm_128x128x32, wmma_16x16x16_gemm_256x256x128_nt) {
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUDA_VERSION) && CUDA_VERSION >= 9100
TEST(WmmaGemm_128x128x32, wmma_8x32x16_gemm_256x256x128_nt) {
typedef cutlass::gemm::WmmaGemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor,
@ -76,9 +77,11 @@ TEST(WmmaGemm_128x128x32, wmma_8x32x16_gemm_256x256x128_nt) {
WmmaGemmTraits;
run_gemm<WmmaGemmTraits>(256, 256, 128);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUDA_VERSION) && CUDA_VERSION >= 9100
TEST(WmmaGemm_128x128x32, wmma_32x8x16_gemm_256x256x128_nt) {
typedef cutlass::gemm::WmmaGemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor,
@ -91,6 +94,7 @@ TEST(WmmaGemm_128x128x32, wmma_32x8x16_gemm_256x256x128_nt) {
WmmaGemmTraits;
run_gemm<WmmaGemmTraits>(256, 256, 128);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -124,6 +128,7 @@ TEST(WmmaGemm_128x128x32, wmma_16x16x16_gemm_256x256x128_nn) {
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUDA_VERSION) && CUDA_VERSION >= 9100
TEST(WmmaGemm_128x128x32, wmma_8x32x16_gemm_256x256x128_nn) {
typedef cutlass::gemm::WmmaGemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
@ -137,9 +142,11 @@ TEST(WmmaGemm_128x128x32, wmma_8x32x16_gemm_256x256x128_nn) {
WmmaGemmTraits;
run_gemm<WmmaGemmTraits>(256, 256, 128);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUDA_VERSION) && CUDA_VERSION >= 9100
TEST(WmmaGemm_128x128x32, wmma_32x8x16_gemm_256x256x128_nn) {
typedef cutlass::gemm::WmmaGemmTraits<cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kColumnMajor,
@ -152,6 +159,7 @@ TEST(WmmaGemm_128x128x32, wmma_32x8x16_gemm_256x256x128_nn) {
WmmaGemmTraits;
run_gemm<WmmaGemmTraits>(256, 256, 128);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -185,6 +193,7 @@ TEST(WmmaGemm_128x128x32, wmma_16x16x16_gemm_256x256x128_tt) {
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUDA_VERSION) && CUDA_VERSION >= 9100
TEST(WmmaGemm_128x128x32, wmma_8x32x16_gemm_256x256x128_tt) {
typedef cutlass::gemm::WmmaGemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor,
@ -198,9 +207,11 @@ TEST(WmmaGemm_128x128x32, wmma_8x32x16_gemm_256x256x128_tt) {
WmmaGemmTraits;
run_gemm<WmmaGemmTraits>(256, 256, 128);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUDA_VERSION) && CUDA_VERSION >= 9100
TEST(WmmaGemm_128x128x32, wmma_32x8x16_gemm_256x256x128_tt) {
typedef cutlass::gemm::WmmaGemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kRowMajor,
@ -213,6 +224,7 @@ TEST(WmmaGemm_128x128x32, wmma_32x8x16_gemm_256x256x128_tt) {
WmmaGemmTraits;
run_gemm<WmmaGemmTraits>(256, 256, 128);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -246,6 +258,7 @@ TEST(WmmaGemm_128x128x32, wmma_16x16x16_gemm_256x256x128_tn) {
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUDA_VERSION) && CUDA_VERSION >= 9100
TEST(WmmaGemm_128x128x32, wmma_8x32x16_gemm_256x256x128_tn) {
typedef cutlass::gemm::WmmaGemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
@ -259,9 +272,11 @@ TEST(WmmaGemm_128x128x32, wmma_8x32x16_gemm_256x256x128_tn) {
WmmaGemmTraits;
run_gemm<WmmaGemmTraits>(256, 256, 128);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUDA_VERSION) && CUDA_VERSION >= 9100
TEST(WmmaGemm_128x128x32, wmma_32x8x16_gemm_256x256x128_tn) {
typedef cutlass::gemm::WmmaGemmTraits<cutlass::MatrixLayout::kRowMajor,
cutlass::MatrixLayout::kColumnMajor,
@ -274,6 +289,7 @@ TEST(WmmaGemm_128x128x32, wmma_32x8x16_gemm_256x256x128_tn) {
WmmaGemmTraits;
run_gemm<WmmaGemmTraits>(256, 256, 128);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -109,6 +109,9 @@ class HostTensor : public HostTensorView<T> {
host_.clear();
host_.resize(_capacity);
for (size_t i = 0; i < _capacity; ++i) {
host_[i] = T((int)0xdeadbeef);
}
device_.reset(_device_memory, _capacity);
Base::reset(TensorRef_t(host_.data(), _stride), _size);