@ -112,12 +112,47 @@ transform(Tensor<EngineIn,LayoutIn>& tensor_in, Tensor<EngineOut,LayoutOut>& ten
|
||||
}
|
||||
|
||||
// Accept mutable temporaries
|
||||
template <class EngineIn, class LayoutIn, class EngineOut, class LayoutOut, class UnaryOp>
|
||||
template <class EngineIn, class LayoutIn,
|
||||
class EngineOut, class LayoutOut, class UnaryOp>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
transform(Tensor<EngineIn,LayoutIn>&& tensor_in, Tensor<EngineOut,LayoutOut>&& tensor_out, UnaryOp&& op)
|
||||
{
|
||||
return transform(tensor_in, tensor_out, std::forward<UnaryOp>(op));
|
||||
return transform(tensor_in, tensor_out, op);
|
||||
}
|
||||
|
||||
// Similar to std::transform with a binary operation
|
||||
// Takes two tensors as input and one tensor as output.
|
||||
// Applies the binary_op to tensor_in1 and and tensor_in2 and
|
||||
// assigns it to tensor_out
|
||||
template <class EngineIn1, class LayoutIn1,
|
||||
class EngineIn2, class LayoutIn2,
|
||||
class EngineOut, class LayoutOut, class BinaryOp>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
transform(Tensor<EngineIn1,LayoutIn1>& tensor_in1,
|
||||
Tensor<EngineIn2,LayoutIn2>& tensor_in2,
|
||||
Tensor<EngineOut,LayoutOut>& tensor_out,
|
||||
BinaryOp&& op)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(tensor_in1); ++i) {
|
||||
tensor_out(i) = static_cast<BinaryOp&&>(op)(tensor_in1(i), tensor_in2(i));
|
||||
}
|
||||
}
|
||||
|
||||
// Accept mutable temporaries
|
||||
template <class EngineIn1, class LayoutIn1,
|
||||
class EngineIn2, class LayoutIn2,
|
||||
class EngineOut, class LayoutOut, class BinaryOp>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
transform(Tensor<EngineIn1,LayoutIn1>&& tensor_in1,
|
||||
Tensor<EngineIn2,LayoutIn2>&& tensor_in2,
|
||||
Tensor<EngineOut,LayoutOut>&& tensor_out,
|
||||
BinaryOp&& op)
|
||||
{
|
||||
return transform(tensor_in1, tensor_in2, tensor_out, op);
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
@ -38,11 +38,38 @@
|
||||
#include <cute/numeric/integer_sequence.hpp>
|
||||
#include <cute/numeric/integral_constant.hpp>
|
||||
|
||||
/** Common algorithms on (hierarchical) tuples */
|
||||
/** Style choice:
|
||||
* Forward params [using static_cast<T&&>(.)] for const/non-const/ref/non-ref args
|
||||
* but don't bother forwarding functions as ref-qualified member fns are extremely rare
|
||||
*/
|
||||
/// @file tuple_algorithms.hpp
|
||||
/// @brief Common algorithms on (hierarchical) tuples
|
||||
///
|
||||
/// Code guidelines and style preferences:
|
||||
///
|
||||
/// For perfect forwarding, don't use std::forward, because it may not
|
||||
/// be defined in device code when compiling with NVRTC. Instead, use
|
||||
/// `static_cast<ParameterType&&>(parameter_name)`.
|
||||
///
|
||||
/// CuTe generally does not bother forwarding functions, as
|
||||
/// reference-qualified member functions are rare in this code base.
|
||||
///
|
||||
/// Throughout CUTLASS, cute::make_tuple always needs to be called
|
||||
/// namespace-qualified, EVEN If inside the cute namespace and/or in
|
||||
/// scope of a "using namespace cute" declaration. Otherwise, the
|
||||
/// compiler may select std::make_tuple instead of cute::make_tuple,
|
||||
/// due to argument-dependent lookup. Two problems may result from
|
||||
/// that.
|
||||
///
|
||||
/// 1. Functions have an unexpected return type (std::tuple instead of
|
||||
/// cute::tuple), so functions that take cute::tuple parameters
|
||||
/// fail to compile (generally inside functions that have template
|
||||
/// parameters expected to be cute::tuple).
|
||||
///
|
||||
/// 2. std::tuple does not have the required __host__ __device__
|
||||
/// markings, so the CUDA compiler complains if you use it in
|
||||
/// device code.
|
||||
///
|
||||
/// cute::make_tuple will occur more often than std::make_tuple would
|
||||
/// in modern C++ code, because cute::tuple's design deprioritizes
|
||||
/// correct operation of CTAD (constructor template argument
|
||||
/// deduction) in favor of implementation simplicity.
|
||||
|
||||
namespace cute
|
||||
{
|
||||
@ -142,7 +169,13 @@ CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
for_each(T&& t, F&& f)
|
||||
{
|
||||
detail::apply(t, [&](auto&&... a) { (f(static_cast<decltype(a)&&>(a)), ...); }, tuple_seq<T>{});
|
||||
if constexpr (is_tuple<remove_cvref_t<T>>::value) {
|
||||
return detail::apply(t, [&](auto&&... a) { (f(static_cast<decltype(a)&&>(a)), ...); }, tuple_seq<T>{});
|
||||
} else {
|
||||
return f(static_cast<T&&>(t));
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class T, class F>
|
||||
@ -159,6 +192,36 @@ for_each_leaf(T&& t, F&& f)
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// For Sequence
|
||||
// (s, t, f) => (f(t[s_0]),f(t[s_1]),...,f(t[s_n]))
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <int... I, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
for_sequence(seq<I...> const&, F&& f) {
|
||||
(f(Int<I>{}), ...);
|
||||
}
|
||||
|
||||
}; // end namespace detail
|
||||
|
||||
template <int... I, class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
for_sequence(seq<I...> const& s, T&& t, F&& f) {
|
||||
detail::for_sequence(s, [&](auto&& i){ f(get<remove_cvref_t<decltype(i)>::value>(static_cast<T&&>(t))); });
|
||||
}
|
||||
|
||||
template <int I, class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
for_sequence(T&& t, F&& f) {
|
||||
for_sequence(make_seq<I>{}, static_cast<T&&>(t), static_cast<F&&>(f));
|
||||
}
|
||||
|
||||
//
|
||||
// Transform
|
||||
// (t, f) => (f(t_0),f(t_1),...,f(t_n))
|
||||
@ -169,7 +232,13 @@ CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
transform(T const& t, F&& f)
|
||||
{
|
||||
return detail::tapply(t, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq<T>{});
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return detail::tapply(t, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq<T>{});
|
||||
} else {
|
||||
return f(t);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class T0, class T1, class F>
|
||||
@ -177,8 +246,14 @@ CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
transform(T0 const& t0, T1 const& t1, F&& f)
|
||||
{
|
||||
static_assert(tuple_size<T0>::value == tuple_size<T1>::value, "Mismatched tuple_size");
|
||||
return detail::tapply(t0, t1, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq<T0>{});
|
||||
if constexpr (is_tuple<T0>::value) {
|
||||
static_assert(tuple_size<T0>::value == tuple_size<T1>::value, "Mismatched tuple_size");
|
||||
return detail::tapply(t0, t1, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq<T0>{});
|
||||
} else {
|
||||
return f(t0, t1);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class T0, class T1, class T2, class F>
|
||||
@ -186,9 +261,15 @@ CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
transform(T0 const& t0, T1 const& t1, T2 const& t2, F&& f)
|
||||
{
|
||||
static_assert(tuple_size<T0>::value == tuple_size<T1>::value, "Mismatched tuple_size");
|
||||
static_assert(tuple_size<T0>::value == tuple_size<T2>::value, "Mismatched tuple_size");
|
||||
return detail::tapply(t0, t1, t2, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq<T0>{});
|
||||
if constexpr (is_tuple<T0>::value) {
|
||||
static_assert(tuple_size<T0>::value == tuple_size<T1>::value, "Mismatched tuple_size");
|
||||
static_assert(tuple_size<T0>::value == tuple_size<T2>::value, "Mismatched tuple_size");
|
||||
return detail::tapply(t0, t1, t2, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq<T0>{});
|
||||
} else {
|
||||
return f(t0, t1, t2);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class T, class F>
|
||||
@ -399,7 +480,7 @@ fold_first(T&& t, F&& f)
|
||||
}
|
||||
|
||||
//
|
||||
// front, back, take, unwrap
|
||||
// front, back, take, select, unwrap
|
||||
//
|
||||
|
||||
// Get the first non-tuple element in a hierarchical tuple
|
||||
@ -425,7 +506,16 @@ back(T&& t)
|
||||
{
|
||||
if constexpr (is_tuple<remove_cvref_t<T>>::value) {
|
||||
constexpr int N = tuple_size<remove_cvref_t<T>>::value;
|
||||
return back(get<N-1>(static_cast<T&&>(t)));
|
||||
|
||||
// MSVC needs a bit of extra help here deducing return types.
|
||||
// We help it by peeling off the nonrecursive case a level "early."
|
||||
if constexpr (! is_tuple<remove_cvref_t<decltype(get<N - 1>(static_cast<T&&>(t)))>>::value) {
|
||||
return get<N - 1>(static_cast<T&&>(t));
|
||||
}
|
||||
else {
|
||||
return back(get<N - 1>(static_cast<T&&>(t)));
|
||||
}
|
||||
|
||||
} else {
|
||||
return static_cast<T&&>(t);
|
||||
}
|
||||
@ -442,6 +532,47 @@ take(T const& t)
|
||||
return detail::apply(t, [](auto const&... a) { return cute::make_tuple(a...); }, make_range<B,E>{});
|
||||
}
|
||||
|
||||
//
|
||||
// Select tuple elements with given indices.
|
||||
//
|
||||
|
||||
template <int... I, class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
select(T const & t)
|
||||
{
|
||||
return cute::make_tuple(get<I>(t)...);
|
||||
}
|
||||
|
||||
template <class T, typename Indices>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
select(T const & t, Indices const & indices)
|
||||
{
|
||||
if constexpr (is_tuple<Indices>::value) {
|
||||
return cute::transform(indices, [&t](auto i) { return select(t, i); });
|
||||
}
|
||||
else {
|
||||
static_assert(is_static<Indices>::value, "Order must be static");
|
||||
return get<Indices::value>(t);
|
||||
}
|
||||
}
|
||||
|
||||
// Wrap non-tuples into rank-1 tuples or forward
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
wrap(T const& t)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return t;
|
||||
} else {
|
||||
return cute::make_tuple(t);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
// Unwrap rank-1 tuples until we're left with a rank>1 tuple or a non-tuple
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
@ -576,7 +707,11 @@ CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
repeat(X const& x)
|
||||
{
|
||||
return detail::construct(0, x, seq<>{}, make_seq<N>{}, seq<>{});
|
||||
if constexpr (N == 1) {
|
||||
return x;
|
||||
} else {
|
||||
return detail::construct(0, x, seq<>{}, make_seq<N>{}, seq<>{});
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
@ -605,7 +740,23 @@ CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
group(T const& t)
|
||||
{
|
||||
return detail::construct(t, take<B,E>(t), make_seq<B>{}, seq<0>{}, make_range<E,tuple_size<T>::value>{});
|
||||
if constexpr (not is_tuple<T>::value) {
|
||||
if constexpr (E == -1) {
|
||||
return group<B,1>(t);
|
||||
} else {
|
||||
return detail::construct(t, take<B,E>(t), make_seq<B>{}, make_seq<(B < E)>{}, make_range<E,1>{});
|
||||
}
|
||||
} else
|
||||
if constexpr (E == -1) {
|
||||
return group<B,tuple_size<T>::value>(t);
|
||||
} else
|
||||
if constexpr (B <= E) {
|
||||
return detail::construct(t, take<B,E>(t), make_seq<B>{}, make_seq<(B < E)>{}, make_range<E,tuple_size<T>::value>{});
|
||||
} else {
|
||||
static_assert(B <= E);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
@ -685,6 +836,48 @@ prepend(T const& a, X const& x)
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// Unflatten a flat tuple into a hierarchical one
|
||||
// unflatten(x, flatten(x)) == x
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
template<class FlatTuple, class TargetProfile>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
unflatten_impl(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
|
||||
{
|
||||
if constexpr (is_tuple<TargetProfile>::value) {
|
||||
return fold(target_profile, cute::make_tuple(cute::make_tuple(), flat_tuple), [](auto const& v, auto const& t) {
|
||||
auto [result, remaining_tuple] = v;
|
||||
auto [sub_result, sub_tuple] = unflatten_impl(remaining_tuple, t);
|
||||
return cute::make_tuple(append(result, sub_result), sub_tuple);
|
||||
});
|
||||
} else {
|
||||
return cute::make_tuple(get<0>(flat_tuple), take<1, decltype(rank(flat_tuple))::value>(flat_tuple));
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
// @pre flatten(@a flat_tuple) == @a flat_tuple
|
||||
// @pre rank(flatten(@a target_profile)) == rank(@a flat_tuple)
|
||||
// @post congruent(@a result, @a target_profile)
|
||||
// @post flatten(@a result) == @a flat_tuple
|
||||
template<class FlatTuple, class TargetProfile>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
unflatten(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
|
||||
{
|
||||
auto [unflatten_tuple, flat_remainder] = detail::unflatten_impl(flat_tuple, target_profile);
|
||||
CUTE_STATIC_ASSERT_V(rank(flat_remainder) == Int<0>{});
|
||||
return unflatten_tuple;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Inclusive scan (prefix sum)
|
||||
//
|
||||
@ -872,4 +1065,18 @@ zip2_by(T const& t, TG const& guide)
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
/// @return A tuple of the elements of @c t in reverse order.
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr auto
|
||||
reverse(T const& t) {
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return detail::apply(t, [] (auto const&... a) {
|
||||
return cute::make_tuple(a...);
|
||||
}, tuple_rseq<T>{});
|
||||
}
|
||||
else {
|
||||
return t;
|
||||
}
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
Reference in New Issue
Block a user