CUTLASS 3.4.0 (#1286)

* CUTLASS 3.4.0

* Update CHANGELOG.md

---------

Co-authored-by: Pradeep Ramani <prramani@nvidia.com>
This commit is contained in:
Pradeep Ramani
2023-12-29 12:21:31 -08:00
committed by GitHub
parent b7508e3379
commit 8236f30675
211 changed files with 11409 additions and 2763 deletions

View File

@ -64,3 +64,10 @@ add_executable(
cpp11.cu
)
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
target_compile_options(
cutlass_test_unit_core_cpp11
PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler -Werror>
)
endif()

View File

@ -42,6 +42,7 @@
#include <cutlass/cutlass.h>
#include <cutlass/complex.h>
#include <cutlass/coord.h>
#include <cutlass/core_io.h>
#include <cutlass/array.h>
@ -51,12 +52,12 @@
#include <cutlass/half.h>
#include <cutlass/integer_subbyte.h>
#include <cutlass/kernel_hardware_info.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_size.h>
#include <cutlass/numeric_types.h>
#include <cutlass/tfloat32.h>
#include <cutlass/workspace.h>
#include <cutlass/subbyte_reference.h>
#include <cutlass/conv/convolution.h>
#include <cutlass/conv/conv2d_problem_size.h>

View File

@ -147,6 +147,20 @@ TEST(FastNumericConversion, s32_to_f32) {
test::core::kernel::run_test_integer_range_limited<Destination, Source, kN>();
}
TEST(FastNumericConversion, s8_to_f32_array) {
int const kN = 256;
using Source = int8_t;
using Destination = float;
test::core::kernel::run_test_integer_range_all<Destination, Source, kN>();
}
TEST(FastNumericConversion, u8_to_f32_array) {
int const kN = 256;
using Source = uint8_t;
using Destination = float;
test::core::kernel::run_test_integer_range_all<Destination, Source, kN>();
}
TEST(FastNumericConversion, s8_to_f16_array) {
int const kN = 256;
using Source = int8_t;

View File

@ -60,8 +60,8 @@ __global__ void convert(
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Destination, typename Source, int Count, int Range = 4>
void run_test(const char dest_name[], const char source_name[]) {
template <typename Destination, typename Source, int Count>
void run_test(const char dest_name[], const char source_name[], const int range = 4, const int offset = 0) {
const int kN = Count;
dim3 grid(1, 1);
@ -73,7 +73,7 @@ void run_test(const char dest_name[], const char source_name[]) {
auto destination_ref = destination.host_ref();
for (int i = 0; i < kN; ++i) {
source_ref.at({0, i}) = Source(i % Range);
source_ref.at({0, i}) = Source(i % range + offset);
}
source.sync_device();
@ -509,4 +509,160 @@ TEST(NumericConversion, int_to_fe4m3_t_array_32) {
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
struct GetName {
static constexpr char name[] = "UNSUPPORTED";
};
template <>
struct GetName<cutlass::int4b_t> {
static constexpr char name[] = "int4b_t";
};
template <>
struct GetName<uint8_t> {
static constexpr char name[] = "uint8_t";
};
template <>
struct GetName<int8_t> {
static constexpr char name[] = "int8_t";
};
template <>
struct GetName<cutlass::float_e4m3_t> {
static constexpr char name[] = "float_e4m3_t";
};
template <>
struct GetName<cutlass::half_t> {
static constexpr char name[] = "half_t";
};
template <>
struct GetName<cutlass::bfloat16_t> {
static constexpr char name[] = "bfloat16_t";
};
template <>
struct GetName<float> {
static constexpr char name[] = "float";
};
template <typename Result_, typename Source_>
struct ResultSourcePair {
using Result = Result_;
using Source = Source_;
};
template <typename ResultSourcePair>
class VectorArrayConverterTest : public testing::Test {
public:
using Result = typename ResultSourcePair::Result;
using Source = typename ResultSourcePair::Source;
template <int N>
static void emit_test() {
const int range = 1 << cutlass::sizeof_bits<Source>::value;
const int offset = cutlass::platform::numeric_limits<Source>::lowest();
test::core::kernel::run_test<Result, Source, N>(GetName<Result>::name, GetName<Source>::name, range, offset);
}
};
using VectorConvertTypes = ::testing::Types<
ResultSourcePair<float, int8_t>,
ResultSourcePair<float, uint8_t>,
ResultSourcePair<cutlass::half_t, int8_t>,
ResultSourcePair<cutlass::half_t, uint8_t>,
ResultSourcePair<cutlass::bfloat16_t, uint8_t>,
ResultSourcePair<cutlass::bfloat16_t, int8_t>,
ResultSourcePair<cutlass::float_e4m3_t, cutlass::int4b_t>,
ResultSourcePair<cutlass::half_t, cutlass::int4b_t>,
ResultSourcePair<cutlass::bfloat16_t, cutlass::int4b_t>,
ResultSourcePair<float, cutlass::int4b_t>
>;
TYPED_TEST_SUITE(VectorArrayConverterTest, VectorConvertTypes);
TYPED_TEST(VectorArrayConverterTest, array_1) {
TestFixture::template emit_test<1>();
}
TYPED_TEST(VectorArrayConverterTest, array_2) {
TestFixture::template emit_test<2>();
}
TYPED_TEST(VectorArrayConverterTest, array_3) {
TestFixture::template emit_test<3>();
}
TYPED_TEST(VectorArrayConverterTest, array_4) {
TestFixture::template emit_test<4>();
}
TYPED_TEST(VectorArrayConverterTest, array_5) {
TestFixture::template emit_test<5>();
}
TYPED_TEST(VectorArrayConverterTest, array_8) {
TestFixture::template emit_test<8>();
}
TYPED_TEST(VectorArrayConverterTest, array_10) {
// N > 8 and N is not a multiple of 4
TestFixture::template emit_test<10>();
}
TYPED_TEST(VectorArrayConverterTest, array_12) {
// N > 8 and N is a multiple of 4
TestFixture::template emit_test<12>();
}
TYPED_TEST(VectorArrayConverterTest, array_16) {
// N > 8 and N is a multiple of 8
TestFixture::template emit_test<16>();
}
TYPED_TEST(VectorArrayConverterTest, array_17) {
// N > 8 and N is not a multiple of 8
TestFixture::template emit_test<17>();
}
TYPED_TEST(VectorArrayConverterTest, array_27) {
// Test entire conversion range with residue (for int4)
TestFixture::template emit_test<27>();
}
TYPED_TEST(VectorArrayConverterTest, array_31) {
// Force use of converters for 16, 8, 4, 2 and scalar
// if max width is 16
TestFixture::template emit_test<31>();
}
TYPED_TEST(VectorArrayConverterTest, array_63) {
// Force use of converters for 32, 16, 8, 4, 2 and scalar
// if max width is 32
TestFixture::template emit_test<63>();
}
TYPED_TEST(VectorArrayConverterTest, array_256) {
// Test entire conversion range (for int8)
TestFixture::template emit_test<256>();
}
TYPED_TEST(VectorArrayConverterTest, array_259) {
// Force use of 4, 2 and scalar converter (if max width is 4)
TestFixture::template emit_test<259>();
}
TYPED_TEST(VectorArrayConverterTest, array_263) {
// Force use of 8, 4, 2 and scalar converter (if max width is 8)
TestFixture::template emit_test<263>();
}
/////////////////////////////////////////////////////////////////////////////////////////////////