Updates for CUTLASS 3.5.0 (#1468)

This commit is contained in:
Vijay Thakkar
2024-04-11 21:33:40 -04:00
committed by GitHub
parent a40e08e9d5
commit 7d49e6c7e2
171 changed files with 7526 additions and 1888 deletions

View File

@ -73,25 +73,17 @@ make_arithmetic_tuple(T const&... t) {
return ArithmeticTuple<T...>(t...);
}
template <class... T>
template <class T>
CUTE_HOST_DEVICE constexpr
auto
as_arithmetic_tuple(tuple<T...> const& t) {
return ArithmeticTuple<T...>(t);
}
template <class T, __CUTE_REQUIRES(is_integral<T>::value)>
CUTE_HOST_DEVICE constexpr
T const&
as_arithmetic_tuple(T const& t) {
return t;
}
template <class... T>
CUTE_HOST_DEVICE constexpr
auto
as_arithmetic_tuple(ArithmeticTuple<T...> const& t) {
return t;
if constexpr (is_tuple<T>::value) {
return detail::tapply(t, [](auto const& x){ return as_arithmetic_tuple(x); },
[](auto const&... a){ return make_arithmetic_tuple(a...); },
tuple_seq<T>{});
} else {
return t;
}
}
//
@ -289,6 +281,26 @@ basis_get(SB const& e, Tuple const& t)
namespace detail {
template <class T, int... I>
CUTE_HOST_DEVICE constexpr
auto
to_atuple_i(T const& t, seq<I...>) {
return make_arithmetic_tuple((void(I),Int<0>{})..., t);
}
} // end namespace detail
// Turn a ScaledBases<T,N> into a rank-N+1 ArithmeticTuple
// with N prefix 0s: (_0,_0,...N...,_0,T)
template <class T, int N>
CUTE_HOST_DEVICE constexpr
auto
as_arithmetic_tuple(ScaledBasis<T,N> const& t) {
return detail::to_atuple_i(as_arithmetic_tuple(t.value()), make_seq<N>{});
}
namespace detail {
template <int... Ns>
struct Basis;
@ -315,71 +327,6 @@ struct Basis<N,Ns...> {
template <int... N>
using E = typename detail::Basis<N...>::type;
namespace detail {
template <class T, int... I, int... J>
CUTE_HOST_DEVICE constexpr
auto
as_arithmetic_tuple(T const& t, seq<I...>, seq<J...>) {
return make_arithmetic_tuple((void(I),Int<0>{})..., t, (void(J),Int<0>{})...);
}
template <class... T, int... I, int... J>
CUTE_HOST_DEVICE constexpr
auto
as_arithmetic_tuple(ArithmeticTuple<T...> const& t, seq<I...>, seq<J...>) {
return make_arithmetic_tuple(get<I>(t)..., (void(J),Int<0>{})...);
}
} // end namespace detail
// Turn a ScaledBases<T,N> into a rank-M ArithmeticTuple
// with N prefix 0s: (_0,_0,...N...,_0,T,_0,...,_0,_0)
template <int M, class T, int N>
CUTE_HOST_DEVICE constexpr
auto
as_arithmetic_tuple(ScaledBasis<T,N> const& t) {
static_assert(M > N, "Mismatched ranks");
return detail::as_arithmetic_tuple(t.value(), make_seq<N>{}, make_seq<M-N-1>{});
}
// Turn a ScaledBases<T,N> into a rank-N ArithmeticTuple
// with N prefix 0s: (_0,_0,...N...,_0,T)
template <class T, int N>
CUTE_HOST_DEVICE constexpr
auto
as_arithmetic_tuple(ScaledBasis<T,N> const& t) {
return as_arithmetic_tuple<N+1>(t);
}
// Turn an ArithmeticTuple into a rank-M ArithmeticTuple
// with postfix 0s: (t0,t1,t2,...,_0,...,_0,_0)
template <int M, class... T>
CUTE_HOST_DEVICE constexpr
auto
as_arithmetic_tuple(ArithmeticTuple<T...> const& t) {
static_assert(M >= sizeof...(T), "Mismatched ranks");
return detail::as_arithmetic_tuple(t, make_seq<int(sizeof...(T))>{}, make_seq<M-int(sizeof...(T))>{});
}
template <class T, int M, class U>
CUTE_HOST_DEVICE constexpr
auto
safe_div(ScaledBasis<T,M> const& b, U const& u)
{
auto t = safe_div(b.value(), u);
return ScaledBasis<decltype(t),M>{t};
}
template <class T, int M, class U>
CUTE_HOST_DEVICE constexpr
auto
shape_div(ScaledBasis<T,M> const& b, U const& u)
{
auto t = shape_div(b.value(), u);
return ScaledBasis<decltype(t),M>{t};
}
template <class Shape>
CUTE_HOST_DEVICE constexpr
auto
@ -387,8 +334,7 @@ make_basis_like(Shape const& shape)
{
if constexpr (is_integral<Shape>::value) {
return Int<1>{};
}
else {
} else {
// Generate bases for each rank of shape
return transform(tuple_seq<Shape>{}, shape, [](auto I, auto si) {
// Generate bases for each rank of si and add an i on front
@ -408,6 +354,28 @@ make_basis_like(Shape const& shape)
CUTE_GCC_UNREACHABLE;
}
//
// Arithmetic
//
template <class T, int M, class U>
CUTE_HOST_DEVICE constexpr
auto
safe_div(ScaledBasis<T,M> const& b, U const& u)
{
auto t = safe_div(b.value(), u);
return ScaledBasis<decltype(t),M>{t};
}
template <class T, int M, class U>
CUTE_HOST_DEVICE constexpr
auto
shape_div(ScaledBasis<T,M> const& b, U const& u)
{
auto t = shape_div(b.value(), u);
return ScaledBasis<decltype(t),M>{t};
}
// Equality
template <class T, int N, class U, int M>
CUTE_HOST_DEVICE constexpr
@ -432,7 +400,7 @@ operator==(T const&, ScaledBasis<U,M> const&) {
}
// Abs
template <int N, class T>
template <class T, int N>
CUTE_HOST_DEVICE constexpr
auto
abs(ScaledBasis<T,N> const& e) {
@ -440,7 +408,7 @@ abs(ScaledBasis<T,N> const& e) {
}
// Multiplication
template <class A, int N, class T>
template <class A, class T, int N>
CUTE_HOST_DEVICE constexpr
auto
operator*(A const& a, ScaledBasis<T,N> const& e) {
@ -448,7 +416,7 @@ operator*(A const& a, ScaledBasis<T,N> const& e) {
return ScaledBasis<decltype(r),N>{r};
}
template <int N, class T, class B>
template <class T, int N, class B>
CUTE_HOST_DEVICE constexpr
auto
operator*(ScaledBasis<T,N> const& e, B const& b) {
@ -457,44 +425,25 @@ operator*(ScaledBasis<T,N> const& e, B const& b) {
}
// Addition
template <int N, class T, class... U>
CUTE_HOST_DEVICE constexpr
auto
operator+(ScaledBasis<T,N> const& t, ArithmeticTuple<U...> const& u) {
constexpr int R = cute::max(N+1, int(sizeof...(U)));
return as_arithmetic_tuple<R>(t) + as_arithmetic_tuple<R>(u);
}
template <class... T, int M, class U>
CUTE_HOST_DEVICE constexpr
auto
operator+(ArithmeticTuple<T...> const& t, ScaledBasis<U,M> const& u) {
constexpr int R = cute::max(int(sizeof...(T)), M+1);
return as_arithmetic_tuple<R>(t) + as_arithmetic_tuple<R>(u);
}
template <int N, class T, class... U>
CUTE_HOST_DEVICE constexpr
auto
operator+(ScaledBasis<T,N> const& t, tuple<U...> const& u) {
constexpr int R = cute::max(N+1, int(sizeof...(U)));
return as_arithmetic_tuple<R>(t) + as_arithmetic_tuple(u);
}
template <class... T, int M, class U>
CUTE_HOST_DEVICE constexpr
auto
operator+(tuple<T...> const& t, ScaledBasis<U,M> const& u) {
constexpr int R = cute::max(int(sizeof...(T)), M+1);
return as_arithmetic_tuple(t) + as_arithmetic_tuple<R>(u);
}
template <int N, class T, int M, class U>
template <class T, int N, class U, int M>
CUTE_HOST_DEVICE constexpr
auto
operator+(ScaledBasis<T,N> const& t, ScaledBasis<U,M> const& u) {
constexpr int R = cute::max(N+1,M+1);
return as_arithmetic_tuple<R>(t) + as_arithmetic_tuple<R>(u);
return as_arithmetic_tuple(t) + as_arithmetic_tuple(u);
}
template <class T, int N, class... U>
CUTE_HOST_DEVICE constexpr
auto
operator+(ScaledBasis<T,N> const& t, ArithmeticTuple<U...> const& u) {
return as_arithmetic_tuple(t) + u;
}
template <class... T, class U, int M>
CUTE_HOST_DEVICE constexpr
auto
operator+(ArithmeticTuple<T...> const& t, ScaledBasis<U,M> const& u) {
return t + as_arithmetic_tuple(u);
}
template <auto t, class U, int M>

View File

@ -56,10 +56,10 @@ fma(complex<T> & d,
complex<T> const& b,
complex<T> const& c)
{
d.real(fma( a.real(), b.real(), c.real()));
d.imag(fma( a.real(), b.imag(), c.imag()));
d.real(fma(-a.imag(), b.imag(), d.real()));
d.imag(fma( a.imag(), b.real(), d.imag()));
fma(d.real(), a.real(), b.real(), c.real());
fma(d.imag(), a.real(), b.imag(), c.imag());
fma(d.real(), -a.imag(), b.imag(), d.real());
fma(d.imag(), a.imag(), b.real(), d.imag());
}
/// Fused multiply-add for triplets