@ -36,6 +36,7 @@
|
||||
#include <cute/numeric/integral_constant.hpp>
|
||||
#include <cute/algorithm/functional.hpp>
|
||||
#include <cute/algorithm/tuple_algorithms.hpp>
|
||||
#include <cute/util/type_traits.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
@ -361,28 +362,75 @@ operator+(ScaledBasis<T,N> const& t, constant<U,0>) {
|
||||
|
||||
template <class T, int N>
|
||||
CUTE_HOST_DEVICE void print(ScaledBasis<T,N> const& e) {
|
||||
printf("%d:", N); print(e.value());
|
||||
print(e.value()); printf("@%d", N);
|
||||
}
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
template <class T, int N>
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, ScaledBasis<T,N> const& e) {
|
||||
return os << N << ":" << e.value();
|
||||
return os << e.value() << "@" << N;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
|
||||
namespace std
|
||||
namespace CUTE_STL_NAMESPACE
|
||||
{
|
||||
|
||||
template <class... T>
|
||||
struct tuple_size<cute::ArithmeticTuple<T...>>
|
||||
: std::integral_constant<std::size_t, sizeof...(T)>
|
||||
: cute::integral_constant<size_t, sizeof...(T)>
|
||||
{};
|
||||
|
||||
template <std::size_t I, class... T>
|
||||
template <size_t I, class... T>
|
||||
struct tuple_element<I, cute::ArithmeticTuple<T...>>
|
||||
: std::tuple_element<I, std::tuple<T...>>
|
||||
: CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>
|
||||
{};
|
||||
|
||||
template <class... T>
|
||||
struct tuple_size<const cute::ArithmeticTuple<T...>>
|
||||
: cute::integral_constant<size_t, sizeof...(T)>
|
||||
{};
|
||||
|
||||
template <size_t I, class... T>
|
||||
struct tuple_element<I, const cute::ArithmeticTuple<T...>>
|
||||
: CUTE_STL_NAMESPACE::tuple_element<I, const CUTE_STL_NAMESPACE::tuple<T...>>
|
||||
{};
|
||||
|
||||
} // end namespace CUTE_STL_NAMESPACE
|
||||
|
||||
#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD
|
||||
namespace std
|
||||
{
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
template <class... _Tp>
|
||||
struct tuple_size;
|
||||
|
||||
template<size_t _Ip, class... _Tp>
|
||||
struct tuple_element;
|
||||
#endif
|
||||
|
||||
template <class... T>
|
||||
struct tuple_size<cute::ArithmeticTuple<T...>>
|
||||
: cute::integral_constant<size_t, sizeof...(T)>
|
||||
{};
|
||||
|
||||
template <size_t I, class... T>
|
||||
struct tuple_element<I, cute::ArithmeticTuple<T...>>
|
||||
: CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>
|
||||
{};
|
||||
|
||||
template <class... T>
|
||||
struct tuple_size<const cute::ArithmeticTuple<T...>>
|
||||
: cute::integral_constant<size_t, sizeof...(T)>
|
||||
{};
|
||||
|
||||
template <size_t I, class... T>
|
||||
struct tuple_element<I, const cute::ArithmeticTuple<T...>>
|
||||
: CUTE_STL_NAMESPACE::tuple_element<I, const CUTE_STL_NAMESPACE::tuple<T...>>
|
||||
{};
|
||||
|
||||
} // end namespace std
|
||||
#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD
|
||||
|
||||
@ -43,9 +43,11 @@ using cutlass::bfloat16_t;
|
||||
// Display utilities
|
||||
//
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, bfloat16_t const& v)
|
||||
{
|
||||
return os << float(v);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
@ -30,7 +30,7 @@
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <cute/util/type_traits.hpp>
|
||||
|
||||
//#if defined(__CUDA_ARCH__)
|
||||
//# include <cuda/std/complex>
|
||||
@ -38,13 +38,37 @@
|
||||
//# include <complex>
|
||||
//#endif
|
||||
|
||||
// With CUDA 11.4, builds show spurious "-Wconversion" warnings
|
||||
// on line 656 of thrust/detail/type_traits.h.
|
||||
// These pragmas suppress the warnings.
|
||||
// Suppress warnings for code in Thrust headers.
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
// We check for MSVC first, because MSVC also defines __GNUC__.
|
||||
// It's common for non-GCC compilers that emulate GCC's behavior
|
||||
// to define __GNUC__.
|
||||
//
|
||||
// thrust/complex.h triggers MSVC's warning on conversion
|
||||
// from double to float (or const float) ("possible loss of data").
|
||||
// MSVC treats this as an error by default (at least with
|
||||
// CUTLASS's default CMake configuration).
|
||||
#pragma warning( push )
|
||||
#pragma warning( disable : 4244 )
|
||||
#elif defined(__GNUC__)
|
||||
// With GCC + CUDA 11.4, builds show spurious "-Wconversion"
|
||||
// warnings on line 656 of thrust/detail/type_traits.h.
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wconversion"
|
||||
#endif
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/complex>
|
||||
#else
|
||||
#include <thrust/complex.h>
|
||||
#endif
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning( pop )
|
||||
#elif defined(__GNUC__)
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
@ -62,7 +86,11 @@ namespace cute
|
||||
//template <class T>
|
||||
//using complex = thrust::complex<T>;
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
using cuda::std::complex;
|
||||
#else
|
||||
using thrust::complex;
|
||||
#endif
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE
|
||||
@ -147,6 +175,7 @@ struct is_complex<complex<T>> {
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Display utilities
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
template <class T>
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, complex<T> const& z)
|
||||
{
|
||||
@ -159,5 +188,6 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, complex<T> const& z)
|
||||
return os << _r;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
@ -46,10 +46,12 @@ namespace cute
|
||||
// Signed integers
|
||||
//
|
||||
|
||||
using int8_t = std::int8_t;
|
||||
using int16_t = std::int16_t;
|
||||
using int32_t = std::int32_t;
|
||||
using int64_t = std::int64_t;
|
||||
using int2_t = cute::int2b_t;
|
||||
using int4_t = cute::int4b_t;
|
||||
using int8_t = CUTE_STL_NAMESPACE::int8_t;
|
||||
using int16_t = CUTE_STL_NAMESPACE::int16_t;
|
||||
using int32_t = CUTE_STL_NAMESPACE::int32_t;
|
||||
using int64_t = CUTE_STL_NAMESPACE::int64_t;
|
||||
|
||||
template <int N> struct int_bit;
|
||||
template <> struct int_bit< 2> { using type = cute::int2b_t; };
|
||||
@ -72,10 +74,14 @@ using int_byte_t = typename int_byte<N>::type;
|
||||
// Unsigned integers
|
||||
//
|
||||
|
||||
using uint8_t = std::uint8_t;
|
||||
using uint16_t = std::uint16_t;
|
||||
using uint32_t = std::uint32_t;
|
||||
using uint64_t = std::uint64_t;
|
||||
using uint1_t = cute::uint1b_t;
|
||||
using uint2_t = cute::uint2b_t;
|
||||
using uint4_t = cute::uint4b_t;
|
||||
using uint8_t = CUTE_STL_NAMESPACE::uint8_t;
|
||||
using uint16_t = CUTE_STL_NAMESPACE::uint16_t;
|
||||
using uint32_t = CUTE_STL_NAMESPACE::uint32_t;
|
||||
using uint64_t = CUTE_STL_NAMESPACE::uint64_t;
|
||||
using uint128_t = cute::uint128_t;
|
||||
|
||||
template <int N> struct uint_bit;
|
||||
template <> struct uint_bit< 1> { using type = cute::uint1b_t; };
|
||||
@ -102,7 +108,7 @@ using uint_byte_t = typename uint_byte<N>::type;
|
||||
|
||||
template <class T>
|
||||
struct sizeof_bytes {
|
||||
static constexpr std::size_t value = sizeof(T);
|
||||
static constexpr size_t value = sizeof(T);
|
||||
};
|
||||
template <class T>
|
||||
static constexpr int sizeof_bytes_v = sizeof_bytes<T>::value;
|
||||
@ -113,15 +119,15 @@ static constexpr int sizeof_bytes_v = sizeof_bytes<T>::value;
|
||||
|
||||
template <class T>
|
||||
struct sizeof_bits {
|
||||
static constexpr std::size_t value = sizeof(T) * 8;
|
||||
static constexpr size_t value = sizeof(T) * 8;
|
||||
};
|
||||
template <>
|
||||
struct sizeof_bits<bool> {
|
||||
static constexpr std::size_t value = 1;
|
||||
static constexpr size_t value = 1;
|
||||
};
|
||||
template <int Bits, bool Signed>
|
||||
struct sizeof_bits<integer_subbyte<Bits,Signed>> {
|
||||
static constexpr std::size_t value = Bits;
|
||||
static constexpr size_t value = Bits;
|
||||
};
|
||||
template <class T>
|
||||
static constexpr int sizeof_bits_v = sizeof_bits<T>::value;
|
||||
|
||||
@ -30,34 +30,46 @@
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <utility> // std::integer_sequence
|
||||
|
||||
#include <cute/config.hpp>
|
||||
#include <cute/util/type_traits.hpp>
|
||||
#include <cute/numeric/integral_constant.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
using std::integer_sequence;
|
||||
using std::make_integer_sequence;
|
||||
using CUTE_STL_NAMESPACE::integer_sequence;
|
||||
using CUTE_STL_NAMESPACE::make_integer_sequence;
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class T, class S, T Begin>
|
||||
struct make_integer_range_impl;
|
||||
struct range_impl;
|
||||
|
||||
template <class T, T... N, T Begin>
|
||||
struct make_integer_range_impl<T, integer_sequence<T, N...>, Begin> {
|
||||
struct range_impl<T, integer_sequence<T, N...>, Begin> {
|
||||
using type = integer_sequence<T, N+Begin...>;
|
||||
};
|
||||
|
||||
template <class S>
|
||||
struct reverse_impl;
|
||||
|
||||
template <class T, T... N>
|
||||
struct reverse_impl<integer_sequence<T, N...>> {
|
||||
using type = integer_sequence<T, sizeof...(N)-1-N...>;
|
||||
};
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class T, T Begin, T End>
|
||||
using make_integer_range = typename detail::make_integer_range_impl<
|
||||
using make_integer_range = typename detail::range_impl<
|
||||
T,
|
||||
make_integer_sequence<T, (End-Begin > 0) ? (End-Begin) : 0>,
|
||||
Begin>::type;
|
||||
|
||||
template <class T, T N>
|
||||
using make_integer_sequence_reverse = typename detail::reverse_impl<
|
||||
make_integer_sequence<T, N>>::type;
|
||||
|
||||
//
|
||||
// Common aliases
|
||||
//
|
||||
@ -70,19 +82,25 @@ using int_sequence = integer_sequence<int, Ints...>;
|
||||
template <int N>
|
||||
using make_int_sequence = make_integer_sequence<int, N>;
|
||||
|
||||
template <int N>
|
||||
using make_int_rsequence = make_integer_sequence_reverse<int, N>;
|
||||
|
||||
template <int Begin, int End>
|
||||
using make_int_range = make_integer_range<int, Begin, End>;
|
||||
|
||||
// index_sequence
|
||||
|
||||
template <std::size_t... Ints>
|
||||
using index_sequence = integer_sequence<std::size_t, Ints...>;
|
||||
template <size_t... Ints>
|
||||
using index_sequence = integer_sequence<size_t, Ints...>;
|
||||
|
||||
template <std::size_t N>
|
||||
using make_index_sequence = make_integer_sequence<std::size_t, N>;
|
||||
template <size_t N>
|
||||
using make_index_sequence = make_integer_sequence<size_t, N>;
|
||||
|
||||
template <std::size_t Begin, std::size_t End>
|
||||
using make_index_range = make_integer_range<std::size_t, Begin, End>;
|
||||
template <size_t N>
|
||||
using make_index_rsequence = make_integer_sequence_reverse<size_t, N>;
|
||||
|
||||
template <size_t Begin, size_t End>
|
||||
using make_index_range = make_integer_range<size_t, Begin, End>;
|
||||
|
||||
//
|
||||
// Shortcuts
|
||||
@ -94,46 +112,40 @@ using seq = int_sequence<Ints...>;
|
||||
template <int N>
|
||||
using make_seq = make_int_sequence<N>;
|
||||
|
||||
template <int N>
|
||||
using make_rseq = make_int_rsequence<N>;
|
||||
|
||||
template <int Min, int Max>
|
||||
using make_range = make_int_range<Min, Max>;
|
||||
|
||||
template <class Tuple>
|
||||
using tuple_seq = make_seq<std::tuple_size<std::remove_reference_t<Tuple>>::value>;
|
||||
|
||||
} // end namespace cute
|
||||
using tuple_seq = make_seq<tuple_size<remove_cvref_t<Tuple>>::value>;
|
||||
|
||||
template <class Tuple>
|
||||
using tuple_rseq = make_rseq<tuple_size<remove_cvref_t<Tuple>>::value>;
|
||||
|
||||
//
|
||||
// Specialize tuple-related functionality for cute::integer_sequence
|
||||
// Specialize cute::tuple-traits for std::integer_sequence
|
||||
//
|
||||
|
||||
#include <tuple>
|
||||
#include <cute/numeric/integral_constant.hpp>
|
||||
template <class T, T... Ints>
|
||||
struct tuple_size<integer_sequence<T, Ints...>>
|
||||
: cute::integral_constant<size_t, sizeof...(Ints)>
|
||||
{};
|
||||
|
||||
namespace cute
|
||||
template <size_t I, class T, T... Is>
|
||||
struct tuple_element<I, integer_sequence<T, Is...>>
|
||||
{
|
||||
constexpr static T idx[sizeof...(Is)] = {Is...};
|
||||
using type = cute::integral_constant<T, idx[I]>;
|
||||
};
|
||||
|
||||
template <std::size_t I, class T, T... Ints>
|
||||
template <size_t I, class T, T... Ints>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
std::tuple_element_t<I, integer_sequence<T, Ints...>>
|
||||
tuple_element_t<I, integer_sequence<T, Ints...>>
|
||||
get(integer_sequence<T, Ints...>) {
|
||||
static_assert(I < sizeof...(Ints), "Index out of range");
|
||||
return {};
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
namespace std
|
||||
{
|
||||
|
||||
template <class T, T... Ints>
|
||||
struct tuple_size<cute::integer_sequence<T, Ints...>>
|
||||
: std::integral_constant<std::size_t, sizeof...(Ints)>
|
||||
{};
|
||||
|
||||
template <std::size_t I, class T, T... Ints>
|
||||
struct tuple_element<I, cute::integer_sequence<T, Ints...>>
|
||||
: std::tuple_element<I, std::tuple<cute::integral_constant<T,Ints>...>>
|
||||
{};
|
||||
|
||||
} // end namespace std
|
||||
|
||||
@ -53,7 +53,7 @@ struct integer_subbyte
|
||||
static_assert(Bits <= 8*sizeof(Storage), "Require a subbyte of bits in integer_subbyte");
|
||||
|
||||
/// External type
|
||||
using xint_t = typename std::conditional<Signed, int, unsigned>::type;
|
||||
using xint_t = typename conditional<Signed, int, unsigned>::type;
|
||||
|
||||
/// Bitmask for truncation from larger integers
|
||||
static constexpr Storage bits_mask_ = Storage((1 << Bits) - 1);
|
||||
@ -166,7 +166,7 @@ using bin1_t = bool;
|
||||
|
||||
#include <limits>
|
||||
|
||||
namespace std {
|
||||
namespace CUTE_STL_NAMESPACE {
|
||||
|
||||
template <>
|
||||
struct numeric_limits<cute::uint1b_t> {
|
||||
@ -230,4 +230,4 @@ struct numeric_limits<cute::uint4b_t> {
|
||||
|
||||
} // namespace std
|
||||
|
||||
#endif
|
||||
#endif // !defined(__CUDACC_RTC__)
|
||||
|
||||
@ -39,7 +39,7 @@ namespace cute
|
||||
{
|
||||
|
||||
template <class T, T v>
|
||||
struct constant : std::integral_constant<T,v> {
|
||||
struct constant : CUTE_STL_NAMESPACE::integral_constant<T,v> {
|
||||
static constexpr T value = v;
|
||||
using value_type = T;
|
||||
using type = constant<T,v>;
|
||||
@ -56,7 +56,7 @@ using bool_constant = constant<bool,b>;
|
||||
using true_type = bool_constant<true>;
|
||||
using false_type = bool_constant<false>;
|
||||
|
||||
//
|
||||
//
|
||||
// Traits
|
||||
//
|
||||
|
||||
@ -64,14 +64,14 @@ using false_type = bool_constant<false>;
|
||||
// Use cute::is_integral<T> to match both built-in integral types AND constant<T,t>
|
||||
|
||||
template <class T>
|
||||
struct is_integral : bool_constant<std::is_integral<T>::value> {};
|
||||
struct is_integral : bool_constant<CUTE_STL_NAMESPACE::is_integral<T>::value> {};
|
||||
template <class T, T v>
|
||||
struct is_integral<constant<T,v>> : true_type {};
|
||||
|
||||
// is_static detects if an (abstract) value is defined completely by it's type (no members)
|
||||
|
||||
template <class T>
|
||||
struct is_static : bool_constant<std::is_empty<T>::value> {};
|
||||
struct is_static : bool_constant<is_empty<T>::value> {};
|
||||
|
||||
// is_constant detects if a type is a constant<T,v> and if v is equal to a value
|
||||
|
||||
@ -95,45 +95,51 @@ struct is_constant<n, constant<T,v> &&> : bool_constant<v == n> {};
|
||||
template <int v>
|
||||
using Int = constant<int,v>;
|
||||
|
||||
using _m32 = Int<-32>;
|
||||
using _m24 = Int<-24>;
|
||||
using _m16 = Int<-16>;
|
||||
using _m12 = Int<-12>;
|
||||
using _m10 = Int<-10>;
|
||||
using _m9 = Int<-9>;
|
||||
using _m8 = Int<-8>;
|
||||
using _m7 = Int<-7>;
|
||||
using _m6 = Int<-6>;
|
||||
using _m5 = Int<-5>;
|
||||
using _m4 = Int<-4>;
|
||||
using _m3 = Int<-3>;
|
||||
using _m2 = Int<-2>;
|
||||
using _m1 = Int<-1>;
|
||||
using _0 = Int<0>;
|
||||
using _1 = Int<1>;
|
||||
using _2 = Int<2>;
|
||||
using _3 = Int<3>;
|
||||
using _4 = Int<4>;
|
||||
using _5 = Int<5>;
|
||||
using _6 = Int<6>;
|
||||
using _7 = Int<7>;
|
||||
using _8 = Int<8>;
|
||||
using _9 = Int<9>;
|
||||
using _10 = Int<10>;
|
||||
using _12 = Int<12>;
|
||||
using _16 = Int<16>;
|
||||
using _24 = Int<24>;
|
||||
using _32 = Int<32>;
|
||||
using _64 = Int<64>;
|
||||
using _96 = Int<96>;
|
||||
using _128 = Int<128>;
|
||||
using _192 = Int<192>;
|
||||
using _256 = Int<256>;
|
||||
using _512 = Int<512>;
|
||||
using _1024 = Int<1024>;
|
||||
using _2048 = Int<2048>;
|
||||
using _4096 = Int<4096>;
|
||||
using _8192 = Int<8192>;
|
||||
using _m32 = Int<-32>;
|
||||
using _m24 = Int<-24>;
|
||||
using _m16 = Int<-16>;
|
||||
using _m12 = Int<-12>;
|
||||
using _m10 = Int<-10>;
|
||||
using _m9 = Int<-9>;
|
||||
using _m8 = Int<-8>;
|
||||
using _m7 = Int<-7>;
|
||||
using _m6 = Int<-6>;
|
||||
using _m5 = Int<-5>;
|
||||
using _m4 = Int<-4>;
|
||||
using _m3 = Int<-3>;
|
||||
using _m2 = Int<-2>;
|
||||
using _m1 = Int<-1>;
|
||||
using _0 = Int<0>;
|
||||
using _1 = Int<1>;
|
||||
using _2 = Int<2>;
|
||||
using _3 = Int<3>;
|
||||
using _4 = Int<4>;
|
||||
using _5 = Int<5>;
|
||||
using _6 = Int<6>;
|
||||
using _7 = Int<7>;
|
||||
using _8 = Int<8>;
|
||||
using _9 = Int<9>;
|
||||
using _10 = Int<10>;
|
||||
using _12 = Int<12>;
|
||||
using _16 = Int<16>;
|
||||
using _24 = Int<24>;
|
||||
using _32 = Int<32>;
|
||||
using _64 = Int<64>;
|
||||
using _96 = Int<96>;
|
||||
using _128 = Int<128>;
|
||||
using _192 = Int<192>;
|
||||
using _256 = Int<256>;
|
||||
using _512 = Int<512>;
|
||||
using _1024 = Int<1024>;
|
||||
using _2048 = Int<2048>;
|
||||
using _4096 = Int<4096>;
|
||||
using _8192 = Int<8192>;
|
||||
using _16384 = Int<16384>;
|
||||
using _32768 = Int<32768>;
|
||||
using _65536 = Int<65536>;
|
||||
using _131072 = Int<131072>;
|
||||
using _262144 = Int<262144>;
|
||||
using _524288 = Int<524288>;
|
||||
|
||||
/***************/
|
||||
/** Operators **/
|
||||
@ -198,7 +204,7 @@ CUTE_BINARY_OP(<=);
|
||||
//
|
||||
|
||||
template <class T, class U,
|
||||
__CUTE_REQUIRES(std::is_integral<U>::value)>
|
||||
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<U>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
constant<T, 0>
|
||||
operator*(constant<T, 0>, U) {
|
||||
@ -206,7 +212,7 @@ operator*(constant<T, 0>, U) {
|
||||
}
|
||||
|
||||
template <class U, class T,
|
||||
__CUTE_REQUIRES(std::is_integral<U>::value)>
|
||||
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<U>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
constant<T, 0>
|
||||
operator*(U, constant<T, 0>) {
|
||||
@ -214,7 +220,7 @@ operator*(U, constant<T, 0>) {
|
||||
}
|
||||
|
||||
template <class T, class U,
|
||||
__CUTE_REQUIRES(std::is_integral<U>::value)>
|
||||
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<U>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
constant<T, 0>
|
||||
operator/(constant<T, 0>, U) {
|
||||
@ -222,7 +228,7 @@ operator/(constant<T, 0>, U) {
|
||||
}
|
||||
|
||||
template <class U, class T,
|
||||
__CUTE_REQUIRES(std::is_integral<U>::value)>
|
||||
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<U>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
constant<T, 0>
|
||||
operator%(U, constant<T, 1>) {
|
||||
@ -230,7 +236,7 @@ operator%(U, constant<T, 1>) {
|
||||
}
|
||||
|
||||
template <class U, class T,
|
||||
__CUTE_REQUIRES(std::is_integral<U>::value)>
|
||||
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<U>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
constant<T, 0>
|
||||
operator%(U, constant<T,-1>) {
|
||||
@ -238,7 +244,7 @@ operator%(U, constant<T,-1>) {
|
||||
}
|
||||
|
||||
template <class T, class U,
|
||||
__CUTE_REQUIRES(std::is_integral<U>::value)>
|
||||
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<U>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
constant<T, 0>
|
||||
operator%(constant<T, 0>, U) {
|
||||
@ -246,7 +252,7 @@ operator%(constant<T, 0>, U) {
|
||||
}
|
||||
|
||||
template <class T, class U,
|
||||
__CUTE_REQUIRES(std::is_integral<U>::value)>
|
||||
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<U>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
constant<T, 0>
|
||||
operator&(constant<T, 0>, U) {
|
||||
@ -254,7 +260,7 @@ operator&(constant<T, 0>, U) {
|
||||
}
|
||||
|
||||
template <class T, class U,
|
||||
__CUTE_REQUIRES(std::is_integral<U>::value)>
|
||||
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<U>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
constant<T, 0>
|
||||
operator&(U, constant<T, 0>) {
|
||||
@ -262,7 +268,7 @@ operator&(U, constant<T, 0>) {
|
||||
}
|
||||
|
||||
template <class T, T t, class U,
|
||||
__CUTE_REQUIRES(std::is_integral<U>::value && !bool(t))>
|
||||
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<U>::value && !bool(t))>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
constant<bool, false>
|
||||
operator&&(constant<T, t>, U) {
|
||||
@ -270,7 +276,7 @@ operator&&(constant<T, t>, U) {
|
||||
}
|
||||
|
||||
template <class T, T t, class U,
|
||||
__CUTE_REQUIRES(std::is_integral<U>::value && !bool(t))>
|
||||
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<U>::value && !bool(t))>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
constant<bool, false>
|
||||
operator&&(U, constant<T, t>) {
|
||||
@ -278,7 +284,7 @@ operator&&(U, constant<T, t>) {
|
||||
}
|
||||
|
||||
template <class T, class U, T t,
|
||||
__CUTE_REQUIRES(std::is_integral<U>::value && bool(t))>
|
||||
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<U>::value && bool(t))>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
constant<bool, true>
|
||||
operator||(constant<T, t>, U) {
|
||||
@ -286,7 +292,7 @@ operator||(constant<T, t>, U) {
|
||||
}
|
||||
|
||||
template <class T, class U, T t,
|
||||
__CUTE_REQUIRES(std::is_integral<U>::value && bool(t))>
|
||||
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<U>::value && bool(t))>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
constant<bool, true>
|
||||
operator||(U, constant<T, t>) {
|
||||
@ -314,7 +320,7 @@ operator||(U, constant<T, t>) {
|
||||
} \
|
||||
\
|
||||
template <class T, T t, class U, \
|
||||
__CUTE_REQUIRES(std::is_integral<U>::value)> \
|
||||
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<U>::value)> \
|
||||
CUTE_HOST_DEVICE constexpr \
|
||||
auto \
|
||||
OP (constant<T,t>, U u) { \
|
||||
@ -322,7 +328,7 @@ operator||(U, constant<T, t>) {
|
||||
} \
|
||||
\
|
||||
template <class T, class U, U u, \
|
||||
__CUTE_REQUIRES(std::is_integral<T>::value)> \
|
||||
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<T>::value)> \
|
||||
CUTE_HOST_DEVICE constexpr \
|
||||
auto \
|
||||
OP (T t, constant<U,u>) { \
|
||||
@ -356,7 +362,7 @@ safe_div(constant<T, t>, constant<U, u>) {
|
||||
}
|
||||
|
||||
template <class T, T t, class U,
|
||||
__CUTE_REQUIRES(std::is_integral<U>::value)>
|
||||
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<U>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
safe_div(constant<T, t>, U u) {
|
||||
@ -364,7 +370,7 @@ safe_div(constant<T, t>, U u) {
|
||||
}
|
||||
|
||||
template <class T, class U, U u,
|
||||
__CUTE_REQUIRES(std::is_integral<T>::value)>
|
||||
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<T>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
safe_div(T t, constant<U, u>) {
|
||||
@ -376,7 +382,7 @@ safe_div(T t, constant<U, u>) {
|
||||
template <class TrueType, class FalseType>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
conditional_return(std::true_type, TrueType&& t, FalseType&&) {
|
||||
conditional_return(true_type, TrueType&& t, FalseType&&) {
|
||||
return static_cast<TrueType&&>(t);
|
||||
}
|
||||
|
||||
@ -385,7 +391,7 @@ conditional_return(std::true_type, TrueType&& t, FalseType&&) {
|
||||
template <class TrueType, class FalseType>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
conditional_return(std::false_type, TrueType&&, FalseType&& f) {
|
||||
conditional_return(false_type, TrueType&&, FalseType&& f) {
|
||||
return static_cast<FalseType&&>(f);
|
||||
}
|
||||
|
||||
@ -397,6 +403,18 @@ conditional_return(bool b, TrueType const& t, FalseType const& f) {
|
||||
return b ? t : f;
|
||||
}
|
||||
|
||||
// TrueType and FalseType don't require a common type
|
||||
template <bool b, class TrueType, class FalseType>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
conditional_return(TrueType const& t, FalseType const& f) {
|
||||
if constexpr (b) {
|
||||
return t;
|
||||
} else {
|
||||
return f;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Display utilities
|
||||
//
|
||||
@ -406,9 +424,11 @@ CUTE_HOST_DEVICE void print(integral_constant<T,N> const&) {
|
||||
printf("_%d", N);
|
||||
}
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
template <class T, T N>
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, integral_constant<T,N> const&) {
|
||||
return os << "_" << N;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
@ -30,16 +30,10 @@
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <limits>
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cstdint>
|
||||
#else
|
||||
#include <cstdint>
|
||||
#endif
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/util/type_traits.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
@ -48,8 +42,8 @@ namespace cute
|
||||
//
|
||||
|
||||
template <class T, class U,
|
||||
__CUTE_REQUIRES(std::is_arithmetic<T>::value &&
|
||||
std::is_arithmetic<U>::value)>
|
||||
__CUTE_REQUIRES(is_arithmetic<T>::value &&
|
||||
is_arithmetic<U>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
max(T const& t, U const& u) {
|
||||
@ -57,8 +51,8 @@ max(T const& t, U const& u) {
|
||||
}
|
||||
|
||||
template <class T, class U,
|
||||
__CUTE_REQUIRES(std::is_arithmetic<T>::value &&
|
||||
std::is_arithmetic<U>::value)>
|
||||
__CUTE_REQUIRES(is_arithmetic<T>::value &&
|
||||
is_arithmetic<U>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
min(T const& t, U const& u) {
|
||||
@ -66,11 +60,11 @@ min(T const& t, U const& u) {
|
||||
}
|
||||
|
||||
template <class T,
|
||||
__CUTE_REQUIRES(std::is_arithmetic<T>::value)>
|
||||
__CUTE_REQUIRES(is_arithmetic<T>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
abs(T const& t) {
|
||||
if constexpr (std::is_signed<T>::value) {
|
||||
if constexpr (is_signed<T>::value) {
|
||||
return t < T(0) ? -t : t;
|
||||
} else {
|
||||
return t;
|
||||
@ -85,8 +79,8 @@ abs(T const& t) {
|
||||
|
||||
// Greatest common divisor of two integers
|
||||
template <class T, class U,
|
||||
__CUTE_REQUIRES(std::is_integral<T>::value &&
|
||||
std::is_integral<U>::value)>
|
||||
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<T>::value &&
|
||||
CUTE_STL_NAMESPACE::is_integral<U>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
gcd(T t, U u) {
|
||||
@ -100,8 +94,8 @@ gcd(T t, U u) {
|
||||
|
||||
// Least common multiple of two integers
|
||||
template <class T, class U,
|
||||
__CUTE_REQUIRES(std::is_integral<T>::value &&
|
||||
std::is_integral<U>::value)>
|
||||
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<T>::value &&
|
||||
CUTE_STL_NAMESPACE::is_integral<U>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
lcm(T const& t, U const& u) {
|
||||
@ -133,11 +127,11 @@ template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T
|
||||
bit_width(T x) {
|
||||
static_assert(std::is_unsigned<T>::value, "Only to be used for unsigned types.");
|
||||
constexpr int N = (std::numeric_limits<T>::digits == 64 ? 6 :
|
||||
(std::numeric_limits<T>::digits == 32 ? 5 :
|
||||
(std::numeric_limits<T>::digits == 16 ? 4 :
|
||||
(std::numeric_limits<T>::digits == 8 ? 3 : (assert(false),0)))));
|
||||
static_assert(is_unsigned<T>::value, "Only to be used for unsigned types.");
|
||||
constexpr int N = (numeric_limits<T>::digits == 64 ? 6 :
|
||||
(numeric_limits<T>::digits == 32 ? 5 :
|
||||
(numeric_limits<T>::digits == 16 ? 4 :
|
||||
(numeric_limits<T>::digits == 8 ? 3 : (assert(false),0)))));
|
||||
T r = 0;
|
||||
for (int i = N - 1; i >= 0; --i) {
|
||||
T shift = (x > ((T(1) << (T(1) << i))-1)) << i;
|
||||
@ -193,7 +187,7 @@ template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T
|
||||
rotl(T x, int s) {
|
||||
constexpr int N = std::numeric_limits<T>::digits;
|
||||
constexpr int N = numeric_limits<T>::digits;
|
||||
return s == 0 ? x : s > 0 ? (x << s) | (x >> (N - s)) : rotr(x, -s);
|
||||
}
|
||||
|
||||
@ -202,7 +196,7 @@ template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T
|
||||
rotr(T x, int s) {
|
||||
constexpr int N = std::numeric_limits<T>::digits;
|
||||
constexpr int N = numeric_limits<T>::digits;
|
||||
return s == 0 ? x : s > 0 ? (x >> s) | (x << (N - s)) : rotl(x, -s);
|
||||
}
|
||||
|
||||
@ -214,7 +208,7 @@ template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T
|
||||
countl_zero(T x) {
|
||||
return std::numeric_limits<T>::digits - bit_width(x);
|
||||
return numeric_limits<T>::digits - bit_width(x);
|
||||
}
|
||||
|
||||
// Counts the number of consecutive 1 bits, starting from the most significant bit
|
||||
@ -236,7 +230,7 @@ template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T
|
||||
countr_zero(T x) {
|
||||
return x == 0 ? std::numeric_limits<T>::digits : bit_width(T(x & T(-x))) - 1; // bit_width of the LSB
|
||||
return x == 0 ? numeric_limits<T>::digits : bit_width(T(x & T(-x))) - 1; // bit_width of the LSB
|
||||
}
|
||||
|
||||
// Counts the number of consecutive 1 bits, starting from the least significant bit
|
||||
@ -288,7 +282,7 @@ shiftr(T x, int s) {
|
||||
|
||||
// Returns 1 if x > 0, -1 if x < 0, and 0 if x is zero.
|
||||
template <class T,
|
||||
__CUTE_REQUIRES(std::is_unsigned<T>::value)>
|
||||
__CUTE_REQUIRES(is_unsigned<T>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
int
|
||||
signum(T const& x) {
|
||||
@ -296,7 +290,7 @@ signum(T const& x) {
|
||||
}
|
||||
|
||||
template <class T,
|
||||
__CUTE_REQUIRES(not std::is_unsigned<T>::value)>
|
||||
__CUTE_REQUIRES(not is_unsigned<T>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
int
|
||||
signum(T const& x) {
|
||||
@ -307,8 +301,8 @@ signum(T const& x) {
|
||||
// @pre t % u == 0
|
||||
// @result t / u
|
||||
template <class T, class U,
|
||||
__CUTE_REQUIRES(std::is_integral<T>::value &&
|
||||
std::is_integral<U>::value)>
|
||||
__CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral<T>::value &&
|
||||
CUTE_STL_NAMESPACE::is_integral<U>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
safe_div(T const& t, U const& u) {
|
||||
|
||||
@ -43,9 +43,11 @@ using cutlass::tfloat32_t;
|
||||
// Display utilities
|
||||
//
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, tfloat32_t const& v)
|
||||
{
|
||||
return os << float(v);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
Reference in New Issue
Block a user