@ -77,6 +77,20 @@ as_arithmetic_tuple(tuple<T...> const& t) {
|
||||
return ArithmeticTuple<T...>(t);
|
||||
}
|
||||
|
||||
template <class T, __CUTE_REQUIRES(is_integral<T>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T const&
|
||||
as_arithmetic_tuple(T const& t) {
|
||||
return t;
|
||||
}
|
||||
|
||||
template <class... T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
as_arithmetic_tuple(ArithmeticTuple<T...> const& t) {
|
||||
return t;
|
||||
}
|
||||
|
||||
//
|
||||
// Numeric operators
|
||||
//
|
||||
@ -110,18 +124,26 @@ operator+(tuple<T...> const& t, ArithmeticTuple<U...> const& u) {
|
||||
// Special cases
|
||||
//
|
||||
|
||||
template <class T, class... U>
|
||||
template <auto t, class... U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator+(constant<T,0>, ArithmeticTuple<U...> const& u) {
|
||||
return u;
|
||||
operator+(C<t>, ArithmeticTuple<U...> const& u) {
|
||||
if constexpr (t == 0) {
|
||||
return u;
|
||||
} else {
|
||||
static_assert(t == 0, "Artihmetic tuple op+ error!");
|
||||
}
|
||||
}
|
||||
|
||||
template <class... T, class U>
|
||||
template <class... T, auto u>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator+(ArithmeticTuple<T...> const& t, constant<U,0>) {
|
||||
return t;
|
||||
operator+(ArithmeticTuple<T...> const& t, C<u>) {
|
||||
if constexpr (u == 0) {
|
||||
return t;
|
||||
} else {
|
||||
static_assert(u == 0, "Artihmetic tuple op+ error!");
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
@ -159,11 +181,9 @@ CUTE_HOST_DEVICE void print(ArithmeticTupleIterator<ArithTuple> const& iter) {
|
||||
|
||||
//
|
||||
// ArithmeticTuple "basis" elements
|
||||
//
|
||||
|
||||
// Abstract value:
|
||||
// A ScaledBasis<T,N> is a (at least) rank-N0 ArithmeticTuple:
|
||||
// A ScaledBasis<T,N> is a (at least) rank-N+1 ArithmeticTuple:
|
||||
// (_0,_0,...,T,_0,...)
|
||||
// with value T in the Nth mode
|
||||
|
||||
template <class T, int N>
|
||||
struct ScaledBasis : private tuple<T>
|
||||
@ -188,16 +208,30 @@ struct is_scaled_basis<ScaledBasis<T,N>> : true_type {};
|
||||
template <class T, int N>
|
||||
struct is_integral<ScaledBasis<T,N>> : true_type {};
|
||||
|
||||
template <class T>
|
||||
// Get the scalar T out of a ScaledBasis
|
||||
template <class SB>
|
||||
CUTE_HOST_DEVICE constexpr auto
|
||||
basis_value(T const& e) {
|
||||
return e;
|
||||
basis_value(SB const& e)
|
||||
{
|
||||
if constexpr (is_scaled_basis<SB>::value) {
|
||||
return basis_value(e.value());
|
||||
} else {
|
||||
return e;
|
||||
}
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class T, int N>
|
||||
// Apply the N... pack to another Tuple
|
||||
template <class SB, class Tuple>
|
||||
CUTE_HOST_DEVICE constexpr auto
|
||||
basis_value(ScaledBasis<T,N> const& e) {
|
||||
return basis_value(e.value());
|
||||
basis_get(SB const& e, Tuple const& t)
|
||||
{
|
||||
if constexpr (is_scaled_basis<SB>::value) {
|
||||
return basis_get(e.value(), get<SB::mode()>(t));
|
||||
} else {
|
||||
return t;
|
||||
}
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
@ -217,6 +251,14 @@ struct Basis<N,Ns...> {
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
// Shortcut for writing ScaledBasis<ScaledBasis<ScaledBasis<Int<1>, N0>, N1>, ...>
|
||||
// E<> := _1
|
||||
// E<0> := (_1,_0,_0,...)
|
||||
// E<1> := (_0,_1,_0,...)
|
||||
// E<0,0> := ((_1,_0,_0,...),_0,_0,...)
|
||||
// E<0,1> := ((_0,_1,_0,...),_0,_0,...)
|
||||
// E<1,0> := (_0,(_1,_0,_0,...),_0,...)
|
||||
// E<1,1> := (_0,(_0,_1,_0,...),_0,...)
|
||||
template <int... N>
|
||||
using E = typename detail::Basis<N...>::type;
|
||||
|
||||
@ -248,6 +290,15 @@ as_arithmetic_tuple(ScaledBasis<T,N> const& t) {
|
||||
return detail::as_arithmetic_tuple(t.value(), make_seq<N>{}, make_seq<M-N-1>{});
|
||||
}
|
||||
|
||||
// Turn a ScaledBases<T,N> into a rank-N ArithmeticTuple
|
||||
// with N prefix 0s: (_0,_0,...N...,_0,T)
|
||||
template <class T, int N>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
as_arithmetic_tuple(ScaledBasis<T,N> const& t) {
|
||||
return as_arithmetic_tuple<N+1>(t);
|
||||
}
|
||||
|
||||
// Turn an ArithmeticTuple into a rank-M ArithmeticTuple
|
||||
// with postfix 0s: (t0,t1,t2,...,_0,...,_0,_0)
|
||||
template <int M, class... T>
|
||||
@ -258,7 +309,24 @@ as_arithmetic_tuple(ArithmeticTuple<T...> const& t) {
|
||||
return detail::as_arithmetic_tuple(t, make_seq<int(sizeof...(T))>{}, make_seq<M-int(sizeof...(T))>{});
|
||||
}
|
||||
|
||||
// Return...
|
||||
template <class T, int M, class U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
safe_div(ScaledBasis<T,M> const& b, U const& u)
|
||||
{
|
||||
auto t = safe_div(b.value(), u);
|
||||
return ScaledBasis<decltype(t),M>{t};
|
||||
}
|
||||
|
||||
template <class T, int M, class U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
shape_div(ScaledBasis<T,M> const& b, U const& u)
|
||||
{
|
||||
auto t = shape_div(b.value(), u);
|
||||
return ScaledBasis<decltype(t),M>{t};
|
||||
}
|
||||
|
||||
template <class Shape>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
@ -266,12 +334,21 @@ make_basis_like(Shape const& shape)
|
||||
{
|
||||
if constexpr (is_integral<Shape>::value) {
|
||||
return Int<1>{};
|
||||
} else {
|
||||
}
|
||||
else {
|
||||
// Generate bases for each rank of shape
|
||||
return transform(tuple_seq<Shape>{}, [&](auto I) {
|
||||
// Generate bases for each rank of shape_i and add an i on front
|
||||
constexpr int i = decltype(I)::value; // NOTE: nvcc workaround
|
||||
return transform_leaf(make_basis_like(get<i>(shape)), [&](auto e) { return ScaledBasis<decltype(e),i>{}; });
|
||||
return transform(tuple_seq<Shape>{}, shape, [](auto I, auto si) {
|
||||
// Generate bases for each rank of si and add an i on front
|
||||
using I_type = decltype(I);
|
||||
return transform_leaf(make_basis_like(si), [](auto e) {
|
||||
// MSVC has trouble capturing variables as constexpr,
|
||||
// so that they can be used as template arguments.
|
||||
// This is exactly what the code needs to do with i, unfortunately.
|
||||
// The work-around is to define i inside the inner lambda,
|
||||
// by using just the type from the enclosing scope.
|
||||
constexpr int i = I_type::value;
|
||||
return ScaledBasis<decltype(e), i>{};
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@ -279,20 +356,6 @@ make_basis_like(Shape const& shape)
|
||||
}
|
||||
|
||||
// Equality
|
||||
template <class T, int N, int M>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator==(ScaledBasis<T,N>, Int<M>) {
|
||||
return false_type{};
|
||||
}
|
||||
|
||||
template <int N, class U, int M>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator==(Int<N>, ScaledBasis<U,M>) {
|
||||
return false_type{};
|
||||
}
|
||||
|
||||
template <class T, int N, class U, int M>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
@ -300,13 +363,37 @@ operator==(ScaledBasis<T,N> const& t, ScaledBasis<U,M> const& u) {
|
||||
return bool_constant<M == N>{} && t.value() == u.value();
|
||||
}
|
||||
|
||||
// Not equal to anything else
|
||||
template <class T, int N, class U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
false_type
|
||||
operator==(ScaledBasis<T,N> const&, U const&) {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <class T, class U, int M>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
false_type
|
||||
operator==(T const&, ScaledBasis<U,M> const&) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Abs
|
||||
template <int N, class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
abs(ScaledBasis<T,N> const& e) {
|
||||
return ScaledBasis<decltype(abs(e.value())),N>{abs(e.value())};
|
||||
}
|
||||
|
||||
// Multiplication
|
||||
template <class A, int N, class T,
|
||||
__CUTE_REQUIRES(cute::is_integral<A>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator*(A const& a, ScaledBasis<T,N> const& e) {
|
||||
return ScaledBasis<decltype(a*e.value()),N>{a*e.value()};
|
||||
auto r = a * e.value();
|
||||
return ScaledBasis<decltype(r),N>{r};
|
||||
}
|
||||
|
||||
template <int N, class T, class B,
|
||||
@ -314,7 +401,8 @@ template <int N, class T, class B,
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator*(ScaledBasis<T,N> const& e, B const& b) {
|
||||
return ScaledBasis<decltype(e.value()*b),N>{e.value()*b};
|
||||
auto r = e.value() * b;
|
||||
return ScaledBasis<decltype(r),N>{r};
|
||||
}
|
||||
|
||||
// Addition
|
||||
@ -334,6 +422,22 @@ operator+(ArithmeticTuple<T...> const& t, ScaledBasis<U,M> const& u) {
|
||||
return as_arithmetic_tuple<R>(t) + as_arithmetic_tuple<R>(u);
|
||||
}
|
||||
|
||||
template <int N, class T, class... U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator+(ScaledBasis<T,N> const& t, tuple<U...> const& u) {
|
||||
constexpr int R = cute::max(N+1, int(sizeof...(U)));
|
||||
return as_arithmetic_tuple<R>(t) + as_arithmetic_tuple(u);
|
||||
}
|
||||
|
||||
template <class... T, int M, class U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator+(tuple<T...> const& t, ScaledBasis<U,M> const& u) {
|
||||
constexpr int R = cute::max(int(sizeof...(T)), M+1);
|
||||
return as_arithmetic_tuple(t) + as_arithmetic_tuple<R>(u);
|
||||
}
|
||||
|
||||
template <int N, class T, int M, class U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
@ -342,18 +446,26 @@ operator+(ScaledBasis<T,N> const& t, ScaledBasis<U,M> const& u) {
|
||||
return as_arithmetic_tuple<R>(t) + as_arithmetic_tuple<R>(u);
|
||||
}
|
||||
|
||||
template <class T, class U, int M>
|
||||
template <auto t, class U, int M>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator+(constant<T,0>, ScaledBasis<U,M> const& u) {
|
||||
return u;
|
||||
operator+(C<t>, ScaledBasis<U,M> const& u) {
|
||||
if constexpr (t == 0) {
|
||||
return u;
|
||||
} else {
|
||||
static_assert(t == 0, "ScaledBasis op+ error!");
|
||||
}
|
||||
}
|
||||
|
||||
template <class T, int N, class U>
|
||||
template <class T, int N, auto u>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator+(ScaledBasis<T,N> const& t, constant<U,0>) {
|
||||
return t;
|
||||
operator+(ScaledBasis<T,N> const& t, C<u>) {
|
||||
if constexpr (u == 0) {
|
||||
return t;
|
||||
} else {
|
||||
static_assert(u == 0, "ScaledBasis op+ error!");
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
@ -380,7 +492,7 @@ namespace CUTE_STL_NAMESPACE
|
||||
|
||||
template <class... T>
|
||||
struct tuple_size<cute::ArithmeticTuple<T...>>
|
||||
: cute::integral_constant<size_t, sizeof...(T)>
|
||||
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
|
||||
{};
|
||||
|
||||
template <size_t I, class... T>
|
||||
@ -390,7 +502,7 @@ struct tuple_element<I, cute::ArithmeticTuple<T...>>
|
||||
|
||||
template <class... T>
|
||||
struct tuple_size<const cute::ArithmeticTuple<T...>>
|
||||
: cute::integral_constant<size_t, sizeof...(T)>
|
||||
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
|
||||
{};
|
||||
|
||||
template <size_t I, class... T>
|
||||
@ -414,7 +526,7 @@ struct tuple_element;
|
||||
|
||||
template <class... T>
|
||||
struct tuple_size<cute::ArithmeticTuple<T...>>
|
||||
: cute::integral_constant<size_t, sizeof...(T)>
|
||||
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
|
||||
{};
|
||||
|
||||
template <size_t I, class... T>
|
||||
@ -424,7 +536,7 @@ struct tuple_element<I, cute::ArithmeticTuple<T...>>
|
||||
|
||||
template <class... T>
|
||||
struct tuple_size<const cute::ArithmeticTuple<T...>>
|
||||
: cute::integral_constant<size_t, sizeof...(T)>
|
||||
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
|
||||
{};
|
||||
|
||||
template <size_t I, class... T>
|
||||
|
||||
Reference in New Issue
Block a user