CUTLASS 3.2 (#1024)

* CUTLASS 3.2
This commit is contained in:
ANIKET SHIVAM
2023-08-07 14:50:32 -10:00
committed by GitHub
parent a0d787b746
commit 4575443d44
392 changed files with 47559 additions and 7940 deletions

View File

@ -77,6 +77,20 @@ as_arithmetic_tuple(tuple<T...> const& t) {
return ArithmeticTuple<T...>(t);
}
template <class T, __CUTE_REQUIRES(is_integral<T>::value)>
CUTE_HOST_DEVICE constexpr
T const&
as_arithmetic_tuple(T const& t) {
return t;
}
template <class... T>
CUTE_HOST_DEVICE constexpr
auto
as_arithmetic_tuple(ArithmeticTuple<T...> const& t) {
return t;
}
//
// Numeric operators
//
@ -110,18 +124,26 @@ operator+(tuple<T...> const& t, ArithmeticTuple<U...> const& u) {
// Special cases
//
template <class T, class... U>
template <auto t, class... U>
CUTE_HOST_DEVICE constexpr
auto
operator+(constant<T,0>, ArithmeticTuple<U...> const& u) {
return u;
operator+(C<t>, ArithmeticTuple<U...> const& u) {
if constexpr (t == 0) {
return u;
} else {
static_assert(t == 0, "Artihmetic tuple op+ error!");
}
}
template <class... T, class U>
template <class... T, auto u>
CUTE_HOST_DEVICE constexpr
auto
operator+(ArithmeticTuple<T...> const& t, constant<U,0>) {
return t;
operator+(ArithmeticTuple<T...> const& t, C<u>) {
if constexpr (u == 0) {
return t;
} else {
static_assert(u == 0, "Artihmetic tuple op+ error!");
}
}
//
@ -159,11 +181,9 @@ CUTE_HOST_DEVICE void print(ArithmeticTupleIterator<ArithTuple> const& iter) {
//
// ArithmeticTuple "basis" elements
//
// Abstract value:
// A ScaledBasis<T,N> is a (at least) rank-N0 ArithmeticTuple:
// A ScaledBasis<T,N> is a (at least) rank-N+1 ArithmeticTuple:
// (_0,_0,...,T,_0,...)
// with value T in the Nth mode
template <class T, int N>
struct ScaledBasis : private tuple<T>
@ -188,16 +208,30 @@ struct is_scaled_basis<ScaledBasis<T,N>> : true_type {};
template <class T, int N>
struct is_integral<ScaledBasis<T,N>> : true_type {};
template <class T>
// Get the scalar T out of a ScaledBasis
template <class SB>
CUTE_HOST_DEVICE constexpr auto
basis_value(T const& e) {
return e;
basis_value(SB const& e)
{
if constexpr (is_scaled_basis<SB>::value) {
return basis_value(e.value());
} else {
return e;
}
CUTE_GCC_UNREACHABLE;
}
template <class T, int N>
// Apply the N... pack to another Tuple
template <class SB, class Tuple>
CUTE_HOST_DEVICE constexpr auto
basis_value(ScaledBasis<T,N> const& e) {
return basis_value(e.value());
basis_get(SB const& e, Tuple const& t)
{
if constexpr (is_scaled_basis<SB>::value) {
return basis_get(e.value(), get<SB::mode()>(t));
} else {
return t;
}
CUTE_GCC_UNREACHABLE;
}
namespace detail {
@ -217,6 +251,14 @@ struct Basis<N,Ns...> {
} // end namespace detail
// Shortcut for writing ScaledBasis<ScaledBasis<ScaledBasis<Int<1>, N0>, N1>, ...>
// E<> := _1
// E<0> := (_1,_0,_0,...)
// E<1> := (_0,_1,_0,...)
// E<0,0> := ((_1,_0,_0,...),_0,_0,...)
// E<0,1> := ((_0,_1,_0,...),_0,_0,...)
// E<1,0> := (_0,(_1,_0,_0,...),_0,...)
// E<1,1> := (_0,(_0,_1,_0,...),_0,...)
template <int... N>
using E = typename detail::Basis<N...>::type;
@ -248,6 +290,15 @@ as_arithmetic_tuple(ScaledBasis<T,N> const& t) {
return detail::as_arithmetic_tuple(t.value(), make_seq<N>{}, make_seq<M-N-1>{});
}
// Turn a ScaledBases<T,N> into a rank-N ArithmeticTuple
// with N prefix 0s: (_0,_0,...N...,_0,T)
template <class T, int N>
CUTE_HOST_DEVICE constexpr
auto
as_arithmetic_tuple(ScaledBasis<T,N> const& t) {
return as_arithmetic_tuple<N+1>(t);
}
// Turn an ArithmeticTuple into a rank-M ArithmeticTuple
// with postfix 0s: (t0,t1,t2,...,_0,...,_0,_0)
template <int M, class... T>
@ -258,7 +309,24 @@ as_arithmetic_tuple(ArithmeticTuple<T...> const& t) {
return detail::as_arithmetic_tuple(t, make_seq<int(sizeof...(T))>{}, make_seq<M-int(sizeof...(T))>{});
}
// Return...
template <class T, int M, class U>
CUTE_HOST_DEVICE constexpr
auto
safe_div(ScaledBasis<T,M> const& b, U const& u)
{
auto t = safe_div(b.value(), u);
return ScaledBasis<decltype(t),M>{t};
}
template <class T, int M, class U>
CUTE_HOST_DEVICE constexpr
auto
shape_div(ScaledBasis<T,M> const& b, U const& u)
{
auto t = shape_div(b.value(), u);
return ScaledBasis<decltype(t),M>{t};
}
template <class Shape>
CUTE_HOST_DEVICE constexpr
auto
@ -266,12 +334,21 @@ make_basis_like(Shape const& shape)
{
if constexpr (is_integral<Shape>::value) {
return Int<1>{};
} else {
}
else {
// Generate bases for each rank of shape
return transform(tuple_seq<Shape>{}, [&](auto I) {
// Generate bases for each rank of shape_i and add an i on front
constexpr int i = decltype(I)::value; // NOTE: nvcc workaround
return transform_leaf(make_basis_like(get<i>(shape)), [&](auto e) { return ScaledBasis<decltype(e),i>{}; });
return transform(tuple_seq<Shape>{}, shape, [](auto I, auto si) {
// Generate bases for each rank of si and add an i on front
using I_type = decltype(I);
return transform_leaf(make_basis_like(si), [](auto e) {
// MSVC has trouble capturing variables as constexpr,
// so that they can be used as template arguments.
// This is exactly what the code needs to do with i, unfortunately.
// The work-around is to define i inside the inner lambda,
// by using just the type from the enclosing scope.
constexpr int i = I_type::value;
return ScaledBasis<decltype(e), i>{};
});
});
}
@ -279,20 +356,6 @@ make_basis_like(Shape const& shape)
}
// Equality
template <class T, int N, int M>
CUTE_HOST_DEVICE constexpr
auto
operator==(ScaledBasis<T,N>, Int<M>) {
return false_type{};
}
template <int N, class U, int M>
CUTE_HOST_DEVICE constexpr
auto
operator==(Int<N>, ScaledBasis<U,M>) {
return false_type{};
}
template <class T, int N, class U, int M>
CUTE_HOST_DEVICE constexpr
auto
@ -300,13 +363,37 @@ operator==(ScaledBasis<T,N> const& t, ScaledBasis<U,M> const& u) {
return bool_constant<M == N>{} && t.value() == u.value();
}
// Not equal to anything else
template <class T, int N, class U>
CUTE_HOST_DEVICE constexpr
false_type
operator==(ScaledBasis<T,N> const&, U const&) {
return {};
}
template <class T, class U, int M>
CUTE_HOST_DEVICE constexpr
false_type
operator==(T const&, ScaledBasis<U,M> const&) {
return {};
}
// Abs
template <int N, class T>
CUTE_HOST_DEVICE constexpr
auto
abs(ScaledBasis<T,N> const& e) {
return ScaledBasis<decltype(abs(e.value())),N>{abs(e.value())};
}
// Multiplication
template <class A, int N, class T,
__CUTE_REQUIRES(cute::is_integral<A>::value)>
CUTE_HOST_DEVICE constexpr
auto
operator*(A const& a, ScaledBasis<T,N> const& e) {
return ScaledBasis<decltype(a*e.value()),N>{a*e.value()};
auto r = a * e.value();
return ScaledBasis<decltype(r),N>{r};
}
template <int N, class T, class B,
@ -314,7 +401,8 @@ template <int N, class T, class B,
CUTE_HOST_DEVICE constexpr
auto
operator*(ScaledBasis<T,N> const& e, B const& b) {
return ScaledBasis<decltype(e.value()*b),N>{e.value()*b};
auto r = e.value() * b;
return ScaledBasis<decltype(r),N>{r};
}
// Addition
@ -334,6 +422,22 @@ operator+(ArithmeticTuple<T...> const& t, ScaledBasis<U,M> const& u) {
return as_arithmetic_tuple<R>(t) + as_arithmetic_tuple<R>(u);
}
template <int N, class T, class... U>
CUTE_HOST_DEVICE constexpr
auto
operator+(ScaledBasis<T,N> const& t, tuple<U...> const& u) {
constexpr int R = cute::max(N+1, int(sizeof...(U)));
return as_arithmetic_tuple<R>(t) + as_arithmetic_tuple(u);
}
template <class... T, int M, class U>
CUTE_HOST_DEVICE constexpr
auto
operator+(tuple<T...> const& t, ScaledBasis<U,M> const& u) {
constexpr int R = cute::max(int(sizeof...(T)), M+1);
return as_arithmetic_tuple(t) + as_arithmetic_tuple<R>(u);
}
template <int N, class T, int M, class U>
CUTE_HOST_DEVICE constexpr
auto
@ -342,18 +446,26 @@ operator+(ScaledBasis<T,N> const& t, ScaledBasis<U,M> const& u) {
return as_arithmetic_tuple<R>(t) + as_arithmetic_tuple<R>(u);
}
template <class T, class U, int M>
template <auto t, class U, int M>
CUTE_HOST_DEVICE constexpr
auto
operator+(constant<T,0>, ScaledBasis<U,M> const& u) {
return u;
operator+(C<t>, ScaledBasis<U,M> const& u) {
if constexpr (t == 0) {
return u;
} else {
static_assert(t == 0, "ScaledBasis op+ error!");
}
}
template <class T, int N, class U>
template <class T, int N, auto u>
CUTE_HOST_DEVICE constexpr
auto
operator+(ScaledBasis<T,N> const& t, constant<U,0>) {
return t;
operator+(ScaledBasis<T,N> const& t, C<u>) {
if constexpr (u == 0) {
return t;
} else {
static_assert(u == 0, "ScaledBasis op+ error!");
}
}
//
@ -380,7 +492,7 @@ namespace CUTE_STL_NAMESPACE
template <class... T>
struct tuple_size<cute::ArithmeticTuple<T...>>
: cute::integral_constant<size_t, sizeof...(T)>
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
{};
template <size_t I, class... T>
@ -390,7 +502,7 @@ struct tuple_element<I, cute::ArithmeticTuple<T...>>
template <class... T>
struct tuple_size<const cute::ArithmeticTuple<T...>>
: cute::integral_constant<size_t, sizeof...(T)>
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
{};
template <size_t I, class... T>
@ -414,7 +526,7 @@ struct tuple_element;
template <class... T>
struct tuple_size<cute::ArithmeticTuple<T...>>
: cute::integral_constant<size_t, sizeof...(T)>
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
{};
template <size_t I, class... T>
@ -424,7 +536,7 @@ struct tuple_element<I, cute::ArithmeticTuple<T...>>
template <class... T>
struct tuple_size<const cute::ArithmeticTuple<T...>>
: cute::integral_constant<size_t, sizeof...(T)>
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
{};
template <size_t I, class... T>