CUTLASS 3.2.1 (#1113)
* Updates for 3.2.1 release. * Minor fix in gemm op profiler for raster order. * Add scheduler mapping for raster order in the kernels.
This commit is contained in:
@ -140,7 +140,11 @@ CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
transform_apply(T&& t, F&& f, G&& g)
|
||||
{
|
||||
return detail::tapply(static_cast<T&&>(t), f, g, tuple_seq<T>{});
|
||||
if constexpr (is_tuple<remove_cvref_t<T>>::value) {
|
||||
return detail::tapply(static_cast<T&&>(t), f, g, tuple_seq<T>{});
|
||||
} else {
|
||||
return g(f(static_cast<T&&>(t)));
|
||||
}
|
||||
}
|
||||
|
||||
template <class T0, class T1, class F, class G>
|
||||
@ -148,7 +152,11 @@ CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
transform_apply(T0&& t0, T1&& t1, F&& f, G&& g)
|
||||
{
|
||||
return detail::tapply(static_cast<T0&&>(t0), static_cast<T1&&>(t1), f, g, tuple_seq<T0>{});
|
||||
if constexpr (is_tuple<remove_cvref_t<T0>>::value) {
|
||||
return detail::tapply(static_cast<T0&&>(t0), static_cast<T1&&>(t1), f, g, tuple_seq<T0>{});
|
||||
} else {
|
||||
return g(f(static_cast<T0&&>(t0), static_cast<T1&&>(t1)));
|
||||
}
|
||||
}
|
||||
|
||||
template <class T0, class T1, class T2, class F, class G>
|
||||
@ -156,7 +164,11 @@ CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
transform_apply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g)
|
||||
{
|
||||
return detail::tapply(static_cast<T0&&>(t0), static_cast<T1&&>(t1), static_cast<T2&&>(t2), f, g, tuple_seq<T0>{});
|
||||
if constexpr (is_tuple<remove_cvref_t<T0>>::value) {
|
||||
return detail::tapply(static_cast<T0&&>(t0), static_cast<T1&&>(t1), static_cast<T2&&>(t2), f, g, tuple_seq<T0>{});
|
||||
} else {
|
||||
return g(f(static_cast<T0&&>(t0), static_cast<T1&&>(t1), static_cast<T2&&>(t2)));
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
@ -306,21 +318,16 @@ transform_leaf(T0 const& t0, T1 const& t1, F&& f)
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
find_if(T const& t, F&& f, seq<>)
|
||||
{
|
||||
return cute::integral_constant<int, tuple_size<T>::value>{};
|
||||
}
|
||||
|
||||
template <class T, class F, int I, int... Is>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
find_if(T const& t, F&& f, seq<I,Is...>)
|
||||
{
|
||||
if constexpr (decltype(f(get<I>(t)))::value) {
|
||||
return cute::integral_constant<int, I>{};
|
||||
return cute::C<I>{};
|
||||
} else
|
||||
if constexpr (sizeof...(Is) == 0) {
|
||||
return cute::C<I+1>{};
|
||||
} else {
|
||||
return find_if(t, f, seq<Is...>{});
|
||||
}
|
||||
@ -338,7 +345,7 @@ find_if(T const& t, F&& f)
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return detail::find_if(t, f, tuple_seq<T>{});
|
||||
} else {
|
||||
return cute::integral_constant<int, decltype(f(t))::value ? 0 : 1>{};
|
||||
return cute::C<decltype(f(t))::value ? 0 : 1>{};
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
@ -355,12 +362,12 @@ find(T const& t, X const& x)
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
none_of(T const& t, F&& f)
|
||||
any_of(T const& t, F&& f)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return cute::integral_constant<bool, decltype(find_if(t, f))::value == tuple_size<T>::value>{};
|
||||
return detail::apply(cute::transform(t, f), [&] (auto const&... a) { return (false_type{} || ... || a); }, tuple_seq<T>{});
|
||||
} else {
|
||||
return not f(t);
|
||||
return f(t);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
@ -372,8 +379,7 @@ auto
|
||||
all_of(T const& t, F&& f)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
auto not_f = [&](auto const& a) { return not f(a); };
|
||||
return cute::integral_constant<bool, decltype(find_if(t, not_f))::value == tuple_size<T>::value>{};
|
||||
return detail::apply(t, [&] (auto const&... a) { return (true_type{} && ... && f(a)); }, tuple_seq<T>{});
|
||||
} else {
|
||||
return f(t);
|
||||
}
|
||||
@ -384,9 +390,9 @@ all_of(T const& t, F&& f)
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
any_of(T const& t, F&& f)
|
||||
none_of(T const& t, F&& f)
|
||||
{
|
||||
return not none_of(t, f);
|
||||
return not any_of(t, f);
|
||||
}
|
||||
|
||||
//
|
||||
@ -410,6 +416,14 @@ filter_tuple(T0 const& t0, T1 const& t1, F&& f)
|
||||
return transform_apply(t0, t1, f, [](auto const&... a) { return cute::tuple_cat(a...); });
|
||||
}
|
||||
|
||||
template <class T0, class T1, class T2, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
filter_tuple(T0 const& t0, T1 const& t1, T2 const& t2, F&& f)
|
||||
{
|
||||
return transform_apply(t0, t1, t2, f, [](auto const&... a) { return cute::tuple_cat(a...); });
|
||||
}
|
||||
|
||||
//
|
||||
// Fold (Reduce, Accumulate)
|
||||
// (t, v, f) => f(...f(f(v,t_0),t_1),...,t_n)
|
||||
@ -595,6 +609,13 @@ unwrap(T const& t)
|
||||
//
|
||||
// Flatten a hierarchical tuple to a tuple of depth one.
|
||||
//
|
||||
//
|
||||
|
||||
template <class T>
|
||||
struct is_flat : true_type {};
|
||||
|
||||
template <class... Ts>
|
||||
struct is_flat<tuple<Ts...>> : bool_constant<(true && ... && (not is_tuple<Ts>::value))> {};
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
@ -602,7 +623,12 @@ auto
|
||||
flatten_to_tuple(T const& t)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); });
|
||||
if constexpr (is_flat<T>::value) {
|
||||
return t;
|
||||
} else
|
||||
{
|
||||
return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); });
|
||||
}
|
||||
} else {
|
||||
return cute::make_tuple(t);
|
||||
}
|
||||
@ -616,7 +642,12 @@ auto
|
||||
flatten(T const& t)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); });
|
||||
if constexpr (is_flat<T>::value) {
|
||||
return t;
|
||||
} else
|
||||
{
|
||||
return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); });
|
||||
}
|
||||
} else {
|
||||
return t;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user