Updates for CUTLASS 3.5.0 (#1468)

This commit is contained in:
Vijay Thakkar
2024-04-11 21:33:40 -04:00
committed by GitHub
parent a40e08e9d5
commit 7d49e6c7e2
171 changed files with 7526 additions and 1888 deletions

View File

@ -35,22 +35,22 @@
#include <cute/tensor.hpp>
template <class Layout, class CoSizeHi>
template <class Layout, class CoTarget>
void
test_complement(Layout const& layout, CoSizeHi const& cosize_hi)
test_complement(Layout const& layout, CoTarget const& cotarget)
{
using namespace cute;
auto result = complement(layout, cosize_hi);
auto result = complement(layout, cotarget);
CUTLASS_TRACE_HOST("complement(" << layout << ", " << cosize_hi << ") => " << result);
CUTLASS_TRACE_HOST("complement(" << layout << ", " << cotarget << ") => " << result);
auto completed = make_layout(layout, result);
// Lower-bound on the codomain size of the layout ++ complement (1)
EXPECT_GE(cosize(completed), cosize_hi);
EXPECT_GE(cosize(completed), size(cotarget));
// Upper-bound on the codomain size of the complement (2)
EXPECT_LE(cosize(result), cute::round_up(cosize_hi, cosize(layout)));
EXPECT_LE(cosize(result), cute::round_up(size(cotarget), cosize(layout)));
// Post-condition on the codomain of the complement
for (int i = 1; i < size(result); ++i) {
@ -62,9 +62,9 @@ test_complement(Layout const& layout, CoSizeHi const& cosize_hi)
// Other observations
EXPECT_LE(size(result), cosize(result)); // As a result of the ordered condition (3)
EXPECT_GE(size(result), cosize_hi / size(filter(layout)));
EXPECT_GE(size(result), size(cotarget) / size(filter(layout)));
EXPECT_LE(cosize(completed), cosize(result) + cosize(layout));
EXPECT_GE(cosize(result), cosize_hi / size(filter(layout)));
EXPECT_GE(cosize(result), size(cotarget) / size(filter(layout)));
if constexpr (is_static<decltype(stride(completed))>::value) { // If we can apply complement again
EXPECT_EQ(size(complement(completed)), 1); // There's no more codomain left over
}
@ -90,6 +90,8 @@ TEST(CuTe_core, Complement)
test_complement(layout);
test_complement(layout, Int<2>{});
test_complement(layout, Int<5>{});
test_complement(layout, make_shape(Int<2>{}, 2));
}
{
@ -97,6 +99,8 @@ TEST(CuTe_core, Complement)
test_complement(layout);
test_complement(layout, Int<2>{});
test_complement(layout, Int<5>{});
test_complement(layout, make_shape(Int<2>{}, 2));
}
{
@ -105,6 +109,8 @@ TEST(CuTe_core, Complement)
test_complement(layout, Int<1>{});
test_complement(layout, Int<2>{});
test_complement(layout, Int<8>{});
test_complement(layout, Int<5>{});
test_complement(layout, make_shape(Int<2>{}, 2));
}
{
@ -130,6 +136,7 @@ TEST(CuTe_core, Complement)
test_complement(layout);
test_complement(layout, Int<16>{});
test_complement(layout, Int<19>{});
test_complement(layout, make_shape(Int<2>{}, 2));
}
{
@ -138,6 +145,7 @@ TEST(CuTe_core, Complement)
test_complement(layout, Int<1>{});
test_complement(layout);
test_complement(layout, Int<17>{});
test_complement(layout, make_shape(Int<2>{}, 2));
}
{
@ -193,8 +201,8 @@ TEST(CuTe_core, Complement)
// Fails due to non-injective layout
// {
// auto layout = make_layout(Shape<Shape<_2,_2>,Shape<_2, _2>>{},
// Stride<Stride<_1,_8>,Stride<_8,_4>>{});
// auto layout = make_layout(Shape <Shape <_2,_2>,Shape <_2,_2>>{},
// Stride<Stride<_1,_8>,Stride<_8,_4>>{});
// test_complement(layout);
// }
@ -289,4 +297,11 @@ TEST(CuTe_core, Complement)
test_complement(layout);
}
{
auto layout = make_layout(Int<64>{});
test_complement(layout, make_shape(Int<32>{}, Int<4>{}, Int<4>{}));
test_complement(layout, make_shape(Int<32>{}, Int<4>{}, 4));
}
}