Updates for 3.4 release. (#1305)
This commit is contained in:
@ -604,8 +604,7 @@ unwrap(T const& t)
|
||||
}
|
||||
|
||||
//
|
||||
// Flatten a hierarchical tuple to a tuple of depth one.
|
||||
//
|
||||
// Flatten and Unflatten
|
||||
//
|
||||
|
||||
template <class T>
|
||||
@ -614,13 +613,15 @@ struct is_flat : true_type {};
|
||||
template <class... Ts>
|
||||
struct is_flat<tuple<Ts...>> : bool_constant<(true && ... && (not is_tuple<Ts>::value))> {};
|
||||
|
||||
// Flatten a hierarchical tuple to a tuple of depth one
|
||||
// and wrap non-tuples into a rank-1 tuple.
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
flatten_to_tuple(T const& t)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
if constexpr (is_flat<T>::value) {
|
||||
if constexpr (is_flat<T>::value) { // Shortcut for perf
|
||||
return t;
|
||||
} else {
|
||||
return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); });
|
||||
@ -632,13 +633,15 @@ flatten_to_tuple(T const& t)
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
// Flatten a hierarchical tuple to a tuple of depth one
|
||||
// and leave non-tuple untouched.
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
flatten(T const& t)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
if constexpr (is_flat<T>::value) {
|
||||
if constexpr (is_flat<T>::value) { // Shortcut for perf
|
||||
return t;
|
||||
} else {
|
||||
return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); });
|
||||
@ -650,6 +653,43 @@ flatten(T const& t)
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
template<class FlatTuple, class TargetProfile>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
unflatten_impl(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
|
||||
{
|
||||
if constexpr (is_tuple<TargetProfile>::value) {
|
||||
return fold(target_profile, cute::make_tuple(cute::make_tuple(), flat_tuple), [](auto const& v, auto const& t) {
|
||||
auto [result, remaining_tuple] = v;
|
||||
auto [sub_result, sub_tuple] = unflatten_impl(remaining_tuple, t);
|
||||
return cute::make_tuple(append(result, sub_result), sub_tuple);
|
||||
});
|
||||
} else {
|
||||
return cute::make_tuple(get<0>(flat_tuple), take<1, decltype(rank(flat_tuple))::value>(flat_tuple));
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
// Unflatten a flat tuple into a hierarchical tuple
|
||||
// @pre flatten(@a flat_tuple) == @a flat_tuple
|
||||
// @pre rank(flatten(@a target_profile)) == rank(@a flat_tuple)
|
||||
// @post congruent(@a result, @a target_profile)
|
||||
// @post flatten(@a result) == @a flat_tuple
|
||||
template<class FlatTuple, class TargetProfile>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
unflatten(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
|
||||
{
|
||||
auto [unflatten_tuple, flat_remainder] = detail::unflatten_impl(flat_tuple, target_profile);
|
||||
CUTE_STATIC_ASSERT_V(rank(flat_remainder) == Int<0>{});
|
||||
return unflatten_tuple;
|
||||
}
|
||||
|
||||
//
|
||||
// insert and remove and replace
|
||||
//
|
||||
@ -728,6 +768,18 @@ replace_back(T const& t, X const& x)
|
||||
// Make a tuple of Xs of tuple_size N
|
||||
//
|
||||
|
||||
template <int N, class X>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tuple_repeat(X const& x)
|
||||
{
|
||||
return detail::construct(0, x, seq<>{}, make_seq<N>{}, seq<>{});
|
||||
}
|
||||
|
||||
//
|
||||
// Make repeated Xs of rank N
|
||||
//
|
||||
|
||||
template <int N, class X>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
@ -743,7 +795,7 @@ repeat(X const& x)
|
||||
}
|
||||
|
||||
//
|
||||
// Make a tuple of Xs the same profile as tuple
|
||||
// Make a tuple of Xs the same profile as tuple T
|
||||
//
|
||||
|
||||
template <class T, class X>
|
||||
@ -864,48 +916,6 @@ prepend(T const& a, X const& x)
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// Unflatten a flat tuple into a hierarchical one
|
||||
// unflatten(x, flatten(x)) == x
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
template<class FlatTuple, class TargetProfile>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
unflatten_impl(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
|
||||
{
|
||||
if constexpr (is_tuple<TargetProfile>::value) {
|
||||
return fold(target_profile, cute::make_tuple(cute::make_tuple(), flat_tuple), [](auto const& v, auto const& t) {
|
||||
auto [result, remaining_tuple] = v;
|
||||
auto [sub_result, sub_tuple] = unflatten_impl(remaining_tuple, t);
|
||||
return cute::make_tuple(append(result, sub_result), sub_tuple);
|
||||
});
|
||||
} else {
|
||||
return cute::make_tuple(get<0>(flat_tuple), take<1, decltype(rank(flat_tuple))::value>(flat_tuple));
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
// @pre flatten(@a flat_tuple) == @a flat_tuple
|
||||
// @pre rank(flatten(@a target_profile)) == rank(@a flat_tuple)
|
||||
// @post congruent(@a result, @a target_profile)
|
||||
// @post flatten(@a result) == @a flat_tuple
|
||||
template<class FlatTuple, class TargetProfile>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
unflatten(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
|
||||
{
|
||||
auto [unflatten_tuple, flat_remainder] = detail::unflatten_impl(flat_tuple, target_profile);
|
||||
CUTE_STATIC_ASSERT_V(rank(flat_remainder) == Int<0>{});
|
||||
return unflatten_tuple;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Inclusive scan (prefix sum)
|
||||
//
|
||||
|
||||
Reference in New Issue
Block a user