CUTLASS 3.5.0 (#1411)
This commit is contained in:
@ -204,36 +204,6 @@ for_each_leaf(T&& t, F&& f)
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// For Sequence
|
||||
// (s, t, f) => (f(t[s_0]),f(t[s_1]),...,f(t[s_n]))
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <int... I, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
for_sequence(seq<I...> const&, F&& f) {
|
||||
(f(Int<I>{}), ...);
|
||||
}
|
||||
|
||||
}; // end namespace detail
|
||||
|
||||
template <int... I, class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
for_sequence(seq<I...> const& s, T&& t, F&& f) {
|
||||
detail::for_sequence(s, [&](auto&& i){ f(get<remove_cvref_t<decltype(i)>::value>(static_cast<T&&>(t))); });
|
||||
}
|
||||
|
||||
template <int I, class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
for_sequence(T&& t, F&& f) {
|
||||
for_sequence(make_seq<I>{}, static_cast<T&&>(t), static_cast<F&&>(f));
|
||||
}
|
||||
|
||||
//
|
||||
// Transform
|
||||
// (t, f) => (f(t_0),f(t_1),...,f(t_n))
|
||||
@ -551,15 +521,15 @@ take(T const& t)
|
||||
template <int... I, class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
select(T const & t)
|
||||
select(T const& t)
|
||||
{
|
||||
return cute::make_tuple(get<I>(t)...);
|
||||
}
|
||||
|
||||
template <class T, typename Indices>
|
||||
template <class T, class Indices>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
select(T const & t, Indices const & indices)
|
||||
select(T const& t, Indices const& indices)
|
||||
{
|
||||
if constexpr (is_tuple<Indices>::value) {
|
||||
return cute::transform(indices, [&t](auto i) { return select(t, i); });
|
||||
@ -655,7 +625,7 @@ flatten(T const& t)
|
||||
|
||||
namespace detail {
|
||||
|
||||
template<class FlatTuple, class TargetProfile>
|
||||
template <class FlatTuple, class TargetProfile>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
unflatten_impl(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
|
||||
@ -680,7 +650,7 @@ unflatten_impl(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
|
||||
// @pre rank(flatten(@a target_profile)) == rank(@a flat_tuple)
|
||||
// @post congruent(@a result, @a target_profile)
|
||||
// @post flatten(@a result) == @a flat_tuple
|
||||
template<class FlatTuple, class TargetProfile>
|
||||
template <class FlatTuple, class TargetProfile>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
unflatten(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
|
||||
@ -865,6 +835,7 @@ append(T const& a, X const& x)
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class T, class X>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
@ -902,6 +873,7 @@ prepend(T const& a, X const& x)
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class T, class X>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
@ -1105,14 +1077,13 @@ zip2_by(T const& t, TG const& guide)
|
||||
|
||||
/// @return A tuple of the elements of @c t in reverse order.
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr auto
|
||||
reverse(T const& t) {
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
reverse(T const& t)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return detail::apply(t, [] (auto const&... a) {
|
||||
return cute::make_tuple(a...);
|
||||
}, tuple_rseq<T>{});
|
||||
}
|
||||
else {
|
||||
return detail::apply(t, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_rseq<T>{});
|
||||
} else {
|
||||
return t;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user