Updates for CUTLASS 3.5.0 (#1468)
This commit is contained in:
@ -73,25 +73,17 @@ make_arithmetic_tuple(T const&... t) {
|
||||
return ArithmeticTuple<T...>(t...);
|
||||
}
|
||||
|
||||
template <class... T>
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
as_arithmetic_tuple(tuple<T...> const& t) {
|
||||
return ArithmeticTuple<T...>(t);
|
||||
}
|
||||
|
||||
template <class T, __CUTE_REQUIRES(is_integral<T>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T const&
|
||||
as_arithmetic_tuple(T const& t) {
|
||||
return t;
|
||||
}
|
||||
|
||||
template <class... T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
as_arithmetic_tuple(ArithmeticTuple<T...> const& t) {
|
||||
return t;
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return detail::tapply(t, [](auto const& x){ return as_arithmetic_tuple(x); },
|
||||
[](auto const&... a){ return make_arithmetic_tuple(a...); },
|
||||
tuple_seq<T>{});
|
||||
} else {
|
||||
return t;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
@ -289,6 +281,26 @@ basis_get(SB const& e, Tuple const& t)
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class T, int... I>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
to_atuple_i(T const& t, seq<I...>) {
|
||||
return make_arithmetic_tuple((void(I),Int<0>{})..., t);
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
// Turn a ScaledBases<T,N> into a rank-N+1 ArithmeticTuple
|
||||
// with N prefix 0s: (_0,_0,...N...,_0,T)
|
||||
template <class T, int N>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
as_arithmetic_tuple(ScaledBasis<T,N> const& t) {
|
||||
return detail::to_atuple_i(as_arithmetic_tuple(t.value()), make_seq<N>{});
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <int... Ns>
|
||||
struct Basis;
|
||||
|
||||
@ -315,71 +327,6 @@ struct Basis<N,Ns...> {
|
||||
template <int... N>
|
||||
using E = typename detail::Basis<N...>::type;
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class T, int... I, int... J>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
as_arithmetic_tuple(T const& t, seq<I...>, seq<J...>) {
|
||||
return make_arithmetic_tuple((void(I),Int<0>{})..., t, (void(J),Int<0>{})...);
|
||||
}
|
||||
|
||||
template <class... T, int... I, int... J>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
as_arithmetic_tuple(ArithmeticTuple<T...> const& t, seq<I...>, seq<J...>) {
|
||||
return make_arithmetic_tuple(get<I>(t)..., (void(J),Int<0>{})...);
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
// Turn a ScaledBases<T,N> into a rank-M ArithmeticTuple
|
||||
// with N prefix 0s: (_0,_0,...N...,_0,T,_0,...,_0,_0)
|
||||
template <int M, class T, int N>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
as_arithmetic_tuple(ScaledBasis<T,N> const& t) {
|
||||
static_assert(M > N, "Mismatched ranks");
|
||||
return detail::as_arithmetic_tuple(t.value(), make_seq<N>{}, make_seq<M-N-1>{});
|
||||
}
|
||||
|
||||
// Turn a ScaledBases<T,N> into a rank-N ArithmeticTuple
|
||||
// with N prefix 0s: (_0,_0,...N...,_0,T)
|
||||
template <class T, int N>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
as_arithmetic_tuple(ScaledBasis<T,N> const& t) {
|
||||
return as_arithmetic_tuple<N+1>(t);
|
||||
}
|
||||
|
||||
// Turn an ArithmeticTuple into a rank-M ArithmeticTuple
|
||||
// with postfix 0s: (t0,t1,t2,...,_0,...,_0,_0)
|
||||
template <int M, class... T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
as_arithmetic_tuple(ArithmeticTuple<T...> const& t) {
|
||||
static_assert(M >= sizeof...(T), "Mismatched ranks");
|
||||
return detail::as_arithmetic_tuple(t, make_seq<int(sizeof...(T))>{}, make_seq<M-int(sizeof...(T))>{});
|
||||
}
|
||||
|
||||
template <class T, int M, class U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
safe_div(ScaledBasis<T,M> const& b, U const& u)
|
||||
{
|
||||
auto t = safe_div(b.value(), u);
|
||||
return ScaledBasis<decltype(t),M>{t};
|
||||
}
|
||||
|
||||
template <class T, int M, class U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
shape_div(ScaledBasis<T,M> const& b, U const& u)
|
||||
{
|
||||
auto t = shape_div(b.value(), u);
|
||||
return ScaledBasis<decltype(t),M>{t};
|
||||
}
|
||||
|
||||
template <class Shape>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
@ -387,8 +334,7 @@ make_basis_like(Shape const& shape)
|
||||
{
|
||||
if constexpr (is_integral<Shape>::value) {
|
||||
return Int<1>{};
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
// Generate bases for each rank of shape
|
||||
return transform(tuple_seq<Shape>{}, shape, [](auto I, auto si) {
|
||||
// Generate bases for each rank of si and add an i on front
|
||||
@ -408,6 +354,28 @@ make_basis_like(Shape const& shape)
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// Arithmetic
|
||||
//
|
||||
|
||||
template <class T, int M, class U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
safe_div(ScaledBasis<T,M> const& b, U const& u)
|
||||
{
|
||||
auto t = safe_div(b.value(), u);
|
||||
return ScaledBasis<decltype(t),M>{t};
|
||||
}
|
||||
|
||||
template <class T, int M, class U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
shape_div(ScaledBasis<T,M> const& b, U const& u)
|
||||
{
|
||||
auto t = shape_div(b.value(), u);
|
||||
return ScaledBasis<decltype(t),M>{t};
|
||||
}
|
||||
|
||||
// Equality
|
||||
template <class T, int N, class U, int M>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
@ -432,7 +400,7 @@ operator==(T const&, ScaledBasis<U,M> const&) {
|
||||
}
|
||||
|
||||
// Abs
|
||||
template <int N, class T>
|
||||
template <class T, int N>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
abs(ScaledBasis<T,N> const& e) {
|
||||
@ -440,7 +408,7 @@ abs(ScaledBasis<T,N> const& e) {
|
||||
}
|
||||
|
||||
// Multiplication
|
||||
template <class A, int N, class T>
|
||||
template <class A, class T, int N>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator*(A const& a, ScaledBasis<T,N> const& e) {
|
||||
@ -448,7 +416,7 @@ operator*(A const& a, ScaledBasis<T,N> const& e) {
|
||||
return ScaledBasis<decltype(r),N>{r};
|
||||
}
|
||||
|
||||
template <int N, class T, class B>
|
||||
template <class T, int N, class B>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator*(ScaledBasis<T,N> const& e, B const& b) {
|
||||
@ -457,44 +425,25 @@ operator*(ScaledBasis<T,N> const& e, B const& b) {
|
||||
}
|
||||
|
||||
// Addition
|
||||
template <int N, class T, class... U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator+(ScaledBasis<T,N> const& t, ArithmeticTuple<U...> const& u) {
|
||||
constexpr int R = cute::max(N+1, int(sizeof...(U)));
|
||||
return as_arithmetic_tuple<R>(t) + as_arithmetic_tuple<R>(u);
|
||||
}
|
||||
|
||||
template <class... T, int M, class U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator+(ArithmeticTuple<T...> const& t, ScaledBasis<U,M> const& u) {
|
||||
constexpr int R = cute::max(int(sizeof...(T)), M+1);
|
||||
return as_arithmetic_tuple<R>(t) + as_arithmetic_tuple<R>(u);
|
||||
}
|
||||
|
||||
template <int N, class T, class... U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator+(ScaledBasis<T,N> const& t, tuple<U...> const& u) {
|
||||
constexpr int R = cute::max(N+1, int(sizeof...(U)));
|
||||
return as_arithmetic_tuple<R>(t) + as_arithmetic_tuple(u);
|
||||
}
|
||||
|
||||
template <class... T, int M, class U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator+(tuple<T...> const& t, ScaledBasis<U,M> const& u) {
|
||||
constexpr int R = cute::max(int(sizeof...(T)), M+1);
|
||||
return as_arithmetic_tuple(t) + as_arithmetic_tuple<R>(u);
|
||||
}
|
||||
|
||||
template <int N, class T, int M, class U>
|
||||
template <class T, int N, class U, int M>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator+(ScaledBasis<T,N> const& t, ScaledBasis<U,M> const& u) {
|
||||
constexpr int R = cute::max(N+1,M+1);
|
||||
return as_arithmetic_tuple<R>(t) + as_arithmetic_tuple<R>(u);
|
||||
return as_arithmetic_tuple(t) + as_arithmetic_tuple(u);
|
||||
}
|
||||
|
||||
template <class T, int N, class... U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator+(ScaledBasis<T,N> const& t, ArithmeticTuple<U...> const& u) {
|
||||
return as_arithmetic_tuple(t) + u;
|
||||
}
|
||||
|
||||
template <class... T, class U, int M>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator+(ArithmeticTuple<T...> const& t, ScaledBasis<U,M> const& u) {
|
||||
return t + as_arithmetic_tuple(u);
|
||||
}
|
||||
|
||||
template <auto t, class U, int M>
|
||||
|
||||
@ -56,10 +56,10 @@ fma(complex<T> & d,
|
||||
complex<T> const& b,
|
||||
complex<T> const& c)
|
||||
{
|
||||
d.real(fma( a.real(), b.real(), c.real()));
|
||||
d.imag(fma( a.real(), b.imag(), c.imag()));
|
||||
d.real(fma(-a.imag(), b.imag(), d.real()));
|
||||
d.imag(fma( a.imag(), b.real(), d.imag()));
|
||||
fma(d.real(), a.real(), b.real(), c.real());
|
||||
fma(d.imag(), a.real(), b.imag(), c.imag());
|
||||
fma(d.real(), -a.imag(), b.imag(), d.real());
|
||||
fma(d.imag(), a.imag(), b.real(), d.imag());
|
||||
}
|
||||
|
||||
/// Fused multiply-add for triplets
|
||||
|
||||
Reference in New Issue
Block a user