CUTLASS 3.1 (#915)

Co-authored-by: Aniket Shivam <ashivam@nvidia.com>
This commit is contained in:
ANIKET SHIVAM
2023-04-14 20:19:34 -07:00
committed by GitHub
parent 9b8166e3f0
commit d572cc1aab
482 changed files with 37184 additions and 16419 deletions

View File

@ -32,11 +32,11 @@
#include <cute/config.hpp>
#include <cute/util/type_traits.hpp>
#include <cute/container/tuple.hpp>
#include <cute/algorithm/functional.hpp>
#include <cute/numeric/integer_sequence.hpp>
#include <cute/numeric/integral_constant.hpp>
#include <cute/util/type_traits.hpp>
/** Common algorithms on (hierarchical) tuples */
/** Style choice:
@ -150,7 +150,7 @@ CUTE_HOST_DEVICE constexpr
auto
for_each_leaf(T&& t, F&& f)
{
if constexpr (is_tuple<std::remove_reference_t<T>>::value) {
if constexpr (is_tuple<remove_cvref_t<T>>::value) {
return detail::apply(static_cast<T&&>(t), [&](auto&&... a){ return (for_each_leaf(static_cast<decltype(a)&&>(a), f), ...); }, tuple_seq<T>{});
} else {
return f(static_cast<T&&>(t));
@ -205,6 +205,20 @@ transform_leaf(T const& t, F&& f)
CUTE_GCC_UNREACHABLE;
}
template <class T0, class T1, class F>
CUTE_HOST_DEVICE constexpr
auto
transform_leaf(T0 const& t0, T1 const& t1, F&& f)
{
if constexpr (is_tuple<T0>::value) {
return transform(t0, t1, [&](auto const& a, auto const& b) { return transform_leaf(a, b, f); });
} else {
return f(t0, t1);
}
CUTE_GCC_UNREACHABLE;
}
//
// find and find_if
//
@ -258,25 +272,40 @@ find(T const& t, X const& x)
}
template <class T, class F>
CUTE_HOST_DEVICE constexpr
auto
none_of(T const& t, F&& f)
{
return cute::integral_constant<bool, decltype(find_if(t, f))::value == std::tuple_size<T>::value>{};
if constexpr (is_tuple<T>::value) {
return cute::integral_constant<bool, decltype(find_if(t, f))::value == tuple_size<T>::value>{};
} else {
return not f(t);
}
CUTE_GCC_UNREACHABLE;
}
template <class T, class F>
CUTE_HOST_DEVICE constexpr
auto
all_of(T const& t, F&& f)
{
auto not_f = [&](auto const& a) { return !f(a); };
return cute::integral_constant<bool, decltype(find_if(t, not_f))::value == std::tuple_size<T>::value>{};
if constexpr (is_tuple<T>::value) {
auto not_f = [&](auto const& a) { return not f(a); };
return cute::integral_constant<bool, decltype(find_if(t, not_f))::value == tuple_size<T>::value>{};
} else {
return f(t);
}
CUTE_GCC_UNREACHABLE;
}
template <class T, class F>
CUTE_HOST_DEVICE constexpr
auto
any_of(T const& t, F&& f)
{
return cute::integral_constant<bool, !decltype(none_of(t, f))::value>{};
return not none_of(t, f);
}
//
@ -340,7 +369,7 @@ CUTE_HOST_DEVICE constexpr
auto
fold(T&& t, V&& v, F&& f)
{
if constexpr (is_tuple<std::remove_reference_t<T>>::value) {
if constexpr (is_tuple<remove_cvref_t<T>>::value) {
return detail::fold(static_cast<T&&>(t),
static_cast<V&&>(v),
f,
@ -357,11 +386,11 @@ CUTE_HOST_DEVICE constexpr
decltype(auto)
fold_first(T&& t, F&& f)
{
if constexpr (is_tuple<std::remove_reference_t<T>>::value) {
if constexpr (is_tuple<remove_cvref_t<T>>::value) {
return detail::fold(static_cast<T&&>(t),
get<0>(static_cast<T&&>(t)),
f,
make_range<1,std::tuple_size<std::remove_reference_t<T>>::value>{});
make_range<1,tuple_size<remove_cvref_t<T>>::value>{});
} else {
return static_cast<T&&>(t);
}
@ -753,12 +782,12 @@ escan(T const& t, V const& v, F&& f)
namespace detail {
template <int J, class T, int... Is>
template <int J, class... Ts>
CUTE_HOST_DEVICE constexpr
auto
zip_(T const& t, seq<Is...>)
zip_(Ts const&... ts)
{
return cute::make_tuple(get<J>(get<Is>(t))...);
return cute::make_tuple(get<J>(ts)...);
}
template <class T, int... Is, int... Js>
@ -767,7 +796,7 @@ auto
zip(T const& t, seq<Is...>, seq<Js...>)
{
static_assert(conjunction<bool_constant<tuple_size<tuple_element_t<0,T>>::value == tuple_size<tuple_element_t<Is,T>>::value>...>::value, "Mismatched Ranks");
return cute::make_tuple(detail::zip_<Js>(t, seq<Is...>{})...);
return cute::make_tuple(zip_<Js>(get<Is>(t)...)...);
}
} // end namespace detail
@ -817,8 +846,8 @@ zip2_by(T const& t, TG const& guide, seq<Is...>, seq<Js...>)
auto split = cute::make_tuple(zip2_by(get<Is>(t), get<Is>(guide))...);
// Rearrange and append missing modes from t to make ((A,B,...),(a,b,...,x,y))
return cute::make_tuple(cute::make_tuple(get<Is,0>(split)...),
cute::make_tuple(get<Is,1>(split)..., get<Js>(t)...));
return cute::make_tuple(cute::make_tuple(get<0>(get<Is>(split))...),
cute::make_tuple(get<1>(get<Is>(split))..., get<Js>(t)...));
}
} // end namespace detail