CUTLASS 3.5.0 (#1411)

This commit is contained in:
Vijay Thakkar
2024-03-19 17:51:04 -04:00
committed by GitHub
parent ffa34e7075
commit 629f4653c3
468 changed files with 48730 additions and 7253 deletions

View File

@ -204,36 +204,6 @@ for_each_leaf(T&& t, F&& f)
CUTE_GCC_UNREACHABLE;
}
//
// For Sequence
// (s, t, f) => (f(t[s_0]),f(t[s_1]),...,f(t[s_n]))
//
namespace detail {
template <int... I, class F>
CUTE_HOST_DEVICE constexpr
void
for_sequence(seq<I...> const&, F&& f) {
(f(Int<I>{}), ...);
}
}; // end namespace detail
template <int... I, class T, class F>
CUTE_HOST_DEVICE constexpr
void
for_sequence(seq<I...> const& s, T&& t, F&& f) {
detail::for_sequence(s, [&](auto&& i){ f(get<remove_cvref_t<decltype(i)>::value>(static_cast<T&&>(t))); });
}
template <int I, class T, class F>
CUTE_HOST_DEVICE constexpr
void
for_sequence(T&& t, F&& f) {
for_sequence(make_seq<I>{}, static_cast<T&&>(t), static_cast<F&&>(f));
}
//
// Transform
// (t, f) => (f(t_0),f(t_1),...,f(t_n))
@ -551,15 +521,15 @@ take(T const& t)
template <int... I, class T>
CUTE_HOST_DEVICE constexpr
auto
select(T const & t)
select(T const& t)
{
return cute::make_tuple(get<I>(t)...);
}
template <class T, typename Indices>
template <class T, class Indices>
CUTE_HOST_DEVICE constexpr
auto
select(T const & t, Indices const & indices)
select(T const& t, Indices const& indices)
{
if constexpr (is_tuple<Indices>::value) {
return cute::transform(indices, [&t](auto i) { return select(t, i); });
@ -655,7 +625,7 @@ flatten(T const& t)
namespace detail {
template<class FlatTuple, class TargetProfile>
template <class FlatTuple, class TargetProfile>
CUTE_HOST_DEVICE constexpr
auto
unflatten_impl(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
@ -680,7 +650,7 @@ unflatten_impl(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
// @pre rank(flatten(@a target_profile)) == rank(@a flat_tuple)
// @post congruent(@a result, @a target_profile)
// @post flatten(@a result) == @a flat_tuple
template<class FlatTuple, class TargetProfile>
template <class FlatTuple, class TargetProfile>
CUTE_HOST_DEVICE constexpr
auto
unflatten(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
@ -865,6 +835,7 @@ append(T const& a, X const& x)
CUTE_GCC_UNREACHABLE;
}
template <class T, class X>
CUTE_HOST_DEVICE constexpr
auto
@ -902,6 +873,7 @@ prepend(T const& a, X const& x)
CUTE_GCC_UNREACHABLE;
}
template <class T, class X>
CUTE_HOST_DEVICE constexpr
auto
@ -1105,14 +1077,13 @@ zip2_by(T const& t, TG const& guide)
/// @return A tuple of the elements of @c t in reverse order.
template <class T>
CUTE_HOST_DEVICE constexpr auto
reverse(T const& t) {
CUTE_HOST_DEVICE constexpr
auto
reverse(T const& t)
{
if constexpr (is_tuple<T>::value) {
return detail::apply(t, [] (auto const&... a) {
return cute::make_tuple(a...);
}, tuple_rseq<T>{});
}
else {
return detail::apply(t, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_rseq<T>{});
} else {
return t;
}
}