Updates for 3.4 release. (#1305)

This commit is contained in:
ANIKET SHIVAM
2024-01-16 10:42:51 -08:00
committed by GitHub
parent acba5beee5
commit 2f589ffa76
166 changed files with 5996 additions and 4702 deletions

View File

@ -604,8 +604,7 @@ unwrap(T const& t)
}
//
// Flatten a hierarchical tuple to a tuple of depth one.
//
// Flatten and Unflatten
//
template <class T>
@ -614,13 +613,15 @@ struct is_flat : true_type {};
template <class... Ts>
struct is_flat<tuple<Ts...>> : bool_constant<(true && ... && (not is_tuple<Ts>::value))> {};
// Flatten a hierarchical tuple to a tuple of depth one
// and wrap non-tuples into a rank-1 tuple.
template <class T>
CUTE_HOST_DEVICE constexpr
auto
flatten_to_tuple(T const& t)
{
if constexpr (is_tuple<T>::value) {
if constexpr (is_flat<T>::value) {
if constexpr (is_flat<T>::value) { // Shortcut for perf
return t;
} else {
return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); });
@ -632,13 +633,15 @@ flatten_to_tuple(T const& t)
CUTE_GCC_UNREACHABLE;
}
// Flatten a hierarchical tuple to a tuple of depth one
// and leave non-tuple untouched.
template <class T>
CUTE_HOST_DEVICE constexpr
auto
flatten(T const& t)
{
if constexpr (is_tuple<T>::value) {
if constexpr (is_flat<T>::value) {
if constexpr (is_flat<T>::value) { // Shortcut for perf
return t;
} else {
return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); });
@ -650,6 +653,43 @@ flatten(T const& t)
CUTE_GCC_UNREACHABLE;
}
namespace detail {
template<class FlatTuple, class TargetProfile>
CUTE_HOST_DEVICE constexpr
auto
unflatten_impl(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
{
if constexpr (is_tuple<TargetProfile>::value) {
return fold(target_profile, cute::make_tuple(cute::make_tuple(), flat_tuple), [](auto const& v, auto const& t) {
auto [result, remaining_tuple] = v;
auto [sub_result, sub_tuple] = unflatten_impl(remaining_tuple, t);
return cute::make_tuple(append(result, sub_result), sub_tuple);
});
} else {
return cute::make_tuple(get<0>(flat_tuple), take<1, decltype(rank(flat_tuple))::value>(flat_tuple));
}
CUTE_GCC_UNREACHABLE;
}
} // end namespace detail
// Unflatten a flat tuple into a hierarchical tuple
// @pre flatten(@a flat_tuple) == @a flat_tuple
// @pre rank(flatten(@a target_profile)) == rank(@a flat_tuple)
// @post congruent(@a result, @a target_profile)
// @post flatten(@a result) == @a flat_tuple
template<class FlatTuple, class TargetProfile>
CUTE_HOST_DEVICE constexpr
auto
unflatten(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
{
auto [unflatten_tuple, flat_remainder] = detail::unflatten_impl(flat_tuple, target_profile);
CUTE_STATIC_ASSERT_V(rank(flat_remainder) == Int<0>{});
return unflatten_tuple;
}
//
// insert and remove and replace
//
@ -728,6 +768,18 @@ replace_back(T const& t, X const& x)
// Make a tuple of Xs of tuple_size N
//
template <int N, class X>
CUTE_HOST_DEVICE constexpr
auto
tuple_repeat(X const& x)
{
return detail::construct(0, x, seq<>{}, make_seq<N>{}, seq<>{});
}
//
// Make repeated Xs of rank N
//
template <int N, class X>
CUTE_HOST_DEVICE constexpr
auto
@ -743,7 +795,7 @@ repeat(X const& x)
}
//
// Make a tuple of Xs the same profile as tuple
// Make a tuple of Xs the same profile as tuple T
//
template <class T, class X>
@ -864,48 +916,6 @@ prepend(T const& a, X const& x)
CUTE_GCC_UNREACHABLE;
}
//
// Unflatten a flat tuple into a hierarchical one
// unflatten(x, flatten(x)) == x
//
namespace detail {
template<class FlatTuple, class TargetProfile>
CUTE_HOST_DEVICE constexpr
auto
unflatten_impl(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
{
if constexpr (is_tuple<TargetProfile>::value) {
return fold(target_profile, cute::make_tuple(cute::make_tuple(), flat_tuple), [](auto const& v, auto const& t) {
auto [result, remaining_tuple] = v;
auto [sub_result, sub_tuple] = unflatten_impl(remaining_tuple, t);
return cute::make_tuple(append(result, sub_result), sub_tuple);
});
} else {
return cute::make_tuple(get<0>(flat_tuple), take<1, decltype(rank(flat_tuple))::value>(flat_tuple));
}
CUTE_GCC_UNREACHABLE;
}
} // end namespace detail
// @pre flatten(@a flat_tuple) == @a flat_tuple
// @pre rank(flatten(@a target_profile)) == rank(@a flat_tuple)
// @post congruent(@a result, @a target_profile)
// @post flatten(@a result) == @a flat_tuple
template<class FlatTuple, class TargetProfile>
CUTE_HOST_DEVICE constexpr
auto
unflatten(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
{
auto [unflatten_tuple, flat_remainder] = detail::unflatten_impl(flat_tuple, target_profile);
CUTE_STATIC_ASSERT_V(rank(flat_remainder) == Int<0>{});
return unflatten_tuple;
}
//
// Inclusive scan (prefix sum)
//