CUTLASS 3.1 (#915)

Co-authored-by: Aniket Shivam <ashivam@nvidia.com>
This commit is contained in:
ANIKET SHIVAM
2023-04-14 20:19:34 -07:00
committed by GitHub
parent 9b8166e3f0
commit d572cc1aab
482 changed files with 37184 additions and 16419 deletions

View File

@ -36,6 +36,7 @@
#include <cute/numeric/integral_constant.hpp>
#include <cute/algorithm/functional.hpp>
#include <cute/algorithm/tuple_algorithms.hpp>
#include <cute/util/type_traits.hpp>
namespace cute
{
@ -361,28 +362,75 @@ operator+(ScaledBasis<T,N> const& t, constant<U,0>) {
template <class T, int N>
CUTE_HOST_DEVICE void print(ScaledBasis<T,N> const& e) {
printf("%d:", N); print(e.value());
print(e.value()); printf("@%d", N);
}
#if !defined(__CUDACC_RTC__)
template <class T, int N>
CUTE_HOST std::ostream& operator<<(std::ostream& os, ScaledBasis<T,N> const& e) {
return os << N << ":" << e.value();
return os << e.value() << "@" << N;
}
#endif
} // end namespace cute
namespace std
namespace CUTE_STL_NAMESPACE
{
template <class... T>
struct tuple_size<cute::ArithmeticTuple<T...>>
: std::integral_constant<std::size_t, sizeof...(T)>
: cute::integral_constant<size_t, sizeof...(T)>
{};
template <std::size_t I, class... T>
template <size_t I, class... T>
struct tuple_element<I, cute::ArithmeticTuple<T...>>
: std::tuple_element<I, std::tuple<T...>>
: CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>
{};
template <class... T>
struct tuple_size<const cute::ArithmeticTuple<T...>>
: cute::integral_constant<size_t, sizeof...(T)>
{};
template <size_t I, class... T>
struct tuple_element<I, const cute::ArithmeticTuple<T...>>
: CUTE_STL_NAMESPACE::tuple_element<I, const CUTE_STL_NAMESPACE::tuple<T...>>
{};
} // end namespace CUTE_STL_NAMESPACE
#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD
namespace std
{
#if defined(__CUDACC_RTC__)
template <class... _Tp>
struct tuple_size;
template<size_t _Ip, class... _Tp>
struct tuple_element;
#endif
template <class... T>
struct tuple_size<cute::ArithmeticTuple<T...>>
: cute::integral_constant<size_t, sizeof...(T)>
{};
template <size_t I, class... T>
struct tuple_element<I, cute::ArithmeticTuple<T...>>
: CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>
{};
template <class... T>
struct tuple_size<const cute::ArithmeticTuple<T...>>
: cute::integral_constant<size_t, sizeof...(T)>
{};
template <size_t I, class... T>
struct tuple_element<I, const cute::ArithmeticTuple<T...>>
: CUTE_STL_NAMESPACE::tuple_element<I, const CUTE_STL_NAMESPACE::tuple<T...>>
{};
} // end namespace std
#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD

View File

@ -43,9 +43,11 @@ using cutlass::bfloat16_t;
// Display utilities
//
#if !defined(__CUDACC_RTC__)
CUTE_HOST std::ostream& operator<<(std::ostream& os, bfloat16_t const& v)
{
return os << float(v);
}
#endif
} // end namespace cute

View File

@ -30,7 +30,7 @@
**************************************************************************************************/
#pragma once
#include <cstdint>
#include <cute/util/type_traits.hpp>
//#if defined(__CUDA_ARCH__)
//# include <cuda/std/complex>
@ -38,13 +38,37 @@
//# include <complex>
//#endif
// With CUDA 11.4, builds show spurious "-Wconversion" warnings
// on line 656 of thrust/detail/type_traits.h.
// These pragmas suppress the warnings.
// Suppress warnings for code in Thrust headers.
#if defined(_MSC_VER)
// We check for MSVC first, because MSVC also defines __GNUC__.
// It's common for non-GCC compilers that emulate GCC's behavior
// to define __GNUC__.
//
// thrust/complex.h triggers MSVC's warning on conversion
// from double to float (or const float) ("possible loss of data").
// MSVC treats this as an error by default (at least with
// CUTLASS's default CMake configuration).
#pragma warning( push )
#pragma warning( disable : 4244 )
#elif defined(__GNUC__)
// With GCC + CUDA 11.4, builds show spurious "-Wconversion"
// warnings on line 656 of thrust/detail/type_traits.h.
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wconversion"
#endif
#if defined(__CUDACC_RTC__)
#include <cuda/std/complex>
#else
#include <thrust/complex.h>
#endif
#if defined(_MSC_VER)
#pragma warning( pop )
#elif defined(__GNUC__)
#pragma GCC diagnostic pop
#endif
#include <cute/config.hpp>
@ -62,7 +86,11 @@ namespace cute
//template <class T>
//using complex = thrust::complex<T>;
#if defined(__CUDACC_RTC__)
using cuda::std::complex;
#else
using thrust::complex;
#endif
template <class T>
CUTE_HOST_DEVICE
@ -147,6 +175,7 @@ struct is_complex<complex<T>> {
//////////////////////////////////////////////////////////////////////////////////////////////////
// Display utilities
#if !defined(__CUDACC_RTC__)
template <class T>
CUTE_HOST std::ostream& operator<<(std::ostream& os, complex<T> const& z)
{
@ -159,5 +188,6 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, complex<T> const& z)
return os << _r;
}
}
#endif
} // end namespace cute

View File

@ -46,10 +46,12 @@ namespace cute
// Signed integers
//
using int8_t = std::int8_t;
using int16_t = std::int16_t;
using int32_t = std::int32_t;
using int64_t = std::int64_t;
using int2_t = cute::int2b_t;
using int4_t = cute::int4b_t;
using int8_t = CUTE_STL_NAMESPACE::int8_t;
using int16_t = CUTE_STL_NAMESPACE::int16_t;
using int32_t = CUTE_STL_NAMESPACE::int32_t;
using int64_t = CUTE_STL_NAMESPACE::int64_t;
template <int N> struct int_bit;
template <> struct int_bit< 2> { using type = cute::int2b_t; };
@ -72,10 +74,14 @@ using int_byte_t = typename int_byte<N>::type;
// Unsigned integers
//
using uint8_t = std::uint8_t;
using uint16_t = std::uint16_t;
using uint32_t = std::uint32_t;
using uint64_t = std::uint64_t;
using uint1_t = cute::uint1b_t;
using uint2_t = cute::uint2b_t;
using uint4_t = cute::uint4b_t;
using uint8_t = CUTE_STL_NAMESPACE::uint8_t;
using uint16_t = CUTE_STL_NAMESPACE::uint16_t;
using uint32_t = CUTE_STL_NAMESPACE::uint32_t;
using uint64_t = CUTE_STL_NAMESPACE::uint64_t;
using uint128_t = cute::uint128_t;
template <int N> struct uint_bit;
template <> struct uint_bit< 1> { using type = cute::uint1b_t; };
@ -102,7 +108,7 @@ using uint_byte_t = typename uint_byte<N>::type;
template <class T>
struct sizeof_bytes {
static constexpr std::size_t value = sizeof(T);
static constexpr size_t value = sizeof(T);
};
template <class T>
static constexpr int sizeof_bytes_v = sizeof_bytes<T>::value;
@ -113,15 +119,15 @@ static constexpr int sizeof_bytes_v = sizeof_bytes<T>::value;
template <class T>
struct sizeof_bits {
static constexpr std::size_t value = sizeof(T) * 8;
static constexpr size_t value = sizeof(T) * 8;
};
template <>
struct sizeof_bits<bool> {
static constexpr std::size_t value = 1;
static constexpr size_t value = 1;
};
template <int Bits, bool Signed>
struct sizeof_bits<integer_subbyte<Bits,Signed>> {
static constexpr std::size_t value = Bits;
static constexpr size_t value = Bits;
};
template <class T>
static constexpr int sizeof_bits_v = sizeof_bits<T>::value;

View File

@ -30,34 +30,46 @@
**************************************************************************************************/
#pragma once
#include <utility> // std::integer_sequence
#include <cute/config.hpp>
#include <cute/util/type_traits.hpp>
#include <cute/numeric/integral_constant.hpp>
namespace cute
{
using std::integer_sequence;
using std::make_integer_sequence;
using CUTE_STL_NAMESPACE::integer_sequence;
using CUTE_STL_NAMESPACE::make_integer_sequence;
namespace detail {
template <class T, class S, T Begin>
struct make_integer_range_impl;
struct range_impl;
template <class T, T... N, T Begin>
struct make_integer_range_impl<T, integer_sequence<T, N...>, Begin> {
struct range_impl<T, integer_sequence<T, N...>, Begin> {
using type = integer_sequence<T, N+Begin...>;
};
template <class S>
struct reverse_impl;
template <class T, T... N>
struct reverse_impl<integer_sequence<T, N...>> {
using type = integer_sequence<T, sizeof...(N)-1-N...>;
};
} // end namespace detail
template <class T, T Begin, T End>
using make_integer_range = typename detail::make_integer_range_impl<
using make_integer_range = typename detail::range_impl<
T,
make_integer_sequence<T, (End-Begin > 0) ? (End-Begin) : 0>,
Begin>::type;
template <class T, T N>
using make_integer_sequence_reverse = typename detail::reverse_impl<
make_integer_sequence<T, N>>::type;
//
// Common aliases
//
@ -70,19 +82,25 @@ using int_sequence = integer_sequence<int, Ints...>;
template <int N>
using make_int_sequence = make_integer_sequence<int, N>;
template <int N>
using make_int_rsequence = make_integer_sequence_reverse<int, N>;
template <int Begin, int End>
using make_int_range = make_integer_range<int, Begin, End>;
// index_sequence
template <std::size_t... Ints>
using index_sequence = integer_sequence<std::size_t, Ints...>;
template <size_t... Ints>
using index_sequence = integer_sequence<size_t, Ints...>;
template <std::size_t N>
using make_index_sequence = make_integer_sequence<std::size_t, N>;
template <size_t N>
using make_index_sequence = make_integer_sequence<size_t, N>;
template <std::size_t Begin, std::size_t End>
using make_index_range = make_integer_range<std::size_t, Begin, End>;
template <size_t N>
using make_index_rsequence = make_integer_sequence_reverse<size_t, N>;
template <size_t Begin, size_t End>
using make_index_range = make_integer_range<size_t, Begin, End>;
//
// Shortcuts
@ -94,46 +112,40 @@ using seq = int_sequence<Ints...>;
template <int N>
using make_seq = make_int_sequence<N>;
template <int N>
using make_rseq = make_int_rsequence<N>;
template <int Min, int Max>
using make_range = make_int_range<Min, Max>;
template <class Tuple>
using tuple_seq = make_seq<std::tuple_size<std::remove_reference_t<Tuple>>::value>;
} // end namespace cute
using tuple_seq = make_seq<tuple_size<remove_cvref_t<Tuple>>::value>;
template <class Tuple>
using tuple_rseq = make_rseq<tuple_size<remove_cvref_t<Tuple>>::value>;
//
// Specialize tuple-related functionality for cute::integer_sequence
// Specialize cute::tuple-traits for std::integer_sequence
//
#include <tuple>
#include <cute/numeric/integral_constant.hpp>
template <class T, T... Ints>
struct tuple_size<integer_sequence<T, Ints...>>
: cute::integral_constant<size_t, sizeof...(Ints)>
{};
namespace cute
template <size_t I, class T, T... Is>
struct tuple_element<I, integer_sequence<T, Is...>>
{
constexpr static T idx[sizeof...(Is)] = {Is...};
using type = cute::integral_constant<T, idx[I]>;
};
template <std::size_t I, class T, T... Ints>
template <size_t I, class T, T... Ints>
CUTE_HOST_DEVICE constexpr
std::tuple_element_t<I, integer_sequence<T, Ints...>>
tuple_element_t<I, integer_sequence<T, Ints...>>
get(integer_sequence<T, Ints...>) {
static_assert(I < sizeof...(Ints), "Index out of range");
return {};
}
} // end namespace cute
namespace std
{
template <class T, T... Ints>
struct tuple_size<cute::integer_sequence<T, Ints...>>
: std::integral_constant<std::size_t, sizeof...(Ints)>
{};
template <std::size_t I, class T, T... Ints>
struct tuple_element<I, cute::integer_sequence<T, Ints...>>
: std::tuple_element<I, std::tuple<cute::integral_constant<T,Ints>...>>
{};
} // end namespace std

View File

@ -53,7 +53,7 @@ struct integer_subbyte
static_assert(Bits <= 8*sizeof(Storage), "Require a subbyte of bits in integer_subbyte");
/// External type
using xint_t = typename std::conditional<Signed, int, unsigned>::type;
using xint_t = typename conditional<Signed, int, unsigned>::type;
/// Bitmask for truncation from larger integers
static constexpr Storage bits_mask_ = Storage((1 << Bits) - 1);
@ -166,7 +166,7 @@ using bin1_t = bool;
#include <limits>
namespace std {
namespace CUTE_STL_NAMESPACE {
template <>
struct numeric_limits<cute::uint1b_t> {
@ -230,4 +230,4 @@ struct numeric_limits<cute::uint4b_t> {
} // namespace std
#endif
#endif // !defined(__CUDACC_RTC__)

View File

@ -39,7 +39,7 @@ namespace cute
{
template <class T, T v>
struct constant : std::integral_constant<T,v> {
struct constant : CUTE_STL_NAMESPACE::integral_constant<T,v> {
static constexpr T value = v;
using value_type = T;
using type = constant<T,v>;
@ -56,7 +56,7 @@ using bool_constant = constant<bool,b>;
using true_type = bool_constant<true>;
using false_type = bool_constant<false>;
//
//
// Traits
//
@ -64,14 +64,14 @@ using false_type = bool_constant<false>;
// Use cute::is_integral<T> to match both built-in integral types AND constant<T,t>
template <class T>
struct is_integral : bool_constant<std::is_integral<T>::value> {};
struct is_integral : bool_constant<CUTE_STL_NAMESPACE::is_integral<T>::value> {};
template <class T, T v>
struct is_integral<constant<T,v>> : true_type {};
// is_static detects if an (abstract) value is defined completely by it's type (no members)
template <class T>
struct is_static : bool_constant<std::is_empty<T>::value> {};
struct is_static : bool_constant<is_empty<T>::value> {};
// is_constant detects if a type is a constant<T,v> and if v is equal to a value
@ -95,45 +95,51 @@ struct is_constant<n, constant<T,v> &&> : bool_constant<v == n> {};
template <int v>
using Int = constant<int,v>;
using _m32 = Int<-32>;
using _m24 = Int<-24>;
using _m16 = Int<-16>;
using _m12 = Int<-12>;
using _m10 = Int<-10>;
using _m9 = Int<-9>;
using _m8 = Int<-8>;
using _m7 = Int<-7>;
using _m6 = Int<-6>;
using _m5 = Int<-5>;
using _m4 = Int<-4>;
using _m3 = Int<-3>;
using _m2 = Int<-2>;
using _m1 = Int<-1>;
using _0 = Int<0>;
using _1 = Int<1>;
using _2 = Int<2>;
using _3 = Int<3>;
using _4 = Int<4>;
using _5 = Int<5>;
using _6 = Int<6>;
using _7 = Int<7>;
using _8 = Int<8>;
using _9 = Int<9>;
using _10 = Int<10>;
using _12 = Int<12>;
using _16 = Int<16>;
using _24 = Int<24>;
using _32 = Int<32>;
using _64 = Int<64>;
using _96 = Int<96>;
using _128 = Int<128>;
using _192 = Int<192>;
using _256 = Int<256>;
using _512 = Int<512>;
using _1024 = Int<1024>;
using _2048 = Int<2048>;
using _4096 = Int<4096>;
using _8192 = Int<8192>;
using _m32 = Int<-32>;
using _m24 = Int<-24>;
using _m16 = Int<-16>;
using _m12 = Int<-12>;
using _m10 = Int<-10>;
using _m9 = Int<-9>;
using _m8 = Int<-8>;
using _m7 = Int<-7>;
using _m6 = Int<-6>;
using _m5 = Int<-5>;
using _m4 = Int<-4>;
using _m3 = Int<-3>;
using _m2 = Int<-2>;
using _m1 = Int<-1>;
using _0 = Int<0>;
using _1 = Int<1>;
using _2 = Int<2>;
using _3 = Int<3>;
using _4 = Int<4>;
using _5 = Int<5>;
using _6 = Int<6>;
using _7 = Int<7>;
using _8 = Int<8>;
using _9 = Int<9>;
using _10 = Int<10>;
using _12 = Int<12>;
using _16 = Int<16>;
using _24 = Int<24>;
using _32 = Int<32>;
using _64 = Int<64>;
using _96 = Int<96>;
using _128 = Int<128>;
using _192 = Int<192>;
using _256 = Int<256>;
using _512 = Int<512>;
using _1024 = Int<1024>;
using _2048 = Int<2048>;
using _4096 = Int<4096>;
using _8192 = Int<8192>;
using _16384 = Int<16384>;
using _32768 = Int<32768>;
using _65536 = Int<65536>;
using _131072 = Int<131072>;
using _262144 = Int<262144>;
using _524288 = Int<524288>;
/***************/
/** Operators **/
@ -198,7 +204,7 @@ CUTE_BINARY_OP(<=);
//
template <class T, class U,
__CUTE_REQUIRES(std::is_integral<U>::value)>
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<U>::value)>
CUTE_HOST_DEVICE constexpr
constant<T, 0>
operator*(constant<T, 0>, U) {
@ -206,7 +212,7 @@ operator*(constant<T, 0>, U) {
}
template <class U, class T,
__CUTE_REQUIRES(std::is_integral<U>::value)>
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<U>::value)>
CUTE_HOST_DEVICE constexpr
constant<T, 0>
operator*(U, constant<T, 0>) {
@ -214,7 +220,7 @@ operator*(U, constant<T, 0>) {
}
template <class T, class U,
__CUTE_REQUIRES(std::is_integral<U>::value)>
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<U>::value)>
CUTE_HOST_DEVICE constexpr
constant<T, 0>
operator/(constant<T, 0>, U) {
@ -222,7 +228,7 @@ operator/(constant<T, 0>, U) {
}
template <class U, class T,
__CUTE_REQUIRES(std::is_integral<U>::value)>
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<U>::value)>
CUTE_HOST_DEVICE constexpr
constant<T, 0>
operator%(U, constant<T, 1>) {
@ -230,7 +236,7 @@ operator%(U, constant<T, 1>) {
}
template <class U, class T,
__CUTE_REQUIRES(std::is_integral<U>::value)>
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<U>::value)>
CUTE_HOST_DEVICE constexpr
constant<T, 0>
operator%(U, constant<T,-1>) {
@ -238,7 +244,7 @@ operator%(U, constant<T,-1>) {
}
template <class T, class U,
__CUTE_REQUIRES(std::is_integral<U>::value)>
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<U>::value)>
CUTE_HOST_DEVICE constexpr
constant<T, 0>
operator%(constant<T, 0>, U) {
@ -246,7 +252,7 @@ operator%(constant<T, 0>, U) {
}
template <class T, class U,
__CUTE_REQUIRES(std::is_integral<U>::value)>
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<U>::value)>
CUTE_HOST_DEVICE constexpr
constant<T, 0>
operator&(constant<T, 0>, U) {
@ -254,7 +260,7 @@ operator&(constant<T, 0>, U) {
}
template <class T, class U,
__CUTE_REQUIRES(std::is_integral<U>::value)>
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<U>::value)>
CUTE_HOST_DEVICE constexpr
constant<T, 0>
operator&(U, constant<T, 0>) {
@ -262,7 +268,7 @@ operator&(U, constant<T, 0>) {
}
template <class T, T t, class U,
__CUTE_REQUIRES(std::is_integral<U>::value && !bool(t))>
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<U>::value && !bool(t))>
CUTE_HOST_DEVICE constexpr
constant<bool, false>
operator&&(constant<T, t>, U) {
@ -270,7 +276,7 @@ operator&&(constant<T, t>, U) {
}
template <class T, T t, class U,
__CUTE_REQUIRES(std::is_integral<U>::value && !bool(t))>
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<U>::value && !bool(t))>
CUTE_HOST_DEVICE constexpr
constant<bool, false>
operator&&(U, constant<T, t>) {
@ -278,7 +284,7 @@ operator&&(U, constant<T, t>) {
}
template <class T, class U, T t,
__CUTE_REQUIRES(std::is_integral<U>::value && bool(t))>
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<U>::value && bool(t))>
CUTE_HOST_DEVICE constexpr
constant<bool, true>
operator||(constant<T, t>, U) {
@ -286,7 +292,7 @@ operator||(constant<T, t>, U) {
}
template <class T, class U, T t,
__CUTE_REQUIRES(std::is_integral<U>::value && bool(t))>
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<U>::value && bool(t))>
CUTE_HOST_DEVICE constexpr
constant<bool, true>
operator||(U, constant<T, t>) {
@ -314,7 +320,7 @@ operator||(U, constant<T, t>) {
} \
\
template <class T, T t, class U, \
__CUTE_REQUIRES(std::is_integral<U>::value)> \
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<U>::value)> \
CUTE_HOST_DEVICE constexpr \
auto \
OP (constant<T,t>, U u) { \
@ -322,7 +328,7 @@ operator||(U, constant<T, t>) {
} \
\
template <class T, class U, U u, \
__CUTE_REQUIRES(std::is_integral<T>::value)> \
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<T>::value)> \
CUTE_HOST_DEVICE constexpr \
auto \
OP (T t, constant<U,u>) { \
@ -356,7 +362,7 @@ safe_div(constant<T, t>, constant<U, u>) {
}
template <class T, T t, class U,
__CUTE_REQUIRES(std::is_integral<U>::value)>
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<U>::value)>
CUTE_HOST_DEVICE constexpr
auto
safe_div(constant<T, t>, U u) {
@ -364,7 +370,7 @@ safe_div(constant<T, t>, U u) {
}
template <class T, class U, U u,
__CUTE_REQUIRES(std::is_integral<T>::value)>
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<T>::value)>
CUTE_HOST_DEVICE constexpr
auto
safe_div(T t, constant<U, u>) {
@ -376,7 +382,7 @@ safe_div(T t, constant<U, u>) {
template <class TrueType, class FalseType>
CUTE_HOST_DEVICE constexpr
decltype(auto)
conditional_return(std::true_type, TrueType&& t, FalseType&&) {
conditional_return(true_type, TrueType&& t, FalseType&&) {
return static_cast<TrueType&&>(t);
}
@ -385,7 +391,7 @@ conditional_return(std::true_type, TrueType&& t, FalseType&&) {
template <class TrueType, class FalseType>
CUTE_HOST_DEVICE constexpr
decltype(auto)
conditional_return(std::false_type, TrueType&&, FalseType&& f) {
conditional_return(false_type, TrueType&&, FalseType&& f) {
return static_cast<FalseType&&>(f);
}
@ -397,6 +403,18 @@ conditional_return(bool b, TrueType const& t, FalseType const& f) {
return b ? t : f;
}
// TrueType and FalseType don't require a common type
template <bool b, class TrueType, class FalseType>
CUTE_HOST_DEVICE constexpr
auto
conditional_return(TrueType const& t, FalseType const& f) {
if constexpr (b) {
return t;
} else {
return f;
}
}
//
// Display utilities
//
@ -406,9 +424,11 @@ CUTE_HOST_DEVICE void print(integral_constant<T,N> const&) {
printf("_%d", N);
}
#if !defined(__CUDACC_RTC__)
template <class T, T N>
CUTE_HOST std::ostream& operator<<(std::ostream& os, integral_constant<T,N> const&) {
return os << "_" << N;
}
#endif
} // end namespace cute

View File

@ -30,16 +30,10 @@
**************************************************************************************************/
#pragma once
#include <limits>
#if defined(__CUDACC_RTC__)
#include <cuda/std/cstdint>
#else
#include <cstdint>
#endif
#include <cute/config.hpp>
#include <cute/util/type_traits.hpp>
namespace cute
{
@ -48,8 +42,8 @@ namespace cute
//
template <class T, class U,
__CUTE_REQUIRES(std::is_arithmetic<T>::value &&
std::is_arithmetic<U>::value)>
__CUTE_REQUIRES(is_arithmetic<T>::value &&
is_arithmetic<U>::value)>
CUTE_HOST_DEVICE constexpr
auto
max(T const& t, U const& u) {
@ -57,8 +51,8 @@ max(T const& t, U const& u) {
}
template <class T, class U,
__CUTE_REQUIRES(std::is_arithmetic<T>::value &&
std::is_arithmetic<U>::value)>
__CUTE_REQUIRES(is_arithmetic<T>::value &&
is_arithmetic<U>::value)>
CUTE_HOST_DEVICE constexpr
auto
min(T const& t, U const& u) {
@ -66,11 +60,11 @@ min(T const& t, U const& u) {
}
template <class T,
__CUTE_REQUIRES(std::is_arithmetic<T>::value)>
__CUTE_REQUIRES(is_arithmetic<T>::value)>
CUTE_HOST_DEVICE constexpr
auto
abs(T const& t) {
if constexpr (std::is_signed<T>::value) {
if constexpr (is_signed<T>::value) {
return t < T(0) ? -t : t;
} else {
return t;
@ -85,8 +79,8 @@ abs(T const& t) {
// Greatest common divisor of two integers
template <class T, class U,
__CUTE_REQUIRES(std::is_integral<T>::value &&
std::is_integral<U>::value)>
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<T>::value &&
CUTE_STL_NAMESPACE::is_integral<U>::value)>
CUTE_HOST_DEVICE constexpr
auto
gcd(T t, U u) {
@ -100,8 +94,8 @@ gcd(T t, U u) {
// Least common multiple of two integers
template <class T, class U,
__CUTE_REQUIRES(std::is_integral<T>::value &&
std::is_integral<U>::value)>
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<T>::value &&
CUTE_STL_NAMESPACE::is_integral<U>::value)>
CUTE_HOST_DEVICE constexpr
auto
lcm(T const& t, U const& u) {
@ -133,11 +127,11 @@ template <class T>
CUTE_HOST_DEVICE constexpr
T
bit_width(T x) {
static_assert(std::is_unsigned<T>::value, "Only to be used for unsigned types.");
constexpr int N = (std::numeric_limits<T>::digits == 64 ? 6 :
(std::numeric_limits<T>::digits == 32 ? 5 :
(std::numeric_limits<T>::digits == 16 ? 4 :
(std::numeric_limits<T>::digits == 8 ? 3 : (assert(false),0)))));
static_assert(is_unsigned<T>::value, "Only to be used for unsigned types.");
constexpr int N = (numeric_limits<T>::digits == 64 ? 6 :
(numeric_limits<T>::digits == 32 ? 5 :
(numeric_limits<T>::digits == 16 ? 4 :
(numeric_limits<T>::digits == 8 ? 3 : (assert(false),0)))));
T r = 0;
for (int i = N - 1; i >= 0; --i) {
T shift = (x > ((T(1) << (T(1) << i))-1)) << i;
@ -193,7 +187,7 @@ template <class T>
CUTE_HOST_DEVICE constexpr
T
rotl(T x, int s) {
constexpr int N = std::numeric_limits<T>::digits;
constexpr int N = numeric_limits<T>::digits;
return s == 0 ? x : s > 0 ? (x << s) | (x >> (N - s)) : rotr(x, -s);
}
@ -202,7 +196,7 @@ template <class T>
CUTE_HOST_DEVICE constexpr
T
rotr(T x, int s) {
constexpr int N = std::numeric_limits<T>::digits;
constexpr int N = numeric_limits<T>::digits;
return s == 0 ? x : s > 0 ? (x >> s) | (x << (N - s)) : rotl(x, -s);
}
@ -214,7 +208,7 @@ template <class T>
CUTE_HOST_DEVICE constexpr
T
countl_zero(T x) {
return std::numeric_limits<T>::digits - bit_width(x);
return numeric_limits<T>::digits - bit_width(x);
}
// Counts the number of consecutive 1 bits, starting from the most significant bit
@ -236,7 +230,7 @@ template <class T>
CUTE_HOST_DEVICE constexpr
T
countr_zero(T x) {
return x == 0 ? std::numeric_limits<T>::digits : bit_width(T(x & T(-x))) - 1; // bit_width of the LSB
return x == 0 ? numeric_limits<T>::digits : bit_width(T(x & T(-x))) - 1; // bit_width of the LSB
}
// Counts the number of consecutive 1 bits, starting from the least significant bit
@ -288,7 +282,7 @@ shiftr(T x, int s) {
// Returns 1 if x > 0, -1 if x < 0, and 0 if x is zero.
template <class T,
__CUTE_REQUIRES(std::is_unsigned<T>::value)>
__CUTE_REQUIRES(is_unsigned<T>::value)>
CUTE_HOST_DEVICE constexpr
int
signum(T const& x) {
@ -296,7 +290,7 @@ signum(T const& x) {
}
template <class T,
__CUTE_REQUIRES(not std::is_unsigned<T>::value)>
__CUTE_REQUIRES(not is_unsigned<T>::value)>
CUTE_HOST_DEVICE constexpr
int
signum(T const& x) {
@ -307,8 +301,8 @@ signum(T const& x) {
// @pre t % u == 0
// @result t / u
template <class T, class U,
__CUTE_REQUIRES(std::is_integral<T>::value &&
std::is_integral<U>::value)>
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<T>::value &&
CUTE_STL_NAMESPACE::is_integral<U>::value)>
CUTE_HOST_DEVICE constexpr
auto
safe_div(T const& t, U const& u) {

View File

@ -43,9 +43,11 @@ using cutlass::tfloat32_t;
// Display utilities
//
#if !defined(__CUDACC_RTC__)
CUTE_HOST std::ostream& operator<<(std::ostream& os, tfloat32_t const& v)
{
return os << float(v);
}
#endif
} // end namespace cute