CUTLASS 3.2.1 (#1113)

* Updates for 3.2.1 release.

* Minor fix in gemm op profiler for raster order.

* Add scheduler mapping for raster order in the kernels.
This commit is contained in:
ANIKET SHIVAM
2023-09-26 14:24:26 -07:00
committed by GitHub
parent e0aaa3c3b3
commit 90d3b0fb18
428 changed files with 22253 additions and 21762 deletions

View File

@ -140,7 +140,11 @@ CUTE_HOST_DEVICE constexpr
auto
transform_apply(T&& t, F&& f, G&& g)
{
return detail::tapply(static_cast<T&&>(t), f, g, tuple_seq<T>{});
if constexpr (is_tuple<remove_cvref_t<T>>::value) {
return detail::tapply(static_cast<T&&>(t), f, g, tuple_seq<T>{});
} else {
return g(f(static_cast<T&&>(t)));
}
}
template <class T0, class T1, class F, class G>
@ -148,7 +152,11 @@ CUTE_HOST_DEVICE constexpr
auto
transform_apply(T0&& t0, T1&& t1, F&& f, G&& g)
{
return detail::tapply(static_cast<T0&&>(t0), static_cast<T1&&>(t1), f, g, tuple_seq<T0>{});
if constexpr (is_tuple<remove_cvref_t<T0>>::value) {
return detail::tapply(static_cast<T0&&>(t0), static_cast<T1&&>(t1), f, g, tuple_seq<T0>{});
} else {
return g(f(static_cast<T0&&>(t0), static_cast<T1&&>(t1)));
}
}
template <class T0, class T1, class T2, class F, class G>
@ -156,7 +164,11 @@ CUTE_HOST_DEVICE constexpr
auto
transform_apply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g)
{
return detail::tapply(static_cast<T0&&>(t0), static_cast<T1&&>(t1), static_cast<T2&&>(t2), f, g, tuple_seq<T0>{});
if constexpr (is_tuple<remove_cvref_t<T0>>::value) {
return detail::tapply(static_cast<T0&&>(t0), static_cast<T1&&>(t1), static_cast<T2&&>(t2), f, g, tuple_seq<T0>{});
} else {
return g(f(static_cast<T0&&>(t0), static_cast<T1&&>(t1), static_cast<T2&&>(t2)));
}
}
//
@ -306,21 +318,16 @@ transform_leaf(T0 const& t0, T1 const& t1, F&& f)
namespace detail {
template <class T, class F>
CUTE_HOST_DEVICE constexpr
auto
find_if(T const& t, F&& f, seq<>)
{
return cute::integral_constant<int, tuple_size<T>::value>{};
}
template <class T, class F, int I, int... Is>
CUTE_HOST_DEVICE constexpr
auto
find_if(T const& t, F&& f, seq<I,Is...>)
{
if constexpr (decltype(f(get<I>(t)))::value) {
return cute::integral_constant<int, I>{};
return cute::C<I>{};
} else
if constexpr (sizeof...(Is) == 0) {
return cute::C<I+1>{};
} else {
return find_if(t, f, seq<Is...>{});
}
@ -338,7 +345,7 @@ find_if(T const& t, F&& f)
if constexpr (is_tuple<T>::value) {
return detail::find_if(t, f, tuple_seq<T>{});
} else {
return cute::integral_constant<int, decltype(f(t))::value ? 0 : 1>{};
return cute::C<decltype(f(t))::value ? 0 : 1>{};
}
CUTE_GCC_UNREACHABLE;
@ -355,12 +362,12 @@ find(T const& t, X const& x)
template <class T, class F>
CUTE_HOST_DEVICE constexpr
auto
none_of(T const& t, F&& f)
any_of(T const& t, F&& f)
{
if constexpr (is_tuple<T>::value) {
return cute::integral_constant<bool, decltype(find_if(t, f))::value == tuple_size<T>::value>{};
return detail::apply(cute::transform(t, f), [&] (auto const&... a) { return (false_type{} || ... || a); }, tuple_seq<T>{});
} else {
return not f(t);
return f(t);
}
CUTE_GCC_UNREACHABLE;
@ -372,8 +379,7 @@ auto
all_of(T const& t, F&& f)
{
if constexpr (is_tuple<T>::value) {
auto not_f = [&](auto const& a) { return not f(a); };
return cute::integral_constant<bool, decltype(find_if(t, not_f))::value == tuple_size<T>::value>{};
return detail::apply(t, [&] (auto const&... a) { return (true_type{} && ... && f(a)); }, tuple_seq<T>{});
} else {
return f(t);
}
@ -384,9 +390,9 @@ all_of(T const& t, F&& f)
template <class T, class F>
CUTE_HOST_DEVICE constexpr
auto
any_of(T const& t, F&& f)
none_of(T const& t, F&& f)
{
return not none_of(t, f);
return not any_of(t, f);
}
//
@ -410,6 +416,14 @@ filter_tuple(T0 const& t0, T1 const& t1, F&& f)
return transform_apply(t0, t1, f, [](auto const&... a) { return cute::tuple_cat(a...); });
}
template <class T0, class T1, class T2, class F>
CUTE_HOST_DEVICE constexpr
auto
filter_tuple(T0 const& t0, T1 const& t1, T2 const& t2, F&& f)
{
return transform_apply(t0, t1, t2, f, [](auto const&... a) { return cute::tuple_cat(a...); });
}
//
// Fold (Reduce, Accumulate)
// (t, v, f) => f(...f(f(v,t_0),t_1),...,t_n)
@ -595,6 +609,13 @@ unwrap(T const& t)
//
// Flatten a hierarchical tuple to a tuple of depth one.
//
//
template <class T>
struct is_flat : true_type {};
template <class... Ts>
struct is_flat<tuple<Ts...>> : bool_constant<(true && ... && (not is_tuple<Ts>::value))> {};
template <class T>
CUTE_HOST_DEVICE constexpr
@ -602,7 +623,12 @@ auto
flatten_to_tuple(T const& t)
{
if constexpr (is_tuple<T>::value) {
return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); });
if constexpr (is_flat<T>::value) {
return t;
} else
{
return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); });
}
} else {
return cute::make_tuple(t);
}
@ -616,7 +642,12 @@ auto
flatten(T const& t)
{
if constexpr (is_tuple<T>::value) {
return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); });
if constexpr (is_flat<T>::value) {
return t;
} else
{
return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); });
}
} else {
return t;
}