CUTLASS 3.5.1 (#1623)

* CUTLASS 3.5.1

* updates, optimizations, fixes
This commit is contained in:
Vijay Thakkar
2024-07-29 08:46:24 -04:00
committed by GitHub
parent 56b46e2d13
commit be60a0b272
312 changed files with 19793 additions and 6775 deletions

View File

@ -404,29 +404,54 @@ namespace detail {
// This impl compiles much faster than cute::apply and variadic args
template <class T, class V, class F>
CUTE_HOST_DEVICE constexpr
decltype(auto)
fold(T&& t, V&& v, F&& f, seq<>)
auto
fold(T&&, V&& v, F&&, seq<>)
{
return static_cast<V&&>(v);
return v;
}
template <class T, class V, class F, int I, int... Is>
template <class T, class V, class F, int I0>
CUTE_HOST_DEVICE constexpr
decltype(auto)
fold(T&& t, V&& v, F&& f, seq<I,Is...>)
auto
fold(T&& t, V&& v, F&& f, seq<I0>)
{
if constexpr (sizeof...(Is) == 0) {
return f(static_cast<V&&>(v), get<I>(static_cast<T&&>(t)));
} else {
return fold(static_cast<T&&>(t),
f(static_cast<V&&>(v), get<I>(static_cast<T&&>(t))),
f,
seq<Is...>{});
}
CUTE_GCC_UNREACHABLE;
return f(static_cast<V&&>(v), get<I0>(static_cast<T&&>(t)));
}
template <class T, class V, class F, int I0, int I1>
CUTE_HOST_DEVICE constexpr
auto
fold(T&& t, V&& v, F&& f, seq<I0,I1>)
{
return f(f(static_cast<V&&>(v), get<I0>(static_cast<T&&>(t))), get<I1>(static_cast<T&&>(t)));
}
template <class T, class V, class F, int I0, int I1, int I2>
CUTE_HOST_DEVICE constexpr
auto
fold(T&& t, V&& v, F&& f, seq<I0,I1,I2>)
{
return f(f(f(static_cast<V&&>(v), get<I0>(static_cast<T&&>(t))), get<I1>(static_cast<T&&>(t))), get<I2>(static_cast<T&&>(t)));
}
template <class T, class V, class F, int I0, int I1, int I2, int I3>
CUTE_HOST_DEVICE constexpr
auto
fold(T&& t, V&& v, F&& f, seq<I0,I1,I2,I3>)
{
return f(f(f(f(static_cast<V&&>(v), get<I0>(static_cast<T&&>(t))), get<I1>(static_cast<T&&>(t))), get<I2>(static_cast<T&&>(t))), get<I3>(static_cast<T&&>(t)));
}
template <class T, class V, class F, int I0, int I1, int I2, int I3, int... Is>
CUTE_HOST_DEVICE constexpr
auto
fold(T&& t, V&& v, F&& f, seq<I0,I1,I2,I3,Is...>)
{
return fold(static_cast<T&&>(t),
f(f(f(f(static_cast<V&&>(v), get<I0>(static_cast<T&&>(t))), get<I1>(static_cast<T&&>(t))), get<I2>(static_cast<T&&>(t))), get<I3>(static_cast<T&&>(t))),
f,
seq<Is...>{});
}
} // end namespace detail
template <class T, class V, class F>
@ -448,7 +473,7 @@ fold(T&& t, V&& v, F&& f)
template <class T, class F>
CUTE_HOST_DEVICE constexpr
decltype(auto)
auto
fold_first(T&& t, F&& f)
{
if constexpr (is_tuple<remove_cvref_t<T>>::value) {
@ -457,7 +482,7 @@ fold_first(T&& t, F&& f)
f,
make_range<1,tuple_size<remove_cvref_t<T>>::value>{});
} else {
return static_cast<T&&>(t);
return t;
}
CUTE_GCC_UNREACHABLE;
@ -701,7 +726,14 @@ CUTE_HOST_DEVICE constexpr
auto
replace(T const& t, X const& x)
{
return detail::construct(t, x, make_seq<N>{}, seq<0>{}, make_range<N+1,tuple_size<T>::value>{});
if constexpr (is_tuple<T>::value) {
return detail::construct(t, x, make_seq<N>{}, seq<0>{}, make_range<N+1,tuple_size<T>::value>{});
} else {
static_assert(N == 0);
return x;
}
CUTE_GCC_UNREACHABLE;
}
// Replace the first element of the tuple with x
@ -1077,9 +1109,9 @@ zip2_by(T const& t, TG const& guide)
/// @return A tuple of the elements of @c t in reverse order.
template <class T>
CUTE_HOST_DEVICE constexpr
CUTE_HOST_DEVICE constexpr
auto
reverse(T const& t)
reverse(T const& t)
{
if constexpr (is_tuple<T>::value) {
return detail::apply(t, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_rseq<T>{});