CUTLASS 3.2 (#1024)

* CUTLASS 3.2
This commit is contained in:
ANIKET SHIVAM
2023-08-07 14:50:32 -10:00
committed by GitHub
parent a0d787b746
commit 4575443d44
392 changed files with 47559 additions and 7940 deletions

View File

@ -112,12 +112,47 @@ transform(Tensor<EngineIn,LayoutIn>& tensor_in, Tensor<EngineOut,LayoutOut>& ten
}
// Accept mutable temporaries
template <class EngineIn, class LayoutIn, class EngineOut, class LayoutOut, class UnaryOp>
template <class EngineIn, class LayoutIn,
class EngineOut, class LayoutOut, class UnaryOp>
CUTE_HOST_DEVICE constexpr
void
transform(Tensor<EngineIn,LayoutIn>&& tensor_in, Tensor<EngineOut,LayoutOut>&& tensor_out, UnaryOp&& op)
{
return transform(tensor_in, tensor_out, std::forward<UnaryOp>(op));
return transform(tensor_in, tensor_out, op);
}
// Similar to std::transform with a binary operation
// Takes two tensors as input and one tensor as output.
// Applies the binary_op to tensor_in1 and and tensor_in2 and
// assigns it to tensor_out
template <class EngineIn1, class LayoutIn1,
class EngineIn2, class LayoutIn2,
class EngineOut, class LayoutOut, class BinaryOp>
CUTE_HOST_DEVICE constexpr
void
transform(Tensor<EngineIn1,LayoutIn1>& tensor_in1,
Tensor<EngineIn2,LayoutIn2>& tensor_in2,
Tensor<EngineOut,LayoutOut>& tensor_out,
BinaryOp&& op)
{
CUTE_UNROLL
for (int i = 0; i < size(tensor_in1); ++i) {
tensor_out(i) = static_cast<BinaryOp&&>(op)(tensor_in1(i), tensor_in2(i));
}
}
// Accept mutable temporaries
template <class EngineIn1, class LayoutIn1,
class EngineIn2, class LayoutIn2,
class EngineOut, class LayoutOut, class BinaryOp>
CUTE_HOST_DEVICE constexpr
void
transform(Tensor<EngineIn1,LayoutIn1>&& tensor_in1,
Tensor<EngineIn2,LayoutIn2>&& tensor_in2,
Tensor<EngineOut,LayoutOut>&& tensor_out,
BinaryOp&& op)
{
return transform(tensor_in1, tensor_in2, tensor_out, op);
}
} // end namespace cute

View File

@ -38,11 +38,38 @@
#include <cute/numeric/integer_sequence.hpp>
#include <cute/numeric/integral_constant.hpp>
/** Common algorithms on (hierarchical) tuples */
/** Style choice:
* Forward params [using static_cast<T&&>(.)] for const/non-const/ref/non-ref args
* but don't bother forwarding functions as ref-qualified member fns are extremely rare
*/
/// @file tuple_algorithms.hpp
/// @brief Common algorithms on (hierarchical) tuples
///
/// Code guidelines and style preferences:
///
/// For perfect forwarding, don't use std::forward, because it may not
/// be defined in device code when compiling with NVRTC. Instead, use
/// `static_cast<ParameterType&&>(parameter_name)`.
///
/// CuTe generally does not bother forwarding functions, as
/// reference-qualified member functions are rare in this code base.
///
/// Throughout CUTLASS, cute::make_tuple always needs to be called
/// namespace-qualified, EVEN If inside the cute namespace and/or in
/// scope of a "using namespace cute" declaration. Otherwise, the
/// compiler may select std::make_tuple instead of cute::make_tuple,
/// due to argument-dependent lookup. Two problems may result from
/// that.
///
/// 1. Functions have an unexpected return type (std::tuple instead of
/// cute::tuple), so functions that take cute::tuple parameters
/// fail to compile (generally inside functions that have template
/// parameters expected to be cute::tuple).
///
/// 2. std::tuple does not have the required __host__ __device__
/// markings, so the CUDA compiler complains if you use it in
/// device code.
///
/// cute::make_tuple will occur more often than std::make_tuple would
/// in modern C++ code, because cute::tuple's design deprioritizes
/// correct operation of CTAD (constructor template argument
/// deduction) in favor of implementation simplicity.
namespace cute
{
@ -142,7 +169,13 @@ CUTE_HOST_DEVICE constexpr
void
for_each(T&& t, F&& f)
{
detail::apply(t, [&](auto&&... a) { (f(static_cast<decltype(a)&&>(a)), ...); }, tuple_seq<T>{});
if constexpr (is_tuple<remove_cvref_t<T>>::value) {
return detail::apply(t, [&](auto&&... a) { (f(static_cast<decltype(a)&&>(a)), ...); }, tuple_seq<T>{});
} else {
return f(static_cast<T&&>(t));
}
CUTE_GCC_UNREACHABLE;
}
template <class T, class F>
@ -159,6 +192,36 @@ for_each_leaf(T&& t, F&& f)
CUTE_GCC_UNREACHABLE;
}
//
// For Sequence
// (s, t, f) => (f(t[s_0]),f(t[s_1]),...,f(t[s_n]))
//
namespace detail {
template <int... I, class F>
CUTE_HOST_DEVICE constexpr
void
for_sequence(seq<I...> const&, F&& f) {
(f(Int<I>{}), ...);
}
}; // end namespace detail
template <int... I, class T, class F>
CUTE_HOST_DEVICE constexpr
void
for_sequence(seq<I...> const& s, T&& t, F&& f) {
detail::for_sequence(s, [&](auto&& i){ f(get<remove_cvref_t<decltype(i)>::value>(static_cast<T&&>(t))); });
}
template <int I, class T, class F>
CUTE_HOST_DEVICE constexpr
void
for_sequence(T&& t, F&& f) {
for_sequence(make_seq<I>{}, static_cast<T&&>(t), static_cast<F&&>(f));
}
//
// Transform
// (t, f) => (f(t_0),f(t_1),...,f(t_n))
@ -169,7 +232,13 @@ CUTE_HOST_DEVICE constexpr
auto
transform(T const& t, F&& f)
{
return detail::tapply(t, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq<T>{});
if constexpr (is_tuple<T>::value) {
return detail::tapply(t, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq<T>{});
} else {
return f(t);
}
CUTE_GCC_UNREACHABLE;
}
template <class T0, class T1, class F>
@ -177,8 +246,14 @@ CUTE_HOST_DEVICE constexpr
auto
transform(T0 const& t0, T1 const& t1, F&& f)
{
static_assert(tuple_size<T0>::value == tuple_size<T1>::value, "Mismatched tuple_size");
return detail::tapply(t0, t1, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq<T0>{});
if constexpr (is_tuple<T0>::value) {
static_assert(tuple_size<T0>::value == tuple_size<T1>::value, "Mismatched tuple_size");
return detail::tapply(t0, t1, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq<T0>{});
} else {
return f(t0, t1);
}
CUTE_GCC_UNREACHABLE;
}
template <class T0, class T1, class T2, class F>
@ -186,9 +261,15 @@ CUTE_HOST_DEVICE constexpr
auto
transform(T0 const& t0, T1 const& t1, T2 const& t2, F&& f)
{
static_assert(tuple_size<T0>::value == tuple_size<T1>::value, "Mismatched tuple_size");
static_assert(tuple_size<T0>::value == tuple_size<T2>::value, "Mismatched tuple_size");
return detail::tapply(t0, t1, t2, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq<T0>{});
if constexpr (is_tuple<T0>::value) {
static_assert(tuple_size<T0>::value == tuple_size<T1>::value, "Mismatched tuple_size");
static_assert(tuple_size<T0>::value == tuple_size<T2>::value, "Mismatched tuple_size");
return detail::tapply(t0, t1, t2, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq<T0>{});
} else {
return f(t0, t1, t2);
}
CUTE_GCC_UNREACHABLE;
}
template <class T, class F>
@ -399,7 +480,7 @@ fold_first(T&& t, F&& f)
}
//
// front, back, take, unwrap
// front, back, take, select, unwrap
//
// Get the first non-tuple element in a hierarchical tuple
@ -425,7 +506,16 @@ back(T&& t)
{
if constexpr (is_tuple<remove_cvref_t<T>>::value) {
constexpr int N = tuple_size<remove_cvref_t<T>>::value;
return back(get<N-1>(static_cast<T&&>(t)));
// MSVC needs a bit of extra help here deducing return types.
// We help it by peeling off the nonrecursive case a level "early."
if constexpr (! is_tuple<remove_cvref_t<decltype(get<N - 1>(static_cast<T&&>(t)))>>::value) {
return get<N - 1>(static_cast<T&&>(t));
}
else {
return back(get<N - 1>(static_cast<T&&>(t)));
}
} else {
return static_cast<T&&>(t);
}
@ -442,6 +532,47 @@ take(T const& t)
return detail::apply(t, [](auto const&... a) { return cute::make_tuple(a...); }, make_range<B,E>{});
}
//
// Select tuple elements with given indices.
//
template <int... I, class T>
CUTE_HOST_DEVICE constexpr
auto
select(T const & t)
{
return cute::make_tuple(get<I>(t)...);
}
template <class T, typename Indices>
CUTE_HOST_DEVICE constexpr
auto
select(T const & t, Indices const & indices)
{
if constexpr (is_tuple<Indices>::value) {
return cute::transform(indices, [&t](auto i) { return select(t, i); });
}
else {
static_assert(is_static<Indices>::value, "Order must be static");
return get<Indices::value>(t);
}
}
// Wrap non-tuples into rank-1 tuples or forward
template <class T>
CUTE_HOST_DEVICE constexpr
auto
wrap(T const& t)
{
if constexpr (is_tuple<T>::value) {
return t;
} else {
return cute::make_tuple(t);
}
CUTE_GCC_UNREACHABLE;
}
// Unwrap rank-1 tuples until we're left with a rank>1 tuple or a non-tuple
template <class T>
CUTE_HOST_DEVICE constexpr
@ -576,7 +707,11 @@ CUTE_HOST_DEVICE constexpr
auto
repeat(X const& x)
{
return detail::construct(0, x, seq<>{}, make_seq<N>{}, seq<>{});
if constexpr (N == 1) {
return x;
} else {
return detail::construct(0, x, seq<>{}, make_seq<N>{}, seq<>{});
}
}
//
@ -605,7 +740,23 @@ CUTE_HOST_DEVICE constexpr
auto
group(T const& t)
{
return detail::construct(t, take<B,E>(t), make_seq<B>{}, seq<0>{}, make_range<E,tuple_size<T>::value>{});
if constexpr (not is_tuple<T>::value) {
if constexpr (E == -1) {
return group<B,1>(t);
} else {
return detail::construct(t, take<B,E>(t), make_seq<B>{}, make_seq<(B < E)>{}, make_range<E,1>{});
}
} else
if constexpr (E == -1) {
return group<B,tuple_size<T>::value>(t);
} else
if constexpr (B <= E) {
return detail::construct(t, take<B,E>(t), make_seq<B>{}, make_seq<(B < E)>{}, make_range<E,tuple_size<T>::value>{});
} else {
static_assert(B <= E);
}
CUTE_GCC_UNREACHABLE;
}
//
@ -685,6 +836,48 @@ prepend(T const& a, X const& x)
CUTE_GCC_UNREACHABLE;
}
//
// Unflatten a flat tuple into a hierarchical one
// unflatten(x, flatten(x)) == x
//
namespace detail {
template<class FlatTuple, class TargetProfile>
CUTE_HOST_DEVICE constexpr
auto
unflatten_impl(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
{
if constexpr (is_tuple<TargetProfile>::value) {
return fold(target_profile, cute::make_tuple(cute::make_tuple(), flat_tuple), [](auto const& v, auto const& t) {
auto [result, remaining_tuple] = v;
auto [sub_result, sub_tuple] = unflatten_impl(remaining_tuple, t);
return cute::make_tuple(append(result, sub_result), sub_tuple);
});
} else {
return cute::make_tuple(get<0>(flat_tuple), take<1, decltype(rank(flat_tuple))::value>(flat_tuple));
}
CUTE_GCC_UNREACHABLE;
}
} // end namespace detail
// @pre flatten(@a flat_tuple) == @a flat_tuple
// @pre rank(flatten(@a target_profile)) == rank(@a flat_tuple)
// @post congruent(@a result, @a target_profile)
// @post flatten(@a result) == @a flat_tuple
template<class FlatTuple, class TargetProfile>
CUTE_HOST_DEVICE constexpr
auto
unflatten(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
{
auto [unflatten_tuple, flat_remainder] = detail::unflatten_impl(flat_tuple, target_profile);
CUTE_STATIC_ASSERT_V(rank(flat_remainder) == Int<0>{});
return unflatten_tuple;
}
//
// Inclusive scan (prefix sum)
//
@ -872,4 +1065,18 @@ zip2_by(T const& t, TG const& guide)
CUTE_GCC_UNREACHABLE;
}
/// @return A tuple of the elements of @c t in reverse order.
template <class T>
CUTE_HOST_DEVICE constexpr auto
reverse(T const& t) {
if constexpr (is_tuple<T>::value) {
return detail::apply(t, [] (auto const&... a) {
return cute::make_tuple(a...);
}, tuple_rseq<T>{});
}
else {
return t;
}
}
} // end namespace cute