@ -32,11 +32,11 @@
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/util/type_traits.hpp>
|
||||
#include <cute/container/tuple.hpp>
|
||||
#include <cute/algorithm/functional.hpp>
|
||||
#include <cute/numeric/integer_sequence.hpp>
|
||||
#include <cute/numeric/integral_constant.hpp>
|
||||
#include <cute/util/type_traits.hpp>
|
||||
|
||||
/** Common algorithms on (hierarchical) tuples */
|
||||
/** Style choice:
|
||||
@ -150,7 +150,7 @@ CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
for_each_leaf(T&& t, F&& f)
|
||||
{
|
||||
if constexpr (is_tuple<std::remove_reference_t<T>>::value) {
|
||||
if constexpr (is_tuple<remove_cvref_t<T>>::value) {
|
||||
return detail::apply(static_cast<T&&>(t), [&](auto&&... a){ return (for_each_leaf(static_cast<decltype(a)&&>(a), f), ...); }, tuple_seq<T>{});
|
||||
} else {
|
||||
return f(static_cast<T&&>(t));
|
||||
@ -205,6 +205,20 @@ transform_leaf(T const& t, F&& f)
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class T0, class T1, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
transform_leaf(T0 const& t0, T1 const& t1, F&& f)
|
||||
{
|
||||
if constexpr (is_tuple<T0>::value) {
|
||||
return transform(t0, t1, [&](auto const& a, auto const& b) { return transform_leaf(a, b, f); });
|
||||
} else {
|
||||
return f(t0, t1);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// find and find_if
|
||||
//
|
||||
@ -258,25 +272,40 @@ find(T const& t, X const& x)
|
||||
}
|
||||
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
none_of(T const& t, F&& f)
|
||||
{
|
||||
return cute::integral_constant<bool, decltype(find_if(t, f))::value == std::tuple_size<T>::value>{};
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return cute::integral_constant<bool, decltype(find_if(t, f))::value == tuple_size<T>::value>{};
|
||||
} else {
|
||||
return not f(t);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
all_of(T const& t, F&& f)
|
||||
{
|
||||
auto not_f = [&](auto const& a) { return !f(a); };
|
||||
return cute::integral_constant<bool, decltype(find_if(t, not_f))::value == std::tuple_size<T>::value>{};
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
auto not_f = [&](auto const& a) { return not f(a); };
|
||||
return cute::integral_constant<bool, decltype(find_if(t, not_f))::value == tuple_size<T>::value>{};
|
||||
} else {
|
||||
return f(t);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
any_of(T const& t, F&& f)
|
||||
{
|
||||
return cute::integral_constant<bool, !decltype(none_of(t, f))::value>{};
|
||||
return not none_of(t, f);
|
||||
}
|
||||
|
||||
//
|
||||
@ -340,7 +369,7 @@ CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
fold(T&& t, V&& v, F&& f)
|
||||
{
|
||||
if constexpr (is_tuple<std::remove_reference_t<T>>::value) {
|
||||
if constexpr (is_tuple<remove_cvref_t<T>>::value) {
|
||||
return detail::fold(static_cast<T&&>(t),
|
||||
static_cast<V&&>(v),
|
||||
f,
|
||||
@ -357,11 +386,11 @@ CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
fold_first(T&& t, F&& f)
|
||||
{
|
||||
if constexpr (is_tuple<std::remove_reference_t<T>>::value) {
|
||||
if constexpr (is_tuple<remove_cvref_t<T>>::value) {
|
||||
return detail::fold(static_cast<T&&>(t),
|
||||
get<0>(static_cast<T&&>(t)),
|
||||
f,
|
||||
make_range<1,std::tuple_size<std::remove_reference_t<T>>::value>{});
|
||||
make_range<1,tuple_size<remove_cvref_t<T>>::value>{});
|
||||
} else {
|
||||
return static_cast<T&&>(t);
|
||||
}
|
||||
@ -753,12 +782,12 @@ escan(T const& t, V const& v, F&& f)
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <int J, class T, int... Is>
|
||||
template <int J, class... Ts>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
zip_(T const& t, seq<Is...>)
|
||||
zip_(Ts const&... ts)
|
||||
{
|
||||
return cute::make_tuple(get<J>(get<Is>(t))...);
|
||||
return cute::make_tuple(get<J>(ts)...);
|
||||
}
|
||||
|
||||
template <class T, int... Is, int... Js>
|
||||
@ -767,7 +796,7 @@ auto
|
||||
zip(T const& t, seq<Is...>, seq<Js...>)
|
||||
{
|
||||
static_assert(conjunction<bool_constant<tuple_size<tuple_element_t<0,T>>::value == tuple_size<tuple_element_t<Is,T>>::value>...>::value, "Mismatched Ranks");
|
||||
return cute::make_tuple(detail::zip_<Js>(t, seq<Is...>{})...);
|
||||
return cute::make_tuple(zip_<Js>(get<Is>(t)...)...);
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
@ -817,8 +846,8 @@ zip2_by(T const& t, TG const& guide, seq<Is...>, seq<Js...>)
|
||||
auto split = cute::make_tuple(zip2_by(get<Is>(t), get<Is>(guide))...);
|
||||
|
||||
// Rearrange and append missing modes from t to make ((A,B,...),(a,b,...,x,y))
|
||||
return cute::make_tuple(cute::make_tuple(get<Is,0>(split)...),
|
||||
cute::make_tuple(get<Is,1>(split)..., get<Js>(t)...));
|
||||
return cute::make_tuple(cute::make_tuple(get<0>(get<Is>(split))...),
|
||||
cute::make_tuple(get<1>(get<Is>(split))..., get<Js>(t)...));
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
Reference in New Issue
Block a user