Updates for CUTLASS 3.5.0 (#1468)
This commit is contained in:
@ -35,22 +35,22 @@
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
template <class Layout, class CoSizeHi>
|
||||
template <class Layout, class CoTarget>
|
||||
void
|
||||
test_complement(Layout const& layout, CoSizeHi const& cosize_hi)
|
||||
test_complement(Layout const& layout, CoTarget const& cotarget)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
auto result = complement(layout, cosize_hi);
|
||||
auto result = complement(layout, cotarget);
|
||||
|
||||
CUTLASS_TRACE_HOST("complement(" << layout << ", " << cosize_hi << ") => " << result);
|
||||
CUTLASS_TRACE_HOST("complement(" << layout << ", " << cotarget << ") => " << result);
|
||||
|
||||
auto completed = make_layout(layout, result);
|
||||
|
||||
// Lower-bound on the codomain size of the layout ++ complement (1)
|
||||
EXPECT_GE(cosize(completed), cosize_hi);
|
||||
EXPECT_GE(cosize(completed), size(cotarget));
|
||||
// Upper-bound on the codomain size of the complement (2)
|
||||
EXPECT_LE(cosize(result), cute::round_up(cosize_hi, cosize(layout)));
|
||||
EXPECT_LE(cosize(result), cute::round_up(size(cotarget), cosize(layout)));
|
||||
|
||||
// Post-condition on the codomain of the complement
|
||||
for (int i = 1; i < size(result); ++i) {
|
||||
@ -62,9 +62,9 @@ test_complement(Layout const& layout, CoSizeHi const& cosize_hi)
|
||||
|
||||
// Other observations
|
||||
EXPECT_LE(size(result), cosize(result)); // As a result of the ordered condition (3)
|
||||
EXPECT_GE(size(result), cosize_hi / size(filter(layout)));
|
||||
EXPECT_GE(size(result), size(cotarget) / size(filter(layout)));
|
||||
EXPECT_LE(cosize(completed), cosize(result) + cosize(layout));
|
||||
EXPECT_GE(cosize(result), cosize_hi / size(filter(layout)));
|
||||
EXPECT_GE(cosize(result), size(cotarget) / size(filter(layout)));
|
||||
if constexpr (is_static<decltype(stride(completed))>::value) { // If we can apply complement again
|
||||
EXPECT_EQ(size(complement(completed)), 1); // There's no more codomain left over
|
||||
}
|
||||
@ -90,6 +90,8 @@ TEST(CuTe_core, Complement)
|
||||
|
||||
test_complement(layout);
|
||||
test_complement(layout, Int<2>{});
|
||||
test_complement(layout, Int<5>{});
|
||||
test_complement(layout, make_shape(Int<2>{}, 2));
|
||||
}
|
||||
|
||||
{
|
||||
@ -97,6 +99,8 @@ TEST(CuTe_core, Complement)
|
||||
|
||||
test_complement(layout);
|
||||
test_complement(layout, Int<2>{});
|
||||
test_complement(layout, Int<5>{});
|
||||
test_complement(layout, make_shape(Int<2>{}, 2));
|
||||
}
|
||||
|
||||
{
|
||||
@ -105,6 +109,8 @@ TEST(CuTe_core, Complement)
|
||||
test_complement(layout, Int<1>{});
|
||||
test_complement(layout, Int<2>{});
|
||||
test_complement(layout, Int<8>{});
|
||||
test_complement(layout, Int<5>{});
|
||||
test_complement(layout, make_shape(Int<2>{}, 2));
|
||||
}
|
||||
|
||||
{
|
||||
@ -130,6 +136,7 @@ TEST(CuTe_core, Complement)
|
||||
test_complement(layout);
|
||||
test_complement(layout, Int<16>{});
|
||||
test_complement(layout, Int<19>{});
|
||||
test_complement(layout, make_shape(Int<2>{}, 2));
|
||||
}
|
||||
|
||||
{
|
||||
@ -138,6 +145,7 @@ TEST(CuTe_core, Complement)
|
||||
test_complement(layout, Int<1>{});
|
||||
test_complement(layout);
|
||||
test_complement(layout, Int<17>{});
|
||||
test_complement(layout, make_shape(Int<2>{}, 2));
|
||||
}
|
||||
|
||||
{
|
||||
@ -193,8 +201,8 @@ TEST(CuTe_core, Complement)
|
||||
|
||||
// Fails due to non-injective layout
|
||||
// {
|
||||
// auto layout = make_layout(Shape<Shape<_2,_2>,Shape<_2, _2>>{},
|
||||
// Stride<Stride<_1,_8>,Stride<_8,_4>>{});
|
||||
// auto layout = make_layout(Shape <Shape <_2,_2>,Shape <_2,_2>>{},
|
||||
// Stride<Stride<_1,_8>,Stride<_8,_4>>{});
|
||||
|
||||
// test_complement(layout);
|
||||
// }
|
||||
@ -289,4 +297,11 @@ TEST(CuTe_core, Complement)
|
||||
|
||||
test_complement(layout);
|
||||
}
|
||||
|
||||
{
|
||||
auto layout = make_layout(Int<64>{});
|
||||
|
||||
test_complement(layout, make_shape(Int<32>{}, Int<4>{}, Int<4>{}));
|
||||
test_complement(layout, make_shape(Int<32>{}, Int<4>{}, 4));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user