Updates for 3.2 release (#1065)
This commit is contained in:
@ -61,7 +61,7 @@ __global__ void convert(
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Destination, typename Source, int Count>
|
||||
void run_test() {
|
||||
void run_test(const char dest_name[], const char source_name[]) {
|
||||
const int kN = Count;
|
||||
|
||||
dim3 grid(1, 1);
|
||||
@ -84,7 +84,10 @@ void run_test() {
|
||||
destination.sync_host();
|
||||
|
||||
for (int i = 0; i < kN; ++i) {
|
||||
EXPECT_TRUE(float(destination.host_data()[i]) == float(source.host_data()[i]));
|
||||
EXPECT_TRUE(float(destination.host_data()[i]) == float(source.host_data()[i]))
|
||||
<< "Destination type: " << dest_name
|
||||
<< ", Source type: " << source_name
|
||||
<< ", Count: " << Count;
|
||||
}
|
||||
}
|
||||
|
||||
@ -97,15 +100,19 @@ void run_test() {
|
||||
TEST(NumericConversion, f32_to_f16_rn) {
|
||||
int const kN = 1;
|
||||
using Source = float;
|
||||
const char source_name[] = "float";
|
||||
using Destination = cutlass::half_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "half_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, f32x8_to_f16x8_rn) {
|
||||
int const kN = 8;
|
||||
using Source = float;
|
||||
const char source_name[] = "float";
|
||||
using Destination = cutlass::half_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "half_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -113,15 +120,19 @@ TEST(NumericConversion, f32x8_to_f16x8_rn) {
|
||||
TEST(NumericConversion, f16_to_f32_rn) {
|
||||
int const kN = 1;
|
||||
using Source = cutlass::half_t;
|
||||
const char source_name[] = "half_t";
|
||||
using Destination = float;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, f16x8_to_f32x8_rn) {
|
||||
int const kN = 8;
|
||||
using Source = cutlass::half_t;
|
||||
const char source_name[] = "half_t";
|
||||
using Destination = float;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -129,86 +140,109 @@ TEST(NumericConversion, f16x8_to_f32x8_rn) {
|
||||
TEST(NumericConversion, f32_to_fe4m3_rn) {
|
||||
int const kN = 1;
|
||||
using Source = float;
|
||||
const char source_name[] = "float";
|
||||
using Destination = cutlass::float_e4m3_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e4m3_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, f32_to_fe4m3_rn_array) {
|
||||
int const kN = 27;
|
||||
using Source = float;
|
||||
const char source_name[] = "float";
|
||||
using Destination = cutlass::float_e4m3_t;
|
||||
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e4m3_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, f32_to_fe5m2_rn) {
|
||||
int const kN = 1;
|
||||
using Source = float;
|
||||
const char source_name[] = "float";
|
||||
using Destination = cutlass::float_e5m2_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e5m2_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, f32_to_fe5m2_rn_array) {
|
||||
int const kN = 27;
|
||||
using Source = float;
|
||||
const char source_name[] = "float";
|
||||
using Destination = cutlass::float_e5m2_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e5m2_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, f16_to_fe4m3_rn) {
|
||||
int const kN = 1;
|
||||
using Source = cutlass::half_t;
|
||||
const char source_name[] = "half_t";
|
||||
using Destination = cutlass::float_e4m3_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e4m3_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, f16_to_fe4m3_rn_array) {
|
||||
int const kN = 27;
|
||||
using Source = cutlass::half_t;
|
||||
const char source_name[] = "half_t";
|
||||
using Destination = cutlass::float_e4m3_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e4m3_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, f16_to_fe5m2_rn) {
|
||||
int const kN = 1;
|
||||
using Source = cutlass::half_t;
|
||||
const char source_name[] = "half_t";
|
||||
using Destination = cutlass::float_e5m2_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e5m2_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, f16_to_fe5m2_rn_array) {
|
||||
int const kN = 27;
|
||||
using Source = cutlass::half_t;
|
||||
const char source_name[] = "half_t";
|
||||
using Destination = cutlass::float_e5m2_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e5m2_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, bf16_to_fe4m3_rn) {
|
||||
int const kN = 1;
|
||||
using Source = cutlass::bfloat16_t;
|
||||
const char source_name[] = "bfloat16_t";
|
||||
using Destination = cutlass::float_e4m3_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e4m3_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, bf16_to_fe4m3_rn_array) {
|
||||
int const kN = 27;
|
||||
using Source = cutlass::bfloat16_t;
|
||||
const char source_name[] = "bfloat16_t";
|
||||
using Destination = cutlass::float_e4m3_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e4m3_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, bf16_to_fe5m2_rn) {
|
||||
int const kN = 1;
|
||||
using Source = cutlass::bfloat16_t;
|
||||
const char source_name[] = "bfloat16_t";
|
||||
using Destination = cutlass::float_e5m2_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e5m2_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, bf16_to_fe5m2_rn_array) {
|
||||
int const kN = 27;
|
||||
using Source = cutlass::bfloat16_t;
|
||||
const char source_name[] = "bfloat16_t";
|
||||
using Destination = cutlass::float_e5m2_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e5m2_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -216,36 +250,46 @@ TEST(NumericConversion, bf16_to_fe5m2_rn_array) {
|
||||
TEST(NumericConversion, fe4m3_to_fe5m2_rn) {
|
||||
int const kN = 1;
|
||||
using Source = cutlass::float_e4m3_t;
|
||||
const char source_name[] = "float_e4m3_t";
|
||||
using Destination = cutlass::float_e5m2_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e5m2_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, fe4m3_to_fe5m2_array) {
|
||||
int const kN = 27;
|
||||
using Source = cutlass::float_e4m3_t;
|
||||
const char source_name[] = "float_e4m3_t";
|
||||
using Destination = cutlass::float_e5m2_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e5m2_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, fe5m2_to_fe4m3_rn) {
|
||||
int const kN = 1;
|
||||
using Source = cutlass::float_e5m2_t;
|
||||
const char source_name[] = "float_e5m2_t";
|
||||
using Destination = cutlass::float_e4m3_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e4m3_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, fe5m2_to_fe4m3_array) {
|
||||
int const kN = 27;
|
||||
using Source = cutlass::float_e5m2_t;
|
||||
const char source_name[] = "float_e5m2_t";
|
||||
using Destination = cutlass::float_e4m3_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e4m3_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, fe4m3_to_f32_rn) {
|
||||
int const kN = 1;
|
||||
using Source = cutlass::float_e4m3_t;
|
||||
const char source_name[] = "float_e4m3_t";
|
||||
using Destination = float;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -254,78 +298,100 @@ TEST(NumericConversion, f32x8_to_s8x8_rn) {
|
||||
|
||||
int const kN = 8;
|
||||
using Source = float;
|
||||
const char source_name[] = "float";
|
||||
using Destination = int8_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "int8_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, fe4m3_to_f32_array) {
|
||||
int const kN = 27;
|
||||
using Source = cutlass::float_e4m3_t;
|
||||
const char source_name[] = "float_e4m3_t";
|
||||
using Destination = float;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, fe5m2_to_f32_array) {
|
||||
int const kN = 27;
|
||||
using Source = cutlass::float_e5m2_t;
|
||||
const char source_name[] = "float_e5m2_t";
|
||||
using Destination = float;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, fe4m3_to_f16_rn) {
|
||||
int const kN = 1;
|
||||
using Source = cutlass::float_e4m3_t;
|
||||
const char source_name[] = "float_e4m3_t";
|
||||
using Destination = cutlass::half_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "half_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, fe4m3_to_f16_array) {
|
||||
int const kN = 27;
|
||||
using Source = cutlass::float_e4m3_t;
|
||||
const char source_name[] = "float_e4m3_t";
|
||||
using Destination = cutlass::half_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "half_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, fe5m2_to_f16_rn) {
|
||||
int const kN = 1;
|
||||
using Source = cutlass::float_e5m2_t;
|
||||
const char source_name[] = "float_e5m2_t";
|
||||
using Destination = cutlass::half_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "half_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, fe5m2_to_f16_array) {
|
||||
int const kN = 27;
|
||||
using Source = cutlass::float_e5m2_t;
|
||||
const char source_name[] = "float_e5m2_t";
|
||||
using Destination = cutlass::half_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "half_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, fe4m3_to_bf16_rn) {
|
||||
int const kN = 1;
|
||||
using Source = cutlass::float_e4m3_t;
|
||||
const char source_name[] = "float_e4m3_t";
|
||||
using Destination = cutlass::bfloat16_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "bfloat16_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, fe4m3_to_bf16_array) {
|
||||
int const kN = 27;
|
||||
using Source = cutlass::float_e4m3_t;
|
||||
const char source_name[] = "float_e4m3_t";
|
||||
using Destination = cutlass::bfloat16_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "bfloat16_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, fe5m2_to_bf16_rn) {
|
||||
int const kN = 1;
|
||||
using Source = cutlass::float_e5m2_t;
|
||||
const char source_name[] = "float_e5m2_t";
|
||||
using Destination = cutlass::bfloat16_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "bfloat16_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, fe5m2_to_bf16_array) {
|
||||
int const kN = 27;
|
||||
using Source = cutlass::float_e5m2_t;
|
||||
const char source_name[] = "float_e5m2_t";
|
||||
using Destination = cutlass::bfloat16_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "bfloat16_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -36,6 +36,7 @@ cutlass_test_unit_add_executable(
|
||||
compare.cpp
|
||||
complement.cpp
|
||||
composition.cpp
|
||||
constant_arithmetic.cpp
|
||||
core_unit.cpp
|
||||
inverse_left.cpp
|
||||
inverse_right.cpp
|
||||
|
||||
106
test/unit/cute/core/constant_arithmetic.cpp
Normal file
106
test/unit/cute/core/constant_arithmetic.cpp
Normal file
@ -0,0 +1,106 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include "cutlass_unit_test.h"
|
||||
#include <cutlass/trace.h>
|
||||
#include <cute/swizzle.hpp>
|
||||
|
||||
TEST(CuTe_core, ConstantArithmetic) {
|
||||
using namespace cute;
|
||||
|
||||
constexpr cute::integral_constant<uint32_t, 0> uzero{};
|
||||
|
||||
// This extra test exists historically as part of the diagnosis
|
||||
// of a possible Clang 14 bug. However, it's a nice test for
|
||||
// cute::integral_constant's arithmetic operators, so it's saved here.
|
||||
// It also demonstrates how to work with cute::integral_constant
|
||||
// and lambda captures. Microsoft Visual Studio ("MSVC") tends to
|
||||
// disagree with other compilers about the meaning of decltype
|
||||
// for variables captured by reference. MSVC and GCC 8.3.0
|
||||
// also tend to disagree with other compilers (and other GCC versions)
|
||||
// about whether expressions involving such variables
|
||||
// are constant expressions.
|
||||
//
|
||||
// A typical CuTe idiom is to do lambda captures by reference [&].
|
||||
// This test changes them to capture by value, except for
|
||||
// the innermost lambda's capture of S1, which is by reference.
|
||||
// The point is to show that MSVC and GCC 8 have issues with this
|
||||
// that other compilers do not. For example,
|
||||
//
|
||||
// 1. MSVC needs remove_cvref_t around decltype(S1)
|
||||
// in order to access decltype(S1)::value, and
|
||||
// 2. MSVC and GCC 8.3.0 both report a build error with S1()
|
||||
// (that is, calling operator() on S1, which returns the
|
||||
// same thing as S1.value).
|
||||
//
|
||||
// The reason for (2) is that neither compiler thinks
|
||||
// that S1() is a constant expression.
|
||||
//
|
||||
// This leaves S1.value as the most concise portable expression
|
||||
// for the "value" member of a cute::integral_constant.
|
||||
for_each(make_integer_sequence<uint32_t, 8>{}, [uzero](auto S0) {
|
||||
for_each(make_integer_sequence<uint32_t, 8>{}, [uzero,S0](auto F0) {
|
||||
for_each(make_integer_sequence<uint32_t, 8>{}, [uzero,S0,F0](auto S1) {
|
||||
for_each(make_integer_sequence<uint32_t, 8>{}, [uzero,S0,F0,&S1](auto F1) {
|
||||
static_assert((decltype(S0)::value & decltype(F0)::value) == decltype(S0 & F0)::value);
|
||||
|
||||
// Using S1.value means you don't have to use remove_cvref_t
|
||||
// with a captured-by-reference variable.
|
||||
static_assert((cute::remove_cvref_t<decltype(S1)>::value & decltype(F1)::value) == decltype(S1 & F1)::value);
|
||||
static_assert((S1.value & decltype(F1)::value) == decltype(S1 & F1)::value);
|
||||
// S1() _should_ work, but does not with Visual Studio 2022,
|
||||
// which emits C2131 ("expression did not evaluate to a constant").
|
||||
// It also does not with GCC 8.3.0, which emits an error with messages
|
||||
// "non-constant condition for static assertion" and
|
||||
// "'this' is not a constant expression."
|
||||
//
|
||||
//static_assert((S1() & decltype(F1)::value) == decltype(S1 & F1)::value);
|
||||
static_assert(decltype((S0 & F0) != uzero)::value == ((decltype(S0)::value & decltype(F0)::value) != 0));
|
||||
|
||||
static_assert(decltype((S1 & F1) != uzero)::value == ((cute::remove_cvref_t<decltype(S1)>::value & decltype(F1)::value) != 0));
|
||||
static_assert(decltype((S1 & F1) != uzero)::value == ((S1.value & decltype(F1)::value) != 0));
|
||||
|
||||
constexpr bool left = decltype((S0 & F0) != uzero || (S1 & F1) != uzero)::value;
|
||||
constexpr bool right =
|
||||
((decltype(S0)::value & decltype(F0)::value) != 0) ||
|
||||
((cute::remove_cvref_t<decltype(S1)>::value & decltype(F1)::value) != 0);
|
||||
constexpr bool right2 =
|
||||
((decltype(S0)::value & decltype(F0)::value) != 0) ||
|
||||
((S1.value & decltype(F1)::value) != 0);
|
||||
static_assert(right == right2);
|
||||
static_assert(left == right);
|
||||
constexpr bool left2 = decltype((S0 & F0) != uzero)::value || decltype((S1 & F1) != uzero)::value;
|
||||
static_assert(left == left2);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
@ -31,9 +31,58 @@
|
||||
|
||||
#include "cutlass_unit_test.h"
|
||||
|
||||
// C<uint32_t(something)>::value_type is not uint32_t for GCC 7.5.0.
|
||||
// This test is thus disabled for GCC < 8.
|
||||
#if defined(__GNUC__) && (__GNUC__ < 8)
|
||||
|
||||
#include <cutlass/trace.h>
|
||||
#include <cute/swizzle.hpp>
|
||||
|
||||
namespace { // (anonymous)
|
||||
|
||||
// This function exists to work around a Clang 14 issue, in which
|
||||
// the compiler tries to instantiate code that lives inside the
|
||||
// "else" branch of an "if constexpr," even when the "else" branch
|
||||
// is false. That triggers a spurious static_assert in MixedBits.
|
||||
// The work-around is to make the body of the "else" branch a
|
||||
// function, rather than leaving it in line.
|
||||
//
|
||||
// Some compilers strangely deduce the first two terms of
|
||||
// make_integer_sequence<uint32_t, 8> as C<false> and C<true>, and
|
||||
// the remaining terms as C<2>, C<3>, etc. Making this function take
|
||||
// cute::integral_constant<uint32_t, S0_value>, etc. doesn't work
|
||||
// with those compilers.
|
||||
template<class S0_type, S0_type S0_value,
|
||||
class F0_type, F0_type F0_value,
|
||||
class S1_type, S1_type S1_value,
|
||||
class F1_type, F1_type F1_value>
|
||||
void clang14_workaround(cute::integral_constant<S0_type, S0_value>,
|
||||
cute::integral_constant<F0_type, F0_value>,
|
||||
cute::integral_constant<S1_type, S1_value>,
|
||||
cute::integral_constant<F1_type, F1_value>)
|
||||
{
|
||||
constexpr cute::C<static_cast<uint32_t>(S0_value)> S0{};
|
||||
constexpr cute::C<static_cast<uint32_t>(F0_value)> F0{};
|
||||
constexpr cute::C<static_cast<uint32_t>(S1_value)> S1{};
|
||||
constexpr cute::C<static_cast<uint32_t>(F1_value)> F1{};
|
||||
|
||||
for (uint32_t d0 = 0; d0 < 8; ++d0) {
|
||||
if ((d0 & F0) != d0) { continue; } // Skip repeats
|
||||
for (uint32_t d1 = 0; d1 < 8; ++d1) {
|
||||
if ((d1 & F1) != d1) { continue; } // Skip repeats
|
||||
auto m0 = make_mixed_bits(S0, d0, F0);
|
||||
auto m1 = make_mixed_bits(S1, d1, F1);
|
||||
//print(m0); print(" & "); print(m1); print(" = "); print(m0 & m1); print("\n");
|
||||
EXPECT_EQ(uint32_t(m0 & m1), uint32_t(m0) & uint32_t(m1));
|
||||
//print(m0); print(" | "); print(m1); print(" = "); print(m0 | m1); print("\n");
|
||||
EXPECT_EQ(uint32_t(m0 | m1), uint32_t(m0) | uint32_t(m1));
|
||||
//print(m0); print(" ^ "); print(m1); print(" = "); print(m0 ^ m1); print("\n");
|
||||
EXPECT_EQ(uint32_t(m0 ^ m1), uint32_t(m0) ^ uint32_t(m1));
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace (anonymous)
|
||||
|
||||
TEST(CuTe_core, MixedBits) {
|
||||
using namespace cute;
|
||||
|
||||
@ -48,23 +97,21 @@ TEST(CuTe_core, MixedBits) {
|
||||
} else if constexpr (decltype((S0 & F0) != uzero || (S1 & F1) != uzero)::value) {
|
||||
return;
|
||||
} else {
|
||||
for (uint32_t d0 = 0; d0 < 8; ++d0) {
|
||||
if ((d0 & F0) != d0) { continue; } // Skip repeats
|
||||
for (uint32_t d1 = 0; d1 < 8; ++d1) {
|
||||
if ((d1 & F1) != d1) { continue; } // Skip repeats
|
||||
auto m0 = make_mixed_bits(S0, d0, F0);
|
||||
auto m1 = make_mixed_bits(S1, d1, F1);
|
||||
//print(m0); print(" & "); print(m1); print(" = "); print(m0 & m1); print("\n");
|
||||
EXPECT_EQ(uint32_t(m0 & m1), uint32_t(m0) & uint32_t(m1));
|
||||
//print(m0); print(" | "); print(m1); print(" = "); print(m0 | m1); print("\n");
|
||||
EXPECT_EQ(uint32_t(m0 | m1), uint32_t(m0) | uint32_t(m1));
|
||||
//print(m0); print(" ^ "); print(m1); print(" = "); print(m0 ^ m1); print("\n");
|
||||
EXPECT_EQ(uint32_t(m0 ^ m1), uint32_t(m0) ^ uint32_t(m1));
|
||||
}
|
||||
}
|
||||
clang14_workaround(S0, F0, S1, F1);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
TEST(CuTe_core, MakeIntegerSequence) {
|
||||
cute::for_each(cute::make_integer_sequence<uint32_t, 8>{}, [](auto c) {
|
||||
using c_type = decltype(c);
|
||||
constexpr auto c_value = c_type::value;
|
||||
using expected_type = cute::integral_constant<uint32_t, c_value>;
|
||||
static_assert(cute::is_same_v<c_type, expected_type>);
|
||||
});
|
||||
}
|
||||
|
||||
#endif // defined(__GNUC__) && (__GNUC__ < 8)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -32,6 +32,7 @@
|
||||
\brief Tests that the stream-K scheduler covers the entire problem space.
|
||||
*/
|
||||
|
||||
#include "cutlass/cluster_launch.hpp"
|
||||
#include "cutlass/kernel_hardware_info.hpp"
|
||||
#include "cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp"
|
||||
#include "cutlass/util/device_memory.h"
|
||||
@ -39,6 +40,10 @@
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
|
||||
// Grids are launched with clusters enabled in these tests,
|
||||
// so the CTK version must support cluster launching.
|
||||
#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED)
|
||||
|
||||
using namespace cute;
|
||||
using ProblemShape_MNKL = Shape<int, int, int, int>;
|
||||
|
||||
@ -60,7 +65,7 @@ run_scheduler(int* visit_counters, typename Scheduler::Params params, TileShape
|
||||
while (work_tile_info.is_valid_tile) {
|
||||
// Increment counters to indicate coverage
|
||||
auto tile_idx = Scheduler::output_tile_index(params, work_tile_info);
|
||||
auto offset = tile_idx * params.k_iter_per_tile_ + work_tile_info.K_idx;
|
||||
auto offset = tile_idx * params.k_tiles_per_output_tile_ + work_tile_info.K_idx;
|
||||
for (auto i = 0; i < work_tile_info.k_tile_count; ++i) {
|
||||
// Use atomicAdd because the visit counters are shared by multiple thread blocks.
|
||||
// While having more than one block increment the same counter indicates failure,
|
||||
@ -103,7 +108,7 @@ test_scheduler(
|
||||
|
||||
// Allocate counters indicating the number of times each k iteration of each output tile has been visited
|
||||
auto [blk_m, blk_n, blk_l] = Scheduler::get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cluster_shape);
|
||||
auto total_counters = blk_m * blk_n * blk_l * params.k_iter_per_tile_;
|
||||
auto total_counters = blk_m * blk_n * blk_l * params.k_tiles_per_output_tile_;
|
||||
cutlass::DeviceAllocation<int> visit_counters(total_counters);
|
||||
|
||||
// Initialize counters to zero
|
||||
@ -118,12 +123,55 @@ test_scheduler(
|
||||
// Set up the grid for the problem
|
||||
dim3 grid = Scheduler::get_grid_shape(problem_shape_mnkl, tile_shape, cluster_shape, hw_info, args);
|
||||
|
||||
// Set up cluster and cluster launch. This is needed even for this simple kernel because
|
||||
// the SM90 scheduler needs to be able to query the CTA id within a cluster, which requires
|
||||
// explicitly launching with clusters.
|
||||
dim3 cluster{
|
||||
static_cast<uint32_t>(cute::get<0>(ClusterShape{})),
|
||||
static_cast<uint32_t>(cute::get<1>(ClusterShape{})),
|
||||
static_cast<uint32_t>(cute::get<2>(ClusterShape{}))
|
||||
};
|
||||
|
||||
cudaLaunchConfig_t launch_config;
|
||||
launch_config.gridDim = grid;
|
||||
launch_config.blockDim = {1, 1, 1};
|
||||
launch_config.dynamicSmemBytes = 0;
|
||||
launch_config.stream = NULL;
|
||||
|
||||
cudaLaunchAttribute launch_attribute[1];
|
||||
launch_attribute[0].id = cudaLaunchAttributeClusterDimension;
|
||||
launch_attribute[0].val.clusterDim.x = cluster.x;
|
||||
launch_attribute[0].val.clusterDim.y = cluster.y;
|
||||
launch_attribute[0].val.clusterDim.z = cluster.z;
|
||||
|
||||
launch_config.attrs = launch_attribute;
|
||||
launch_config.numAttrs = 1;
|
||||
|
||||
void const* kernel = (void const*) run_scheduler<Scheduler, TileShape, ClusterShape>;
|
||||
int* counters_ptr = visit_counters.get();
|
||||
void* kernel_params[] = {
|
||||
&counters_ptr,
|
||||
¶ms,
|
||||
&tile_shape,
|
||||
&cluster_shape,
|
||||
&problem_shape_mnkl
|
||||
};
|
||||
|
||||
// Run the scheduler to completion and log visits to each k iteration
|
||||
run_scheduler<Scheduler, TileShape, ClusterShape><<<grid, 1>>>(
|
||||
visit_counters.get(), params, tile_shape, cluster_shape, problem_shape_mnkl);
|
||||
err = cudaLaunchKernelExC(&launch_config, kernel, kernel_params);
|
||||
|
||||
if (err != cudaSuccess) {
|
||||
std::cerr << __FILE__ << ":" << __LINE__
|
||||
<< " cudaLaunchKernelExC failed with error: "
|
||||
<< cudaGetErrorString(err) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
err = cudaDeviceSynchronize();
|
||||
if (err != cudaSuccess) {
|
||||
std::cerr << __FILE__ << ":" << __LINE__ << " scheduler kernel failed with error: " << cudaGetErrorString(err) << std::endl;
|
||||
std::cerr << __FILE__ << ":" << __LINE__
|
||||
<< " scheduler kernel failed with error: "
|
||||
<< cudaGetErrorString(err) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -143,11 +191,11 @@ test_scheduler(
|
||||
<< " and grid size " << grid.x << "x"
|
||||
<< grid.y << "x" << grid.z
|
||||
<< " splits=" << params.splits_
|
||||
<< " k_iter=" << params.k_iter_per_tile_
|
||||
<< " k_iter=" << params.k_tiles_per_output_tile_
|
||||
<< " big_units=" << params.big_units_
|
||||
<< " sk_tiles=" << params.sk_tiles_
|
||||
<< " sk_units=" << params.sk_units_
|
||||
<< " k_iter_per_sk_unit=" << params.k_iter_per_sk_unit_ << std::endl;
|
||||
<< " k_tiles_per_sk_unit=" << params.k_tiles_per_sk_unit_ << std::endl;
|
||||
std::cout << "Error at idx: " << i << ". Got count " << host_visit_counts[i] << std::endl;
|
||||
return false;
|
||||
}
|
||||
@ -274,4 +322,6 @@ TEST(SM90_Device_Gemm_stream_k_scheduler, 128x128x64_2x1x1) {
|
||||
EXPECT_TRUE(test_scheduler({128, 512, 2048, 1}, tile_shape, cluster_shape, 114));
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
Reference in New Issue
Block a user