CUTLASS 2.2 (#96)

Adds support for NVIDIA Ampere Architecture features. CUDA 11 Toolkit recommended.
This commit is contained in:
Andrew Kerr
2020-06-08 16:17:35 -07:00
committed by GitHub
parent e33d90b361
commit 86931fef85
584 changed files with 51080 additions and 3373 deletions

View File

@ -1,4 +1,4 @@
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:

View File

@ -1,4 +1,4 @@
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
@ -50,6 +50,7 @@ target_link_libraries(
)
set(CUTLASS_INSTALL_TESTS ON CACHE BOOL "Install test executables")
set(CUTLASS_TEST_EXECUTION_ENVIRONMENT "" CACHE BOOL "Environment in which to invoke unit test executables")
function(cutlass_test_unit_add_executable)
@ -76,7 +77,7 @@ function(cutlass_test_unit_add_executable)
add_custom_target(
${NAME_STEM}
COMMAND
$<TARGET_FILE:${NAME}>
${CUTLASS_TEST_EXECUTION_ENVIRONMENT} $<TARGET_FILE:${NAME}>
DEPENDS
${NAME}
)

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -71,6 +71,7 @@ void FilterArchitecture() {
{ "SM61*", 61, kMaxDevice},
{ "SM70*", 70, 75},
{ "SM75*", 75, kMaxDevice},
{ "SM80*", 80, kMaxDevice},
{ 0, 0, false }
};

View File

@ -1,4 +1,4 @@
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
@ -24,6 +24,8 @@ cutlass_test_unit_add_executable(
cutlass_test_unit_core
array.cu
half.cu
bfloat16.cu
tfloat32.cu
complex.cu
predicate_vector.cu
tensor_ref.cu

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -228,6 +228,14 @@ TEST(Array, Float16x8) {
}
#endif
TEST(Array, FloatBF16x8) {
TestArray<cutlass::bfloat16_t, 8>().run();
}
TEST(Array, FloatTF32x4) {
TestArray<cutlass::tfloat32_t, 4>().run();
}
TEST(Array, Float32x4) {
TestArray<float, 4>().run();
}

209
test/unit/core/bfloat16.cu Normal file
View File

@ -0,0 +1,209 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types
and is safe to use in a union.
*/
#include "../common/cutlass_unit_test.h"
#include "cutlass/array.h"
#include "cutlass/core_io.h"
#include "cutlass/numeric_types.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/util/host_tensor.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
__global__ void convert_bf16_f32(cutlass::bfloat16_t *output, float const *input, int N) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < N) {
output[tid] = static_cast<cutlass::bfloat16_t>(input[tid]);
}
}
__global__ void convert_and_pack_bf16(cutlass::bfloat16_t *output, float const *input, int N) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid * 2 < N) {
cutlass::NumericArrayConverter<cutlass::bfloat16_t, float, 2> convert;
cutlass::Array<cutlass::bfloat16_t, 2> *dst_ptr =
reinterpret_cast<cutlass::Array<cutlass::bfloat16_t, 2> *>(output + tid * 2);
cutlass::Array<float, 2> const *src_ptr =
reinterpret_cast<cutlass::Array<float, 2> const *>(input + tid * 2);
*dst_ptr = convert(*src_ptr);
}
}
TEST(bfloat16_t, device_conversion) {
using T = cutlass::bfloat16_t;
using S = float;
int const N = 256;
cutlass::HostTensor<T, cutlass::layout::RowMajor> destination({N, 1});
cutlass::HostTensor<S, cutlass::layout::RowMajor> source({N, 1});
for (int i = 0; i < N; ++i) {
source.at({i, 0}) = float(i - 128);
destination.at({i, 0}) = T(0);
}
source.sync_device();
destination.sync_device();
convert_bf16_f32<<< dim3(1,1), dim3(N, 1) >>>(destination.device_data(), source.device_data(), N);
ASSERT_EQ(cudaGetLastError(), cudaSuccess) << "Kernel launch error.";
destination.sync_host();
int errors = 0;
for (int i = 0; i < N; ++i) {
T got = destination.at({i, 0});
S expected = source.at({i, 0});
if (S(got) != expected) {
++errors;
if (errors < 10) {
std::cerr << "Basic conversion error - [" << i << "] - got " << got << ", expected " << expected << "\n";
}
}
destination.at({i, 0}) = T(0);
}
destination.sync_device();
convert_and_pack_bf16<<< dim3(1,1), dim3(N, 1) >>>(destination.device_data(), source.device_data(), N);
ASSERT_EQ(cudaGetLastError(), cudaSuccess) << "Kernel launch error.";
destination.sync_host();
for (int i = 0; i < N; ++i) {
T got = destination.at({i, 0});
S expected = source.at({i, 0});
if (S(got) != expected) {
++errors;
if (errors < 10) {
std::cerr << "Convert and pack error - [" << i << "] - got " << got << ", expected " << expected << "\n";
}
}
}
EXPECT_EQ(errors, 0);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Host
//
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(bfloat16_t, host_conversion) {
for (int i = -128; i < 128; ++i) {
float f = static_cast<float>(i);
cutlass::bfloat16_t x = static_cast<cutlass::bfloat16_t>(i);
cutlass::bfloat16_t y = static_cast<cutlass::bfloat16_t>(f);
EXPECT_TRUE(static_cast<int>(x) == i);
EXPECT_TRUE(static_cast<float>(y) == f);
}
// Try out user-defined literals
EXPECT_TRUE(cutlass::bfloat16_t(7) == 7_bf16);
EXPECT_TRUE(7 == static_cast<int>(7_bf16));
}
TEST(bfloat16_t, host_arithmetic) {
for (int i = -100; i < 100; ++i) {
for (int j = -100; j < 100; ++j) {
cutlass::bfloat16_t x = static_cast<cutlass::bfloat16_t>(i);
cutlass::bfloat16_t y = static_cast<cutlass::bfloat16_t>(j);
EXPECT_TRUE(static_cast<int>(x + y) == (i + j));
}
}
}
TEST(bfloat16_t, host_round) {
struct {
uint32_t f32_bits;
uint16_t expected;
} tests[] = {
{0x40040000, 0x4004}, // M=0, R=0, S=0 => rtz
{0x40048000, 0x4004}, // M=0, R=1, S=0 => rtz
{0x40040001, 0x4004}, // M=0, R=1, S=1 => +inf
{0x4004c000, 0x4005}, // M=0, R=1, S=1 => +inf
{0x4004a000, 0x4005}, // M=0, R=1, S=1 => +inf
{0x40050000, 0x4005}, // M=1, R=0, S=0 => rtz
{0x40054000, 0x4005}, // M=1, R=0, S=1 => rtz
{0x40058000, 0x4006}, // M=1, R=1, S=0 => +inf
{0x40058001, 0x4006}, // M=1, R=1, S=1 => +inf
{0x7f800000, 0x7f80}, // +inf
{0xff800000, 0xff80}, // -inf
{0x7fffffff, 0x7fff}, // canonical NaN
{0x7ff00001, 0x7fff}, // NaN -> canonical NaN
{0xfff00010, 0x7fff}, // Nan -> canonical NaN
{0, 0}
};
bool running = true;
for (int i = 0; running; ++i) {
float f32 = reinterpret_cast<float const &>(tests[i].f32_bits);
cutlass::bfloat16_t bf16 = cutlass::bfloat16_t(f32);
bool passed = (tests[i].expected == bf16.raw());
EXPECT_TRUE(passed)
<< "Error - convert(f32: 0x" << std::hex << tests[i].f32_bits
<< ") -> 0x" << std::hex << tests[i].expected << "\ngot: 0x" << std::hex << bf16.raw();
if (!tests[i].f32_bits) {
running = false;
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Device
//
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -411,3 +411,13 @@ TEST(Functional, multiply_add_f16x17) {
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(Functional, multiply_add_bf16x16) {
Functional_multiply_add_TxN<cutlass::bfloat16_t, 16>();
}
TEST(Functional, multiply_add_bf16x17) {
Functional_multiply_add_TxN<cutlass::bfloat16_t, 17>();
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

197
test/unit/core/tfloat32.cu Normal file
View File

@ -0,0 +1,197 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types
and is safe to use in a union.
*/
#include "../common/cutlass_unit_test.h"
#include "cutlass/array.h"
#include "cutlass/numeric_types.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/util/device_memory.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Host
//
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(tfloat32_t, host_conversion) {
for (int i = -1024; i < 1024; ++i) {
float f = static_cast<float>(i);
cutlass::tfloat32_t x = static_cast<cutlass::tfloat32_t>(i);
cutlass::tfloat32_t y = static_cast<cutlass::tfloat32_t>(f);
EXPECT_TRUE(static_cast<int>(x) == i);
EXPECT_TRUE(static_cast<float>(y) == f);
}
// Try out user-defined literals
EXPECT_TRUE(cutlass::tfloat32_t(7) == 7_tf32);
EXPECT_TRUE(7 == static_cast<int>(7_tf32));
}
TEST(tfloat32_t, host_arithmetic) {
for (int i = -100; i < 100; ++i) {
for (int j = -100; j < 100; ++j) {
cutlass::tfloat32_t x = static_cast<cutlass::tfloat32_t>(i);
cutlass::tfloat32_t y = static_cast<cutlass::tfloat32_t>(j);
EXPECT_TRUE(static_cast<int>(x + y) == (i + j));
}
}
}
TEST(tfloat32_t, host_round_nearest) {
struct {
uint32_t f32_bits;
uint32_t expected;
} tests[] = {
{0x40000000, 0x40000000}, // M=0, R=0, S=0 => rtz
{0x40001000, 0x40000000}, // M=0, R=1, S=0 => rtz
{0x40000001, 0x40000000}, // M=0, R=0, S=1 => rtz
{0x40001001, 0x40002000}, // M=0, R=1, S=1 => +inf
{0x40002000, 0x40002000}, // M=1, R=0, S=0 => rtz
{0x40002001, 0x40002000}, // M=1, R=0, S=1 => rtz
{0x40003000, 0x40004000}, // M=1, R=1, S=0 => +inf
{0x40003001, 0x40004000}, // M=1, R=1, S=1 => +inf
{0x7f800000, 0x7f800000}, // +inf
{0xff800000, 0xff800000}, // -inf
{0x7fffffff, 0x7fffffff}, // canonical NaN to canonical NaN
{0x7f800001, 0x7fffffff}, // NaN to canonical NaN
{0xff800001, 0x7fffffff}, // NaN to canonical NaN
{0, 0}
};
bool running = true;
for (int i = 0; running; ++i) {
float f32 = reinterpret_cast<float const &>(tests[i].f32_bits);
cutlass::NumericConverter<
cutlass::tfloat32_t,
float,
cutlass::FloatRoundStyle::round_to_nearest> converter;
cutlass::tfloat32_t tf32 = converter(f32);
// note, we must explicitly truncate the low-order bits since they are not defined in TF32.
if (cutlass::isfinite(tf32)) {
tf32.storage &= 0xffffe000;
}
bool passed = (tests[i].expected == tf32.raw());
EXPECT_TRUE(passed)
<< "Error - convert(f32: 0x" << std::hex << tests[i].f32_bits
<< ") -> 0x" << std::hex << tests[i].expected << "\ngot: 0x" << std::hex << tf32.raw();
if (!tests[i].f32_bits) {
running = false;
}
}
}
namespace test {
namespace core {
__global__ void convert_tf32_half_ulp(cutlass::tfloat32_t *out, float const *in) {
cutlass::NumericConverter<
cutlass::tfloat32_t,
float,
cutlass::FloatRoundStyle::round_half_ulp_truncate> convert;
*out = convert(*in);
}
}
}
TEST(tfloat32_t, host_round_half_ulp) {
struct {
uint32_t f32_bits;
uint32_t expected;
} tests[] = {
{0x40001fff, 0x40002000},
{0x40000000, 0x40000000}, // M=0, R=0, S=0 => rtz
{0x40001000, 0x40002000}, // M=0, R=1, S=0 => rtz - this difers from RNE
{0x40000001, 0x40000000}, // M=0, R=0, S=1 => rtz
{0x40001001, 0x40002000}, // M=0, R=1, S=1 => +inf
{0x40002000, 0x40002000}, // M=1, R=0, S=0 => rtz
{0x40002001, 0x40002000}, // M=1, R=0, S=1 => rtz
{0x40003000, 0x40004000}, // M=1, R=1, S=0 => +inf
{0x40003001, 0x40004000}, // M=1, R=1, S=1 => +inf
{0x7f800000, 0x7f800000}, // +inf
{0xff800000, 0xff800000}, // -inf
{0x7fffffff, 0x7fffffff}, // canonical NaN to canonical NaN
{0x7f800001, 0x7f800001}, // NaN to NaN
{0xff800001, 0xff800001}, // NaN to NaN
{0, 0}
};
cutlass::NumericConverter<
cutlass::tfloat32_t,
float,
cutlass::FloatRoundStyle::round_half_ulp_truncate> convert;
bool running = true;
for (int i = 0; running; ++i) {
float f32 = reinterpret_cast<float const &>(tests[i].f32_bits);
cutlass::tfloat32_t tf32 = convert(f32);
// note, for this test, we must explicitly truncate the low-order bits since they are not
// defined in TF32.
if (cutlass::isfinite(tf32)) {
tf32.storage &= 0xffffe000;
}
bool passed = (tests[i].expected == tf32.raw());
EXPECT_TRUE(passed)
<< "Error - convert(f32: 0x" << std::hex << tests[i].f32_bits
<< ") -> 0x" << std::hex << tests[i].expected << "\ngot: 0x" << std::hex << tf32.raw();
if (!tests[i].f32_bits) {
running = false;
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Device
//
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,4 +1,4 @@
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:

View File

@ -1,4 +1,4 @@
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,4 +1,4 @@
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -758,6 +758,65 @@ TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_128x128_64x64x16) {
EXPECT_TRUE(passed);
}
TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_64x128_64x64x16) {
//
// Define the warp-level matrix multiply
//
using ElementOutput = int8_t;
using ElementAccumulator = int;
using ElementCompute = float;
int const kElementsPerAccess = 128 / cutlass::sizeof_bits<ElementOutput>::value;
int const kPartitionsK = 1;
using Shape = cutlass::gemm::GemmShape<128, 128, 16>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>;
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
using Element = ElementOutput;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<Element>::value, 64>;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<Element>::value, 64>;
using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type;
//
// Output operator
//
using OutputOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
kElementsPerAccess,
ElementAccumulator,
ElementCompute
>;
//
// Define the epilogue
//
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
Shape,
WarpMmaTensorOp,
kPartitionsK,
OutputOp,
kElementsPerAccess
>::Epilogue;
//
// Instantiate epilogue
//
EpilogueTestbed<Epilogue> testbed;
bool passed = testbed.run_all();
EXPECT_TRUE(passed);
}
TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_128x64_64x32x16) {
//
@ -2516,6 +2575,249 @@ TEST(SM75_Epilogue_threadblock_epilogue, f16_tensor_op_128x64_64x32x8) {
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Epilogue_threadblock_epilogue, f64_tensor_op_64x64_32x32x4) {
//
// Define the warp-level matrix multiply
//
using ElementOutput = double;
using ElementAccumulator = double;
using ElementCompute = double;
int const kElementsPerAccess = 1;
int const kPartitionsK = 1;
using Shape = cutlass::gemm::GemmShape<64, 64, 16>;
using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>;
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
using Element = double;
using ElementC = ElementAccumulator;
using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b;
using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b;
using LayoutC = cutlass::layout::RowMajor;
using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC,
LayoutC>::Type;
//
// Output operator
//
using OutputOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
kElementsPerAccess,
ElementAccumulator,
ElementCompute
>;
//
// Define the epilogue
//
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
Shape,
WarpMmaTensorOp,
kPartitionsK,
OutputOp,
kElementsPerAccess
>::Epilogue;
//
// Instantiate epilogue
//
EpilogueTestbed<Epilogue> testbed;
bool passed = testbed.run_all();
EXPECT_TRUE(passed);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Epilogue_threadblock_epilogue, f64_tensor_op_128x64_64x32x4) {
//
// Define the warp-level matrix multiply
//
using ElementOutput = double;
using ElementAccumulator = double;
using ElementCompute = double;
int const kElementsPerAccess = 1;
int const kPartitionsK = 1;
using Shape = cutlass::gemm::GemmShape<64, 64, 16>;
using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>;
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
using Element = double;
using ElementC = ElementAccumulator;
using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b;
using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b;
using LayoutC = cutlass::layout::RowMajor;
using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC,
LayoutC>::Type;
//
// Output operator
//
using OutputOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
kElementsPerAccess,
ElementAccumulator,
ElementCompute
>;
//
// Define the epilogue
//
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
Shape,
WarpMmaTensorOp,
kPartitionsK,
OutputOp,
kElementsPerAccess
>::Epilogue;
//
// Instantiate epilogue
//
EpilogueTestbed<Epilogue> testbed;
bool passed = testbed.run_all();
EXPECT_TRUE(passed);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Epilogue_threadblock_epilogue, f64_tensor_op_64x128_32x64x4) {
//
// Define the warp-level matrix multiply
//
using ElementOutput = double;
using ElementAccumulator = double;
using ElementCompute = double;
int const kElementsPerAccess = 1;
int const kPartitionsK = 1;
using Shape = cutlass::gemm::GemmShape<64, 64, 16>;
using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>;
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
using Element = double;
using ElementC = ElementAccumulator;
using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b;
using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b;
using LayoutC = cutlass::layout::RowMajor;
using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC,
LayoutC>::Type;
//
// Output operator
//
using OutputOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
kElementsPerAccess,
ElementAccumulator,
ElementCompute
>;
//
// Define the epilogue
//
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
Shape,
WarpMmaTensorOp,
kPartitionsK,
OutputOp,
kElementsPerAccess
>::Epilogue;
//
// Instantiate epilogue
//
EpilogueTestbed<Epilogue> testbed;
bool passed = testbed.run_all();
EXPECT_TRUE(passed);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Epilogue_threadblock_epilogue, f64_tensor_op_128x128_32x64x4) {
//
// Define the warp-level matrix multiply
//
using ElementOutput = double;
using ElementAccumulator = double;
using ElementCompute = double;
int const kElementsPerAccess = 1;
int const kPartitionsK = 1;
using Shape = cutlass::gemm::GemmShape<128, 128, 16>;
using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>;
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
using Element = double;
using ElementC = ElementAccumulator;
using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b;
using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b;
using LayoutC = cutlass::layout::RowMajor;
using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC,
LayoutC>::Type;
//
// Output operator
//
using OutputOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
kElementsPerAccess,
ElementAccumulator,
ElementCompute
>;
//
// Define the epilogue
//
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
Shape,
WarpMmaTensorOp,
kPartitionsK,
OutputOp,
kElementsPerAccess
>::Epilogue;
//
// Instantiate epilogue
//
EpilogueTestbed<Epilogue> testbed;
bool passed = testbed.run_all();
EXPECT_TRUE(passed);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM75_Epilogue_threadblock_epilogue, vec1_mixed_f16_f32_tensor_op_128x128_64x64x8) {

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,4 +1,4 @@
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,4 +1,4 @@
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:

View File

@ -1,4 +1,4 @@
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
@ -26,6 +26,64 @@ cutlass_test_unit_add_executable(
BATCH_SOURCES ON
BATCH_SIZE 4
gemm_planar_complex_f16_f16_f32_tensor_op_sm70.cu
gemm_planar_complex_f16_f16_f32_tensor_op_sm75.cu
gemm_planar_complex_f16_f16_f32_tensor_op_sm80.cu
gemm_universal_f16n_f16t_f32t_tensor_op_f32_sm80.cu
gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu
gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu
gemm_universal_cf32n_cf32n_cf32n_tensor_op_f32_sm80.cu
gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu
gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm80.cu
gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu
gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm80.cu
gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32_sm80.cu
gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32_sm80.cu
gemm_f16n_f16n_f16t_tensor_op_f32_sm80.cu
gemm_f16n_f16n_f32n_tensor_op_f32_sm80.cu
gemm_f16n_f16n_f32t_tensor_op_f32_sm80.cu
gemm_f16n_f16t_f16t_tensor_op_f16_sm80.cu
gemm_f16n_f16t_f32t_tensor_op_f32_sm80.cu
gemm_f16t_f16n_f16t_tensor_op_f16_sm80.cu
gemm_f16t_f16n_f32t_tensor_op_f32_sm80.cu
gemm_f16t_f16t_f32n_tensor_op_f32_sm80.cu
gemm_f16t_f16t_f32t_tensor_op_f32_sm80.cu
gemm_bf16n_bf16n_f32t_tensor_op_f32_sm80.cu
gemm_bf16t_bf16t_bf16t_tensor_op_f32_sm80.cu
gemm_tf32t_tf32n_f32t_tensor_op_f32_sm80.cu
gemm_tf32n_tf32t_f32t_tensor_op_f32_sm80.cu
gemm_tf32n_tf32n_f32t_tensor_op_f32_sm80.cu
gemm_tf32t_tf32t_f32t_tensor_op_f32_sm80.cu
gemm_f16t_f16n_f16t_tensor_op_f16_slicedk_sm80.cu
gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm80.cu
simt_sgemm_nt_sm80.cu
simt_sgemm_tn_sm80.cu
gemm_s8t_s8n_s32t_tensor_op_s32_sm80.cu
gemm_s8t_s8n_s8n_tensor_op_s32_sm80.cu
gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu
gemm_s4t_s4n_s32n_tensor_op_s32_sm80.cu
gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu
gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu
gemm_b1t_b1n_s32t_tensor_op_s32_sm80.cu
gemm_s8n_s8t_s8n_tensor_op_s32_sm80.cu
gemm_s4n_s4t_s4n_tensor_op_s32_sm80.cu
gemm_f64n_f64t_f64t_tensor_op_f64_sm80.cu
gemm_f64t_f64n_f64t_tensor_op_f64_sm80.cu
gemm_b1t_b1n_s32t_tensor_op_s32_sm75.cu
gemm_b1t_b1n_s32n_tensor_op_s32_sm75.cu
gemm_f32n_f32n_f32t_tensor_op_f32_sm80.cu
gemm_f16t_f16n_f16t_tensor_op_f16_sm75.cu
gemm_f16n_f16t_f16t_tensor_op_f16_sm75.cu
gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm75.cu
@ -149,4 +207,5 @@ cutlass_test_unit_add_executable(
gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16_sm70.cu
gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32_sm70.cu
)

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -62,7 +62,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x256x512_64x64x512) {
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
@ -84,7 +84,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 256x128x512_64x64x512) {
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
@ -106,7 +106,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x128x512_64x64x512) {
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
@ -128,7 +128,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 64x128x512_32x64x512) {
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
@ -150,7 +150,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x64x512_64x32x512) {
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
@ -172,7 +172,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 64x64x512_32x32x512) {
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());

View File

@ -0,0 +1,373 @@
/**************************************************************************************************
Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted
provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this list of
conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this list of
conditions and the following disclaimer in the documentation and/or other materials
provided with the distribution.
* Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
to endorse or promote products derived from this software without specific prior written
permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed.h"
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x256x1024_64x64x1024) {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 256, 1024>,
cutlass::gemm::GemmShape<64, 64, 1024>,
cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x128x1024_64x64x1024) {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 1024>,
cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x128x1024_64x64x1024) {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 1024>,
cutlass::gemm::GemmShape<64, 64, 1024>,
cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x64x1024_64x64x1024) {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 64, 1024>,
cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x256x1024_64x64x1024) {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 256, 1024>,
cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x128x1024_32x64x1024) {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 1024>,
cutlass::gemm::GemmShape<32, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x64x1024_64x32x1024) {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 1024>,
cutlass::gemm::GemmShape<64, 32, 1024>, cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x64x1024_32x32x1024) {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 1024>,
cutlass::gemm::GemmShape<32, 32, 1024>, cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x256x512_64x64x512) {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 256, 512>,
cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x128x512_64x64x512) {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 512>,
cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x128x512_64x64x512) {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 512>,
cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x64x512_64x64x512) {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 64, 512>,
cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x256x512_64x64x512) {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 256, 512>,
cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x128x512_32x64x512) {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 512>,
cutlass::gemm::GemmShape<32, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x64x512_64x32x512) {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 512>,
cutlass::gemm::GemmShape<64, 32, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x64x512_32x32x512) {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 512>,
cutlass::gemm::GemmShape<32, 32, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -72,7 +72,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 128x256x512_64x64x512_8x8
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2, 128, 128, false,
cutlass::arch::OpXorPopc
>;
@ -104,7 +104,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 256x128x512_64x64x512_8x8
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2, 128, 128, false,
cutlass::arch::OpXorPopc>;
@ -135,7 +135,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 128x128x512_64x64x512_8x8
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2, 128, 128, false,
cutlass::arch::OpXorPopc>;
@ -166,7 +166,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 64x128x512_32x64x512_8x8x
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2, 128, 128, false,
cutlass::arch::OpXorPopc>;
@ -197,7 +197,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 128x64x512_64x32x512_8x8x
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2, 128, 128, false,
cutlass::arch::OpXorPopc>;
@ -228,7 +228,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 64x64x512_32x32x512_8x8x1
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2, 128, 128, false,
cutlass::arch::OpXorPopc>;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -62,7 +62,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 128x256x512_64x64x512) {
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
@ -84,7 +84,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 256x128x512_64x64x512) {
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
@ -106,7 +106,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 128x128x512_64x64x512) {
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
@ -128,7 +128,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 64x128x512_32x64x512) {
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
@ -150,7 +150,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 128x64x512_64x32x512) {
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
@ -172,7 +172,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 64x64x512_32x32x512) {
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());

View File

@ -0,0 +1,374 @@
/**************************************************************************************************
Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted
provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this list of
conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this list of
conditions and the following disclaimer in the documentation and/or other materials
provided with the distribution.
* Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
to endorse or promote products derived from this software without specific prior written
permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed.h"
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 128x256x1024_64x64x1024, {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 256, 1024>,
cutlass::gemm::GemmShape<64, 64, 1024>,
cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 256x128x1024_64x64x1024, {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 1024>,
cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 128x128x1024_64x64x1024, {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 1024>,
cutlass::gemm::GemmShape<64, 64, 1024>,
cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 256x64x1024_64x64x1024, {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 64, 1024>,
cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 64x256x1024_64x64x1024, {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 256, 1024>,
cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 64x128x1024_32x64x1024, {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 1024>,
cutlass::gemm::GemmShape<32, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 128x64x1024_64x32x1024, {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 1024>,
cutlass::gemm::GemmShape<64, 32, 1024>, cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 64x64x1024_32x32x1024, {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 1024>,
cutlass::gemm::GemmShape<32, 32, 1024>, cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 128x256x512_64x64x512, {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 256, 512>,
cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 256x128x512_64x64x512, {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 512>,
cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 128x128x512_64x64x512, {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 512>,
cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 256x64x512_64x64x512, {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 64, 512>,
cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 64x256x512_64x64x512, {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 256, 512>,
cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 64x128x512_32x64x512, {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 512>,
cutlass::gemm::GemmShape<32, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 128x64x512_64x32x512, {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 512>,
cutlass::gemm::GemmShape<64, 32, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 64x64x512_32x32x512, {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 512>,
cutlass::gemm::GemmShape<32, 32, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6, 128, 128,
false, cutlass::arch::OpXorPopc>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -72,7 +72,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 128x256x512_64x64x512_8x8
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2, 128, 128, false,
cutlass::arch::OpXorPopc
>;
@ -104,7 +104,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 256x128x512_64x64x512_8x8
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2, 128, 128, false,
cutlass::arch::OpXorPopc>;
@ -135,7 +135,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 128x128x512_64x64x512_8x8
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2, 128, 128, false,
cutlass::arch::OpXorPopc>;
@ -166,7 +166,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 64x128x512_32x64x512_8x8x
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2, 128, 128, false,
cutlass::arch::OpXorPopc>;
@ -197,7 +197,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 128x64x512_64x32x512_8x8x
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2, 128, 128, false,
cutlass::arch::OpXorPopc>;
@ -228,7 +228,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 64x64x512_32x32x512_8x8x1
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2, 128, 128, false,
cutlass::arch::OpXorPopc>;

View File

@ -0,0 +1,353 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed.h"
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 128x256x64_64x64x64) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 256, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 256x128x64_64x64x64) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 128x128x64_64x64x64) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 256x64x64_64x64x64) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 64, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 64x256x64_64x64x64) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 256, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 64x128x64_32x64x64) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 128x64x64_64x32x64) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 64x64x64_32x32x64) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 128x256x32_64x64x32) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 256x128x32_64x64x32) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 128x128x32_64x64x32) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 256x64x32_64x64x32) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 64, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 64x256x32_64x64x32) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 64x128x32_32x64x32) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 128x64x32_64x32x32) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 64x64x32_32x32x32) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)

View File

@ -0,0 +1,337 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed.h"
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x256x64_64x64x64) {
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 256, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x128x64_64x64x64) {
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x128x64_64x64x64) {
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x64x64_64x64x64) {
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 64, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x256x64_64x64x64) {
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 256, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x128x64_32x64x64) {
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x64x64_64x32x64) {
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x64x64_32x32x64) {
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x256x32_64x64x32) {
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x128x32_64x64x32) {
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x128x32_64x64x32) {
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x64x32_64x64x32) {
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 64, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x256x32_64x64x32) {
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x128x32_32x64x32) {
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x64x32_64x32x32) {
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x64x32_32x32x32) {
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)

View File

@ -0,0 +1,253 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_complex.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed_complex.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
// Operands data type: complex<float>
// Rounding: float -> tfloat32_t (half_ulp_truncate)
// Instruction operand data type: tfloat32_t (real part) and tfloat32_t (imaginary part)
// Math instruction: MMA.1688.F32.TF32
// Instruction output/accumulation data type: f32 (real part) and f32 (imaginary part)
// Output data type: complex<float>
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 32x32x16_16x16x16) {
using Element = cutlass::complex<float>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 64x64x16_16x32x16) {
using Element = cutlass::complex<float>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<16, 32, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 64x64x16_32x32x16) {
using Element = cutlass::complex<float>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 128x64x16_64x32x16) {
using Element = cutlass::complex<float>;;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 16>,
cutlass::gemm::GemmShape<64, 32, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 64x128x16_32x64x16) {
using Element = cutlass::complex<float>;;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 16>,
cutlass::gemm::GemmShape<32, 64, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 128x128x16_32x64x16) {
using Element = cutlass::complex<float>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 16>,
cutlass::gemm::GemmShape<32, 64, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,252 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_complex.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed_complex.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
// Operands data type: complex<float>
// Rounding: float -> tfloat32_t (round to nearest)
// Instruction operand data type: tfloat32_t (real part) and tfloat32_t (imaginary part)
// Math instruction: MMA.1688.F32.TF32
// Instruction output/accumulation data type: f32 (real part) and f32 (imaginary part)
// Output data type: complex<float>
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 32x32x16_16x16x16) {
using Element = cutlass::complex<float>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 64x64x16_16x32x16) {
using Element = cutlass::complex<float>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<16, 32, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 64x64x16_32x32x16) {
using Element = cutlass::complex<float>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 128x64x16_64x32x16) {
using Element = cutlass::complex<float>;;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 16>,
cutlass::gemm::GemmShape<64, 32, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 64x128x16_32x64x16) {
using Element = cutlass::complex<float>;;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 16>,
cutlass::gemm::GemmShape<32, 64, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 128x128x16_32x64x16) {
using Element = cutlass::complex<float>;;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 16>,
cutlass::gemm::GemmShape<32, 64, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,192 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_complex.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed_complex.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 32x32x16_16x16x16) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
cutlass::arch::OpMultiplyAddGaussianComplex
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 32x32x8_16x16x8) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<32, 32, 8>,
cutlass::gemm::GemmShape<16, 16, 8>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
cutlass::arch::OpMultiplyAddGaussianComplex
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 64x64x16_16x32x16) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<16, 32, 16>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
cutlass::arch::OpMultiplyAddGaussianComplex
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 64x64x8_16x32x8) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 8>,
cutlass::gemm::GemmShape<16, 32, 8>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
cutlass::arch::OpMultiplyAddGaussianComplex
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,246 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_complex.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed_complex.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 32x32x16_16x16x16) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 32x32x8_16x16x8) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<32, 32, 8>,
cutlass::gemm::GemmShape<16, 16, 8>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x16_16x32x16) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<16, 32, 16>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x8_16x32x8) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 8>,
cutlass::gemm::GemmShape<16, 32, 8>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x16_32x32x16) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x8_32x32x8) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 8>,
cutlass::gemm::GemmShape<32, 32, 8>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,191 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_complex.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed_complex.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 32x32x8_16x16x8) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<32, 32, 8>,
cutlass::gemm::GemmShape<16, 16, 8>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
cutlass::arch::OpMultiplyAddGaussianComplex
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 64x64x8_32x16x8) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 8>,
cutlass::gemm::GemmShape<32, 16, 8>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
cutlass::arch::OpMultiplyAddGaussianComplex
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 32x32x16_16x16x16) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
cutlass::arch::OpMultiplyAddGaussianComplex
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 64x64x16_32x16x16) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<32, 16, 16>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
cutlass::arch::OpMultiplyAddGaussianComplex
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,299 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_complex.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed_complex.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 32x32x8_16x16x8) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<32, 32, 8>,
cutlass::gemm::GemmShape<16, 16, 8>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x64x8_32x32x8) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 8>,
cutlass::gemm::GemmShape<32, 32, 8>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x128x8_32x32x8) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 8>,
cutlass::gemm::GemmShape<32, 32, 8>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 128x64x8_32x32x8) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 8>,
cutlass::gemm::GemmShape<32, 32, 8>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 32x32x16_16x16x16) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x64x16_32x32x16) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x128x16_32x32x16) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 16>,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 128x64x16_32x32x16) {
using Element = cutlass::complex<double>;
using Gemm = cutlass::gemm::device::GemmComplex<
Element,
cutlass::layout::RowMajor,
Element,
cutlass::layout::ColumnMajor,
Element,
cutlass::layout::RowMajor,
Element,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 16>,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
Element,
1,
Element,
Element
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -72,7 +72,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -107,7 +107,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_32x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -141,7 +141,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_8x3
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -71,7 +71,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16n_wmma_tensor_op_f32, 128x128x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -105,7 +105,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16n_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -139,7 +139,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16n_wmma_tensor_op_f32, 64x64x32_64x64x32_8x32x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -70,7 +70,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x256x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -122,7 +122,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x128x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -153,7 +153,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x128x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -205,7 +205,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x128x32_32x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -258,7 +258,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x64x32_64x32x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -289,7 +289,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x64x32_32x32x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -0,0 +1,338 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed.h"
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x256x64_64x64x64) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 256, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x128x64_64x64x64) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x128x64_64x64x64) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x64x64_64x64x64) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 64, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x256x64_64x64x64) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 256, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x128x64_32x64x64) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x64x64_64x32x64) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x64x64_32x32x64) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x256x32_64x64x32) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x128x32_64x64x32) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x128x32_64x64x32) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x64x32_64x64x32) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 64, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x256x32_64x64x32) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x128x32_32x64x32) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x64x32_64x32x32) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x64x32_32x32x32) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -70,7 +70,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 128x256x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -101,7 +101,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 256x128x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -132,7 +132,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 128x128x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -163,7 +163,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 128x64x32_64x32x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -194,7 +194,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 64x128x32_32x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -225,7 +225,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 64x64x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -256,7 +256,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 64x64x32_32x32x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -71,7 +71,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 64x64x32_64x64x32_16x16
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -102,7 +102,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 64x128x32_64x64x32_16x1
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -134,7 +134,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 128x64x32_64x64x32_16x1
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -165,7 +165,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -196,7 +196,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 128x256x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -227,7 +227,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 256x128x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -258,7 +258,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 128x64x32_64x32x32_16x1
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -289,7 +289,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 64x128x32_32x64x32_16x1
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -321,7 +321,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 64x64x32_32x32x32_16x16
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -355,7 +355,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_32x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -389,7 +389,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_8x3
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -71,7 +71,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -102,7 +102,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 64x128x32_64x64x32_16x1
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -133,7 +133,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 128x64x32_64x64x32_16x1
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -164,7 +164,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -195,7 +195,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -226,7 +226,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -257,7 +257,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x1
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -288,7 +288,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 64x128x32_32x64x32_16x1
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -320,7 +320,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -354,7 +354,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -388,7 +388,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_8x32x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -70,7 +70,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x256x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -122,7 +122,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 256x128x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -153,7 +153,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x128x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -205,7 +205,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x128x32_32x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -258,7 +258,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x64x32_64x32x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -289,7 +289,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x64x32_32x32x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -0,0 +1,337 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed.h"
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x256x64_64x64x64) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 256, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 256x128x64_64x64x64) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x128x64_64x64x64) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 256x64x64_64x64x64) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 64, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x256x64_64x64x64) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 256, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x128x64_32x64x64) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x64x64_64x32x64) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x64x64_32x32x64) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x256x32_64x64x32) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 256x128x32_64x64x32) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x128x32_64x64x32) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 256x64x32_64x64x32) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 64, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x256x32_64x64x32) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x128x32_32x64x32) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x64x32_64x32x32) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x64x32_32x32x32) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -73,7 +73,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32n_wmma_tensor_op_f32, 256x128x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -108,7 +108,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_32x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -142,7 +142,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_8x3
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -70,7 +70,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x256x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -122,7 +122,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x128x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -153,7 +153,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x128x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -205,7 +205,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x128x32_32x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -258,7 +258,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x64x32_64x32x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -289,7 +289,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x64x32_32x32x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -0,0 +1,340 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed.h"
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x256x64_64x64x64, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 256, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x128x64_64x64x64, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x128x64_64x64x64, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x64x64_64x64x64, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 64, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x256x64_64x64x64, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 256, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x128x64_32x64x64, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<32, 64, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x64x64_64x32x64, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x64x64_32x32x64, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x256x32_64x64x32, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x128x32_64x64x32, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x128x32_64x64x32, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x64x32_64x64x32, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 64, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x256x32_64x64x32, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x128x32_32x64x32, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x64x32_64x32x32, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x64x32_32x32x32, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -70,7 +70,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 128x256x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -101,7 +101,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 256x128x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -132,7 +132,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 128x128x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -163,7 +163,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 128x64x32_64x32x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -194,7 +194,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 64x128x32_32x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -225,7 +225,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 64x64x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -256,7 +256,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 64x64x32_32x32x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -72,7 +72,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -103,7 +103,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -134,7 +134,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -165,7 +165,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -196,7 +196,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x1
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -227,7 +227,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 64x128x32_64x32x32_16x1
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -258,7 +258,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -293,7 +293,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_32x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -327,7 +327,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_8x3
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -72,7 +72,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -107,7 +107,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_32x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -141,7 +141,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_8x3
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -71,7 +71,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16n_wmma_tensor_op_f32, 128x128x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -105,7 +105,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16n_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -139,7 +139,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16n_wmma_tensor_op_f32, 64x64x32_64x64x32_8x32x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -70,7 +70,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16_sliced_k, 64x64x64_64x32x32)
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -0,0 +1,82 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "../../common/cutlass_unit_test.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/gemm.h"
#include "testbed.h"
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16_sliced_k, 128x64x64_64x64x32) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t,
cutlass::layout::ColumnMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
64 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -70,7 +70,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x256x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -101,7 +101,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x128x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -132,7 +132,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x128x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -163,7 +163,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x128x32_32x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -194,7 +194,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x64x32_64x32x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -225,7 +225,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x64x32_32x32x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -0,0 +1,338 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed.h"
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x256x64_64x64x64) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 256, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x128x64_64x64x64) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x128x64_64x64x64) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x64x64_64x64x64) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 64, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x256x64_64x64x64) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 256, 64> ,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x128x64_32x64x64) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<32, 64, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x64x64_64x32x64) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x64x64_32x32x64) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x256x32_64x64x32) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x128x32_64x64x32) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x128x32_64x64x32) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x64x32_64x64x32) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 64, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x256x32_64x64x32) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x128x32_32x64x32) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x64x32_64x32x32) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x64x32_32x32x32) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)

View File

@ -0,0 +1,77 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed.h"
////////////////////////////////////////////////////////////////////////////////
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_GemmUniversal_f16n_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32) {
/*
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 10>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
*/
}
////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -63,7 +63,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 128x256x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -94,7 +94,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 256x128x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -125,7 +125,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 128x128x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -156,7 +156,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 128x64x32_64x32x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -187,7 +187,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 64x128x32_32x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -218,7 +218,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 64x64x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -249,7 +249,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 64x64x32_32x32x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -71,7 +71,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 64x64x32_64x64x32_16x16
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -102,7 +102,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 64x128x32_64x64x32_16x1
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -134,7 +134,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 128x64x32_64x64x32_16x1
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -165,7 +165,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -196,7 +196,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 128x256x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -227,7 +227,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 256x128x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -258,7 +258,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 128x64x32_64x32x32_16x1
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -289,7 +289,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 64x128x32_32x64x32_16x1
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -321,7 +321,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 64x64x32_32x32x32_16x16
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -355,7 +355,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_32x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -389,7 +389,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_8x3
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -71,7 +71,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -73,7 +73,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -108,7 +108,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_32x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -142,7 +142,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_8x3
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -70,7 +70,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x256x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -101,7 +101,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x128x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -132,7 +132,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x128x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -163,7 +163,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x128x32_32x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -194,7 +194,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x64x32_64x32x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -225,7 +225,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -0,0 +1,339 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed.h"
#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x256x64_64x64x64, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 256, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x128x64_64x64x64, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x128x64_64x64x64, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x64x64_64x64x64, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 64, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x256x64_64x64x64, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 256, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x128x64_32x64x64, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x64x64_64x32x64, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x64x64_32x32x64, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x256x32_64x64x32, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x128x32_64x64x32, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x128x32_64x64x32, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x64x32_64x64x32, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 64, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x256x32_64x64x32, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x128x32_32x64x32, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x64x32_64x32x32, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32, {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
////////////////////////////////////////////////////////////////////////////////
#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -63,7 +63,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 128x256x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -94,7 +94,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 256x128x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -125,7 +125,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 128x128x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -156,7 +156,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 128x64x32_64x32x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -187,7 +187,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 64x128x32_32x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -218,7 +218,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 64x64x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -249,7 +249,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 64x64x32_32x32x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -72,7 +72,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -103,7 +103,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -134,7 +134,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -165,7 +165,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -196,7 +196,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x1
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -227,7 +227,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 64x128x32_64x32x32_16x1
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -258,7 +258,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -293,7 +293,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_32x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -327,7 +327,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_8x3
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -73,7 +73,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 128x256x32_
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
kStages
>;
@ -105,7 +105,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 128x64x32_6
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
kStages
>;
@ -137,7 +137,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 64x128x32_6
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
kStages
>;
@ -170,7 +170,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 64x64x32_32
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
kStages
>;
@ -202,7 +202,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 64x64x64_32
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
kStages
>;
@ -234,7 +234,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 128x128x64_
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
kStages
>;
@ -270,7 +270,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 128x128x32_
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
kStages
>;
@ -305,7 +305,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 128x128x32_
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
kStages
>;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -72,7 +72,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -107,7 +107,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_32x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -141,7 +141,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_8x3
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -71,7 +71,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_wmma_tensor_op_f32, 128x128x32_64x64x16_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -106,7 +106,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_wmma_tensor_op_f32, 64x64x32_64x64x16_32x8x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -140,7 +140,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_wmma_tensor_op_f32, 64x64x32_64x64x16_8x32x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -73,7 +73,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 128x256x32_
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
kStages
>;
@ -105,7 +105,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 128x64x32_6
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
kStages
>;
@ -137,7 +137,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 64x128x32_6
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
kStages
>;
@ -170,7 +170,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 64x64x32_32
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
kStages
>;
@ -202,7 +202,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 64x64x64_32
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
kStages
>;
@ -234,7 +234,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 128x128x64_
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
kStages
>;
@ -270,7 +270,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 128x128x32_
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
kStages
>;
@ -305,7 +305,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 128x128x32_
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
kStages
>;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -70,7 +70,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16_sliced_k, 64x64x64_64x32x32)
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -0,0 +1,83 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "../../common/cutlass_unit_test.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/gemm.h"
#include "testbed.h"
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16_sliced_k, 128x64x64_64x64x32) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
64 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3
>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -70,7 +70,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x256x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -101,7 +101,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x128x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -132,7 +132,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x128x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -163,7 +163,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x128x32_32x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -194,7 +194,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x64x32_64x32x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -225,7 +225,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x64x32_32x32x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -0,0 +1,339 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed.h"
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x256x64_64x64x64) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 256, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x128x64_64x64x64) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x128x64_64x64x64) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x64x64_64x64x64) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 64, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x256x64_64x64x64) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 256, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x128x64_32x64x64) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x64x64_64x32x64) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x64x64_32x32x64) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x256x32_64x64x32) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x128x32_64x64x32) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x128x32_64x64x32) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x64x32_64x64x32) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 64, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x256x32_64x64x32) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x128x32_32x64x32) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x64x32_64x32x32) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x64x32_32x32x32) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -70,7 +70,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 128x256x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -101,7 +101,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 256x128x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -132,7 +132,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 128x128x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -163,7 +163,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 128x64x32_64x32x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -194,7 +194,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 64x128x32_32x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -225,7 +225,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 64x64x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -256,7 +256,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 64x64x32_32x32x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -71,7 +71,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 64x64x32_64x64x32_16x16
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -102,7 +102,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 64x128x32_64x64x32_16x1
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -134,7 +134,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 128x64x32_64x64x32_16x1
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -165,7 +165,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -196,7 +196,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 128x256x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -227,7 +227,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 256x128x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -258,7 +258,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 128x64x32_64x32x32_16x1
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -289,7 +289,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 64x128x32_32x64x32_16x1
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -321,7 +321,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 64x64x32_32x32x32_16x16
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -355,7 +355,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_32x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -389,7 +389,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_8x3
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -71,7 +71,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -102,7 +102,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 64x128x32_64x64x32_16x1
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -133,7 +133,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 128x64x32_64x64x32_16x1
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -164,7 +164,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -195,7 +195,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -226,7 +226,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -257,7 +257,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x1
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -288,7 +288,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 64x128x32_32x64x32_16x1
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -319,7 +319,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -353,7 +353,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -387,7 +387,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_8x32x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -72,7 +72,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -107,7 +107,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_32x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -141,7 +141,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_8x3
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -74,7 +74,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32, 128x64x32_6
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
kStages
>;
@ -106,7 +106,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32, 64x128x32_6
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
kStages
>;
@ -138,7 +138,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32, 64x64x32_32
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
kStages
>;
@ -174,7 +174,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32, 128x128x32_
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
kStages
>;
@ -209,7 +209,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32, 128x128x32_
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
kStages
>;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -70,7 +70,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x256x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -101,7 +101,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x128x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -132,7 +132,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x128x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -163,7 +163,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x128x32_32x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -194,7 +194,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x64x32_64x32x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -225,7 +225,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x64x32_32x32x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -0,0 +1,338 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed.h"
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x256x64_64x64x64) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 256, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x128x64_64x64x64) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x128x64_64x64x64) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x64x64_64x64x64) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 64, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x256x64_64x64x64) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 256, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x128x64_32x64x64) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x64x64_64x32x64) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x64x64_32x32x64) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x256x32_64x64x32) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x128x32_64x64x32) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x128x32_64x64x32) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x64x32_64x64x32) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 64, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x256x32_64x64x32) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x128x32_32x64x32) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x64x32_64x32x32) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x64x32_32x32x32) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -70,7 +70,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 128x256x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -101,7 +101,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 256x128x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -132,7 +132,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 128x128x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -163,7 +163,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 128x64x32_64x32x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -194,7 +194,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 64x128x32_32x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -225,7 +225,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 64x64x32_64x64x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -256,7 +256,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 64x64x32_32x32x32) {
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -72,7 +72,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -103,7 +103,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -134,7 +134,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -165,7 +165,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -196,7 +196,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x1
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -227,7 +227,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 64x128x32_64x32x32_16x1
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -258,7 +258,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -293,7 +293,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_32x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -327,7 +327,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_8x3
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -72,7 +72,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_16x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -107,7 +107,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_32x
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;
@ -141,7 +141,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_8x3
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
2
>;

Some files were not shown because too many files have changed in this diff Show More