CUTLASS 2.2 (#96)
Adds support for NVIDIA Ampere Architecture features. CUDA 11 Toolkit recommended.
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
@ -50,6 +50,7 @@ target_link_libraries(
|
||||
)
|
||||
|
||||
set(CUTLASS_INSTALL_TESTS ON CACHE BOOL "Install test executables")
|
||||
set(CUTLASS_TEST_EXECUTION_ENVIRONMENT "" CACHE BOOL "Environment in which to invoke unit test executables")
|
||||
|
||||
function(cutlass_test_unit_add_executable)
|
||||
|
||||
@ -76,7 +77,7 @@ function(cutlass_test_unit_add_executable)
|
||||
add_custom_target(
|
||||
${NAME_STEM}
|
||||
COMMAND
|
||||
$<TARGET_FILE:${NAME}>
|
||||
${CUTLASS_TEST_EXECUTION_ENVIRONMENT} $<TARGET_FILE:${NAME}>
|
||||
DEPENDS
|
||||
${NAME}
|
||||
)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -71,6 +71,7 @@ void FilterArchitecture() {
|
||||
{ "SM61*", 61, kMaxDevice},
|
||||
{ "SM70*", 70, 75},
|
||||
{ "SM75*", 75, kMaxDevice},
|
||||
{ "SM80*", 80, kMaxDevice},
|
||||
{ 0, 0, false }
|
||||
};
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
@ -24,6 +24,8 @@ cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_core
|
||||
array.cu
|
||||
half.cu
|
||||
bfloat16.cu
|
||||
tfloat32.cu
|
||||
complex.cu
|
||||
predicate_vector.cu
|
||||
tensor_ref.cu
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -228,6 +228,14 @@ TEST(Array, Float16x8) {
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST(Array, FloatBF16x8) {
|
||||
TestArray<cutlass::bfloat16_t, 8>().run();
|
||||
}
|
||||
|
||||
TEST(Array, FloatTF32x4) {
|
||||
TestArray<cutlass::tfloat32_t, 4>().run();
|
||||
}
|
||||
|
||||
TEST(Array, Float32x4) {
|
||||
TestArray<float, 4>().run();
|
||||
}
|
||||
|
||||
209
test/unit/core/bfloat16.cu
Normal file
209
test/unit/core/bfloat16.cu
Normal file
@ -0,0 +1,209 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types
|
||||
and is safe to use in a union.
|
||||
*/
|
||||
|
||||
#include "../common/cutlass_unit_test.h"
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/core_io.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
|
||||
#include "cutlass/util/device_memory.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
__global__ void convert_bf16_f32(cutlass::bfloat16_t *output, float const *input, int N) {
|
||||
int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (tid < N) {
|
||||
output[tid] = static_cast<cutlass::bfloat16_t>(input[tid]);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void convert_and_pack_bf16(cutlass::bfloat16_t *output, float const *input, int N) {
|
||||
int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (tid * 2 < N) {
|
||||
|
||||
cutlass::NumericArrayConverter<cutlass::bfloat16_t, float, 2> convert;
|
||||
|
||||
cutlass::Array<cutlass::bfloat16_t, 2> *dst_ptr =
|
||||
reinterpret_cast<cutlass::Array<cutlass::bfloat16_t, 2> *>(output + tid * 2);
|
||||
|
||||
cutlass::Array<float, 2> const *src_ptr =
|
||||
reinterpret_cast<cutlass::Array<float, 2> const *>(input + tid * 2);
|
||||
|
||||
*dst_ptr = convert(*src_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(bfloat16_t, device_conversion) {
|
||||
using T = cutlass::bfloat16_t;
|
||||
using S = float;
|
||||
|
||||
int const N = 256;
|
||||
|
||||
cutlass::HostTensor<T, cutlass::layout::RowMajor> destination({N, 1});
|
||||
cutlass::HostTensor<S, cutlass::layout::RowMajor> source({N, 1});
|
||||
|
||||
for (int i = 0; i < N; ++i) {
|
||||
source.at({i, 0}) = float(i - 128);
|
||||
destination.at({i, 0}) = T(0);
|
||||
}
|
||||
|
||||
source.sync_device();
|
||||
destination.sync_device();
|
||||
|
||||
convert_bf16_f32<<< dim3(1,1), dim3(N, 1) >>>(destination.device_data(), source.device_data(), N);
|
||||
|
||||
ASSERT_EQ(cudaGetLastError(), cudaSuccess) << "Kernel launch error.";
|
||||
|
||||
destination.sync_host();
|
||||
|
||||
int errors = 0;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
T got = destination.at({i, 0});
|
||||
S expected = source.at({i, 0});
|
||||
|
||||
if (S(got) != expected) {
|
||||
++errors;
|
||||
if (errors < 10) {
|
||||
std::cerr << "Basic conversion error - [" << i << "] - got " << got << ", expected " << expected << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
destination.at({i, 0}) = T(0);
|
||||
}
|
||||
|
||||
destination.sync_device();
|
||||
|
||||
convert_and_pack_bf16<<< dim3(1,1), dim3(N, 1) >>>(destination.device_data(), source.device_data(), N);
|
||||
|
||||
ASSERT_EQ(cudaGetLastError(), cudaSuccess) << "Kernel launch error.";
|
||||
|
||||
destination.sync_host();
|
||||
|
||||
for (int i = 0; i < N; ++i) {
|
||||
T got = destination.at({i, 0});
|
||||
S expected = source.at({i, 0});
|
||||
|
||||
if (S(got) != expected) {
|
||||
++errors;
|
||||
if (errors < 10) {
|
||||
std::cerr << "Convert and pack error - [" << i << "] - got " << got << ", expected " << expected << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EXPECT_EQ(errors, 0);
|
||||
}
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Host
|
||||
//
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(bfloat16_t, host_conversion) {
|
||||
for (int i = -128; i < 128; ++i) {
|
||||
float f = static_cast<float>(i);
|
||||
|
||||
cutlass::bfloat16_t x = static_cast<cutlass::bfloat16_t>(i);
|
||||
cutlass::bfloat16_t y = static_cast<cutlass::bfloat16_t>(f);
|
||||
|
||||
EXPECT_TRUE(static_cast<int>(x) == i);
|
||||
EXPECT_TRUE(static_cast<float>(y) == f);
|
||||
}
|
||||
|
||||
// Try out user-defined literals
|
||||
EXPECT_TRUE(cutlass::bfloat16_t(7) == 7_bf16);
|
||||
EXPECT_TRUE(7 == static_cast<int>(7_bf16));
|
||||
}
|
||||
|
||||
TEST(bfloat16_t, host_arithmetic) {
|
||||
|
||||
for (int i = -100; i < 100; ++i) {
|
||||
for (int j = -100; j < 100; ++j) {
|
||||
|
||||
cutlass::bfloat16_t x = static_cast<cutlass::bfloat16_t>(i);
|
||||
cutlass::bfloat16_t y = static_cast<cutlass::bfloat16_t>(j);
|
||||
|
||||
EXPECT_TRUE(static_cast<int>(x + y) == (i + j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(bfloat16_t, host_round) {
|
||||
|
||||
struct {
|
||||
uint32_t f32_bits;
|
||||
uint16_t expected;
|
||||
} tests[] = {
|
||||
{0x40040000, 0x4004}, // M=0, R=0, S=0 => rtz
|
||||
{0x40048000, 0x4004}, // M=0, R=1, S=0 => rtz
|
||||
{0x40040001, 0x4004}, // M=0, R=1, S=1 => +inf
|
||||
{0x4004c000, 0x4005}, // M=0, R=1, S=1 => +inf
|
||||
{0x4004a000, 0x4005}, // M=0, R=1, S=1 => +inf
|
||||
{0x40050000, 0x4005}, // M=1, R=0, S=0 => rtz
|
||||
{0x40054000, 0x4005}, // M=1, R=0, S=1 => rtz
|
||||
{0x40058000, 0x4006}, // M=1, R=1, S=0 => +inf
|
||||
{0x40058001, 0x4006}, // M=1, R=1, S=1 => +inf
|
||||
{0x7f800000, 0x7f80}, // +inf
|
||||
{0xff800000, 0xff80}, // -inf
|
||||
{0x7fffffff, 0x7fff}, // canonical NaN
|
||||
{0x7ff00001, 0x7fff}, // NaN -> canonical NaN
|
||||
{0xfff00010, 0x7fff}, // Nan -> canonical NaN
|
||||
{0, 0}
|
||||
};
|
||||
|
||||
bool running = true;
|
||||
for (int i = 0; running; ++i) {
|
||||
|
||||
float f32 = reinterpret_cast<float const &>(tests[i].f32_bits);
|
||||
|
||||
cutlass::bfloat16_t bf16 = cutlass::bfloat16_t(f32);
|
||||
|
||||
bool passed = (tests[i].expected == bf16.raw());
|
||||
|
||||
EXPECT_TRUE(passed)
|
||||
<< "Error - convert(f32: 0x" << std::hex << tests[i].f32_bits
|
||||
<< ") -> 0x" << std::hex << tests[i].expected << "\ngot: 0x" << std::hex << bf16.raw();
|
||||
|
||||
if (!tests[i].f32_bits) {
|
||||
running = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Device
|
||||
//
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -411,3 +411,13 @@ TEST(Functional, multiply_add_f16x17) {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Functional, multiply_add_bf16x16) {
|
||||
Functional_multiply_add_TxN<cutlass::bfloat16_t, 16>();
|
||||
}
|
||||
|
||||
TEST(Functional, multiply_add_bf16x17) {
|
||||
Functional_multiply_add_TxN<cutlass::bfloat16_t, 17>();
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
||||
197
test/unit/core/tfloat32.cu
Normal file
197
test/unit/core/tfloat32.cu
Normal file
@ -0,0 +1,197 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types
|
||||
and is safe to use in a union.
|
||||
*/
|
||||
|
||||
#include "../common/cutlass_unit_test.h"
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/util/device_memory.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Host
|
||||
//
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(tfloat32_t, host_conversion) {
|
||||
for (int i = -1024; i < 1024; ++i) {
|
||||
float f = static_cast<float>(i);
|
||||
|
||||
cutlass::tfloat32_t x = static_cast<cutlass::tfloat32_t>(i);
|
||||
cutlass::tfloat32_t y = static_cast<cutlass::tfloat32_t>(f);
|
||||
|
||||
EXPECT_TRUE(static_cast<int>(x) == i);
|
||||
EXPECT_TRUE(static_cast<float>(y) == f);
|
||||
}
|
||||
|
||||
// Try out user-defined literals
|
||||
EXPECT_TRUE(cutlass::tfloat32_t(7) == 7_tf32);
|
||||
EXPECT_TRUE(7 == static_cast<int>(7_tf32));
|
||||
}
|
||||
|
||||
TEST(tfloat32_t, host_arithmetic) {
|
||||
|
||||
for (int i = -100; i < 100; ++i) {
|
||||
for (int j = -100; j < 100; ++j) {
|
||||
|
||||
cutlass::tfloat32_t x = static_cast<cutlass::tfloat32_t>(i);
|
||||
cutlass::tfloat32_t y = static_cast<cutlass::tfloat32_t>(j);
|
||||
|
||||
EXPECT_TRUE(static_cast<int>(x + y) == (i + j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(tfloat32_t, host_round_nearest) {
|
||||
|
||||
struct {
|
||||
uint32_t f32_bits;
|
||||
uint32_t expected;
|
||||
} tests[] = {
|
||||
{0x40000000, 0x40000000}, // M=0, R=0, S=0 => rtz
|
||||
{0x40001000, 0x40000000}, // M=0, R=1, S=0 => rtz
|
||||
{0x40000001, 0x40000000}, // M=0, R=0, S=1 => rtz
|
||||
{0x40001001, 0x40002000}, // M=0, R=1, S=1 => +inf
|
||||
{0x40002000, 0x40002000}, // M=1, R=0, S=0 => rtz
|
||||
{0x40002001, 0x40002000}, // M=1, R=0, S=1 => rtz
|
||||
{0x40003000, 0x40004000}, // M=1, R=1, S=0 => +inf
|
||||
{0x40003001, 0x40004000}, // M=1, R=1, S=1 => +inf
|
||||
{0x7f800000, 0x7f800000}, // +inf
|
||||
{0xff800000, 0xff800000}, // -inf
|
||||
{0x7fffffff, 0x7fffffff}, // canonical NaN to canonical NaN
|
||||
{0x7f800001, 0x7fffffff}, // NaN to canonical NaN
|
||||
{0xff800001, 0x7fffffff}, // NaN to canonical NaN
|
||||
{0, 0}
|
||||
};
|
||||
|
||||
bool running = true;
|
||||
for (int i = 0; running; ++i) {
|
||||
|
||||
float f32 = reinterpret_cast<float const &>(tests[i].f32_bits);
|
||||
|
||||
cutlass::NumericConverter<
|
||||
cutlass::tfloat32_t,
|
||||
float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest> converter;
|
||||
|
||||
cutlass::tfloat32_t tf32 = converter(f32);
|
||||
|
||||
// note, we must explicitly truncate the low-order bits since they are not defined in TF32.
|
||||
if (cutlass::isfinite(tf32)) {
|
||||
tf32.storage &= 0xffffe000;
|
||||
}
|
||||
|
||||
bool passed = (tests[i].expected == tf32.raw());
|
||||
|
||||
EXPECT_TRUE(passed)
|
||||
<< "Error - convert(f32: 0x" << std::hex << tests[i].f32_bits
|
||||
<< ") -> 0x" << std::hex << tests[i].expected << "\ngot: 0x" << std::hex << tf32.raw();
|
||||
|
||||
if (!tests[i].f32_bits) {
|
||||
running = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace test {
|
||||
namespace core {
|
||||
|
||||
__global__ void convert_tf32_half_ulp(cutlass::tfloat32_t *out, float const *in) {
|
||||
|
||||
cutlass::NumericConverter<
|
||||
cutlass::tfloat32_t,
|
||||
float,
|
||||
cutlass::FloatRoundStyle::round_half_ulp_truncate> convert;
|
||||
|
||||
*out = convert(*in);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST(tfloat32_t, host_round_half_ulp) {
|
||||
|
||||
struct {
|
||||
uint32_t f32_bits;
|
||||
uint32_t expected;
|
||||
} tests[] = {
|
||||
{0x40001fff, 0x40002000},
|
||||
{0x40000000, 0x40000000}, // M=0, R=0, S=0 => rtz
|
||||
{0x40001000, 0x40002000}, // M=0, R=1, S=0 => rtz - this difers from RNE
|
||||
{0x40000001, 0x40000000}, // M=0, R=0, S=1 => rtz
|
||||
{0x40001001, 0x40002000}, // M=0, R=1, S=1 => +inf
|
||||
{0x40002000, 0x40002000}, // M=1, R=0, S=0 => rtz
|
||||
{0x40002001, 0x40002000}, // M=1, R=0, S=1 => rtz
|
||||
{0x40003000, 0x40004000}, // M=1, R=1, S=0 => +inf
|
||||
{0x40003001, 0x40004000}, // M=1, R=1, S=1 => +inf
|
||||
{0x7f800000, 0x7f800000}, // +inf
|
||||
{0xff800000, 0xff800000}, // -inf
|
||||
{0x7fffffff, 0x7fffffff}, // canonical NaN to canonical NaN
|
||||
{0x7f800001, 0x7f800001}, // NaN to NaN
|
||||
{0xff800001, 0xff800001}, // NaN to NaN
|
||||
{0, 0}
|
||||
};
|
||||
|
||||
cutlass::NumericConverter<
|
||||
cutlass::tfloat32_t,
|
||||
float,
|
||||
cutlass::FloatRoundStyle::round_half_ulp_truncate> convert;
|
||||
|
||||
bool running = true;
|
||||
for (int i = 0; running; ++i) {
|
||||
|
||||
float f32 = reinterpret_cast<float const &>(tests[i].f32_bits);
|
||||
|
||||
cutlass::tfloat32_t tf32 = convert(f32);
|
||||
|
||||
// note, for this test, we must explicitly truncate the low-order bits since they are not
|
||||
// defined in TF32.
|
||||
if (cutlass::isfinite(tf32)) {
|
||||
tf32.storage &= 0xffffe000;
|
||||
}
|
||||
|
||||
bool passed = (tests[i].expected == tf32.raw());
|
||||
|
||||
EXPECT_TRUE(passed)
|
||||
<< "Error - convert(f32: 0x" << std::hex << tests[i].f32_bits
|
||||
<< ") -> 0x" << std::hex << tests[i].expected << "\ngot: 0x" << std::hex << tf32.raw();
|
||||
|
||||
if (!tests[i].f32_bits) {
|
||||
running = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Device
|
||||
//
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,4 +1,4 @@
|
||||
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -758,6 +758,65 @@ TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_128x128_64x64x16) {
|
||||
EXPECT_TRUE(passed);
|
||||
}
|
||||
|
||||
TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_64x128_64x64x16) {
|
||||
|
||||
//
|
||||
// Define the warp-level matrix multiply
|
||||
//
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int;
|
||||
using ElementCompute = float;
|
||||
int const kElementsPerAccess = 128 / cutlass::sizeof_bits<ElementOutput>::value;
|
||||
int const kPartitionsK = 1;
|
||||
|
||||
using Shape = cutlass::gemm::GemmShape<128, 128, 16>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
using Element = ElementOutput;
|
||||
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||
cutlass::sizeof_bits<Element>::value, 64>;
|
||||
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||
cutlass::sizeof_bits<Element>::value, 64>;
|
||||
|
||||
using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||
WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator,
|
||||
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type;
|
||||
|
||||
//
|
||||
// Output operator
|
||||
//
|
||||
|
||||
using OutputOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput,
|
||||
kElementsPerAccess,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
//
|
||||
// Define the epilogue
|
||||
//
|
||||
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
kPartitionsK,
|
||||
OutputOp,
|
||||
kElementsPerAccess
|
||||
>::Epilogue;
|
||||
|
||||
//
|
||||
// Instantiate epilogue
|
||||
//
|
||||
|
||||
EpilogueTestbed<Epilogue> testbed;
|
||||
|
||||
bool passed = testbed.run_all();
|
||||
|
||||
EXPECT_TRUE(passed);
|
||||
}
|
||||
|
||||
TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_128x64_64x32x16) {
|
||||
|
||||
//
|
||||
@ -2516,6 +2575,249 @@ TEST(SM75_Epilogue_threadblock_epilogue, f16_tensor_op_128x64_64x32x8) {
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Epilogue_threadblock_epilogue, f64_tensor_op_64x64_32x32x4) {
|
||||
|
||||
//
|
||||
// Define the warp-level matrix multiply
|
||||
//
|
||||
|
||||
using ElementOutput = double;
|
||||
using ElementAccumulator = double;
|
||||
using ElementCompute = double;
|
||||
int const kElementsPerAccess = 1;
|
||||
int const kPartitionsK = 1;
|
||||
|
||||
using Shape = cutlass::gemm::GemmShape<64, 64, 16>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
|
||||
using Element = double;
|
||||
using ElementC = ElementAccumulator;
|
||||
using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b;
|
||||
using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
|
||||
using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||
WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC,
|
||||
LayoutC>::Type;
|
||||
|
||||
//
|
||||
// Output operator
|
||||
//
|
||||
|
||||
using OutputOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput,
|
||||
kElementsPerAccess,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
//
|
||||
// Define the epilogue
|
||||
//
|
||||
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
kPartitionsK,
|
||||
OutputOp,
|
||||
kElementsPerAccess
|
||||
>::Epilogue;
|
||||
|
||||
//
|
||||
// Instantiate epilogue
|
||||
//
|
||||
|
||||
EpilogueTestbed<Epilogue> testbed;
|
||||
|
||||
bool passed = testbed.run_all();
|
||||
|
||||
EXPECT_TRUE(passed);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Epilogue_threadblock_epilogue, f64_tensor_op_128x64_64x32x4) {
|
||||
|
||||
//
|
||||
// Define the warp-level matrix multiply
|
||||
//
|
||||
|
||||
using ElementOutput = double;
|
||||
using ElementAccumulator = double;
|
||||
using ElementCompute = double;
|
||||
int const kElementsPerAccess = 1;
|
||||
int const kPartitionsK = 1;
|
||||
|
||||
using Shape = cutlass::gemm::GemmShape<64, 64, 16>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
|
||||
using Element = double;
|
||||
using ElementC = ElementAccumulator;
|
||||
using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b;
|
||||
using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
|
||||
using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||
WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC,
|
||||
LayoutC>::Type;
|
||||
|
||||
//
|
||||
// Output operator
|
||||
//
|
||||
|
||||
using OutputOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput,
|
||||
kElementsPerAccess,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
//
|
||||
// Define the epilogue
|
||||
//
|
||||
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
kPartitionsK,
|
||||
OutputOp,
|
||||
kElementsPerAccess
|
||||
>::Epilogue;
|
||||
|
||||
//
|
||||
// Instantiate epilogue
|
||||
//
|
||||
|
||||
EpilogueTestbed<Epilogue> testbed;
|
||||
|
||||
bool passed = testbed.run_all();
|
||||
|
||||
EXPECT_TRUE(passed);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Epilogue_threadblock_epilogue, f64_tensor_op_64x128_32x64x4) {
|
||||
|
||||
//
|
||||
// Define the warp-level matrix multiply
|
||||
//
|
||||
|
||||
using ElementOutput = double;
|
||||
using ElementAccumulator = double;
|
||||
using ElementCompute = double;
|
||||
int const kElementsPerAccess = 1;
|
||||
int const kPartitionsK = 1;
|
||||
|
||||
using Shape = cutlass::gemm::GemmShape<64, 64, 16>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
|
||||
using Element = double;
|
||||
using ElementC = ElementAccumulator;
|
||||
using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b;
|
||||
using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
|
||||
using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||
WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC,
|
||||
LayoutC>::Type;
|
||||
|
||||
//
|
||||
// Output operator
|
||||
//
|
||||
|
||||
using OutputOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput,
|
||||
kElementsPerAccess,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
//
|
||||
// Define the epilogue
|
||||
//
|
||||
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
kPartitionsK,
|
||||
OutputOp,
|
||||
kElementsPerAccess
|
||||
>::Epilogue;
|
||||
|
||||
//
|
||||
// Instantiate epilogue
|
||||
//
|
||||
|
||||
EpilogueTestbed<Epilogue> testbed;
|
||||
|
||||
bool passed = testbed.run_all();
|
||||
|
||||
EXPECT_TRUE(passed);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Epilogue_threadblock_epilogue, f64_tensor_op_128x128_32x64x4) {
|
||||
|
||||
//
|
||||
// Define the warp-level matrix multiply
|
||||
//
|
||||
|
||||
using ElementOutput = double;
|
||||
using ElementAccumulator = double;
|
||||
using ElementCompute = double;
|
||||
int const kElementsPerAccess = 1;
|
||||
int const kPartitionsK = 1;
|
||||
|
||||
using Shape = cutlass::gemm::GemmShape<128, 128, 16>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
|
||||
using Element = double;
|
||||
using ElementC = ElementAccumulator;
|
||||
using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b;
|
||||
using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
|
||||
using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||
WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC,
|
||||
LayoutC>::Type;
|
||||
|
||||
//
|
||||
// Output operator
|
||||
//
|
||||
|
||||
using OutputOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput,
|
||||
kElementsPerAccess,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
//
|
||||
// Define the epilogue
|
||||
//
|
||||
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
kPartitionsK,
|
||||
OutputOp,
|
||||
kElementsPerAccess
|
||||
>::Epilogue;
|
||||
|
||||
//
|
||||
// Instantiate epilogue
|
||||
//
|
||||
|
||||
EpilogueTestbed<Epilogue> testbed;
|
||||
|
||||
bool passed = testbed.run_all();
|
||||
|
||||
EXPECT_TRUE(passed);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM75_Epilogue_threadblock_epilogue, vec1_mixed_f16_f32_tensor_op_128x128_64x64x8) {
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
@ -26,6 +26,64 @@ cutlass_test_unit_add_executable(
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 4
|
||||
|
||||
gemm_planar_complex_f16_f16_f32_tensor_op_sm70.cu
|
||||
gemm_planar_complex_f16_f16_f32_tensor_op_sm75.cu
|
||||
gemm_planar_complex_f16_f16_f32_tensor_op_sm80.cu
|
||||
|
||||
gemm_universal_f16n_f16t_f32t_tensor_op_f32_sm80.cu
|
||||
gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu
|
||||
gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu
|
||||
gemm_universal_cf32n_cf32n_cf32n_tensor_op_f32_sm80.cu
|
||||
|
||||
gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu
|
||||
gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm80.cu
|
||||
|
||||
gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu
|
||||
gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm80.cu
|
||||
|
||||
gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32_sm80.cu
|
||||
gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32_sm80.cu
|
||||
|
||||
gemm_f16n_f16n_f16t_tensor_op_f32_sm80.cu
|
||||
gemm_f16n_f16n_f32n_tensor_op_f32_sm80.cu
|
||||
gemm_f16n_f16n_f32t_tensor_op_f32_sm80.cu
|
||||
gemm_f16n_f16t_f16t_tensor_op_f16_sm80.cu
|
||||
gemm_f16n_f16t_f32t_tensor_op_f32_sm80.cu
|
||||
gemm_f16t_f16n_f16t_tensor_op_f16_sm80.cu
|
||||
gemm_f16t_f16n_f32t_tensor_op_f32_sm80.cu
|
||||
gemm_f16t_f16t_f32n_tensor_op_f32_sm80.cu
|
||||
gemm_f16t_f16t_f32t_tensor_op_f32_sm80.cu
|
||||
gemm_bf16n_bf16n_f32t_tensor_op_f32_sm80.cu
|
||||
gemm_bf16t_bf16t_bf16t_tensor_op_f32_sm80.cu
|
||||
gemm_tf32t_tf32n_f32t_tensor_op_f32_sm80.cu
|
||||
gemm_tf32n_tf32t_f32t_tensor_op_f32_sm80.cu
|
||||
gemm_tf32n_tf32n_f32t_tensor_op_f32_sm80.cu
|
||||
gemm_tf32t_tf32t_f32t_tensor_op_f32_sm80.cu
|
||||
|
||||
gemm_f16t_f16n_f16t_tensor_op_f16_slicedk_sm80.cu
|
||||
gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm80.cu
|
||||
|
||||
simt_sgemm_nt_sm80.cu
|
||||
simt_sgemm_tn_sm80.cu
|
||||
|
||||
gemm_s8t_s8n_s32t_tensor_op_s32_sm80.cu
|
||||
gemm_s8t_s8n_s8n_tensor_op_s32_sm80.cu
|
||||
gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu
|
||||
gemm_s4t_s4n_s32n_tensor_op_s32_sm80.cu
|
||||
gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu
|
||||
gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu
|
||||
gemm_b1t_b1n_s32t_tensor_op_s32_sm80.cu
|
||||
|
||||
gemm_s8n_s8t_s8n_tensor_op_s32_sm80.cu
|
||||
gemm_s4n_s4t_s4n_tensor_op_s32_sm80.cu
|
||||
|
||||
gemm_f64n_f64t_f64t_tensor_op_f64_sm80.cu
|
||||
gemm_f64t_f64n_f64t_tensor_op_f64_sm80.cu
|
||||
|
||||
gemm_b1t_b1n_s32t_tensor_op_s32_sm75.cu
|
||||
gemm_b1t_b1n_s32n_tensor_op_s32_sm75.cu
|
||||
|
||||
gemm_f32n_f32n_f32t_tensor_op_f32_sm80.cu
|
||||
gemm_f16t_f16n_f16t_tensor_op_f16_sm75.cu
|
||||
gemm_f16n_f16t_f16t_tensor_op_f16_sm75.cu
|
||||
gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm75.cu
|
||||
@ -149,4 +207,5 @@ cutlass_test_unit_add_executable(
|
||||
gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16_sm70.cu
|
||||
|
||||
gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32_sm70.cu
|
||||
|
||||
)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -62,7 +62,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x256x512_64x64x512) {
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
@ -84,7 +84,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 256x128x512_64x64x512) {
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
@ -106,7 +106,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x128x512_64x64x512) {
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
@ -128,7 +128,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 64x128x512_32x64x512) {
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
@ -150,7 +150,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x64x512_64x32x512) {
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
@ -172,7 +172,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 64x64x512_32x32x512) {
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
|
||||
373
test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu
Normal file
373
test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu
Normal file
@ -0,0 +1,373 @@
|
||||
/**************************************************************************************************
|
||||
Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
provided that the following conditions are met:
|
||||
* Redistributions of source code must retain the above copyright notice, this list of
|
||||
conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
conditions and the following disclaimer in the documentation and/or other materials
|
||||
provided with the distribution.
|
||||
* Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
to endorse or promote products derived from this software without specific prior written
|
||||
permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "testbed.h"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x256x1024_64x64x1024) {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 256, 1024>,
|
||||
cutlass::gemm::GemmShape<64, 64, 1024>,
|
||||
cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x128x1024_64x64x1024) {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 128, 1024>,
|
||||
cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x128x1024_64x64x1024) {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 1024>,
|
||||
cutlass::gemm::GemmShape<64, 64, 1024>,
|
||||
cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x64x1024_64x64x1024) {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 64, 1024>,
|
||||
cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x256x1024_64x64x1024) {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 256, 1024>,
|
||||
cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x128x1024_32x64x1024) {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 128, 1024>,
|
||||
cutlass::gemm::GemmShape<32, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x64x1024_64x32x1024) {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 64, 1024>,
|
||||
cutlass::gemm::GemmShape<64, 32, 1024>, cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x64x1024_32x32x1024) {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 1024>,
|
||||
cutlass::gemm::GemmShape<32, 32, 1024>, cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x256x512_64x64x512) {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 256, 512>,
|
||||
cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x128x512_64x64x512) {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 128, 512>,
|
||||
cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x128x512_64x64x512) {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 512>,
|
||||
cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x64x512_64x64x512) {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 64, 512>,
|
||||
cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x256x512_64x64x512) {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 256, 512>,
|
||||
cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x128x512_32x64x512) {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 128, 512>,
|
||||
cutlass::gemm::GemmShape<32, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x64x512_64x32x512) {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 64, 512>,
|
||||
cutlass::gemm::GemmShape<64, 32, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x64x512_32x32x512) {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 512>,
|
||||
cutlass::gemm::GemmShape<32, 32, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -72,7 +72,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 128x256x512_64x64x512_8x8
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2, 128, 128, false,
|
||||
cutlass::arch::OpXorPopc
|
||||
>;
|
||||
@ -104,7 +104,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 256x128x512_64x64x512_8x8
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2, 128, 128, false,
|
||||
cutlass::arch::OpXorPopc>;
|
||||
|
||||
@ -135,7 +135,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 128x128x512_64x64x512_8x8
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2, 128, 128, false,
|
||||
cutlass::arch::OpXorPopc>;
|
||||
|
||||
@ -166,7 +166,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 64x128x512_32x64x512_8x8x
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2, 128, 128, false,
|
||||
cutlass::arch::OpXorPopc>;
|
||||
|
||||
@ -197,7 +197,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 128x64x512_64x32x512_8x8x
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2, 128, 128, false,
|
||||
cutlass::arch::OpXorPopc>;
|
||||
|
||||
@ -228,7 +228,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 64x64x512_32x32x512_8x8x1
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2, 128, 128, false,
|
||||
cutlass::arch::OpXorPopc>;
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -62,7 +62,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 128x256x512_64x64x512) {
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
@ -84,7 +84,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 256x128x512_64x64x512) {
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
@ -106,7 +106,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 128x128x512_64x64x512) {
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
@ -128,7 +128,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 64x128x512_32x64x512) {
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
@ -150,7 +150,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 128x64x512_64x32x512) {
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
@ -172,7 +172,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 64x64x512_32x32x512) {
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2, 128, 128,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
|
||||
374
test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm80.cu
Normal file
374
test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm80.cu
Normal file
@ -0,0 +1,374 @@
|
||||
/**************************************************************************************************
|
||||
Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
provided that the following conditions are met:
|
||||
* Redistributions of source code must retain the above copyright notice, this list of
|
||||
conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
conditions and the following disclaimer in the documentation and/or other materials
|
||||
provided with the distribution.
|
||||
* Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
to endorse or promote products derived from this software without specific prior written
|
||||
permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "testbed.h"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 128x256x1024_64x64x1024, {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 256, 1024>,
|
||||
cutlass::gemm::GemmShape<64, 64, 1024>,
|
||||
cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 256x128x1024_64x64x1024, {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 128, 1024>,
|
||||
cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 128x128x1024_64x64x1024, {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 1024>,
|
||||
cutlass::gemm::GemmShape<64, 64, 1024>,
|
||||
cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 256x64x1024_64x64x1024, {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 64, 1024>,
|
||||
cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 64x256x1024_64x64x1024, {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 256, 1024>,
|
||||
cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 64x128x1024_32x64x1024, {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 128, 1024>,
|
||||
cutlass::gemm::GemmShape<32, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 128x64x1024_64x32x1024, {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 64, 1024>,
|
||||
cutlass::gemm::GemmShape<64, 32, 1024>, cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 64x64x1024_32x32x1024, {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 1024>,
|
||||
cutlass::gemm::GemmShape<32, 32, 1024>, cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 128x256x512_64x64x512, {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 256, 512>,
|
||||
cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 256x128x512_64x64x512, {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 128, 512>,
|
||||
cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 128x128x512_64x64x512, {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 512>,
|
||||
cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 256x64x512_64x64x512, {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 64, 512>,
|
||||
cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 64x256x512_64x64x512, {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 256, 512>,
|
||||
cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 64x128x512_32x64x512, {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 128, 512>,
|
||||
cutlass::gemm::GemmShape<32, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 128x64x512_64x32x512, {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 64, 512>,
|
||||
cutlass::gemm::GemmShape<64, 32, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 64x64x512_32x32x512, {
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = int32_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 512>,
|
||||
cutlass::gemm::GemmShape<32, 32, 512>, cutlass::gemm::GemmShape<16, 8, 256>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6, 128, 128,
|
||||
false, cutlass::arch::OpXorPopc>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -72,7 +72,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 128x256x512_64x64x512_8x8
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2, 128, 128, false,
|
||||
cutlass::arch::OpXorPopc
|
||||
>;
|
||||
@ -104,7 +104,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 256x128x512_64x64x512_8x8
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2, 128, 128, false,
|
||||
cutlass::arch::OpXorPopc>;
|
||||
|
||||
@ -135,7 +135,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 128x128x512_64x64x512_8x8
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2, 128, 128, false,
|
||||
cutlass::arch::OpXorPopc>;
|
||||
|
||||
@ -166,7 +166,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 64x128x512_32x64x512_8x8x
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2, 128, 128, false,
|
||||
cutlass::arch::OpXorPopc>;
|
||||
|
||||
@ -197,7 +197,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 128x64x512_64x32x512_8x8x
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2, 128, 128, false,
|
||||
cutlass::arch::OpXorPopc>;
|
||||
|
||||
@ -228,7 +228,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 64x64x512_32x32x512_8x8x1
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2, 128, 128, false,
|
||||
cutlass::arch::OpXorPopc>;
|
||||
|
||||
|
||||
@ -0,0 +1,353 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "testbed.h"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 128x256x64_64x64x64) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
|
||||
cutlass::layout::RowMajor, ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 256, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 256x128x64_64x64x64) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
|
||||
cutlass::layout::RowMajor, ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 128x128x64_64x64x64) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
|
||||
cutlass::layout::RowMajor, ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 256x64x64_64x64x64) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
|
||||
cutlass::layout::RowMajor, ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 64, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 64x256x64_64x64x64) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
|
||||
cutlass::layout::RowMajor, ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 256, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 64x128x64_32x64x64) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
|
||||
cutlass::layout::RowMajor, ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 128, 64>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 128x64x64_64x32x64) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
|
||||
cutlass::layout::RowMajor, ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 64, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 64x64x64_32x32x64) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
|
||||
cutlass::layout::RowMajor, ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 128x256x32_64x64x32) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
|
||||
cutlass::layout::RowMajor, ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 256, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 256x128x32_64x64x32) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
|
||||
cutlass::layout::RowMajor, ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 128x128x32_64x64x32) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
|
||||
cutlass::layout::RowMajor, ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 256x64x32_64x64x32) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
|
||||
cutlass::layout::RowMajor, ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 64, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 64x256x32_64x64x32) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
|
||||
cutlass::layout::RowMajor, ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 256, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 64x128x32_32x64x32) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
|
||||
cutlass::layout::RowMajor, ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 128, 32>,
|
||||
cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 128x64x32_64x32x32) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
|
||||
cutlass::layout::RowMajor, ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 64, 32>,
|
||||
cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 64x64x32_32x32x32) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
|
||||
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput,
|
||||
cutlass::layout::RowMajor, ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
@ -0,0 +1,337 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "testbed.h"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x256x64_64x64x64) {
|
||||
using ElementOutput = cutlass::bfloat16_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 256, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x128x64_64x64x64) {
|
||||
using ElementOutput = cutlass::bfloat16_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x128x64_64x64x64) {
|
||||
using ElementOutput = cutlass::bfloat16_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x64x64_64x64x64) {
|
||||
using ElementOutput = cutlass::bfloat16_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 64, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x256x64_64x64x64) {
|
||||
using ElementOutput = cutlass::bfloat16_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 256, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x128x64_32x64x64) {
|
||||
using ElementOutput = cutlass::bfloat16_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 128, 64>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x64x64_64x32x64) {
|
||||
using ElementOutput = cutlass::bfloat16_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 64, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x64x64_32x32x64) {
|
||||
using ElementOutput = cutlass::bfloat16_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x256x32_64x64x32) {
|
||||
using ElementOutput = cutlass::bfloat16_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 256, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x128x32_64x64x32) {
|
||||
using ElementOutput = cutlass::bfloat16_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x128x32_64x64x32) {
|
||||
using ElementOutput = cutlass::bfloat16_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x64x32_64x64x32) {
|
||||
using ElementOutput = cutlass::bfloat16_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 64, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x256x32_64x64x32) {
|
||||
using ElementOutput = cutlass::bfloat16_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 256, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x128x32_32x64x32) {
|
||||
using ElementOutput = cutlass::bfloat16_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 128, 32>,
|
||||
cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x64x32_64x32x32) {
|
||||
using ElementOutput = cutlass::bfloat16_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 64, 32>,
|
||||
cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x64x32_32x32x32) {
|
||||
using ElementOutput = cutlass::bfloat16_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
@ -0,0 +1,253 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm_complex.h"
|
||||
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "testbed_complex.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Operands data type: complex<float>
|
||||
// Rounding: float -> tfloat32_t (half_ulp_truncate)
|
||||
// Instruction operand data type: tfloat32_t (real part) and tfloat32_t (imaginary part)
|
||||
// Math instruction: MMA.1688.F32.TF32
|
||||
// Instruction output/accumulation data type: f32 (real part) and f32 (imaginary part)
|
||||
// Output data type: complex<float>
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 32x32x16_16x16x16) {
|
||||
|
||||
using Element = cutlass::complex<float>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<32, 32, 16>,
|
||||
cutlass::gemm::GemmShape<16, 16, 16>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 64x64x16_16x32x16) {
|
||||
|
||||
using Element = cutlass::complex<float>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 16>,
|
||||
cutlass::gemm::GemmShape<16, 32, 16>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 64x64x16_32x32x16) {
|
||||
|
||||
|
||||
using Element = cutlass::complex<float>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 16>,
|
||||
cutlass::gemm::GemmShape<32, 32, 16>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 128x64x16_64x32x16) {
|
||||
|
||||
using Element = cutlass::complex<float>;;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 64, 16>,
|
||||
cutlass::gemm::GemmShape<64, 32, 16>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
4
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 64x128x16_32x64x16) {
|
||||
|
||||
using Element = cutlass::complex<float>;;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 128, 16>,
|
||||
cutlass::gemm::GemmShape<32, 64, 16>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 128x128x16_32x64x16) {
|
||||
|
||||
using Element = cutlass::complex<float>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 16>,
|
||||
cutlass::gemm::GemmShape<32, 64, 16>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,252 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm_complex.h"
|
||||
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "testbed_complex.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Operands data type: complex<float>
|
||||
// Rounding: float -> tfloat32_t (round to nearest)
|
||||
// Instruction operand data type: tfloat32_t (real part) and tfloat32_t (imaginary part)
|
||||
// Math instruction: MMA.1688.F32.TF32
|
||||
// Instruction output/accumulation data type: f32 (real part) and f32 (imaginary part)
|
||||
// Output data type: complex<float>
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 32x32x16_16x16x16) {
|
||||
|
||||
using Element = cutlass::complex<float>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<32, 32, 16>,
|
||||
cutlass::gemm::GemmShape<16, 16, 16>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 64x64x16_16x32x16) {
|
||||
|
||||
using Element = cutlass::complex<float>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 16>,
|
||||
cutlass::gemm::GemmShape<16, 32, 16>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 64x64x16_32x32x16) {
|
||||
|
||||
|
||||
using Element = cutlass::complex<float>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 16>,
|
||||
cutlass::gemm::GemmShape<32, 32, 16>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 128x64x16_64x32x16) {
|
||||
|
||||
using Element = cutlass::complex<float>;;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 64, 16>,
|
||||
cutlass::gemm::GemmShape<64, 32, 16>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
4
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 64x128x16_32x64x16) {
|
||||
|
||||
using Element = cutlass::complex<float>;;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 128, 16>,
|
||||
cutlass::gemm::GemmShape<32, 64, 16>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 128x128x16_32x64x16) {
|
||||
|
||||
using Element = cutlass::complex<float>;;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 16>,
|
||||
cutlass::gemm::GemmShape<32, 64, 16>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,192 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm_complex.h"
|
||||
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "testbed_complex.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 32x32x16_16x16x16) {
|
||||
|
||||
using Element = cutlass::complex<double>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<32, 32, 16>,
|
||||
cutlass::gemm::GemmShape<16, 16, 16>,
|
||||
cutlass::gemm::GemmShape<8, 8, 4>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
cutlass::arch::OpMultiplyAddGaussianComplex
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 32x32x8_16x16x8) {
|
||||
|
||||
using Element = cutlass::complex<double>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<32, 32, 8>,
|
||||
cutlass::gemm::GemmShape<16, 16, 8>,
|
||||
cutlass::gemm::GemmShape<8, 8, 4>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
cutlass::arch::OpMultiplyAddGaussianComplex
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 64x64x16_16x32x16) {
|
||||
|
||||
using Element = cutlass::complex<double>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 16>,
|
||||
cutlass::gemm::GemmShape<16, 32, 16>,
|
||||
cutlass::gemm::GemmShape<8, 8, 4>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
cutlass::arch::OpMultiplyAddGaussianComplex
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 64x64x8_16x32x8) {
|
||||
|
||||
using Element = cutlass::complex<double>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 8>,
|
||||
cutlass::gemm::GemmShape<16, 32, 8>,
|
||||
cutlass::gemm::GemmShape<8, 8, 4>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
cutlass::arch::OpMultiplyAddGaussianComplex
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,246 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm_complex.h"
|
||||
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "testbed_complex.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 32x32x16_16x16x16) {
|
||||
|
||||
using Element = cutlass::complex<double>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<32, 32, 16>,
|
||||
cutlass::gemm::GemmShape<16, 16, 16>,
|
||||
cutlass::gemm::GemmShape<8, 8, 4>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 32x32x8_16x16x8) {
|
||||
|
||||
using Element = cutlass::complex<double>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<32, 32, 8>,
|
||||
cutlass::gemm::GemmShape<16, 16, 8>,
|
||||
cutlass::gemm::GemmShape<8, 8, 4>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x16_16x32x16) {
|
||||
|
||||
using Element = cutlass::complex<double>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 16>,
|
||||
cutlass::gemm::GemmShape<16, 32, 16>,
|
||||
cutlass::gemm::GemmShape<8, 8, 4>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x8_16x32x8) {
|
||||
|
||||
using Element = cutlass::complex<double>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 8>,
|
||||
cutlass::gemm::GemmShape<16, 32, 8>,
|
||||
cutlass::gemm::GemmShape<8, 8, 4>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x16_32x32x16) {
|
||||
|
||||
|
||||
using Element = cutlass::complex<double>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 16>,
|
||||
cutlass::gemm::GemmShape<32, 32, 16>,
|
||||
cutlass::gemm::GemmShape<8, 8, 4>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x8_32x32x8) {
|
||||
|
||||
using Element = cutlass::complex<double>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 8>,
|
||||
cutlass::gemm::GemmShape<32, 32, 8>,
|
||||
cutlass::gemm::GemmShape<8, 8, 4>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,191 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm_complex.h"
|
||||
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "testbed_complex.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 32x32x8_16x16x8) {
|
||||
|
||||
using Element = cutlass::complex<double>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<32, 32, 8>,
|
||||
cutlass::gemm::GemmShape<16, 16, 8>,
|
||||
cutlass::gemm::GemmShape<8, 8, 4>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
cutlass::arch::OpMultiplyAddGaussianComplex
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 64x64x8_32x16x8) {
|
||||
|
||||
using Element = cutlass::complex<double>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 8>,
|
||||
cutlass::gemm::GemmShape<32, 16, 8>,
|
||||
cutlass::gemm::GemmShape<8, 8, 4>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
cutlass::arch::OpMultiplyAddGaussianComplex
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 32x32x16_16x16x16) {
|
||||
|
||||
using Element = cutlass::complex<double>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<32, 32, 16>,
|
||||
cutlass::gemm::GemmShape<16, 16, 16>,
|
||||
cutlass::gemm::GemmShape<8, 8, 4>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
cutlass::arch::OpMultiplyAddGaussianComplex
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 64x64x16_32x16x16) {
|
||||
|
||||
using Element = cutlass::complex<double>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 16>,
|
||||
cutlass::gemm::GemmShape<32, 16, 16>,
|
||||
cutlass::gemm::GemmShape<8, 8, 4>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
cutlass::arch::OpMultiplyAddGaussianComplex
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -0,0 +1,299 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm_complex.h"
|
||||
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "testbed_complex.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 32x32x8_16x16x8) {
|
||||
|
||||
using Element = cutlass::complex<double>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<32, 32, 8>,
|
||||
cutlass::gemm::GemmShape<16, 16, 8>,
|
||||
cutlass::gemm::GemmShape<8, 8, 4>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x64x8_32x32x8) {
|
||||
|
||||
using Element = cutlass::complex<double>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 8>,
|
||||
cutlass::gemm::GemmShape<32, 32, 8>,
|
||||
cutlass::gemm::GemmShape<8, 8, 4>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x128x8_32x32x8) {
|
||||
|
||||
using Element = cutlass::complex<double>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 128, 8>,
|
||||
cutlass::gemm::GemmShape<32, 32, 8>,
|
||||
cutlass::gemm::GemmShape<8, 8, 4>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 128x64x8_32x32x8) {
|
||||
|
||||
using Element = cutlass::complex<double>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 128, 8>,
|
||||
cutlass::gemm::GemmShape<32, 32, 8>,
|
||||
cutlass::gemm::GemmShape<8, 8, 4>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3
|
||||
>;
|
||||
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 32x32x16_16x16x16) {
|
||||
|
||||
using Element = cutlass::complex<double>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<32, 32, 16>,
|
||||
cutlass::gemm::GemmShape<16, 16, 16>,
|
||||
cutlass::gemm::GemmShape<8, 8, 4>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x64x16_32x32x16) {
|
||||
|
||||
using Element = cutlass::complex<double>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 16>,
|
||||
cutlass::gemm::GemmShape<32, 32, 16>,
|
||||
cutlass::gemm::GemmShape<8, 8, 4>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x128x16_32x32x16) {
|
||||
|
||||
using Element = cutlass::complex<double>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 128, 16>,
|
||||
cutlass::gemm::GemmShape<32, 32, 16>,
|
||||
cutlass::gemm::GemmShape<8, 8, 4>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 128x64x16_32x32x16) {
|
||||
|
||||
using Element = cutlass::complex<double>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmComplex<
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::layout::ColumnMajor,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
Element,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 128, 16>,
|
||||
cutlass::gemm::GemmShape<32, 32, 16>,
|
||||
cutlass::gemm::GemmShape<8, 8, 4>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
Element,
|
||||
1,
|
||||
Element,
|
||||
Element
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemmComplex<Gemm>());
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -72,7 +72,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -107,7 +107,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_32x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -141,7 +141,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_8x3
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -71,7 +71,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16n_wmma_tensor_op_f32, 128x128x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -105,7 +105,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16n_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -139,7 +139,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16n_wmma_tensor_op_f32, 64x64x32_64x64x32_8x32x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -70,7 +70,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x256x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -122,7 +122,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x128x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -153,7 +153,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x128x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -205,7 +205,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x128x32_32x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -258,7 +258,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x64x32_64x32x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -289,7 +289,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x64x32_32x32x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
338
test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sm80.cu
Normal file
338
test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sm80.cu
Normal file
@ -0,0 +1,338 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "testbed.h"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x256x64_64x64x64) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 256, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x128x64_64x64x64) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x128x64_64x64x64) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x64x64_64x64x64) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 64, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x256x64_64x64x64) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 256, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x128x64_32x64x64) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 128, 64>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x64x64_64x32x64) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 64, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x64x64_32x32x64) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x256x32_64x64x32) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 256, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x128x32_64x64x32) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x128x32_64x64x32) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x64x32_64x64x32) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 64, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x256x32_64x64x32) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 256, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x128x32_32x64x32) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 128, 32>,
|
||||
cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x64x32_64x32x32) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 64, 32>,
|
||||
cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x64x32_32x32x32) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -70,7 +70,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 128x256x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -101,7 +101,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 256x128x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -132,7 +132,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 128x128x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -163,7 +163,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 128x64x32_64x32x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -194,7 +194,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 64x128x32_32x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -225,7 +225,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 64x64x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -256,7 +256,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 64x64x32_32x32x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -71,7 +71,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 64x64x32_64x64x32_16x16
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -102,7 +102,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 64x128x32_64x64x32_16x1
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -134,7 +134,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 128x64x32_64x64x32_16x1
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -165,7 +165,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -196,7 +196,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 128x256x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -227,7 +227,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 256x128x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -258,7 +258,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 128x64x32_64x32x32_16x1
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -289,7 +289,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 64x128x32_32x64x32_16x1
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -321,7 +321,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 64x64x32_32x32x32_16x16
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -355,7 +355,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_32x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -389,7 +389,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_8x3
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -71,7 +71,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -102,7 +102,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 64x128x32_64x64x32_16x1
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -133,7 +133,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 128x64x32_64x64x32_16x1
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -164,7 +164,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -195,7 +195,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -226,7 +226,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -257,7 +257,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x1
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -288,7 +288,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 64x128x32_32x64x32_16x1
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -320,7 +320,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -354,7 +354,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -388,7 +388,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_8x32x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -70,7 +70,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x256x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -122,7 +122,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 256x128x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -153,7 +153,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x128x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -205,7 +205,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x128x32_32x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -258,7 +258,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x64x32_64x32x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -289,7 +289,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x64x32_32x32x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
337
test/unit/gemm/device/gemm_f16n_f16n_f32n_tensor_op_f32_sm80.cu
Normal file
337
test/unit/gemm/device/gemm_f16n_f16n_f32n_tensor_op_f32_sm80.cu
Normal file
@ -0,0 +1,337 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "testbed.h"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x256x64_64x64x64) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 256, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 256x128x64_64x64x64) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x128x64_64x64x64) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 256x64x64_64x64x64) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 64, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x256x64_64x64x64) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 256, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x128x64_32x64x64) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 128, 64>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x64x64_64x32x64) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 64, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x64x64_32x32x64) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x256x32_64x64x32) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 256, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 256x128x32_64x64x32) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x128x32_64x64x32) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 256x64x32_64x64x32) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 64, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x256x32_64x64x32) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 256, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x128x32_32x64x32) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 128, 32>,
|
||||
cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x64x32_64x32x32) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 64, 32>,
|
||||
cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x64x32_32x32x32) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -73,7 +73,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32n_wmma_tensor_op_f32, 256x128x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -108,7 +108,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_32x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -142,7 +142,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_8x3
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -70,7 +70,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x256x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -122,7 +122,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x128x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -153,7 +153,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x128x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -205,7 +205,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x128x32_32x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -258,7 +258,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x64x32_64x32x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -289,7 +289,7 @@ TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x64x32_32x32x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
340
test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sm80.cu
Normal file
340
test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sm80.cu
Normal file
@ -0,0 +1,340 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "testbed.h"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x256x64_64x64x64, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 256, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x128x64_64x64x64, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x128x64_64x64x64, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x64x64_64x64x64, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 64, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x256x64_64x64x64, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 256, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x128x64_32x64x64, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 128, 64>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x64x64_64x32x64, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 64, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x64x64_32x32x64, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x256x32_64x64x32, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 256, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x128x32_64x64x32, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x128x32_64x64x32, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x64x32_64x64x32, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 64, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x256x32_64x64x32, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 256, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x128x32_32x64x32, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 128, 32>,
|
||||
cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x64x32_64x32x32, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 64, 32>,
|
||||
cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x64x32_32x32x32, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -70,7 +70,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 128x256x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -101,7 +101,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 256x128x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -132,7 +132,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 128x128x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -163,7 +163,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 128x64x32_64x32x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -194,7 +194,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 64x128x32_32x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -225,7 +225,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 64x64x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -256,7 +256,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 64x64x32_32x32x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -72,7 +72,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -103,7 +103,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -134,7 +134,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -165,7 +165,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -196,7 +196,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x1
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -227,7 +227,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 64x128x32_64x32x32_16x1
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -258,7 +258,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -293,7 +293,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_32x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -327,7 +327,7 @@ TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_8x3
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -72,7 +72,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -107,7 +107,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_32x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -141,7 +141,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_8x3
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -71,7 +71,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16n_wmma_tensor_op_f32, 128x128x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -105,7 +105,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16n_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -139,7 +139,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16n_wmma_tensor_op_f32, 64x64x32_64x64x32_8x32x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -70,7 +70,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16_sliced_k, 64x64x64_64x32x32)
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
@ -0,0 +1,82 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
|
||||
#include "testbed.h"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16_sliced_k, 128x64x64_64x64x32) {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 64, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -70,7 +70,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x256x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -101,7 +101,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x128x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -132,7 +132,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x128x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -163,7 +163,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x128x32_32x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -194,7 +194,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x64x32_64x32x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -225,7 +225,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x64x32_32x32x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
338
test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm80.cu
Normal file
338
test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm80.cu
Normal file
@ -0,0 +1,338 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "testbed.h"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x256x64_64x64x64) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 256, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x128x64_64x64x64) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x128x64_64x64x64) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x64x64_64x64x64) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 64, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x256x64_64x64x64) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 256, 64> ,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x128x64_32x64x64) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 128, 64>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x64x64_64x32x64) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 64, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x64x64_32x32x64) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x256x32_64x64x32) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 256, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x128x32_64x64x32) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x128x32_64x64x32) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x64x32_64x64x32) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 64, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x256x32_64x64x32) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 256, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x128x32_32x64x32) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 128, 32>,
|
||||
cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x64x32_64x32x32) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 64, 32>,
|
||||
cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x64x32_32x32x32) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
@ -0,0 +1,77 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "testbed.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_GemmUniversal_f16n_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32) {
|
||||
|
||||
/*
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 10>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
*/
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -63,7 +63,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 128x256x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -94,7 +94,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 256x128x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -125,7 +125,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 128x128x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -156,7 +156,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 128x64x32_64x32x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -187,7 +187,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 64x128x32_32x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -218,7 +218,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 64x64x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -249,7 +249,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 64x64x32_32x32x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -71,7 +71,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 64x64x32_64x64x32_16x16
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -102,7 +102,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 64x128x32_64x64x32_16x1
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -134,7 +134,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 128x64x32_64x64x32_16x1
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -165,7 +165,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -196,7 +196,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 128x256x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -227,7 +227,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 256x128x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -258,7 +258,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 128x64x32_64x32x32_16x1
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -289,7 +289,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 64x128x32_32x64x32_16x1
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -321,7 +321,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 64x64x32_32x32x32_16x16
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -355,7 +355,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_32x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -389,7 +389,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_8x3
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -71,7 +71,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -73,7 +73,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -108,7 +108,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_32x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -142,7 +142,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_8x3
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -70,7 +70,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x256x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -101,7 +101,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x128x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -132,7 +132,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x128x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -163,7 +163,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x128x32_32x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -194,7 +194,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x64x32_64x32x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -225,7 +225,7 @@ TEST(SM75_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
339
test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sm80.cu
Normal file
339
test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sm80.cu
Normal file
@ -0,0 +1,339 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "testbed.h"
|
||||
|
||||
#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x256x64_64x64x64, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 256, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x128x64_64x64x64, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x128x64_64x64x64, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x64x64_64x64x64, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 64, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x256x64_64x64x64, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 256, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x128x64_32x64x64, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 128, 64>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x64x64_64x32x64, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 64, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x64x64_32x32x64, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x256x32_64x64x32, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 256, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x128x32_64x64x32, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x128x32_64x64x32, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x64x32_64x64x32, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 64, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x256x32_64x64x32, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 256, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x128x32_32x64x32, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 128, 32>,
|
||||
cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x64x32_64x32x32, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 64, 32>,
|
||||
cutlass::gemm::GemmShape<64, 32, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32, {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
} )
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -63,7 +63,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 128x256x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -94,7 +94,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 256x128x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -125,7 +125,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 128x128x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -156,7 +156,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 128x64x32_64x32x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -187,7 +187,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 64x128x32_32x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -218,7 +218,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 64x64x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -249,7 +249,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 64x64x32_32x32x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -72,7 +72,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -103,7 +103,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -134,7 +134,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -165,7 +165,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -196,7 +196,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x1
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -227,7 +227,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 64x128x32_64x32x32_16x1
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -258,7 +258,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -293,7 +293,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_32x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -327,7 +327,7 @@ TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_8x3
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -73,7 +73,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 128x256x32_
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
kStages
|
||||
>;
|
||||
|
||||
@ -105,7 +105,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 128x64x32_6
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
kStages
|
||||
>;
|
||||
|
||||
@ -137,7 +137,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 64x128x32_6
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
kStages
|
||||
>;
|
||||
|
||||
@ -170,7 +170,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 64x64x32_32
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
kStages
|
||||
>;
|
||||
|
||||
@ -202,7 +202,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 64x64x64_32
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
kStages
|
||||
>;
|
||||
|
||||
@ -234,7 +234,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 128x128x64_
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
kStages
|
||||
>;
|
||||
|
||||
@ -270,7 +270,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 128x128x32_
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
kStages
|
||||
>;
|
||||
|
||||
@ -305,7 +305,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 128x128x32_
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
kStages
|
||||
>;
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -72,7 +72,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -107,7 +107,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_32x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -141,7 +141,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_8x3
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -71,7 +71,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_wmma_tensor_op_f32, 128x128x32_64x64x16_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -106,7 +106,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_wmma_tensor_op_f32, 64x64x32_64x64x16_32x8x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -140,7 +140,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16n_wmma_tensor_op_f32, 64x64x32_64x64x16_8x32x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -73,7 +73,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 128x256x32_
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
kStages
|
||||
>;
|
||||
|
||||
@ -105,7 +105,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 128x64x32_6
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
kStages
|
||||
>;
|
||||
|
||||
@ -137,7 +137,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 64x128x32_6
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
kStages
|
||||
>;
|
||||
|
||||
@ -170,7 +170,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 64x64x32_32
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
kStages
|
||||
>;
|
||||
|
||||
@ -202,7 +202,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 64x64x64_32
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
kStages
|
||||
>;
|
||||
|
||||
@ -234,7 +234,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 128x128x64_
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
kStages
|
||||
>;
|
||||
|
||||
@ -270,7 +270,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 128x128x32_
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
kStages
|
||||
>;
|
||||
|
||||
@ -305,7 +305,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 128x128x32_
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
kStages
|
||||
>;
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -70,7 +70,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16_sliced_k, 64x64x64_64x32x32)
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
@ -0,0 +1,83 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
|
||||
#include "testbed.h"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16_sliced_k, 128x64x64_64x64x32) {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 64, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3
|
||||
>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -70,7 +70,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x256x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -101,7 +101,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x128x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -132,7 +132,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x128x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -163,7 +163,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x128x32_32x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -194,7 +194,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x64x32_64x32x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -225,7 +225,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x64x32_32x32x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
339
test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sm80.cu
Normal file
339
test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sm80.cu
Normal file
@ -0,0 +1,339 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "testbed.h"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x256x64_64x64x64) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 256, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x128x64_64x64x64) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x128x64_64x64x64) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x64x64_64x64x64) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 64, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x256x64_64x64x64) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 256, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x128x64_32x64x64) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 128, 64>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x64x64_64x32x64) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 64, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x64x64_32x32x64) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x256x32_64x64x32) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 256, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x128x32_64x64x32) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x128x32_64x64x32) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x64x32_64x64x32) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 64, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x256x32_64x64x32) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 256, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x128x32_32x64x32) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 128, 32>,
|
||||
cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x64x32_64x32x32) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 64, 32>,
|
||||
cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x64x32_32x32x32) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -70,7 +70,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 128x256x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -101,7 +101,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 256x128x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -132,7 +132,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 128x128x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -163,7 +163,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 128x64x32_64x32x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -194,7 +194,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 64x128x32_32x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -225,7 +225,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 64x64x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -256,7 +256,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 64x64x32_32x32x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -71,7 +71,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 64x64x32_64x64x32_16x16
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -102,7 +102,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 64x128x32_64x64x32_16x1
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -134,7 +134,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 128x64x32_64x64x32_16x1
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -165,7 +165,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -196,7 +196,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 128x256x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -227,7 +227,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 256x128x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -258,7 +258,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 128x64x32_64x32x32_16x1
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -289,7 +289,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 64x128x32_32x64x32_16x1
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -321,7 +321,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 64x64x32_32x32x32_16x16
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -355,7 +355,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_32x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -389,7 +389,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_8x3
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -71,7 +71,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -102,7 +102,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 64x128x32_64x64x32_16x1
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -133,7 +133,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 128x64x32_64x64x32_16x1
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -164,7 +164,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -195,7 +195,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -226,7 +226,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -257,7 +257,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x1
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -288,7 +288,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 64x128x32_32x64x32_16x1
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -319,7 +319,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -353,7 +353,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -387,7 +387,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_8x32x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -72,7 +72,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -107,7 +107,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_32x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -141,7 +141,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_8x3
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -74,7 +74,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32, 128x64x32_6
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
kStages
|
||||
>;
|
||||
|
||||
@ -106,7 +106,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32, 64x128x32_6
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
kStages
|
||||
>;
|
||||
|
||||
@ -138,7 +138,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32, 64x64x32_32
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
kStages
|
||||
>;
|
||||
|
||||
@ -174,7 +174,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32, 128x128x32_
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
kStages
|
||||
>;
|
||||
|
||||
@ -209,7 +209,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32, 128x128x32_
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
kStages
|
||||
>;
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -70,7 +70,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x256x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -101,7 +101,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x128x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -132,7 +132,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x128x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -163,7 +163,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x128x32_32x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -194,7 +194,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x64x32_64x32x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -225,7 +225,7 @@ TEST(SM75_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x64x32_32x32x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
338
test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sm80.cu
Normal file
338
test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sm80.cu
Normal file
@ -0,0 +1,338 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "testbed.h"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x256x64_64x64x64) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 256, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x128x64_64x64x64) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x128x64_64x64x64) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x64x64_64x64x64) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 64, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x256x64_64x64x64) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 256, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x128x64_32x64x64) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 128, 64>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x64x64_64x32x64) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 64, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x64x64_32x32x64) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x256x32_64x64x32) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 256, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x128x32_64x64x32) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x128x32_64x64x32) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x64x32_64x64x32) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 64, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x256x32_64x64x32) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 256, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x128x32_32x64x32) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 128, 32>,
|
||||
cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x64x32_64x32x32) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 64, 32>,
|
||||
cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x64x32_32x32x32) {
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -70,7 +70,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 128x256x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -101,7 +101,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 256x128x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -132,7 +132,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 128x128x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -163,7 +163,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 128x64x32_64x32x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -194,7 +194,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 64x128x32_32x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -225,7 +225,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 64x64x32_64x64x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -256,7 +256,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 64x64x32_32x32x32) {
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -72,7 +72,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -103,7 +103,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -134,7 +134,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -165,7 +165,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -196,7 +196,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x1
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -227,7 +227,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 64x128x32_64x32x32_16x1
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -258,7 +258,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -293,7 +293,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_32x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -327,7 +327,7 @@ TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_8x3
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -72,7 +72,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_16x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -107,7 +107,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_32x
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
@ -141,7 +141,7 @@ TEST(SM70_Device_Gemm_f16t_f16t_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_8x3
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2
|
||||
>;
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user