Updates for 3.4 release. (#1305)

This commit is contained in:
ANIKET SHIVAM
2024-01-16 10:42:51 -08:00
committed by GitHub
parent acba5beee5
commit 2f589ffa76
166 changed files with 5996 additions and 4702 deletions

View File

@ -108,6 +108,28 @@ CUTE_NAMED_UNARY_OP(conjugate, cute::conj);
#undef CUTE_RIGHT_UNARY_OP
#undef CUTE_NAMED_UNARY_OP
template <int Shift_>
struct shift_right_const {
static constexpr int Shift = Shift_;
template <class T>
CUTE_HOST_DEVICE constexpr
decltype(auto) operator()(T&& arg) const {
return std::forward<T>(arg) >> Shift;
}
};
template <int Shift_>
struct shift_left_const {
static constexpr int Shift = Shift_;
template <class T>
CUTE_HOST_DEVICE constexpr
decltype(auto) operator()(T&& arg) const {
return std::forward<T>(arg) << Shift;
}
};
/************/
/** Binary **/
/************/

View File

@ -604,8 +604,7 @@ unwrap(T const& t)
}
//
// Flatten a hierarchical tuple to a tuple of depth one.
//
// Flatten and Unflatten
//
template <class T>
@ -614,13 +613,15 @@ struct is_flat : true_type {};
template <class... Ts>
struct is_flat<tuple<Ts...>> : bool_constant<(true && ... && (not is_tuple<Ts>::value))> {};
// Flatten a hierarchical tuple to a tuple of depth one
// and wrap non-tuples into a rank-1 tuple.
template <class T>
CUTE_HOST_DEVICE constexpr
auto
flatten_to_tuple(T const& t)
{
if constexpr (is_tuple<T>::value) {
if constexpr (is_flat<T>::value) {
if constexpr (is_flat<T>::value) { // Shortcut for perf
return t;
} else {
return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); });
@ -632,13 +633,15 @@ flatten_to_tuple(T const& t)
CUTE_GCC_UNREACHABLE;
}
// Flatten a hierarchical tuple to a tuple of depth one
// and leave non-tuple untouched.
template <class T>
CUTE_HOST_DEVICE constexpr
auto
flatten(T const& t)
{
if constexpr (is_tuple<T>::value) {
if constexpr (is_flat<T>::value) {
if constexpr (is_flat<T>::value) { // Shortcut for perf
return t;
} else {
return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); });
@ -650,6 +653,43 @@ flatten(T const& t)
CUTE_GCC_UNREACHABLE;
}
namespace detail {
template<class FlatTuple, class TargetProfile>
CUTE_HOST_DEVICE constexpr
auto
unflatten_impl(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
{
if constexpr (is_tuple<TargetProfile>::value) {
return fold(target_profile, cute::make_tuple(cute::make_tuple(), flat_tuple), [](auto const& v, auto const& t) {
auto [result, remaining_tuple] = v;
auto [sub_result, sub_tuple] = unflatten_impl(remaining_tuple, t);
return cute::make_tuple(append(result, sub_result), sub_tuple);
});
} else {
return cute::make_tuple(get<0>(flat_tuple), take<1, decltype(rank(flat_tuple))::value>(flat_tuple));
}
CUTE_GCC_UNREACHABLE;
}
} // end namespace detail
// Unflatten a flat tuple into a hierarchical tuple
// @pre flatten(@a flat_tuple) == @a flat_tuple
// @pre rank(flatten(@a target_profile)) == rank(@a flat_tuple)
// @post congruent(@a result, @a target_profile)
// @post flatten(@a result) == @a flat_tuple
template<class FlatTuple, class TargetProfile>
CUTE_HOST_DEVICE constexpr
auto
unflatten(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
{
auto [unflatten_tuple, flat_remainder] = detail::unflatten_impl(flat_tuple, target_profile);
CUTE_STATIC_ASSERT_V(rank(flat_remainder) == Int<0>{});
return unflatten_tuple;
}
//
// insert and remove and replace
//
@ -728,6 +768,18 @@ replace_back(T const& t, X const& x)
// Make a tuple of Xs of tuple_size N
//
template <int N, class X>
CUTE_HOST_DEVICE constexpr
auto
tuple_repeat(X const& x)
{
return detail::construct(0, x, seq<>{}, make_seq<N>{}, seq<>{});
}
//
// Make repeated Xs of rank N
//
template <int N, class X>
CUTE_HOST_DEVICE constexpr
auto
@ -743,7 +795,7 @@ repeat(X const& x)
}
//
// Make a tuple of Xs the same profile as tuple
// Make a tuple of Xs the same profile as tuple T
//
template <class T, class X>
@ -864,48 +916,6 @@ prepend(T const& a, X const& x)
CUTE_GCC_UNREACHABLE;
}
//
// Unflatten a flat tuple into a hierarchical one
// unflatten(x, flatten(x)) == x
//
namespace detail {
template<class FlatTuple, class TargetProfile>
CUTE_HOST_DEVICE constexpr
auto
unflatten_impl(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
{
if constexpr (is_tuple<TargetProfile>::value) {
return fold(target_profile, cute::make_tuple(cute::make_tuple(), flat_tuple), [](auto const& v, auto const& t) {
auto [result, remaining_tuple] = v;
auto [sub_result, sub_tuple] = unflatten_impl(remaining_tuple, t);
return cute::make_tuple(append(result, sub_result), sub_tuple);
});
} else {
return cute::make_tuple(get<0>(flat_tuple), take<1, decltype(rank(flat_tuple))::value>(flat_tuple));
}
CUTE_GCC_UNREACHABLE;
}
} // end namespace detail
// @pre flatten(@a flat_tuple) == @a flat_tuple
// @pre rank(flatten(@a target_profile)) == rank(@a flat_tuple)
// @post congruent(@a result, @a target_profile)
// @post flatten(@a result) == @a flat_tuple
template<class FlatTuple, class TargetProfile>
CUTE_HOST_DEVICE constexpr
auto
unflatten(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
{
auto [unflatten_tuple, flat_remainder] = detail::unflatten_impl(flat_tuple, target_profile);
CUTE_STATIC_ASSERT_V(rank(flat_remainder) == Int<0>{});
return unflatten_tuple;
}
//
// Inclusive scan (prefix sum)
//

View File

@ -63,7 +63,7 @@ initialize_barrier(uint64_t& smem_barrier, // 64 bits user-mange
{
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier);
asm volatile ("mbarrier.init.shared.b64 [%0], %1;\n"
asm volatile ("mbarrier.init.shared::cta.b64 [%0], %1;\n"
:: "r"(smem_int_ptr),
"r"(thread_count));
#endif
@ -77,7 +77,7 @@ set_barrier_transaction_bytes(uint64_t& smem_barrier, // 64 bits user-mange
{
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier);
asm volatile ("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;\n"
asm volatile ("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%0], %1;\n"
:: "r"(smem_int_ptr),
"r"(bytes));
#endif
@ -95,7 +95,7 @@ wait_barrier(uint64_t& smem_barrier, // 64 bits user-mange
"{\n"
".reg .pred P1;\n"
"LAB_WAIT:\n"
"mbarrier.try_wait.parity.shared.b64 P1, [%0], %1;\n"
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n"
"@P1 bra.uni DONE;\n"
"bra.uni LAB_WAIT;\n"
"DONE:\n"
@ -116,7 +116,7 @@ arrive_barrier(uint64_t& smem_barrier) // 64 bits user-mang
asm volatile(
"{\n"
".reg .b64 state; \n"
"mbarrier.arrive.shared.b64 state, [%0];\n"
"mbarrier.arrive.shared::cta.b64 state, [%0];\n"
"}\n"
:: "r"(smem_int_ptr));
#endif

View File

@ -854,11 +854,12 @@ rs_op_selector()
// FP32 accumulator
else if constexpr (is_same_v<ElementC, float>) {
static_assert(is_same_v<ElementA, ElementB>, "ElementA and ElementB must be the same type for this config.");
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
// FP16 inputs
if constexpr (is_same_v<ElementA, half_t>) {
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
static_assert(is_same_v<ElementA, ElementB>, "ElementA and ElementB must be the same type for this config.");
if constexpr (Tile_N % 256 == 0) {
return SM90_64x256x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
}
@ -891,6 +892,7 @@ rs_op_selector()
// BF16 inputs
else if constexpr (is_same_v<ElementA, bfloat16_t>) {
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
static_assert(is_same_v<ElementA, ElementB>, "ElementA and ElementB must be the same type for this config.");
if constexpr (Tile_N % 256 == 0) {
return SM90_64x256x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
@ -925,6 +927,7 @@ rs_op_selector()
else if constexpr (is_same_v<ElementA, tfloat32_t>) {
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8.");
static_assert(is_same_v<ElementA, ElementB>, "ElementA and ElementB must be the same type for this config.");
if constexpr (Tile_N % 256 == 0) {
return SM90_64x256x8_F32TF32TF32_RS_TN<Args...>{};
@ -1023,7 +1026,7 @@ rs_op_selector()
return SM90_64x8x32_F32E4M3E5M2_RS_TN<Args...>{};
}
else {
static_aRSert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
}
}

View File

@ -65,9 +65,9 @@ struct Copy_Atom<Copy_Traits<Args...>, CopyInternalType>
using ValType = CopyInternalType;
using ValLayoutSrc = decltype(upcast<sizeof_bits<ValType>::value>(BitLayoutSrc{}));
using ValLayoutDst = decltype(upcast<sizeof_bits<ValType>::value>(BitLayoutDst{}));
using ValLayoutRef = decltype(upcast<sizeof_bits<ValType>::value>(BitLayoutRef{}));
using ValLayoutSrc = decltype(recast_layout<uint1_t, ValType>(BitLayoutSrc{}));
using ValLayoutDst = decltype(recast_layout<uint1_t, ValType>(BitLayoutDst{}));
using ValLayoutRef = decltype(recast_layout<uint1_t, ValType>(BitLayoutRef{}));
CUTE_STATIC_ASSERT_V(size<0>(ValLayoutSrc{}) == size(ThrID{}), "CopyOperation is not valid for Src of ValType.");
CUTE_STATIC_ASSERT_V(size<0>(ValLayoutDst{}) == size(ThrID{}), "CopyOperation is not valid for Dst of ValType.");
@ -479,20 +479,24 @@ make_tiled_copy(Copy_Atom<Args...> const& copy_atom,
ThrLayout const& thr_layout = {}, // (m,n) -> thr_idx
ValLayout const& val_layout = {}) // (m,n) -> val_idx
{
constexpr int R = cute::max(rank_v<ThrLayout>, rank_v<ValLayout>);
auto thr_layout_mn = append<R>(thr_layout, Layout<_1>{});
auto val_layout_mn = append<R>(val_layout, Layout<_1>{});
// Take the raked_products to compute the Layout_MN
auto layout_mn = raked_product(thr_layout_mn, val_layout_mn);
// (M,N) -> (thr_idx, val_idx)
auto layout_mn = raked_product(thr_layout, val_layout);
// (thr_idx, val_idx) -> (M,N)
auto layout_tv = right_inverse(layout_mn).with_shape(make_shape(size(thr_layout), size(val_layout)));
// print("thr_layout: "); print(thr_layout_mn); print("\n");
// print("val_layout: "); print(val_layout_mn); print("\n");
// print("layout_mn : "); print(layout_mn); print("\n");
// print("layout_tv : "); print(layout_tv); print("\n");
// Tiler for extracting relevant elements
// (M,N) -> tensor coord
auto tiler = product_each(shape(layout_mn));
return make_tiled_copy_impl(copy_atom, layout_tv, product_each(shape(layout_mn)));
#if 0
print("thr_layout: "); print(thr_layout); print("\n");
print("val_layout: "); print(val_layout); print("\n");
print("layout_mn : "); print(layout_mn); print("\n");
print("layout_tv : "); print(layout_tv); print("\n");
print("tiler : "); print(tiler); print("\n");
#endif
return make_tiled_copy_impl(copy_atom, layout_tv, tiler);
}
/** Produce a TiledCopy from thread and value offset maps.
@ -622,7 +626,7 @@ print(Copy_Atom<Copy_Traits<Args...>, T> const&)
print(" ValLayoutSrc: "); print(typename Atom::ValLayoutSrc{}); print("\n");
print(" ValLayoutDst: "); print(typename Atom::ValLayoutDst{}); print("\n");
print(" ValLayoutRef: "); print(typename Atom::ValLayoutRef{}); print("\n");
print(" ValueType: %db\n", int(sizeof_bits<typename Atom::ValType>::value));
print(" ValueType: "); print(sizeof_bits<typename Atom::ValType>::value); print("b\n");
}
template <class Atom, class... Args>
@ -755,6 +759,7 @@ print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and
#include <cute/atom/copy_traits_sm75.hpp>
#include <cute/atom/copy_traits_sm80.hpp>
#include <cute/atom/copy_traits_sm90.hpp>
// Config
#if (__CUDACC_VER_MAJOR__ >= 12)
# define CUTE_COPY_ATOM_TMA_SM90_ENABLED

View File

@ -673,15 +673,14 @@ fill_tma_gmem_shape_stride(Tensor<GEngine,GLayout> const& gtensor, /
// Trivial contribution of this gmem mode to this tma mode
auto ej = unwrap(get<i>(tma_gbasis_stride));
gmem_prob_shape[i] = basis_get(ej, gmem_shape);
gmem_prob_stride[i] = basis_get(ej, gmem_stride) * sizeof_bits_v<TmaInternalType> / 8;
gmem_prob_stride[i] = basis_get(ej, gmem_stride);
} else {
// Apply a recurrence to each gmem mode that contributes to this tma mode
for_each(get<i>(tma_gbasis_stride), [&](auto ej) {
// Problem shape
uint64_t shape_j = basis_get(ej, gmem_shape);
// Problem stride (in bytes)
uint64_t stride_j = basis_get(ej, gmem_stride) * sizeof_bits_v<TmaInternalType> / 8;
uint64_t stride_j = basis_get(ej, gmem_stride);
uint64_t old_stride = gmem_prob_stride[i];
gmem_prob_stride[i] = gcd(gmem_prob_stride[i], stride_j);
@ -764,8 +763,14 @@ make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The origin
assert(gmem_prob_shape[4] <= (uint64_t(1) << 32)); // Size must be max 2^32
// TMA descriptor does not store the zeroth stride and assumes it is 1 (TmaInternalType element).
assert(gmem_prob_stride[0] == sizeof(TmaInternalType) && "Majorness of smem doesn't match majorness of gmem");
assert(gmem_prob_stride[0] == 1 && "Majorness of smem doesn't match majorness of gmem");
// convert strides to byte strides
for(uint64_t& stride : gmem_prob_stride) {
stride = (stride * sizeof_bits_v<TmaInternalType>) / 8;
}
// Assert the byte strides. Tma Descriptor uses byte strides
assert((gmem_prob_stride[1]) < (uint64_t(1) << 40)); // Stride must be max 2^40
assert((gmem_prob_stride[1] & 0b1111) == 0); // Stride must be multiple of 16B (128b)
assert((gmem_prob_stride[2]) < (uint64_t(1) << 40)); // Stride must be max 2^40
@ -866,8 +871,8 @@ make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The origin
}
#endif // (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__)
auto recast_ratio = cute::ratio(Int<sizeof_bits<typename GEngine::value_type>::value>{},
Int<sizeof_bits< TmaInternalType>::value>{});
auto recast_ratio = cute::trait_ratio(sizeof_bits<typename GEngine::value_type>{},
sizeof_bits< TmaInternalType>{});
auto gbasis = make_basis_like(shape(gtensor));
@ -943,7 +948,7 @@ make_tma_copy_atom(CopyOp,
// Construct the Copy_Traits
//
constexpr int num_bits_per_tma = decltype(size(tma_gbasis))::value * sizeof_bits_v<TmaInternalType>;
constexpr int num_bits_per_tma = size(tma_gbasis) * sizeof_bits<TmaInternalType>::value;
using Traits = Copy_Traits<CopyOp, cute::C<num_bits_per_tma>, decltype(aux_params)>;
using Atom = Copy_Atom<Traits, typename GEngine::value_type>;
@ -985,7 +990,7 @@ make_tma_copy_tiled(CopyOp const& copy_op,
[[maybe_unused]] auto cta_tiler = product_each(shape(cta_v_map));
auto num_elems_per_tma = size<1>(typename decltype(atom)::RefLayout{}) / Int<sizeof_bits_v<typename GEngine::value_type>>{};
auto num_elems_per_tma = size<1>(typename decltype(atom)::RefLayout{}) / static_value<sizeof_bits<typename GEngine::value_type>>();
// smem idx -> smem coord
auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout));

View File

@ -55,10 +55,10 @@ struct MMA_Atom<MMA_Traits<Args...>>
using Traits = MMA_Traits<Args...>;
// Element value types from the MMA_Traits
using ValTypeD = typename Traits::ElementDVal;
using ValTypeA = typename Traits::ElementAVal;
using ValTypeB = typename Traits::ElementBVal;
using ValTypeC = typename Traits::ElementCVal;
using ValTypeD = typename Traits::ValTypeD;
using ValTypeA = typename Traits::ValTypeA;
using ValTypeB = typename Traits::ValTypeB;
using ValTypeC = typename Traits::ValTypeC;
// Thr-Val layouts from the MMA_Traits
using Shape_MNK = typename Traits::Shape_MNK;

View File

@ -50,14 +50,14 @@ struct supports_output_scaling<X, void_t<decltype(declval<X>().accumulate_)>> {
/**
* concept MMA_Traits
* {
* using ElementDVal = // Logical A-value type
* using ElementAVal = // Logical B-value type
* using ElementBVal = // Logical C-value type
* using ElementCVal = // Logical D-value type (NOTE: Not used? Assumed == ElementDVal)
* using ValTypeD = // Logical A-value type
* using ValTypeA = // Logical B-value type
* using ValTypeB = // Logical C-value type
* using ValTypeC = // Logical D-value type (NOTE: Not used? Assumed == ValTypeD)
*
* using ElementAFrg = // A-type consumed by MMA (if ommitted, same as ElementAVal)
* using ElementBFrg = // B_type consumed by MMA (if ommitted, same as ElementBVal)
* using ElementCFrg = // C_type consumed by MMA (if ommitted, same as ElementCVal)
* using FrgTypeA = // A-type consumed by MMA (if ommitted, same as ValTypeA)
* using FrgTypeB = // B_type consumed by MMA (if ommitted, same as ValTypeB)
* using FrgTypeC = // C_type consumed by MMA (if ommitted, same as ValTypeC)
*
* using Shape_MNK = // Logical MxNxK shape of the MMA
*
@ -78,10 +78,10 @@ struct MMA_Traits
template <class D, class A, class B, class C>
struct MMA_Traits<UniversalFMA<D,A,B,C>>
{
using ElementDVal = D;
using ElementAVal = A;
using ElementBVal = B;
using ElementCVal = C;
using ValTypeD = D;
using ValTypeA = A;
using ValTypeB = B;
using ValTypeC = C;
// Logical shape of the MMA
using Shape_MNK = Shape<_1,_1,_1>;
@ -209,19 +209,19 @@ mma_unpack(MMA_Traits<MMA_Op, MMA_Args...> const& traits,
namespace detail {
template <class X, class = void>
struct FrgTypeA_or_Default { using type = typename X::ElementAVal; };
struct FrgTypeA_or_Default { using type = typename X::ValTypeA; };
template <class X>
struct FrgTypeA_or_Default<X,void_t<typename X::ElementAFrg>> { using type = typename X::ElementAFrg; };
struct FrgTypeA_or_Default<X,void_t<typename X::FrgTypeA>> { using type = typename X::FrgTypeA; };
template <class X, class = void>
struct FrgTypeB_or_Default { using type = typename X::ElementBVal; };
struct FrgTypeB_or_Default { using type = typename X::ValTypeB; };
template <class X>
struct FrgTypeB_or_Default<X,void_t<typename X::ElementBFrg>> { using type = typename X::ElementBFrg; };
struct FrgTypeB_or_Default<X,void_t<typename X::FrgTypeB>> { using type = typename X::FrgTypeB; };
template <class X, class = void>
struct FrgTypeC_or_Default { using type = typename X::ElementCVal; };
struct FrgTypeC_or_Default { using type = typename X::ValTypeC; };
template <class X>
struct FrgTypeC_or_Default<X,void_t<typename X::ElementCFrg>> { using type = typename X::ElementCFrg; };
struct FrgTypeC_or_Default<X,void_t<typename X::FrgTypeC>> { using type = typename X::FrgTypeC; };
} // end namespace detail

View File

@ -41,10 +41,10 @@ namespace cute
template <>
struct MMA_Traits<SM61_DP4A>
{
using ElementDVal = int32_t;
using ElementAVal = int8_t;
using ElementBVal = int8_t;
using ElementCVal = int32_t;
using ValTypeD = int32_t;
using ValTypeA = int8_t;
using ValTypeB = int8_t;
using ValTypeC = int32_t;
using Shape_MNK = Shape<_1,_1,_4>;
using ThrID = Layout<_1>;
@ -58,10 +58,10 @@ struct MMA_Traits<SM61_DP4A>
template <>
struct MMA_Traits<SM61_DP2A>
{
using ElementDVal = int32_t;
using ElementAVal = int16_t;
using ElementBVal = int16_t;
using ElementCVal = int32_t;
using ValTypeD = int32_t;
using ValTypeA = int16_t;
using ValTypeB = int16_t;
using ValTypeC = int32_t;
using Shape_MNK = Shape<_1,_1,_2>;
using ThrID = Layout<_1>;

View File

@ -63,10 +63,10 @@ using SM70_8x8_32b = Layout<Shape <Shape <_2, _2,_2>,Shape <_2,_2, _2>>,
template <>
struct MMA_Traits<SM70_8x8x4_F16F16F16F16_TN>
{
using ElementDVal = half_t;
using ElementAVal = half_t;
using ElementBVal = half_t;
using ElementCVal = half_t;
using ValTypeD = half_t;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = half_t;
using Shape_MNK = Shape<_8,_8,_4>;
using ThrID = SM70_QuadPair;
@ -80,10 +80,10 @@ struct MMA_Traits<SM70_8x8x4_F16F16F16F16_TN>
template <>
struct MMA_Traits<SM70_8x8x4_F16F16F16F16_NT>
{
using ElementDVal = half_t;
using ElementAVal = half_t;
using ElementBVal = half_t;
using ElementCVal = half_t;
using ValTypeD = half_t;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = half_t;
using Shape_MNK = Shape<_8,_8,_4>;
using ThrID = SM70_QuadPair;
@ -97,10 +97,10 @@ struct MMA_Traits<SM70_8x8x4_F16F16F16F16_NT>
template <>
struct MMA_Traits<SM70_8x8x4_F16F16F16F16_NN>
{
using ElementDVal = half_t;
using ElementAVal = half_t;
using ElementBVal = half_t;
using ElementCVal = half_t;
using ValTypeD = half_t;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = half_t;
using Shape_MNK = Shape<_8,_8,_4>;
using ThrID = SM70_QuadPair;
@ -114,10 +114,10 @@ struct MMA_Traits<SM70_8x8x4_F16F16F16F16_NN>
template <>
struct MMA_Traits<SM70_8x8x4_F16F16F16F16_TT>
{
using ElementDVal = half_t;
using ElementAVal = half_t;
using ElementBVal = half_t;
using ElementCVal = half_t;
using ValTypeD = half_t;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = half_t;
using Shape_MNK = Shape<_8,_8,_4>;
using ThrID = SM70_QuadPair;
@ -131,10 +131,10 @@ struct MMA_Traits<SM70_8x8x4_F16F16F16F16_TT>
template <>
struct MMA_Traits<SM70_8x8x4_F32F16F16F32_TN>
{
using ElementDVal = float;
using ElementAVal = half_t;
using ElementBVal = half_t;
using ElementCVal = float;
using ValTypeD = float;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = float;
using Shape_MNK = Shape<_8,_8,_4>;
using ThrID = SM70_QuadPair;
@ -148,10 +148,10 @@ struct MMA_Traits<SM70_8x8x4_F32F16F16F32_TN>
template <>
struct MMA_Traits<SM70_8x8x4_F32F16F16F32_NT>
{
using ElementDVal = float;
using ElementAVal = half_t;
using ElementBVal = half_t;
using ElementCVal = float;
using ValTypeD = float;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = float;
using Shape_MNK = Shape<_8,_8,_4>;
using ThrID = SM70_QuadPair;
@ -165,10 +165,10 @@ struct MMA_Traits<SM70_8x8x4_F32F16F16F32_NT>
template <>
struct MMA_Traits<SM70_8x8x4_F32F16F16F32_NN>
{
using ElementDVal = float;
using ElementAVal = half_t;
using ElementBVal = half_t;
using ElementCVal = float;
using ValTypeD = float;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = float;
using Shape_MNK = Shape<_8,_8,_4>;
using ThrID = SM70_QuadPair;
@ -182,10 +182,10 @@ struct MMA_Traits<SM70_8x8x4_F32F16F16F32_NN>
template <>
struct MMA_Traits<SM70_8x8x4_F32F16F16F32_TT>
{
using ElementDVal = float;
using ElementAVal = half_t;
using ElementBVal = half_t;
using ElementCVal = float;
using ValTypeD = float;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = float;
using Shape_MNK = Shape<_8,_8,_4>;
using ThrID = SM70_QuadPair;

View File

@ -41,10 +41,10 @@ namespace cute
template <>
struct MMA_Traits<SM75_16x8x8_F32F16F16F32_TN>
{
using ElementDVal = float;
using ElementAVal = half_t;
using ElementBVal = half_t;
using ElementCVal = float;
using ValTypeD = float;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = float;
using Shape_MNK = Shape<_16,_8,_8>;
using ThrID = Layout<_32>;
@ -61,10 +61,10 @@ struct MMA_Traits<SM75_16x8x8_F32F16F16F32_TN>
template <>
struct MMA_Traits<SM75_8x8x16_S32S8S8S32_TN>
{
using ElementDVal = int32_t;
using ElementAVal = int8_t;
using ElementBVal = int8_t;
using ElementCVal = int32_t;
using ValTypeD = int32_t;
using ValTypeA = int8_t;
using ValTypeB = int8_t;
using ValTypeC = int32_t;
using Shape_MNK = Shape<_8,_8,_16>;
using ThrID = Layout<_32>;

View File

@ -66,10 +66,10 @@ using SM80_16x8_Row = Layout<Shape <Shape < _4,_8>,Shape < _2,_2>>,
template <>
struct MMA_Traits<SM80_16x8x8_F16F16F16F16_TN>
{
using ElementDVal = half_t;
using ElementAVal = half_t;
using ElementBVal = half_t;
using ElementCVal = half_t;
using ValTypeD = half_t;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = half_t;
using Shape_MNK = Shape<_16,_8,_8>;
using ThrID = Layout<_32>;
@ -81,10 +81,10 @@ struct MMA_Traits<SM80_16x8x8_F16F16F16F16_TN>
template <>
struct MMA_Traits<SM80_16x8x16_F16F16F16F16_TN>
{
using ElementDVal = half_t;
using ElementAVal = half_t;
using ElementBVal = half_t;
using ElementCVal = half_t;
using ValTypeD = half_t;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = half_t;
using Shape_MNK = Shape<_16,_8,_16>;
using ThrID = Layout<_32>;
@ -103,20 +103,20 @@ template <>
struct MMA_Traits<SM80_16x8x8_F32F16F16F32_TN>
: MMA_Traits<SM80_16x8x8_F16F16F16F16_TN>
{
using ElementDVal = float;
using ElementAVal = half_t;
using ElementBVal = half_t;
using ElementCVal = float;
using ValTypeD = float;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = float;
};
template <>
struct MMA_Traits<SM80_16x8x16_F32F16F16F32_TN>
: MMA_Traits<SM80_16x8x16_F16F16F16F16_TN>
{
using ElementDVal = float;
using ElementAVal = half_t;
using ElementBVal = half_t;
using ElementCVal = float;
using ValTypeD = float;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = float;
};
///////////////////////////////////////////////////////////////////////////////
@ -127,20 +127,20 @@ template <>
struct MMA_Traits<SM80_16x8x8_F32BF16BF16F32_TN>
: MMA_Traits<SM80_16x8x8_F16F16F16F16_TN>
{
using ElementDVal = float;
using ElementAVal = bfloat16_t;
using ElementBVal = bfloat16_t;
using ElementCVal = float;
using ValTypeD = float;
using ValTypeA = bfloat16_t;
using ValTypeB = bfloat16_t;
using ValTypeC = float;
};
template <>
struct MMA_Traits<SM80_16x8x16_F32BF16BF16F32_TN>
: MMA_Traits<SM80_16x8x16_F16F16F16F16_TN>
{
using ElementDVal = float;
using ElementAVal = bfloat16_t;
using ElementBVal = bfloat16_t;
using ElementCVal = float;
using ValTypeD = float;
using ValTypeA = bfloat16_t;
using ValTypeB = bfloat16_t;
using ValTypeC = float;
};
///////////////////////////////////////////////////////////////////////////////
@ -150,10 +150,10 @@ struct MMA_Traits<SM80_16x8x16_F32BF16BF16F32_TN>
template <>
struct MMA_Traits<SM80_16x8x4_F32TF32TF32F32_TN>
{
using ElementDVal = float;
using ElementAVal = cutlass::tfloat32_t;
using ElementBVal = cutlass::tfloat32_t;
using ElementCVal = float;
using ValTypeD = float;
using ValTypeA = cutlass::tfloat32_t;
using ValTypeB = cutlass::tfloat32_t;
using ValTypeC = float;
using Shape_MNK = Shape<_16,_8,_4>;
using ThrID = Layout<_32>;
@ -166,10 +166,10 @@ struct MMA_Traits<SM80_16x8x4_F32TF32TF32F32_TN>
template <>
struct MMA_Traits<SM80_16x8x8_F32TF32TF32F32_TN>
{
using ElementDVal = float;
using ElementAVal = cutlass::tfloat32_t;
using ElementBVal = cutlass::tfloat32_t;
using ElementCVal = float;
using ValTypeD = float;
using ValTypeA = cutlass::tfloat32_t;
using ValTypeB = cutlass::tfloat32_t;
using ValTypeC = float;
using Shape_MNK = Shape<_16,_8,_8>;
using ThrID = Layout<_32>;
@ -187,10 +187,10 @@ struct MMA_Traits<SM80_16x8x8_F32TF32TF32F32_TN>
template <>
struct MMA_Traits<SM80_8x8x4_F64F64F64F64_TN>
{
using ElementDVal = double;
using ElementAVal = double;
using ElementBVal = double;
using ElementCVal = double;
using ValTypeD = double;
using ValTypeA = double;
using ValTypeB = double;
using ValTypeC = double;
using Shape_MNK = Shape<_8,_8,_4>;
using ThrID = Layout<_32>;
@ -204,10 +204,10 @@ template <>
struct MMA_Traits<SM80_8x8x4_C64C64C64C64_TN>
: MMA_Traits<SM80_8x8x4_F64F64F64F64_TN>
{
using ElementDVal = complex<double>;
using ElementAVal = complex<double>;
using ElementBVal = complex<double>;
using ElementCVal = complex<double>;
using ValTypeD = complex<double>;
using ValTypeA = complex<double>;
using ValTypeB = complex<double>;
using ValTypeC = complex<double>;
};
// Custom complex fp64 MMA composed of 3 fp64 MMAs -- same layouts
@ -215,10 +215,10 @@ template <>
struct MMA_Traits<SM80_8x8x4_GC64C64C64GC64_TN>
: MMA_Traits<SM80_8x8x4_F64F64F64F64_TN>
{
using ElementDVal = typename SM80_8x8x4_GC64C64C64GC64_TN::GaussComplex;
using ElementAVal = complex<double>;
using ElementBVal = complex<double>;
using ElementCVal = typename SM80_8x8x4_GC64C64C64GC64_TN::GaussComplex;
using ValTypeD = typename SM80_8x8x4_GC64C64C64GC64_TN::GaussComplex;
using ValTypeA = complex<double>;
using ValTypeB = complex<double>;
using ValTypeC = typename SM80_8x8x4_GC64C64C64GC64_TN::GaussComplex;
};
///////////////////////////////////////////////////////////////////////////////
@ -228,10 +228,10 @@ struct MMA_Traits<SM80_8x8x4_GC64C64C64GC64_TN>
template <>
struct MMA_Traits<SM80_8x8x16_S32S8S8S32_TN>
{
using ElementDVal = int32_t;
using ElementAVal = int8_t;
using ElementBVal = int8_t;
using ElementCVal = int32_t;
using ValTypeD = int32_t;
using ValTypeA = int8_t;
using ValTypeB = int8_t;
using ValTypeC = int32_t;
using Shape_MNK = Shape<_8,_8,_16>;
using ThrID = Layout<_32>;
@ -247,10 +247,10 @@ struct MMA_Traits<SM80_8x8x16_S32S8S8S32_TN_SATURATE>
template <>
struct MMA_Traits<SM80_16x8x16_S32S8S8S32_TN>
{
using ElementDVal = int32_t;
using ElementAVal = int8_t;
using ElementBVal = int8_t;
using ElementCVal = int32_t;
using ValTypeD = int32_t;
using ValTypeA = int8_t;
using ValTypeB = int8_t;
using ValTypeC = int32_t;
using Shape_MNK = Shape<_16,_8,_16>;
using ThrID = Layout<_32>;
@ -267,10 +267,10 @@ struct MMA_Traits<SM80_16x8x16_S32S8S8S32_TN_SATURATE>
template <>
struct MMA_Traits<SM80_16x8x32_S32S8S8S32_TN>
{
using ElementDVal = int32_t;
using ElementAVal = int8_t;
using ElementBVal = int8_t;
using ElementCVal = int32_t;
using ValTypeD = int32_t;
using ValTypeA = int8_t;
using ValTypeB = int8_t;
using ValTypeC = int32_t;
using Shape_MNK = Shape<_16,_8,_32>;
using ThrID = Layout<_32>;
@ -293,10 +293,10 @@ template <>
struct MMA_Traits<SM80_8x8x16_S32S8U8S32_TN>
: MMA_Traits<SM80_8x8x16_S32S8S8S32_TN>
{
using ElementDVal = int32_t;
using ElementAVal = int8_t;
using ElementBVal = uint8_t;
using ElementCVal = int32_t;
using ValTypeD = int32_t;
using ValTypeA = int8_t;
using ValTypeB = uint8_t;
using ValTypeC = int32_t;
};
template <>
@ -307,10 +307,10 @@ template <>
struct MMA_Traits<SM80_16x8x16_S32S8U8S32_TN>
: MMA_Traits<SM80_16x8x16_S32S8S8S32_TN>
{
using ElementDVal = int32_t;
using ElementAVal = int8_t;
using ElementBVal = uint8_t;
using ElementCVal = int32_t;
using ValTypeD = int32_t;
using ValTypeA = int8_t;
using ValTypeB = uint8_t;
using ValTypeC = int32_t;
};
template <>
@ -321,10 +321,10 @@ template <>
struct MMA_Traits<SM80_16x8x32_S32S8U8S32_TN>
: MMA_Traits<SM80_16x8x32_S32S8S8S32_TN>
{
using ElementDVal = int32_t;
using ElementAVal = int8_t;
using ElementBVal = uint8_t;
using ElementCVal = int32_t;
using ValTypeD = int32_t;
using ValTypeA = int8_t;
using ValTypeB = uint8_t;
using ValTypeC = int32_t;
};
template <>
@ -339,10 +339,10 @@ template <>
struct MMA_Traits<SM80_8x8x16_S32U8S8S32_TN>
: MMA_Traits<SM80_8x8x16_S32S8S8S32_TN>
{
using ElementDVal = int32_t;
using ElementAVal = uint8_t;
using ElementBVal = int8_t;
using ElementCVal = int32_t;
using ValTypeD = int32_t;
using ValTypeA = uint8_t;
using ValTypeB = int8_t;
using ValTypeC = int32_t;
};
template <>
@ -353,10 +353,10 @@ template <>
struct MMA_Traits<SM80_16x8x16_S32U8S8S32_TN>
: MMA_Traits<SM80_16x8x16_S32S8S8S32_TN>
{
using ElementDVal = int32_t;
using ElementAVal = uint8_t;
using ElementBVal = int8_t;
using ElementCVal = int32_t;
using ValTypeD = int32_t;
using ValTypeA = uint8_t;
using ValTypeB = int8_t;
using ValTypeC = int32_t;
};
template <>
@ -367,10 +367,10 @@ template <>
struct MMA_Traits<SM80_16x8x32_S32U8S8S32_TN>
: MMA_Traits<SM80_16x8x32_S32S8S8S32_TN>
{
using ElementDVal = int32_t;
using ElementAVal = uint8_t;
using ElementBVal = int8_t;
using ElementCVal = int32_t;
using ValTypeD = int32_t;
using ValTypeA = uint8_t;
using ValTypeB = int8_t;
using ValTypeC = int32_t;
};
template <>
@ -385,10 +385,10 @@ template <>
struct MMA_Traits<SM80_8x8x16_S32U8U8S32_TN>
: MMA_Traits<SM80_8x8x16_S32S8S8S32_TN>
{
using ElementDVal = int32_t;
using ElementAVal = uint8_t;
using ElementBVal = uint8_t;
using ElementCVal = int32_t;
using ValTypeD = int32_t;
using ValTypeA = uint8_t;
using ValTypeB = uint8_t;
using ValTypeC = int32_t;
};
template <>
@ -399,10 +399,10 @@ template <>
struct MMA_Traits<SM80_16x8x16_S32U8U8S32_TN>
: MMA_Traits<SM80_16x8x16_S32S8S8S32_TN>
{
using ElementDVal = int32_t;
using ElementAVal = uint8_t;
using ElementBVal = uint8_t;
using ElementCVal = int32_t;
using ValTypeD = int32_t;
using ValTypeA = uint8_t;
using ValTypeB = uint8_t;
using ValTypeC = int32_t;
};
template <>
@ -413,10 +413,10 @@ template <>
struct MMA_Traits<SM80_16x8x32_S32U8U8S32_TN>
: MMA_Traits<SM80_16x8x32_S32S8S8S32_TN>
{
using ElementDVal = int32_t;
using ElementAVal = uint8_t;
using ElementBVal = uint8_t;
using ElementCVal = int32_t;
using ValTypeD = int32_t;
using ValTypeA = uint8_t;
using ValTypeB = uint8_t;
using ValTypeC = int32_t;
};
template <>
@ -430,10 +430,10 @@ struct MMA_Traits<SM80_16x8x32_S32U8U8S32_TN_SATURATE>
template <>
struct MMA_Traits<SM80_16x8x256_S32U1U1S32_TN_XORPOPC>
{
using ElementDVal = int32_t;
using ElementAVal = cute::uint1b_t;
using ElementBVal = cute::uint1b_t;
using ElementCVal = int32_t;
using ValTypeD = int32_t;
using ValTypeA = cute::uint1b_t;
using ValTypeB = cute::uint1b_t;
using ValTypeC = int32_t;
using Shape_MNK = Shape<_16,_8,_256>;
using ThrID = Layout<_32>;

View File

@ -44,10 +44,10 @@ namespace cute {
template <>
struct MMA_Traits<SM90_16x8x4_F64F64F64F64_TN>
{
using ElementDVal = double;
using ElementAVal = double;
using ElementBVal = double;
using ElementCVal = double;
using ValTypeD = double;
using ValTypeA = double;
using ValTypeB = double;
using ValTypeC = double;
using Shape_MNK = Shape<_16,_8,_4>;
using ThrID = Layout<_32>;
@ -62,10 +62,10 @@ struct MMA_Traits<SM90_16x8x4_F64F64F64F64_TN>
template <>
struct MMA_Traits<SM90_16x8x8_F64F64F64F64_TN>
{
using ElementDVal = double;
using ElementAVal = double;
using ElementBVal = double;
using ElementCVal = double;
using ValTypeD = double;
using ValTypeA = double;
using ValTypeB = double;
using ValTypeC = double;
using Shape_MNK = Shape<_16,_8,_8>;
using ThrID = Layout<_32>;
@ -80,10 +80,10 @@ struct MMA_Traits<SM90_16x8x8_F64F64F64F64_TN>
template <>
struct MMA_Traits<SM90_16x8x16_F64F64F64F64_TN>
{
using ElementDVal = double;
using ElementAVal = double;
using ElementBVal = double;
using ElementCVal = double;
using ValTypeD = double;
using ValTypeA = double;
using ValTypeB = double;
using ValTypeC = double;
using Shape_MNK = Shape<_16,_8,_16>;
using ThrID = Layout<_32>;
@ -103,30 +103,30 @@ template <>
struct MMA_Traits<SM90_16x8x4_C64C64C64C64_TN>
: MMA_Traits<SM90_16x8x4_F64F64F64F64_TN>
{
using ElementDVal = complex<double>;
using ElementAVal = complex<double>;
using ElementBVal = complex<double>;
using ElementCVal = complex<double>;
using ValTypeD = complex<double>;
using ValTypeA = complex<double>;
using ValTypeB = complex<double>;
using ValTypeC = complex<double>;
};
template <>
struct MMA_Traits<SM90_16x8x8_C64C64C64C64_TN>
: MMA_Traits<SM90_16x8x8_F64F64F64F64_TN>
{
using ElementDVal = complex<double>;
using ElementAVal = complex<double>;
using ElementBVal = complex<double>;
using ElementCVal = complex<double>;
using ValTypeD = complex<double>;
using ValTypeA = complex<double>;
using ValTypeB = complex<double>;
using ValTypeC = complex<double>;
};
template <>
struct MMA_Traits<SM90_16x8x16_C64C64C64C64_TN>
: MMA_Traits<SM90_16x8x16_F64F64F64F64_TN>
{
using ElementDVal = complex<double>;
using ElementAVal = complex<double>;
using ElementBVal = complex<double>;
using ElementCVal = complex<double>;
using ValTypeD = complex<double>;
using ValTypeA = complex<double>;
using ValTypeB = complex<double>;
using ValTypeC = complex<double>;
};
} // end namespace cute

File diff suppressed because it is too large Load Diff

View File

@ -479,8 +479,9 @@ weakly_congruent(IntTupleA const& a, IntTupleB const& b)
template <class A, class B>
using is_weakly_congruent = decltype(weakly_congruent(declval<A>(), declval<B>()));
/** Test if Shape B is compatible with Shape A:
* Any coordinate into A can also be used as a coordinate into B
/** Test if Shape A is compatible with Shape B:
* the size of A and B are the same, and
* any coordinate into A can also be used as a coordinate into B
* compatible is a partial order on A and B: A <= B
*/
template <class IntTupleA, class IntTupleB>
@ -509,8 +510,8 @@ compatible(IntTupleA const& a, IntTupleB const& b)
template <class A, class B>
using is_compatible = decltype(compatible(declval<A>(), declval<B>()));
/** Test if Shape B is weakly compatible with Shape A:
* Shape B is a multiple of a shape that is compatible with Shape A
/** Test if Shape A is weakly compatible with Shape B:
* there exists a Shape C congruent to A such that compatible(elem_scale(A,C), B)
* weakly_compatible is a partial order on A and B: A <= B
*/
template <class IntTupleA, class IntTupleB>

View File

@ -36,6 +36,8 @@
#include <cute/int_tuple.hpp>
#include <cute/stride.hpp>
#include <cute/numeric/arithmetic_tuple.hpp>
#include <cute/numeric/integral_ratio.hpp>
#include <cute/numeric/integral_constant.hpp>
namespace cute
{
@ -167,16 +169,6 @@ struct Layout
return operator()(make_coord(c0,c1,cs...));
}
// Map a linear index to a hier ND logical coordinate
// NOTE: Dangerous and error-prone
template <class Int>
CUTE_HOST_DEVICE constexpr
auto
operator[](Int const& linear_idx) const {
static_assert(is_integral<Int>::value);
return get_hier_coord(linear_idx);
}
//
// Compose
//
@ -305,11 +297,24 @@ struct Layout
#endif
};
// Equality, return a static or dynamic boolean
template <class ShapeA, class StrideA,
class ShapeB, class StrideB>
CUTE_HOST_DEVICE constexpr
auto
operator==(Layout<ShapeA,StrideA> const& layoutA, Layout<ShapeB,StrideB> const& layoutB)
{
return layoutA.shape() == layoutB.shape() && layoutA.stride() == layoutB.stride();
}
template <class Layout>
struct is_layout : false_type {};
template <class Shape, class Stride>
struct is_layout<Layout<Shape,Stride>> : true_type {};
//
// Layout construction
//
template <class Shape, class Stride,
__CUTE_REQUIRES((is_tuple<Shape >::value || is_integral<Shape >::value) &&
@ -446,51 +451,59 @@ make_identity_layout(Shape const& shape)
// Operations to manipulate Layouts like a tuple of pairs
//
// Return the Is...th sublayout.
// For Is... = <I0,I1,...,IN>, equivalent to get<IN>(...get<I1>(get<I0>(layout)))
template <size_t... Is, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
get(Layout<Shape,Stride> const& layout)
{
// Let the static_asserts in get<I>(shape|stride) catch problems
return make_layout(get<Is...>(layout.shape()), get<Is...>(layout.stride()));
return make_layout(get<Is...>(layout.shape()),
get<Is...>(layout.stride()));
}
// Return a new layout with only the modes in the range [B,E)
template <int B, int E, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
take(Layout<Shape,Stride> const& layout)
{
// Let the static_asserts in take<B,E>(shape|stride) catch problems
return make_layout(take<B,E>(layout.shape()), take<B,E>(layout.stride()));
static_assert(B < E, "take: empty range error");
static_assert(0 <= B && E <= Layout<Shape,Stride>::rank, "take: range out of bounds");
return make_layout(take<B,E>(layout.shape()),
take<B,E>(layout.stride()));
}
//
// Select layout modes according to an index sequence.
//
template <int... I, class Shape, class Stride>
// Return a new layout with only the modes Is... = <I0,I1,...,IN>
template <int... Is, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
select(Layout<Shape,Stride> const& layout)
{
return make_layout(select<I...>(layout.shape()),
select<I...>(layout.stride()));
return make_layout(select<Is...>(layout.shape()),
select<Is...>(layout.stride()));
}
// Return a layout with depth at most 1
template <class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
flatten(Layout<Shape,Stride> const& layout)
{
return make_layout(flatten(layout.shape()), flatten(layout.stride()));
return make_layout(flatten(layout.shape()),
flatten(layout.stride()));
}
// Return a layout whose profile is congruent to TargetProfile
// @pre Input layout is flat, flatten(@a layout) == @a layout
// @pre Input layout can be folded to profile, rank(@a layout) == rank(flatten(@a target_profile))
// @post congruent(@a result, @a target_profile)
template <class Shape, class Stride, class TargetProfile>
CUTE_HOST_DEVICE constexpr
auto
unflatten(Layout<Shape,Stride> const& layout, TargetProfile const& target_profile)
{
return make_layout(unflatten(layout.shape(), target_profile),
return make_layout(unflatten(layout.shape(), target_profile),
unflatten(layout.stride(), target_profile));
}
@ -498,7 +511,7 @@ unflatten(Layout<Shape,Stride> const& layout, TargetProfile const& target_profil
// Utilities
//
// Return the layout of a mode
// Return the sublayout of mode I...
template <int... Is, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
decltype(auto)
@ -609,17 +622,6 @@ using cosize_t = decltype(cosize(declval<Layout>()));
template <class Layout>
static constexpr int cosize_v = cosize_t<Layout>::value;
// Equality
// Return a static or dynamic boolean
template <class ShapeA, class StrideA,
class ShapeB, class StrideB>
CUTE_HOST_DEVICE constexpr
auto
operator==(Layout<ShapeA,StrideA> const& layoutA, Layout<ShapeB,StrideB> const& layoutB)
{
return layoutA.shape() == layoutB.shape() && layoutA.stride() == layoutB.stride();
}
// With crd2idx(coord, shape), makes sense to have crd2idx(coord, Layout) as well
template <class Coord, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
@ -762,8 +764,11 @@ bw_coalesce(OldShape const& old_shape, OldStride const& old_stride,
} // end namespace detail
// Combine all the modes that are possible to combine
// Does not respect the profile of the layout, but does preserve total size
// "Simplify" the layout by combining modes that are possible to combine
// Does not respect the shape of the layout, but does preserve total size
// @post size(@a result) == size(@a layout)
// @post depth(@a result) <= 1
// @post for all i, 0 <= i < size(@a layout), @a layout(i) == @a result(i)
template <class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
@ -894,7 +899,7 @@ group(Layout<Shape,Stride> const& layout)
// Composition of two layouts: lhs o rhs
// @post compatible(rhs, result)
// @post result(c) = lhs(rhs(c))
// for all c in the domain of result
// for all c in the domain of rhs
//
namespace detail {
@ -984,19 +989,19 @@ composition(Layout<LShape,LStride> const& lhs,
return detail::composition_impl(lhs, rhs.shape(), rhs.stride());
}
template <class LShape, class LStride, class IntTuple>
template <class LShape, class LStride, class Tiler>
CUTE_HOST_DEVICE constexpr
auto
composition(Layout<LShape,LStride> const& lhs,
IntTuple const& rhs)
Tiler const& rhs)
{
if constexpr (is_tuple<IntTuple>::value) {
static_assert(tuple_size<IntTuple>::value <= Layout<LShape,LStride>::rank);
if constexpr (is_tuple<Tiler>::value) {
static_assert(tuple_size<Tiler>::value <= Layout<LShape,LStride>::rank);
// Drop any modes of lhs that aren't hit by rhs
return detail::transform_layout(lhs, rhs, [](auto const& l, auto const& r) { return composition(l,r); }, make_seq<tuple_size<IntTuple>::value>{}, seq<>{}, seq<>{});
} else if constexpr (is_underscore<IntTuple>::value) {
return detail::transform_layout(lhs, rhs, [](auto const& l, auto const& r) { return composition(l,r); }, make_seq<tuple_size<Tiler>::value>{}, seq<>{}, seq<>{});
} else if constexpr (is_underscore<Tiler>::value) {
return lhs;
} else if constexpr (is_integral<IntTuple>::value) {
} else if constexpr (is_integral<Tiler>::value) {
return detail::composition_impl(lhs, rhs, Int<1>{});
}
@ -1041,19 +1046,25 @@ complement(Shape const& shape, Stride const& stride, CoSizeHi const& cosize_hi)
auto [shape, stride, result_shape, result_stride] = init;
auto min_stride = cute::min(stride);
auto min_idx = find(stride, min_stride);
auto new_shape = min_stride / get<i>(result_stride);
auto new_stride = get<min_idx>(shape) * min_stride;
static_assert(not is_constant<0, decltype(new_shape)>::value, "Non-injective Layout detected in complement.");
return cute::make_tuple(remove<min_idx>(shape), // Remove the min_idx from shape
remove<min_idx>(stride), // Remove the min_idx from stride
append(result_shape , min_stride / get<i>(result_stride)), // new shape = min_stride / last_stride
append(result_stride, get<min_idx>(shape) * min_stride)); // new stride = curr_shape * min_stride
return cute::make_tuple(remove<min_idx>(shape), // Remove the min_idx from shape
remove<min_idx>(stride), // Remove the min_idx from stride
append(result_shape , new_shape ), // new shape = min_stride / last_stride
append(result_stride, new_stride)); // new stride = curr_shape * min_stride
});
// Append the last shape mode
auto result_shape = append(result_shape_, get<0>(stride_) / get<R-1>(result_stride)); // new shape = min_stride / last_stride
auto new_shape = get<0>(stride_) / get<R-1>(result_stride);
static_assert(not is_constant<0, decltype(new_shape)>::value, "Non-injective Layout detected in complement.");
auto result_shape = append(result_shape_, new_shape); // new shape = min_stride / last_stride
// Compute the rest_shape and rest_stride
auto rest_stride = get<0>(shape_) * get<0>(stride_);
auto rest_shape = ceil_div(cosize_hi, rest_stride);
// Jump into coalesce and append (rest_shape, rest_stride)
return detail::bw_coalesce<R-1>(result_shape, result_stride, rest_shape, rest_stride);
}
@ -1323,14 +1334,14 @@ zip(Layout<TShape,TStride> const& layoutA,
// their own mode.
//
template <class LShape, class LStride, class IntTuple>
template <class LShape, class LStride, class Tiler>
CUTE_HOST_DEVICE constexpr
auto
tile_unzip(Layout<LShape,LStride> const& layout,
IntTuple const& tile)
Tiler const& tiler)
{
return make_layout(zip2_by(layout.shape(), tile),
zip2_by(layout.stride(), tile));
return make_layout(zip2_by(layout.shape(), tiler),
zip2_by(layout.stride(), tiler));
}
//
@ -1389,10 +1400,10 @@ auto
tiled_divide(Layout<LShape,LStride> const& layout,
Tiler const& tiler)
{
auto div = zipped_divide(layout, tiler);
auto result = zipped_divide(layout, tiler);
auto R = rank<1>(div);
return div(_, repeat<R>(_));
auto R1 = rank<1>(result);
return result(_, repeat<R1>(_));
}
// Same as zipped_divide, but unpacks both modes: (BLK_A,BLK_B,...,a,b,...,x,y)
@ -1403,40 +1414,41 @@ auto
flat_divide(Layout<LShape,LStride> const& layout,
Tiler const& tiler)
{
auto div = zipped_divide(layout, tiler);
auto result = zipped_divide(layout, tiler);
auto R0 = rank<0>(div);
auto R1 = rank<1>(div);
return div(repeat<R0>(_), repeat<R1>(_));
auto R0 = rank<0>(result);
auto R1 = rank<1>(result);
return result(repeat<R0>(_), repeat<R1>(_));
}
//
// Logical product
//
// @post compatible()
template <class LShape, class LStride,
class TShape, class TStride>
CUTE_HOST_DEVICE constexpr
auto
logical_product(Layout<LShape,LStride> const& layout,
logical_product(Layout<LShape,LStride> const& block,
Layout<TShape,TStride> const& tiler)
{
return make_layout(layout, composition(complement(layout, size(layout)*cosize(tiler)), tiler));
return make_layout(block, composition(complement(block, size(block)*cosize(tiler)), tiler));
}
template <class LShape, class LStride, class Tiler>
CUTE_HOST_DEVICE constexpr
auto
logical_product(Layout<LShape,LStride> const& layout,
logical_product(Layout<LShape,LStride> const& block,
Tiler const& tiler)
{
if constexpr (is_tuple<Tiler>::value) {
static_assert(tuple_size<Tiler>::value <= Layout<LShape,LStride>::rank, "logical_product: Too many modes in tiler.");
return transform_layout(layout, tiler, [](auto const& l, auto const& t) { return logical_product(l,t); });
return transform_layout(block, tiler, [](auto const& l, auto const& t) { return logical_product(l,t); });
} else if constexpr (is_underscore<Tiler>::value) {
return layout;
return block;
} else if constexpr (is_integral<Tiler>::value) {
return logical_product(layout, make_layout(tiler));
return logical_product(block, make_layout(tiler));
}
CUTE_GCC_UNREACHABLE;
@ -1452,10 +1464,10 @@ template <class LShape, class LStride,
class Tiler>
CUTE_HOST_DEVICE constexpr
auto
zipped_product(Layout<LShape,LStride> const& layout,
zipped_product(Layout<LShape,LStride> const& block,
Tiler const& tiler)
{
return tile_unzip(logical_product(layout, tiler), tiler);
return tile_unzip(logical_product(block, tiler), tiler);
}
// Same as zipped_product, but unpacks the second mode: ((BLK_A,BLK_B,...),a,b,...,x,y)
@ -1463,69 +1475,107 @@ template <class LShape, class LStride,
class Tiler>
CUTE_HOST_DEVICE constexpr
auto
tiled_product(Layout<LShape,LStride> const& layout,
tiled_product(Layout<LShape,LStride> const& block,
Tiler const& tiler)
{
auto div = zipped_product(layout, tiler);
auto result = zipped_product(block, tiler);
auto R = rank<1>(div);
return div(_, repeat<R>(_));
auto R1 = rank<1>(result);
return result(_, repeat<R1>(_));
}
// Attempts to reproduce a layout over a tiler
// That is, think of every element of "tiler" as a "layout"
// and return the layout of the resulting structure
// Same as zipped_product, but unpacks both modes: (BLK_A,BLK_B,...,a,b,...,x,y)
template <class LShape, class LStride,
class Tiler>
CUTE_HOST_DEVICE constexpr
auto
flat_product(Layout<LShape,LStride> const& block,
Tiler const& tiler)
{
auto result = zipped_product(block, tiler);
auto R0 = rank<0>(result);
auto R1 = rank<1>(result);
return result(repeat<R0>(_), repeat<R1>(_));
}
//
// Rank-sensitive products
//
// blocked_product -- Reproduce a block over a tiler.
// Think of every element of "tiler" as a "block"
// and return the layout of the resulting structure.
// @post rank(@a result) == cute::max(rank(@a block), rank(@a tiler))
template <class TShape, class TStride,
class UShape, class UStride>
CUTE_HOST_DEVICE constexpr
auto
blocked_product(Layout<TShape,TStride> const& layout,
blocked_product(Layout<TShape,TStride> const& block,
Layout<UShape,UStride> const& tiler)
{
constexpr int R = cute::max(rank_v<TShape>, rank_v<UShape>);
auto result = logical_product(append<R>(layout), append<R>(tiler));
auto result = logical_product(append<R>(block), append<R>(tiler));
return coalesce(zip(get<0>(result), get<1>(result)), repeat<R>(Int<1>{}));
return coalesce(zip(get<0>(result), get<1>(result)), tuple_repeat<R>(Int<1>{}));
}
// raked_product -- Reproduce a block over a tiler with block-interleaving.
// Think of every element of "tiler" as a "block", interleave those blocks,
// and return the layout of the resulting structure.
// @post rank(@a result) == cute::max(rank(@a block), rank(@a tiler))
template <class TShape, class TStride,
class UShape, class UStride>
CUTE_HOST_DEVICE constexpr
auto
raked_product(Layout<TShape,TStride> const& layout,
raked_product(Layout<TShape,TStride> const& block,
Layout<UShape,UStride> const& tiler)
{
constexpr int R = cute::max(rank_v<TShape>, rank_v<UShape>);
auto result = logical_product(append<R>(layout), append<R>(tiler));
auto result = logical_product(append<R>(block), append<R>(tiler));
return coalesce(zip(get<1>(result), get<0>(result)), repeat<R>(Int<1>{}));
return coalesce(zip(get<1>(result), get<0>(result)), tuple_repeat<R>(Int<1>{}));
}
// tile_to_shape -- Perform a product of a layout so that the result matches a target shape.
// This is similar to blocked_product, but specifies the result shape instead of the
// product shape, which is more convenient in certain circumstances.
// @param block The layout to repeat
// @param trg_shape The target shape of the result
// @param ord_shape The order of the modes of @a trg_shape to tile @a layout with.
// Defaults to GenColMajor, so @a layout will repeat
// across the first mode first, the second mode second, etc
// E.g. Step<_2,_1,_3> will cause @a layout to repeat
// across the second mode first, the first mode second, and the third mode last.
// @pre rank(@a block) <= rank(@a trg_shape)
// @post compatible(@a trg_shape, shape(@a result))
template <class Shape, class Stride,
class TrgShape, class ModeOrder = GenColMajor>
class TrgShape, class ModeOrder = LayoutLeft>
CUTE_HOST_DEVICE constexpr
auto
tile_to_shape(Layout<Shape,Stride> const& layout,
tile_to_shape(Layout<Shape,Stride> const& block,
TrgShape const& trg_shape,
ModeOrder const& ord_shape = {})
{
CUTE_STATIC_ASSERT_V(rank(layout) <= rank(trg_shape), "Rank of layout must be <= rank of target shape.");
CUTE_STATIC_ASSERT_V(rank(block) <= rank(trg_shape), "Rank of layout must be <= rank of target shape.");
constexpr int R = rank_v<TrgShape>;
auto padded_layout = append<R>(layout);
auto padded_block = append<R>(block);
auto layout_shape = product_each(padded_layout.shape());
auto target_shape = product_each(trg_shape);
auto block_shape = product_each(shape(padded_block));
auto target_shape = product_each(shape(trg_shape));
// Assert proper division
CUTE_STATIC_ASSERT_V(sum(transform(target_shape, layout_shape, modulus{})) == Int<0>{},
"Layout shape does not divide the target shape.");
if constexpr (is_static<decltype(target_shape)>::value) {
CUTE_STATIC_ASSERT_V(weakly_compatible(block_shape, target_shape),
"tile_to_shape: block shape does not divide the target shape.");
}
auto product_shape = shape_div(target_shape, layout_shape);
auto product_shape = ceil_div(target_shape, block_shape);
return coalesce(blocked_product(padded_layout, make_ordered_layout(product_shape, ord_shape)), product_shape);
return coalesce(blocked_product(padded_block, make_ordered_layout(product_shape, ord_shape)), product_shape);
}
//
@ -1602,15 +1652,20 @@ CUTE_HOST_DEVICE constexpr
auto
recast_layout(Layout<Shape,Stride> const& layout)
{
if constexpr (sizeof_bits<NewType>::value == sizeof_bits<OldType>::value) {
using scale = decltype(trait_ratio(sizeof_bits<NewType>{}, sizeof_bits<OldType>{}));
if constexpr (scale::num == 1 && scale::den == 1) {
return layout;
} else if constexpr (sizeof_bits<NewType>::value > sizeof_bits<OldType>::value) {
static_assert(sizeof_bits<NewType>::value % sizeof_bits<OldType>::value == 0, "NewType must be a multiple of OldType");
return upcast<sizeof_bits<NewType>::value/sizeof_bits<OldType>::value>(layout);
} else if constexpr (sizeof_bits<NewType>::value < sizeof_bits<OldType>::value) {
static_assert(sizeof_bits<OldType>::value % sizeof_bits<NewType>::value == 0, "NewType must be a divisor of OldType");
return downcast<sizeof_bits<OldType>::value/sizeof_bits<NewType>::value>(layout);
}
else if constexpr (scale::num == 1) {
return downcast<scale::den>(layout);
}
else if constexpr (scale::den == 1) {
return upcast<scale::num>(layout);
}
else {
static_assert(dependent_false<scale>, "Recast not supported.");
}
CUTE_GCC_UNREACHABLE;
}
@ -1693,12 +1748,13 @@ print_layout(Layout const& layout, ThrID const& thrid) // (m,n) -> (tid,vid) a
}
// Generic 2D Layout to Latex printer -- B&W 8-value color coding
template <class Layout>
template <class LayoutA>
CUTE_HOST_DEVICE
void
print_latex(Layout const& layout) // (m,n) -> idx
print_latex(LayoutA const& layout_a)
{
CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{});
CUTE_STATIC_ASSERT_V(rank(layout_a) <= Int<2>{});
auto layout = append<2>(layout_a, Layout<_1,_0>{});
char const* latex_header =
"\\documentclass[convert]{standalone}\n"
@ -1727,7 +1783,6 @@ print_latex(Layout const& layout) // (m,n) -> idx
for (int i = 0; i < size<0>(layout); ++i) {
for (int j = 0; j < size<1>(layout); ++j) {
int idx = layout(i,j);
printf("\\node[box,fill=%s] at (%d,%d) {%d};\n",
color_map[idx % 8],
i, j,

View File

@ -37,7 +37,7 @@
/* This implements a ComposedLayout of the form
* LayoutA o Offset o LayoutB
* and is useful in cases where composition() does not or cannot apply to LayoutA and LayoutB.
* For example, then the "divisibility condition" in shape_div is violated in composition(LayoutA, LayoutB).
* For example, when the "divisibility condition" in shape_div is violated in composition(LayoutA, LayoutB).
*
* This ComposedLayout provides similar functionality to Layout including tiling, partitioning,
* coordinate-to-index mapping and layout manipulations, but is not considered a "normal" layout.
@ -357,12 +357,11 @@ composition(LayoutA const& layoutA,
return ComposedLayout<LayoutA, Offset, LayoutB>{layoutA, offset, layoutB};
}
template <class A, class O, class B,
class LayoutOrTile>
template <class A, class O, class B, class Tiler>
CUTE_HOST_DEVICE constexpr
auto
composition(ComposedLayout<A,O,B> const& a,
LayoutOrTile const& b)
Tiler const& b)
{
return composition(a.layout_a(), a.offset(), composition(a.layout_b(), b));
}
@ -433,92 +432,101 @@ zip(ComposedLayout<A,O,B> const& a)
// Partitions
template <class A, class O, class B,
class Tile>
template <class A, class O, class B, class Tiler>
CUTE_HOST_DEVICE constexpr
auto
logical_divide(ComposedLayout<A,O,B> const& a,
Tile const& b)
Tiler const& b)
{
return composition(a.layout_a(), a.offset(), logical_divide(a.layout_b(), b));
}
template <class A, class O, class B,
class Tile>
template <class A, class O, class B, class Tiler>
CUTE_HOST_DEVICE constexpr
auto
tile_unzip(ComposedLayout<A,O,B> const& a,
Tile const& b)
Tiler const& b)
{
return composition(a.layout_a(), a.offset(), tile_unzip(a.layout_b(), b));
}
template <class A, class O, class B,
class Tile>
template <class A, class O, class B, class Tiler>
CUTE_HOST_DEVICE constexpr
auto
tiled_divide(ComposedLayout<A,O,B> const& a,
Tile const& b)
Tiler const& b)
{
return composition(a.layout_a(), a.offset(), tiled_divide(a.layout_b(), b));
}
template <class A, class O, class B,
class Tile>
template <class A, class O, class B, class Tiler>
CUTE_HOST_DEVICE constexpr
auto
zipped_divide(ComposedLayout<A,O,B> const& a,
Tile const& b)
Tiler const& b)
{
return composition(a.layout_a(), a.offset(), zipped_divide(a.layout_b(), b));
}
template <class A, class O, class B,
class Tile>
template <class A, class O, class B, class Tiler>
CUTE_HOST_DEVICE constexpr
auto
flat_divide(ComposedLayout<A,O,B> const& a,
Tile const& b)
Tiler const& b)
{
return composition(a.layout_a(), a.offset(), flat_divide(a.layout_b(), b));
}
template <class A, class O, class B,
class Tile>
template <class A, class O, class B, class Tiler>
CUTE_HOST_DEVICE constexpr
auto
logical_product(ComposedLayout<A,O,B> const& a,
Tile const& b)
Tiler const& b)
{
return composition(a.layout_a(), a.offset(), logical_product(a.layout_b(), b));
}
template <class A, class O, class B,
class Tile>
template <class A, class O, class B, class Tiler>
CUTE_HOST_DEVICE constexpr
auto
zipped_product(ComposedLayout<A,O,B> const& a,
Tiler const& b)
{
return composition(a.layout_a(), a.offset(), zipped_product(a.layout_b(), b));
}
template <class A, class O, class B, class Tiler>
CUTE_HOST_DEVICE constexpr
auto
tiled_product(ComposedLayout<A,O,B> const& a,
Tile const& b)
Tiler const& b)
{
return composition(a.layout_a(), a.offset(), tiled_product(a.layout_b(), b));
}
template <class A, class O, class B,
class Tile>
template <class A, class O, class B, class Tiler>
CUTE_HOST_DEVICE constexpr
auto
flat_product(ComposedLayout<A,O,B> const& a,
Tiler const& b)
{
return composition(a.layout_a(), a.offset(), flat_product(a.layout_b(), b));
}
template <class A, class O, class B, class Tiler>
CUTE_HOST_DEVICE constexpr
auto
blocked_product(ComposedLayout<A,O,B> const& a,
Tile const& b)
Tiler const& b)
{
return composition(a.layout_a(), a.offset(), blocked_product(a.layout_b(), b));
}
template <class A, class O, class B,
class Tile>
template <class A, class O, class B, class Tiler>
CUTE_HOST_DEVICE constexpr
auto
raked_product(ComposedLayout<A,O,B> const& a,
Tile const& b)
Tiler const& b)
{
return composition(a.layout_a(), a.offset(), raked_product(a.layout_b(), b));
}
@ -585,16 +593,19 @@ CUTE_HOST_DEVICE constexpr
auto
recast_layout(ComposedLayout<A,O,B> const& layout)
{
if constexpr (sizeof(NewType) == sizeof(OldType)) {
using scale = decltype(trait_ratio(sizeof_bits<NewType>{}, sizeof_bits<OldType>{}));
if constexpr (scale::num == 1 && scale::den == 1) {
return layout;
} else if constexpr (sizeof(NewType) > sizeof(OldType)) {
static_assert(sizeof(NewType) % sizeof(OldType) == 0, "NewType must be a multiple of OldType");
return upcast<sizeof(NewType)/sizeof(OldType)>(layout);
} else if constexpr (sizeof(NewType) < sizeof(OldType)) {
static_assert(sizeof(OldType) % sizeof(NewType) == 0, "NewType must be a divisor of OldType");
return downcast<sizeof(OldType)/sizeof(NewType)>(layout);
}
else if constexpr (scale::num == 1) {
return downcast<scale::den>(layout);
}
else if constexpr (scale::den == 1) {
return upcast<scale::num>(layout);
}
else {
static_assert(dependent_false<scale>, "Recast not supported.");
}
CUTE_GCC_UNREACHABLE;
}

View File

@ -413,6 +413,19 @@ conditional_return(TrueType const& t, FalseType const& f) {
}
}
template <class Trait>
CUTE_HOST_DEVICE constexpr
auto
static_value()
{
if constexpr (is_std_integral<decltype(Trait::value)>::value) {
return Int<Trait::value>{};
} else {
return Trait::value;
}
CUTE_GCC_UNREACHABLE;
}
//
// Display utilities
//

View File

@ -65,6 +65,11 @@ class R {
using type = typename conditional<num == 0 || den == 1, C<num>, R<num,den>>::type;
};
template <class T>
struct is_ratio : false_type {};
template <auto n, auto d>
struct is_ratio<R<n,d>> : true_type {};
template <auto a, auto b>
CUTE_HOST_DEVICE constexpr
typename R<a,b>::type
@ -72,6 +77,59 @@ ratio(C<a>, C<b>) {
return {};
}
template <auto a, auto b, auto c>
CUTE_HOST_DEVICE constexpr
typename R<a*c,b>::type
ratio(C<a>, R<b,c>) {
return {};
}
template <auto a, auto b, auto c>
CUTE_HOST_DEVICE constexpr
typename R<b,a*c>::type
ratio(R<b,c>, C<a>) {
return {};
}
template <auto a, auto b, auto c, auto d>
CUTE_HOST_DEVICE constexpr
typename R<a*d,b*c>::type
ratio(R<a,b>, R<c,d>) {
return {};
}
//
// Non-reduced ratio implementations
//
template <auto a, auto b>
CUTE_HOST_DEVICE constexpr
R<a,b>
nratio(C<a>, C<b>) {
return {};
}
template <auto a, auto b, auto c>
CUTE_HOST_DEVICE constexpr
R<a*c,b>
nratio(C<a>, R<b,c>) {
return {};
}
template <auto a, auto b, auto c>
CUTE_HOST_DEVICE constexpr
R<b,a*c>
nratio(R<b,c>, C<a>) {
return {};
}
template <auto a, auto b, auto c, auto d>
CUTE_HOST_DEVICE constexpr
R<a*d,b*c>
nratio(R<a,b>, R<c,d>) {
return {};
}
template <auto a, auto b, auto x, auto y>
CUTE_HOST_DEVICE constexpr
typename R<a*x,b*y>::type
@ -93,6 +151,13 @@ operator*(C<c>, R<a,b>) {
return {};
}
template <auto c, auto a, auto b>
CUTE_HOST_DEVICE constexpr
typename R<c*b,a>::type
operator/(C<c>, R<a,b>) {
return {};
}
// Product with dynamic type needs to produce an integer...
template <class C, auto a, auto b,
__CUTE_REQUIRES(cute::is_std_integral<C>::value)>
@ -160,6 +225,23 @@ abs(R<a,b>) {
return {};
}
template <auto a, auto b>
CUTE_HOST_DEVICE constexpr
auto
log_2(R<a,b>) {
static_assert(R<a,b>::num > 0);
static_assert(R<a,b>::den > 0);
return log_2(static_cast<uint32_t>(R<a,b>::num)) - log_2(static_cast<uint32_t>(R<a,b>::den));
}
template <class Trait0, class Trait1>
CUTE_HOST_DEVICE constexpr
auto
trait_ratio(Trait0, Trait1) {
return nratio(static_value<Trait0>(), static_value<Trait1>());
}
//
// Display utilities
//

View File

@ -310,4 +310,17 @@ safe_div(T const& t, U const& u) {
return t / u;
}
/**
* log2 computation
*/
template <class T>
CUTE_HOST_DEVICE constexpr
auto
log_2(T x) {
assert(x > 0);
static_assert(is_unsigned<T>::value, "Only to be used for unsigned integral types.");
return bit_width(x) - 1;
}
} // namespace cute

View File

@ -41,6 +41,7 @@
#include <cute/pointer_base.hpp>
#include <cute/pointer_swizzle.hpp>
#include <cute/layout.hpp>
namespace cute
{

View File

@ -227,7 +227,7 @@ raw_pointer_cast(counting_iterator<T> const& x) {
template <class T>
CUTE_HOST_DEVICE void print(T const* const ptr)
{
printf("ptr[%db](%p)", int(sizeof_bits<T>::value), ptr);
printf("ptr["); print(sizeof_bits<T>::value); printf("b](%p)", ptr);
}
template <class T>

View File

@ -37,7 +37,8 @@
namespace cute
{
/** crd2idx maps a coordinate within <Shape,Stride> to an index
/** crd2idx(c,s,d) maps a coordinate within <Shape,Stride> to an index
*
* This is computed as follows:
* [coord, shape, and stride are all integers => step forward by stride]
* op(c, s, d) => c * d
@ -46,7 +47,6 @@ namespace cute
* [coord, shape, and stride are all tuples => consider each mode independently]
* op((c,C), (s,S), (d,D)) => op(c, s, d) + op((C), (S), (D))
*/
template <class Coord, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
@ -115,10 +115,6 @@ crd2idx(Coord const& coord,
CUTE_GCC_UNREACHABLE;
}
//
// If we know Stride is default [CompactColMajor], then we can take shortcuts
//
namespace detail {
template <class CTuple, class STuple, int I0, int... Is>
@ -138,26 +134,31 @@ crd2idx_horner(CTuple const& coord,
} // end namespace detail
/** crd2idx(c,s) maps a coordinate within Shape to an index
* via a colexicographical enumeration of coordinates in Shape.
* i = c0 + s0 * (c1 + s1 * (c2 + s2 * ...))
*/
template <class Coord, class Shape>
CUTE_HOST_DEVICE constexpr
auto
crd2idx(Coord const& coord,
Shape const& shape)
{
static_assert(decltype(congruent(coord,shape))::value, "Mismatched Ranks");
if constexpr (is_tuple<Shape>::value) {
// Flatten and apply Horner's method
auto flat_coord = flatten(coord);
auto flat_shape = flatten(shape);
return detail::crd2idx_horner(flat_coord, flat_shape, tuple_seq<decltype(flat_shape)>{});
} else {
if constexpr (is_integral<Coord>::value) { // Coord is already an index
return coord;
} else if constexpr (is_integral<Shape>::value) {
static_assert(dependent_false<Shape>, "Invalid parameters");
} else { // Make congruent, flatten, and apply Horner's method
static_assert(tuple_size<Coord>::value == tuple_size<Shape>::value, "Mismatched Ranks");
auto flat_coord = flatten(coord);
auto flat_shape = flatten(product_like(shape, coord));
return detail::crd2idx_horner(flat_coord, flat_shape, tuple_seq<decltype(flat_shape)>{});
}
CUTE_GCC_UNREACHABLE;
}
/** idx2crd splits an index to a coordinate within <Shape,Stride>.
/** idx2crd(i,s,d) splits an index into a coordinate within <Shape,Stride>.
*
* This is computed as follows:
* [index, shape, and stride are all integers => determine 1D coord]
@ -170,7 +171,6 @@ crd2idx(Coord const& coord,
* NOTE: This only works for compact shape+stride layouts. A more general version would
* apply to all surjective layouts
*/
template <class Index, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
@ -207,15 +207,13 @@ idx2crd(Index const& idx,
CUTE_GCC_UNREACHABLE;
}
//
// If we know Stride is default [CompactColMajor], then we can take shortcuts
//
//(idx / 1) % s0
//(idx / s0) % s1
//(idx / (s0 * s1)) % s2
//...
/** idx2crd(i,s) splits an index into a coordinate within Shape
* via a colexicographical enumeration of coordinates in Shape.
* c0 = (idx / 1) % s0
* c1 = (idx / s0) % s1
* c2 = (idx / (s0 * s1)) % s2
* ...
*/
template <class Index, class Shape>
CUTE_HOST_DEVICE constexpr
auto

View File

@ -434,15 +434,20 @@ CUTE_HOST_DEVICE constexpr
auto
recast_layout(Swizzle<B,M,S> const& swizzle)
{
if constexpr (sizeof_bits<NewType>::value == sizeof_bits<OldType>::value) {
using scale = decltype(trait_ratio(sizeof_bits<NewType>{}, sizeof_bits<OldType>{}));
if constexpr (scale::num == 1 && scale::den == 1) {
return swizzle;
} else if constexpr (sizeof_bits<NewType>::value > sizeof_bits<OldType>::value) {
static_assert(sizeof_bits<NewType>::value % sizeof_bits<OldType>::value == 0, "NewType must be a multiple of OldType");
return upcast<sizeof_bits<NewType>::value/sizeof_bits<OldType>::value>(swizzle);
} else if constexpr (sizeof_bits<NewType>::value < sizeof_bits<OldType>::value) {
static_assert(sizeof_bits<OldType>::value % sizeof_bits<NewType>::value == 0, "NewType must be a divisor of OldType");
return downcast<sizeof_bits<OldType>::value/sizeof_bits<NewType>::value>(swizzle);
}
else if constexpr (scale::num == 1) {
return downcast<scale::den>(swizzle);
}
else if constexpr (scale::den == 1) {
return upcast<scale::num>(swizzle);
}
else {
static_assert(dependent_false<scale>, "Recast not supported.");
}
CUTE_GCC_UNREACHABLE;
}
//
@ -453,7 +458,7 @@ template <int B, int M, int S, class Offset, class LayoutB, class Shape, class S
CUTE_HOST_DEVICE constexpr
auto
max_common_layout(ComposedLayout<Swizzle<B,M,S>,Offset,LayoutB> const& a,
Layout<Shape,Stride> const& b)
Layout<Shape,Stride> const& b)
{
auto common = max_common_layout(a.layout_b(), b);
auto base = Int<(1 << M)>{};
@ -467,7 +472,7 @@ max_common_layout(ComposedLayout<Swizzle<B,M,S>,Offset,LayoutB> const& a,
template <class Shape, class Stride, int B, int M, int S, class Offset, class LayoutB>
CUTE_HOST_DEVICE constexpr
auto
max_common_layout(Layout<Shape,Stride> const& a,
max_common_layout(Layout<Shape,Stride> const& a,
ComposedLayout<Swizzle<B,M,S>,Offset,LayoutB> const& b)
{
return max_common_layout(b, a);
@ -477,7 +482,7 @@ template <int B, int M, int S, class Offset, class LayoutB, class Shape, class S
CUTE_HOST_DEVICE constexpr
auto
max_common_vector(ComposedLayout<Swizzle<B,M,S>,Offset,LayoutB> const& a,
Layout<Shape,Stride> const& b)
Layout<Shape,Stride> const& b)
{
// This assumes that Offset is in the YZ domain of the Swizzle...
return cute::min(Int<(1 << M)>{}, max_common_vector(a.layout_b(), b));
@ -486,7 +491,7 @@ max_common_vector(ComposedLayout<Swizzle<B,M,S>,Offset,LayoutB> const& a,
template <class Shape, class Stride, int B, int M, int S, class Offset, class LayoutB>
CUTE_HOST_DEVICE constexpr
auto
max_common_vector(Layout<Shape,Stride> const& a,
max_common_vector(Layout<Shape,Stride> const& a,
ComposedLayout<Swizzle<B,M,S>,Offset,LayoutB> const& b)
{
return max_common_vector(b, a);
@ -517,13 +522,13 @@ template <class Shape, class Stride,
int B, int M, int S, class Offset, class LayoutT>
CUTE_HOST_DEVICE constexpr
auto
logical_product(Layout<Shape,Stride> const& block,
ComposedLayout<Swizzle<B,M,S>,Offset,LayoutT> const& tile)
logical_product(Layout<Shape,Stride> const& layout,
ComposedLayout<Swizzle<B,M,S>,Offset,LayoutT> const& tiler)
{
CUTE_STATIC_ASSERT_V(tile.offset() == Int<0>{}, "Require Swizzle offset == 0.");
CUTE_STATIC_ASSERT_V(tiler.offset() == Int<0>{}, "Require Swizzle offset == 0.");
// The new layout -- if swizzle wasn't an issue, this is the result
// our goal is to determine a new swizzle for these strides
auto new_layout = logical_product(block, tile.layout_b());
auto new_layout = logical_product(layout, tiler.layout_b());
// This is accomplished by identifying
// S o L :=: S? o L*
@ -536,8 +541,8 @@ logical_product(Layout<Shape,Stride> const& block,
auto swizzle_only_zy = make_layout(make_shape (Int<(1 << M)>{}, Int<(1 << B)>{}, Int<(1 << (abs(S)-B))>{}, Int<(1 << B )>{}, Int<1>{}),
make_stride( Int<0>{}, Int<(1 << M)>{}, Int<0>{}, Int<(1 << (M+abs(S)))>{}, Int<0>{}));
// Compose with the tile to get the swizzle projection, P o L [The Z and Y contributing portions of L]
auto layout_only_zy = composition(swizzle_only_zy, tile.layout_b());
// Compose with the tiler to get the swizzle projection, P o L [The Z and Y contributing portions of L]
auto layout_only_zy = composition(swizzle_only_zy, tiler.layout_b());
// Transform the end coordinate to get the active bits of the swizzle, (P o L)(c*)
auto swizzle_active_bits = layout_only_zy(size(layout_only_zy)-Int<1>{});
// Get the Z bit and the Y bits
@ -545,8 +550,8 @@ logical_product(Layout<Shape,Stride> const& block,
auto active_Y = swizzle_active_bits & typename Swizzle<B,M,S>::yyy_msk{};
// Pass the identifiers through the old layout and new layout to make a new swizzle identifier, L*(L[(P o L)(c*)])
auto new_active_Z = new_layout(Int<0>{}, tile.layout_b()[active_Z]);
auto new_active_Y = new_layout(Int<0>{}, tile.layout_b()[active_Y]);
auto new_active_Z = new_layout(Int<0>{}, tiler.layout_b()[active_Z]);
auto new_active_Y = new_layout(Int<0>{}, tiler.layout_b()[active_Y]);
// Use this new swizzle identifier to construxt the new swizzle for new_layout
// (this also makes sure it's a "valid" swizzle that Swizzle can represent)

View File

@ -127,6 +127,18 @@ print(unsigned long long a) {
printf("%llu", a);
}
CUTE_HOST_DEVICE
void
print(float a) {
printf("%f", a);
}
CUTE_HOST_DEVICE
void
print(double a) {
printf("%f", a);
}
template <class... T>
CUTE_HOST_DEVICE
void

View File

@ -236,7 +236,7 @@ public:
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
asm volatile(
"{\n\t"
"mbarrier.init.shared.b64 [%1], %0; \n"
"mbarrier.init.shared::cta.b64 [%1], %0; \n"
"}"
:
: "r"(arrive_count), "r"(smem_addr));
@ -256,7 +256,7 @@ public:
"{\n\t"
".reg .pred P1; \n\t"
"LAB_WAIT: \n\t"
"mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t"
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1, %2; \n\t"
"@P1 bra.uni DONE; \n\t"
"bra.uni LAB_WAIT; \n\t"
"DONE: \n\t"
@ -280,7 +280,7 @@ public:
".reg .pred P1; \n\t"
".reg .pred P2; \n\t"
"setp.eq.u32 P2, %3, 1;\n\t"
"@P2 mbarrier.test_wait.parity.shared.b64 P1, [%1], %2; \n\t"
"@P2 mbarrier.test_wait.parity.shared::cta.b64 P1, [%1], %2; \n\t"
"selp.b32 %0, 1, 0, P1; \n\t"
"}"
: "=r"(waitComplete)
@ -302,7 +302,7 @@ public:
asm volatile(
"{\n\t"
".reg .pred P1; \n\t"
"mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \n\t"
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%1], %2; \n\t"
"selp.b32 %0, 1, 0, P1; \n\t"
"}"
: "=r"(waitComplete)
@ -342,7 +342,7 @@ public:
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
asm volatile(
"{\n\t"
"mbarrier.arrive.shared.b64 _, [%0];\n\t"
"mbarrier.arrive.shared::cta.b64 _, [%0];\n\t"
"}"
:
: "r"(smem_addr));
@ -357,7 +357,7 @@ public:
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
asm volatile(
"{\n\t"
"mbarrier.ival.shared.b64 [%0]; \n\t"
"mbarrier.ival.shared::cta.b64 [%0]; \n\t"
"}"
:
: "r"(smem_addr));
@ -418,7 +418,7 @@ struct ClusterTransactionBarrier : public ClusterBarrier {
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
asm volatile(
"{\n\t"
"mbarrier.arrive.expect_tx.shared.b64 _, [%1], %0; \n\t"
"mbarrier.arrive.expect_tx.shared::cta.b64 _, [%1], %0; \n\t"
"}"
:
: "r"(transaction_bytes), "r"(smem_addr));
@ -455,7 +455,7 @@ struct ClusterTransactionBarrier : public ClusterBarrier {
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
asm volatile(
"{\n\t"
"mbarrier.expect_tx.shared.b64 [%1], %0; \n\t"
"mbarrier.expect_tx.shared::cta.b64 [%1], %0; \n\t"
"}"
:
: "r"(transaction_bytes), "r"(smem_addr));
@ -563,7 +563,7 @@ void cpasync_barrier_arrive(uint64_t const* smem_ptr) {
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
asm volatile(
"{\n\t"
"cp.async.mbarrier.arrive.shared.b64 [%0];\n\t"
"cp.async.mbarrier.arrive.shared::cta.b64 [%0];\n\t"
"}"
:
: "r"(smem_addr));

View File

@ -77,7 +77,7 @@ struct ClusterLauncher {
constexpr static int MaxClusterSize = 32;
// Check for hardware compatibility
static inline __host__
static inline CUTLASS_HOST
Status check_cluster_dims(dim3 grid, dim3 cluster) {
if (((cluster.x * cluster.y * cluster.z) <= MaxClusterSize) &&
(grid.x % cluster.x == 0) && (grid.y % cluster.y == 0) && (grid.z % cluster.z == 0)) {
@ -89,7 +89,7 @@ struct ClusterLauncher {
}
}
static inline __host__
static inline CUTLASS_HOST
Status
#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED)
init(void const* kernel_function)
@ -109,7 +109,7 @@ struct ClusterLauncher {
}
// This is the method we expect to use going forward
static inline __host__
static inline CUTLASS_HOST
Status launch(
dim3 const grid_dims,
dim3 const cluster_dims,
@ -217,7 +217,7 @@ struct ClusterLaunchParams {
/// kernel_ptr, x, y, z);
/// @endcode
template<class ... Args>
__host__ cutlass::Status
CUTLASS_HOST cutlass::Status
launch_kernel_on_cluster(const ClusterLaunchParams& params,
void const* kernel_ptr,
Args&& ... args)

View File

@ -81,23 +81,59 @@ struct CudaHostAdapter {
void *kernel_handles[kMaximumKernelCount];
int32_t kernel_count = 0;
//
// Methods
//
/// Ctor
CudaHostAdapter() = default;
/// Dtor
virtual ~CudaHostAdapter() {}
/// Copy Ctor deleted
CudaHostAdapter(const CudaHostAdapter&) = delete;
/// Copy Ctor
inline CudaHostAdapter(const CudaHostAdapter & rhs):
kernel_count(rhs.kernel_count)
{
CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount);
for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) {
kernel_handles[i] = rhs.kernel_handles[i];
}
}
/// Copy Assignment deleted
CudaHostAdapter& operator=(const CudaHostAdapter&) = delete;
/// Copy Assignment
inline CudaHostAdapter& operator=(const CudaHostAdapter & rhs) {
/// Move ctor deleted
CudaHostAdapter(CudaHostAdapter&&) = delete;
CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount);
for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) {
kernel_handles[i] = rhs.kernel_handles[i];
}
kernel_count = rhs.kernel_count;
return *this;
}
/// Move assignment deleted
CudaHostAdapter& operator=(CudaHostAdapter&&) = delete;
/// Move ctor
inline CudaHostAdapter(CudaHostAdapter && rhs):
kernel_count(rhs.kernel_count)
{
CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount);
for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) {
kernel_handles[i] = rhs.kernel_handles[i];
}
}
/// Move assignment
inline CudaHostAdapter& operator=(CudaHostAdapter && rhs) {
CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount);
for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) {
kernel_handles[i] = rhs.kernel_handles[i];
}
kernel_count = rhs.kernel_count;
return *this;
}
/// Ctor
inline CudaHostAdapter(
@ -112,13 +148,19 @@ struct CudaHostAdapter {
}
}
/// Returns true if the CudaHostAdapter is empty (kernel_count == 0)
inline bool empty() const { return !kernel_count; }
/// Returns kernel_count
inline size_t size() const { return static_cast<size_t>(kernel_count); }
/// Queries the occupancy of a kernel
virtual Status query_occupancy(
int32_t *device_sms,
int32_t *sm_occupancy,
int32_t kernel_index,
int32_t thread_count,
int32_t smem_size) = 0;
int32_t smem_size) const = 0;
/// Launches a kernel without using Threadblock Clusters.
virtual Status launch(
@ -127,7 +169,7 @@ struct CudaHostAdapter {
size_t const smem_size,
cudaStream_t cuda_stream,
void** kernel_params,
int32_t kernel_index) = 0;
int32_t kernel_index) const = 0;
/// Launches a kernel using the CUDA Extensible Launch API and Threadblock Clusters.
virtual Status launch(
@ -137,7 +179,7 @@ struct CudaHostAdapter {
size_t const smem_size,
cudaStream_t cuda_stream,
void** kernel_params,
int32_t kernel_index) = 0;
int32_t kernel_index) const = 0;
};
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -57,6 +57,9 @@
#define CUTLASS_DEVICE inline
#endif
#define CUTLASS_HOST __host__
#define CUTLASS_GLOBAL __global__ static
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>

View File

@ -60,7 +60,7 @@ namespace cutlass {
/// Generic CUTLASS kernel template.
template <typename Operator>
__global__
CUTLASS_GLOBAL
void Kernel(typename Operator::Params params) {
// Dynamic shared memory base pointer
extern __shared__ int SharedStorageBase[];
@ -76,7 +76,7 @@ void Kernel(typename Operator::Params params) {
/// Generic CUTLASS kernel template.
template <typename Operator>
__global__
CUTLASS_GLOBAL
void Kernel2(typename Operator::Params params) {
// Dynamic shared memory base pointer
extern __shared__ int SharedStorageBase[];
@ -96,7 +96,7 @@ void Kernel2(typename Operator::Params params) {
/// Generic CUTLASS kernel template.
template <typename Operator>
__global__
CUTLASS_GLOBAL
#ifdef __CUDACC__
// Enclosing this in __CUDACC__ suppresses MSVC warnings.
__launch_bounds__(Operator::MaxThreadsPerBlock, Operator::MinBlocksPerMultiprocessor)

View File

@ -58,7 +58,6 @@ struct FusionOperation {
static constexpr int AlignmentScalar = 0;
static constexpr bool IsScaleFactorSupported = false;
static constexpr bool IsPerRowScaleSupported = false;
using ElementBias = void;
static constexpr int AlignmentBias = 0;
static constexpr bool IsPerRowBiasSupported = false;

View File

@ -240,8 +240,10 @@ public:
NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
frag_Z = convert_z(result_Z);
NumericArrayConverter<ElementT, ElementCompute, kElementsPerAccess> convert_t;
frag_T = convert_t(result_T);
if constexpr (kStoreT) {
NumericArrayConverter<ElementT, ElementCompute, kElementsPerAccess> convert_t;
frag_T = convert_t(result_T);
}
}
/// Applies the operation when is_source_needed() is false
@ -269,8 +271,10 @@ public:
NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
frag_Z = convert_z(result_Z);
NumericArrayConverter<ElementT, ElementCompute, kElementsPerAccess> convert_t;
frag_T = convert_t(result_T);
if constexpr (kStoreT) {
NumericArrayConverter<ElementT, ElementCompute, kElementsPerAccess> convert_t;
frag_T = convert_t(result_T);
}
}
};

View File

@ -402,7 +402,7 @@ struct OutputTileThreadLayout: DefaultThreadMapTensorOp<
CUTLASS_DEVICE
static auto tid2coord(int thread_idx) {
return make_layout(ThreadShape{})[thread_idx];
return cute::idx2crd(thread_idx, ThreadShape{});
}
template <class TensorInput>

View File

@ -1,3 +1,34 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"

View File

@ -44,7 +44,6 @@
namespace cutlass::gemm::collective {
using namespace cute;
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
@ -78,7 +77,8 @@ struct CollectiveMma<
GmemTiledCopyB_,
SmemLayoutAtomB_,
SmemCopyAtomB_,
TransformB_>
TransformB_
>
{
//
// Type Aliases
@ -286,7 +286,6 @@ struct CollectiveMma<
copy(smem_tiled_copy_B, tCsB_p(_,_,Int<0>{}), tCrB_copy_view(_,_,Int<0>{}));
}
CUTLASS_PRAGMA_NO_UNROLL
for ( ; k_tile_count > -(DispatchPolicy::Stages-1); --k_tile_count)
{
@ -332,6 +331,7 @@ struct CollectiveMma<
});
}
}
};
@ -352,7 +352,8 @@ template <
class GmemTiledCopyB_,
class SmemLayoutAtomB_,
class SmemCopyAtomB_,
class TransformB_>
class TransformB_
>
struct CollectiveMma<
MainloopSm80CpAsync<Stages>,
TileShape_,
@ -368,7 +369,8 @@ struct CollectiveMma<
GmemTiledCopyB_,
SmemLayoutAtomB_,
SmemCopyAtomB_,
TransformB_>
TransformB_
>
{
//
// Type Aliases
@ -627,7 +629,6 @@ struct CollectiveMma<
copy(smem_tiled_copy_B, tCsB_p(_,_,Int<0>{}), tCrB_copy_view(_,_,Int<0>{}));
}
CUTLASS_PRAGMA_NO_UNROLL
for ( ; k_tile_count > -(DispatchPolicy::Stages-1); --k_tile_count)
{
@ -678,6 +679,7 @@ struct CollectiveMma<
});
}
}
};

View File

@ -353,11 +353,9 @@ struct CollectiveMma<
int thread_idx,
uint32_t block_rank_in_cluster,
TensorStorage& shared_tensors) {
int warp_idx = canonical_warp_idx_sync();
int warp_idx_in_warp_group = warp_idx % 4;
int lane_predicate = cute::elect_one_sync();
if (warp_idx_in_warp_group == 0 and lane_predicate) {
if (lane_predicate) {
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
@ -433,12 +431,10 @@ struct CollectiveMma<
// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
CUTLASS_DEVICE void
load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) {
int warp_idx = canonical_warp_idx_sync();
int warp_idx_in_warp_group = warp_idx % 4;
int lane_predicate = cute::elect_one_sync();
// Issue the epilogue waits
if (warp_idx_in_warp_group == 0 and lane_predicate) {
if (lane_predicate) {
// This helps avoid early exit of blocks in Cluster.
// Waits for all stages to either be released (all
// Consumer UNLOCKs), or if the stage was never used

View File

@ -380,15 +380,10 @@ struct CollectiveMma<
KTileIterator k_tile_iter, int k_tile_count,
int thread_idx,
uint32_t block_rank_in_cluster,
TensorStorage& shared_tensors)
{
using namespace cute;
int warp_idx = canonical_warp_idx_sync();
int warp_idx_in_warp_group = warp_idx % 4;
TensorStorage& shared_tensors) {
int lane_predicate = cute::elect_one_sync();
if (warp_idx_in_warp_group == 0 and lane_predicate) {
if (lane_predicate) {
Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE)
@ -464,14 +459,11 @@ struct CollectiveMma<
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
CUTLASS_DEVICE void
load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write)
{
int warp_idx = canonical_warp_idx_sync();
int warp_idx_in_warp_group = warp_idx % 4;
load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) {
int lane_predicate = cute::elect_one_sync();
// Issue the epilogue waits
if (warp_idx_in_warp_group == 0 and lane_predicate) {
if (lane_predicate) {
/* This helps avoid early exit of blocks in Cluster
* Waits for all stages to either be released (all
* Consumer UNLOCKs), or if the stage was never used
@ -494,9 +486,7 @@ struct CollectiveMma<
int k_tile_count,
int thread_idx,
TensorStorage& shared_tensors,
Params const& mainloop_params)
{
using namespace cute;
Params const& mainloop_params) {
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");

View File

@ -680,9 +680,6 @@ public:
int thread_idx,
uint32_t block_rank_in_cluster,
TensorStorage& shared_tensors) {
using namespace cute;
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
static_assert(sizeof... (Ts) == 2, "Direct convert needs two inputs");
}
@ -696,11 +693,9 @@ public:
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in TMA load.");
}
int warp_idx = canonical_warp_idx_sync();
int warp_idx_in_warp_group = warp_idx % 4;
int lane_predicate = cute::elect_one_sync();
if (warp_idx_in_warp_group == 0 and lane_predicate) {
if (lane_predicate) {
Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE)
@ -812,12 +807,10 @@ public:
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
CUTLASS_DEVICE void
load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) {
int warp_idx = canonical_warp_idx_sync();
int warp_idx_in_warp_group = warp_idx % 4;
int lane_predicate = cute::elect_one_sync();
// Issue the epilogue waits
if (warp_idx_in_warp_group == 0 and lane_predicate) {
if (lane_predicate) {
/* This helps avoid early exit of blocks in Cluster
* Waits for all stages to either be released (all
* Consumer UNLOCKs), or if the stage was never used
@ -841,7 +834,6 @@ public:
int thread_idx,
TensorStorage& shared_tensors,
Params const& mainloop_params) {
using namespace cute;
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");

View File

@ -111,6 +111,8 @@ struct CollectiveMma<
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename cutlass::PipelineState<DispatchPolicy::Stages>;
static constexpr int ThreadCount = CUTE_STATIC_V(size(TiledMma{}));
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");

View File

@ -300,12 +300,9 @@ struct CollectiveMma<
int thread_idx,
uint32_t block_rank_in_cluster,
TensorStorage& shared_tensors) {
using namespace cute;
int warp_idx = canonical_warp_idx_sync();
int warp_idx_in_warp_group = warp_idx % 4;
int lane_predicate = cute::elect_one_sync();
if (warp_idx_in_warp_group == 0 and lane_predicate) {
if (lane_predicate) {
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
@ -381,12 +378,10 @@ struct CollectiveMma<
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
CUTLASS_DEVICE void
load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) {
int warp_idx = canonical_warp_idx_sync();
int warp_idx_in_warp_group = warp_idx % 4;
int lane_predicate = cute::elect_one_sync();
// Issue the epilogue waits
if (warp_idx_in_warp_group == 0 and lane_predicate) {
if (lane_predicate) {
/* This helps avoid early exit of blocks in Cluster
* Waits for all stages to either be released (all
* Consumer UNLOCKs), or if the stage was never used
@ -410,8 +405,6 @@ struct CollectiveMma<
int thread_idx,
TensorStorage& shared_tensors,
Params const& mainloop_params) {
using namespace cute;
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");

View File

@ -297,15 +297,10 @@ struct CollectiveMma<
KTileIterator k_tile_iter, int k_tile_count,
int thread_idx,
uint32_t block_rank_in_cluster,
TensorStorage& shared_tensors)
{
using namespace cute;
int warp_idx = canonical_warp_idx_sync();
int warp_idx_in_warp_group = warp_idx % 4;
TensorStorage& shared_tensors) {
int lane_predicate = cute::elect_one_sync();
if (warp_idx_in_warp_group == 0 and lane_predicate) {
if (lane_predicate) {
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
@ -382,14 +377,11 @@ struct CollectiveMma<
CUTLASS_DEVICE void
load_tail(
MainloopPipeline pipeline,
PipelineState smem_pipe_write)
{
int warp_idx = canonical_warp_idx_sync();
int warp_idx_in_warp_group = warp_idx % 4;
PipelineState smem_pipe_write) {
int lane_predicate = cute::elect_one_sync();
// Issue the epilogue waits
if (warp_idx_in_warp_group == 0 and lane_predicate) {
if (lane_predicate) {
/* This helps avoid early exit of blocks in Cluster
* Waits for all stages to either be released (all
* Consumer UNLOCKs), or if the stage was never used
@ -412,9 +404,7 @@ struct CollectiveMma<
int k_tile_count,
int thread_idx,
TensorStorage& shared_tensors,
Params const& mainloop_params)
{
using namespace cute;
Params const& mainloop_params) {
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");

View File

@ -75,7 +75,7 @@ template <
/// Operator class tag
typename OperatorClass_ = arch::OpClassSimt,
/// Tag indicating architecture to tune for
typename ArchTag_ = arch::Sm70,
typename ArchTag_ = arch::Sm80,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
@ -243,7 +243,7 @@ public:
/// Gets the workspace size
static size_t get_workspace_size(Arguments const &args) {
size_t bytes = 0;
return bytes;
@ -271,7 +271,7 @@ public:
args.ref_E.non_const_ref(),
args.epilogue
};
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
if (smem_size >= (48 << 10)) {
cudaError_t result = cudaFuncSetAttribute(Kernel<GemmKernel>,
@ -324,9 +324,9 @@ public:
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess) {
status = run(stream);
}

View File

@ -339,7 +339,10 @@ public:
/// Primary run() entry point API that is static allowing users to create and manage their own params.
/// Supplied params struct must be construct by calling GemmKernel::to_underling_arguments()
static Status
run(Params& params, cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) {
run(Params& params,
cudaStream_t stream = nullptr,
CudaHostAdapter *cuda_adapter = nullptr) {
CUTLASS_TRACE_HOST("GemmUniversal::run()");
dim3 const block = GemmKernel::get_block_shape();
dim3 const grid = get_grid_shape(params);
@ -425,7 +428,9 @@ public:
cudaStream_t stream = nullptr,
CudaHostAdapter *cuda_adapter = nullptr
) {
Status status = initialize(args, workspace, stream);
Status status = initialize(args, workspace, stream, cuda_adapter);
if (Status::kSuccess == status) {
status = run(params_, stream, cuda_adapter);
}
@ -444,14 +449,14 @@ public:
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
run(cudaStream_t stream = nullptr) {
return run(params_, stream);
run(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) {
return run(params_, stream, cuda_adapter);
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
operator()(cudaStream_t stream = nullptr) {
return run(params_, stream);
operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) {
return run(params_, stream, cuda_adapter);
}
};

View File

@ -70,6 +70,8 @@ class GemmUniversalBase {
public:
using GemmKernel = GemmKernel_;
/// Boolean indicating whether the CudaHostAdapter is enabled
static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER;
using ThreadblockShape = typename GemmKernel::Mma::Shape;
@ -99,6 +101,14 @@ public:
/// Argument structure
using Arguments = typename GemmKernel::Arguments;
/// Index of the GEMM Kernel within the CudaHostAdapter
static int32_t const kGemmKernelIndex = 0;
/// Kernel dynamic shared memory allocation requirement
/// Update the kernel function's shared memory configuration for the current device
static constexpr size_t kSharedStorageSize = sizeof(typename GemmKernel::SharedStorage);
protected:
//
@ -114,9 +124,7 @@ protected:
/// Kernel SM occupancy (in thread blocks)
CUTLASS_THREAD_LOCAL static int sm_occupancy_;
/// Kernel dynamic shared memory allocation requirement
/// Update the kernel function's shared memory configuration for the current device
static constexpr size_t smem_size_ = sizeof(typename GemmKernel::SharedStorage);
protected:
/// Initialize static thread-local members for the thread's current device,
/// if necessary.
@ -148,12 +156,12 @@ protected:
}
// If requires more than 48KB: configure for extended, dynamic shared memory
if constexpr (smem_size_ >= (48 << 10))
if constexpr (kSharedStorageSize >= (48 << 10))
{
cudart_result = cudaFuncSetAttribute(
Kernel2<GemmKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size_);
kSharedStorageSize);
if (cudart_result != cudaSuccess) {
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(cudart_result));
return Status::kErrorInternal;
@ -165,7 +173,7 @@ protected:
&sm_occupancy_,
Kernel2<GemmKernel>,
GemmKernel::kThreadCount,
smem_size_,
kSharedStorageSize,
cudaOccupancyDisableCachingOverride);
if (cudart_result != cudaSuccess) {
CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags() returned error " << cudaGetErrorString(cudart_result));
@ -179,7 +187,7 @@ protected:
"device_ordinal: (" << device_ordinal_ << "), "
"device_sms: (" << device_sms_ << "), "
"sm_occupancy: (" << sm_occupancy_ << ") "
"smem_size: (" << smem_size_ << ") "
"smem_size: (" << kSharedStorageSize << ") "
"GemmKernel::kThreadCount: (" << GemmKernel::kThreadCount << ")");
return Status::kSuccess;
@ -197,16 +205,58 @@ protected:
/// Initialize params member
Status init_params(Arguments const &args)
Status init_params(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr)
{
// Initialize static device properties, if necessary
Status result = init_device_props();
if (result != Status::kSuccess) {
return result;
int32_t device_sms = 0;
int32_t sm_occupancy = 0;
if constexpr (kEnableCudaHostAdapter) {
CUTLASS_ASSERT(cuda_adapter);
//
// Occupancy query using CudaHostAdapter::query_occupancy().
//
if (cuda_adapter) {
Status status = cuda_adapter->query_occupancy(
&device_sms,
&sm_occupancy,
kGemmKernelIndex,
GemmKernel::kThreadCount,
kSharedStorageSize);
CUTLASS_ASSERT(status == Status::kSuccess);
if (status != Status::kSuccess) {
return status;
}
}
else {
return Status::kErrorInternal;
}
}
else {
CUTLASS_ASSERT(cuda_adapter == nullptr);
// Initialize static device properties, if necessary
Status result = init_device_props();
if (result != Status::kSuccess) {
return result;
}
//
// Use thread-local static members for occupancy query initialized by call to
// `init_device_props()`
//
device_sms = device_sms_;
sm_occupancy = sm_occupancy_;
}
// Initialize params member
params_ = typename GemmKernel::Params(args, device_sms_, sm_occupancy_);
params_ = typename GemmKernel::Params(args, device_sms, sm_occupancy);
return Status::kSuccess;
}
@ -217,11 +267,11 @@ public:
//---------------------------------------------------------------------------------------------
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args)
static Status can_implement(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr)
{
CUTLASS_TRACE_HOST("GemmUniversalBase::can_implement()");
dim3 grid = get_grid_shape(args);
dim3 grid = get_grid_shape(args, cuda_adapter);
if (!(grid.y <= std::numeric_limits<uint16_t>::max() &&
grid.z <= std::numeric_limits<uint16_t>::max()))
@ -235,13 +285,13 @@ public:
/// Returns the workspace size (in bytes) needed for the problem
/// geometry expressed by these arguments
static size_t get_workspace_size(Arguments const &args)
static size_t get_workspace_size(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr)
{
CUTLASS_TRACE_HOST("GemmUniversalBase::get_workspace_size()");
// Initialize parameters from args
GemmUniversalBase base;
if (base.init_params(args) != Status::kSuccess) {
if (base.init_params(args, cuda_adapter) != Status::kSuccess) {
return 0;
}
@ -254,13 +304,13 @@ public:
/// Returns the grid extents in thread blocks to launch
static dim3 get_grid_shape(Arguments const &args)
static dim3 get_grid_shape(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr)
{
CUTLASS_TRACE_HOST("GemmUniversalBase::get_grid_shape()");
// Initialize parameters from args
GemmUniversalBase base;
if (base.init_params(args) != Status::kSuccess) {
if (base.init_params(args, cuda_adapter) != Status::kSuccess) {
return dim3(0,0,0);
}
@ -276,17 +326,48 @@ public:
/// Returns the maximum number of active thread blocks per multiprocessor
static int maximum_active_blocks()
static int maximum_active_blocks(CudaHostAdapter *cuda_adapter = nullptr)
{
CUTLASS_TRACE_HOST("GemmUniversalBase::maximum_active_blocks()");
// Initialize static device properties, if necessary
if (init_device_props() != Status::kSuccess) {
return -1;
int32_t device_sms = 0;
int32_t sm_occupancy = 0;
if constexpr (kEnableCudaHostAdapter) {
CUTLASS_ASSERT(cuda_adapter);
if (cuda_adapter) {
Status status = cuda_adapter->query_occupancy(
&device_sms,
&sm_occupancy,
kGemmKernelIndex,
GemmKernel::kThreadCount,
kSharedStorageSize);
CUTLASS_ASSERT(status == Status::kSuccess);
if (status != Status::kSuccess) {
return -1;
}
}
else {
return -1;
}
}
else {
CUTLASS_ASSERT(cuda_adapter == nullptr);
// Initialize static device properties, if necessary
if (init_device_props() != Status::kSuccess) {
return -1;
}
sm_occupancy = sm_occupancy_;
}
CUTLASS_TRACE_HOST(" max_active_blocks: " << sm_occupancy_);
return sm_occupancy_;
return sm_occupancy;
}
@ -305,7 +386,7 @@ public:
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
// Initialize parameters from args
Status result = init_params(args);
Status result = init_params(args, cuda_adapter);
if (result != Status::kSuccess) {
return result;
}
@ -340,13 +421,13 @@ public:
CUTLASS_TRACE_HOST(" "
"grid: (" << grid << "), "
"block: (" << block << "), "
"SMEM: (" << smem_size_ << ")");
"SMEM: (" << kSharedStorageSize << ")");
if constexpr (kEnableCudaHostAdapter) {
CUTLASS_ASSERT(cuda_adapter);
if (cuda_adapter) {
void* kernel_params[] = {&params_};
return cuda_adapter->launch(grid, block, smem_size_, stream, kernel_params, 0);
return cuda_adapter->launch(grid, block, kSharedStorageSize, stream, kernel_params, 0);
}
else {
return Status::kErrorInternal;
@ -355,7 +436,7 @@ public:
else {
CUTLASS_ASSERT(cuda_adapter == nullptr);
Kernel2<GemmKernel><<<grid, block, smem_size_, stream>>>(params_);
Kernel2<GemmKernel><<<grid, block, kSharedStorageSize, stream>>>(params_);
// Query for errors
cudaError_t result = cudaGetLastError();
@ -370,9 +451,9 @@ public:
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr)
Status operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr)
{
return run(stream);
return run(stream, cuda_adapter);
}
@ -383,7 +464,7 @@ public:
cudaStream_t stream = nullptr,
CudaHostAdapter *cuda_adapter = nullptr)
{
Status status = initialize(args, workspace, stream);
Status status = initialize(args, workspace, stream, cuda_adapter);
if (status == Status::kSuccess) {
status = run(stream, cuda_adapter);

View File

@ -195,4 +195,3 @@ struct DefaultSparseGemmWithVisitor<ElementA, LayoutA, kAlignmentA, ElementB, La
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@ -53,7 +53,7 @@ namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Mma, typename Epilogue, typename ThreadblockSwizzle>
__global__ void GemmPipelined(
CUTLASS_GLOBAL void GemmPipelined(
cutlass::gemm::GemmCoord problem_size,
cutlass::gemm::GemmCoord grid_tiled_shape,
typename Mma::IteratorA::Params params_A,

View File

@ -186,7 +186,7 @@ CUTLASS_DEVICE void GemvBatchedStridedDevice(
}
template <typename GemvKernel, typename ElementAlphaBeta, bool BetaIsZero>
__global__ void GemvBatchedStrided(
CUTLASS_GLOBAL void GemvBatchedStrided(
cutlass::gemm::BatchedGemmCoord problem_size,
ElementAlphaBeta alpha,
ElementAlphaBeta beta,
@ -205,7 +205,7 @@ __global__ void GemvBatchedStrided(
}
template <typename GemvKernel, typename ElementAlphaBeta>
__global__ void GemvBatchedStrided(
CUTLASS_GLOBAL void GemvBatchedStrided(
cutlass::gemm::BatchedGemmCoord problem_size,
ElementAlphaBeta alpha,
typename GemvKernel::IteratorA::TensorRef ref_A,
@ -221,7 +221,7 @@ __global__ void GemvBatchedStrided(
}
template <typename GemvKernel>
__global__ void GemvBatchedStrided(
CUTLASS_GLOBAL void GemvBatchedStrided(
cutlass::gemm::BatchedGemmCoord problem_size,
typename GemvKernel::IteratorA::TensorRef ref_A,
typename GemvKernel::IteratorA::TensorRef::LongIndex lda,

View File

@ -59,7 +59,6 @@ public:
// Type Aliases
//
using ProblemShape = ProblemShape_;
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
@ -77,13 +76,14 @@ public:
using MainloopArguments = typename CollectiveMainloop::Arguments;
using MainloopParams = typename CollectiveMainloop::Params;
static_assert(cute::is_void_v<TileScheduler_> or cute::is_same_v<TileScheduler_, PersistentScheduler>,
"SM70 kernel does not support specializing the tile scheduler.");
using TileSchedulerTag = TileScheduler_;
using TileScheduler = typename detail::TileSchedulerSelector<
TileScheduler_, ArchTag, TileShape,
cute::Shape<cute::Int<1>, cute::Int<1>, cute::Int<1>>>::Scheduler;
using TileSchedulerArguments = typename TileScheduler::Arguments;
static constexpr bool is_valid_tile_scheduler =
cute::is_void_v<TileScheduler_> or cute::is_same_v<TileScheduler_, PersistentScheduler>;
static_assert(is_valid_tile_scheduler, "SM70 kernel does not support specializing the tile scheduler.");
// Epilogue derived types
using CollectiveEpilogue = CollectiveEpilogue_;
@ -131,6 +131,10 @@ public:
Params
to_underlying_arguments(Arguments const& args, void* workspace) {
(void) workspace;
KernelHardwareInfo hw_info{args.hw_info.device_id, args.hw_info.sm_count};
auto problem_shape_MNKL = append<4>(args.problem_shape, Int<1>{});
return {
args.mode,
args.problem_shape,
@ -148,13 +152,16 @@ public:
static int
get_workspace_size(Arguments const& args) {
return 0;
int workspace_size = 0;
return workspace_size;
}
static
cutlass::Status
initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
return Status::kSuccess;
cutlass::Status status = Status::kSuccess;
return status;
}
static dim3

View File

@ -45,7 +45,6 @@
#include "cutlass/pipeline/pipeline.hpp"
#include "cute/tensor.hpp"
#include "cutlass/trace.h"
///////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::kernel {
@ -74,7 +73,6 @@ public:
using ProblemShape = ProblemShape_;
static_assert(rank(typename ProblemShape::UnderlyingProblemShape{}) == 3 or rank(typename ProblemShape::UnderlyingProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types
using CollectiveMainloop = CollectiveMainloop_;
using TileShape = typename CollectiveMainloop::TileShape;

View File

@ -40,7 +40,6 @@
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp"
#include "cutlass/trace.h"
#include "cute/tensor.hpp"
///////////////////////////////////////////////////////////////////////////////
@ -82,7 +81,6 @@ public:
using ProblemShape = ProblemShape_;
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types
using CollectiveMainloop = CollectiveMainloop_;
using TileShape = typename CollectiveMainloop::TileShape;
@ -121,7 +119,8 @@ public:
sizeof(typename CollectiveMainloop::SharedStorage),
sizeof(typename CollectiveEpilogue::SharedStorage)));
static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{}));
static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::ThreadCount;
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
// Device side arguments

View File

@ -44,7 +44,6 @@
#include "cutlass/trace.h"
#include "cute/tensor.hpp"
///////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::kernel {
@ -71,7 +70,6 @@ public:
using ProblemShape = ProblemShape_;
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types
using CollectiveMainloop = CollectiveMainloop_;
using TileShape = typename CollectiveMainloop::TileShape;

View File

@ -44,7 +44,6 @@
#include "cutlass/pipeline/pipeline.hpp"
#include "cute/tensor.hpp"
#include "cutlass/trace.h"
///////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::kernel {
@ -71,7 +70,6 @@ public:
using ProblemShape = ProblemShape_;
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types
using CollectiveMainloop = CollectiveMainloop_;
using TileShape = typename CollectiveMainloop::TileShape;

View File

@ -45,7 +45,6 @@
#include "cutlass/trace.h"
#include "cute/tensor.hpp"
///////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::kernel {
@ -72,7 +71,6 @@ public:
using ProblemShape = ProblemShape_;
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types
using CollectiveMainloop = CollectiveMainloop_;
using TileShape = typename CollectiveMainloop::TileShape;
@ -521,10 +519,10 @@ public:
shared_storage.tensors.epilogue
);
// Get next work tile
scheduler.advance_to_next_work();
work_tile_info = scheduler.get_current_work();
} // Scheduler work fetch loop
// Get next work tile
scheduler.advance_to_next_work();
work_tile_info = scheduler.get_current_work();
} // Scheduler work fetch loop
// Make sure all Consumer Warp Groups have been waited upon
collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state);

View File

@ -42,7 +42,6 @@
#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "cute/tensor.hpp"
///////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::kernel {
@ -69,7 +68,6 @@ public:
using ProblemShape = ProblemShape_;
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types
using CollectiveMainloop = CollectiveMainloop_;
using TileShape = typename CollectiveMainloop::TileShape;

View File

@ -44,7 +44,6 @@
#include "cutlass/trace.h"
#include "cute/tensor.hpp"
///////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::kernel {

View File

@ -29,165 +29,24 @@
*
**************************************************************************************************/
#pragma once
#include "cutlass/gemm/kernel/static_tile_scheduler.hpp"
#include "cutlass/fast_math.h"
#include "cutlass/gemm_coord.hpp"
#include "cutlass/kernel_hardware_info.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cute/layout.hpp"
#include "cute/tensor.hpp"
#include "cute/arch/cluster_sm90.hpp"
namespace cutlass::gemm::kernel::detail {
///////////////////////////////////////////////////////////////////////////////
// Persistent Thread Block (TB) scheduler
class PersistentTileSchedulerSm90 {
//
// Data members
//
private:
uint64_t current_work_linear_idx_;
uint64_t total_grid_size_;
class PersistentTileSchedulerSm90:
public StaticPersistentTileScheduler<PersistentTileSchedulerSm90> {
using BaseScheduler = StaticPersistentTileScheduler<PersistentTileSchedulerSm90>;
public:
struct WorkTileInfo {
int32_t M_idx = 0;
int32_t N_idx = 0;
int32_t L_idx = 0;
bool is_valid_tile = false;
CUTLASS_HOST_DEVICE
bool
is_valid() const {
return is_valid_tile;
}
CUTLASS_HOST_DEVICE
static WorkTileInfo
invalid_work_tile() {
return {-1, -1, -1, false};
}
CUTLASS_HOST_DEVICE
bool
is_final_split(uint32_t k_tiles_per_output_tile) const {
return true;
}
CUTLASS_HOST_DEVICE
int32_t
reduction_subtile_idx() const {
return -1;
}
};
using StaticPersistentTileScheduler::StaticPersistentTileScheduler;
using Params = PersistentTileSchedulerSm90Params;
using RasterOrder = typename Params::RasterOrder;
using RasterOrderOptions = typename Params::RasterOrderOptions;
struct Arguments {
int max_swizzle_size = 1;
RasterOrderOptions raster_order = RasterOrderOptions::Heuristic;
};
// Sink scheduler params as a member
Params scheduler_params;
//
// Methods
//
template <class ProblemShapeMNKL, class TileShape, class ClusterShape>
static Params
to_underlying_arguments(
ProblemShapeMNKL problem_shape_mnkl,
TileShape tile_shape,
ClusterShape cluster_shape,
[[maybe_unused]] KernelHardwareInfo const& hw_info,
Arguments const& arguments,
[[maybe_unused]] void* workspace=nullptr,
[[maybe_unused]] const uint32_t epilogue_subtile = 1) {
// We only need the tile and cluster shape during scheduler setup, so let FTAD do the magic
static_assert(cute::is_static<TileShape>::value);
static_assert(cute::is_static<ClusterShape>::value);
dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cluster_shape);
Params params;
params.initialize(
problem_blocks,
to_gemm_coord(cluster_shape),
hw_info,
arguments.max_swizzle_size,
arguments.raster_order
);
return params;
}
CUTLASS_HOST_DEVICE
static bool
can_implement(Arguments const& args) {
return true;
}
CUTLASS_HOST_DEVICE
PersistentTileSchedulerSm90() { };
CUTLASS_DEVICE explicit PersistentTileSchedulerSm90(Params const& params_) : scheduler_params(params_) {
// MSVC requires protecting use of CUDA-specific nonstandard syntax,
// like blockIdx and gridDim, with __CUDA_ARCH__.
#if defined(__CUDA_ARCH__)
if (params_.raster_order_ == RasterOrder::AlongN) {
current_work_linear_idx_ = uint64_t(blockIdx.x) + uint64_t(blockIdx.y) * uint64_t(gridDim.x);
}
else {
current_work_linear_idx_ = uint64_t(blockIdx.x) * uint64_t(gridDim.y) + uint64_t(blockIdx.y);
}
total_grid_size_ = uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z);
#else
CUTLASS_ASSERT(false && "This line should never be reached");
#endif
}
CUTLASS_DEVICE
WorkTileInfo
get_current_work() const {
return get_current_work_for_linear_idx(current_work_linear_idx_);
}
CUTLASS_DEVICE
WorkTileInfo
get_current_work_for_linear_idx(uint64_t linear_idx) const {
if (linear_idx >= scheduler_params.blocks_per_problem_) {
return WorkTileInfo::invalid_work_tile();
}
// Map worker's linear index into the CTA tiled problem shape to the corresponding MNL indices
uint64_t work_idx_l, remainder;
scheduler_params.divmod_batch_(work_idx_l, remainder, linear_idx);
uint64_t blk_per_grid_dim = scheduler_params.divmod_cluster_shape_minor_.divide(remainder);
auto [work_idx_m, work_idx_n] = get_work_idx_m_and_n(blk_per_grid_dim,
scheduler_params.divmod_cluster_shape_major_,
scheduler_params.divmod_cluster_shape_minor_,
scheduler_params.divmod_cluster_blk_major_,
scheduler_params.log_swizzle_size_,
scheduler_params.raster_order_);
return {work_idx_m, work_idx_n, static_cast<int32_t>(work_idx_l), true};
}
CUTLASS_DEVICE
void
advance_to_next_work(uint32_t advance_count = 1) {
current_work_linear_idx_ += total_grid_size_ * uint64_t(advance_count);
}
using Arguments = BaseScheduler::Arguments;
// get work_idx_m, work_idx_n from blk_per_grid_dim while applying swizzle
static CUTLASS_DEVICE
@ -236,111 +95,6 @@ public:
}
// Computes the linear index within a batch given M and N tile offsets within the batch.
// This essentially inverts the mapping performed in get_work_idx_m_and_n
static CUTLASS_DEVICE
uint64_t
get_linear_idx_from_m_and_n(
int32_t tile_m,
int32_t tile_n,
FastDivmodU64Pow2 const& divmod_cluster_shape_major,
FastDivmodU64Pow2 const& divmod_cluster_shape_minor,
FastDivmodU64 const& divmod_cluster_blk_major,
int32_t log_swizzle_size,
RasterOrder raster_order) {
auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster();
uint64_t minor_work_idx, major_work_idx, cluster_minor_offset;
if (raster_order == RasterOrder::AlongN) {
minor_work_idx = static_cast<uint64_t>(tile_m);
major_work_idx = static_cast<uint64_t>(tile_n);
cluster_minor_offset = cta_m_in_cluster;
}
else {
major_work_idx = static_cast<uint64_t>(tile_m);
minor_work_idx = static_cast<uint64_t>(tile_n);
cluster_minor_offset = cta_n_in_cluster;
}
uint64_t cluster_idx_minor, cluster_idx_major, cluster_major_offset;
cluster_idx_minor = divmod_cluster_shape_minor.divide(minor_work_idx - cluster_minor_offset);
divmod_cluster_shape_major(cluster_idx_major, cluster_major_offset, major_work_idx);
uint64_t cluster_idx_minor_div_swizzle = cluster_idx_minor >> log_swizzle_size;
uint64_t offset = cluster_idx_minor & ((1 << log_swizzle_size) - 1);
uint64_t extra = cluster_idx_minor_div_swizzle * divmod_cluster_blk_major.divisor + cluster_idx_major;
uint64_t cluster_id = (extra << log_swizzle_size) | offset;
return (cluster_id * divmod_cluster_shape_major.divisor + cluster_major_offset) * divmod_cluster_shape_minor.divisor + cluster_minor_offset;
}
// Given the inputs, computes the total number of output blocks this problem will compute over
// Note that this is only the logical size of our grid, not the physical grid we will actually launch.
template<class ProblemShapeMNKL, class BlockShape, class ClusterShape>
CUTLASS_HOST_DEVICE static
dim3
get_tiled_cta_shape_mnl(ProblemShapeMNKL problem_shape_mnkl, BlockShape cta_shape, ClusterShape cluster_shape) {
auto cta_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shape_mnkl), cute::shape<0>(cta_shape)));
auto cta_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shape_mnkl), cute::shape<1>(cta_shape)));
return Params::get_tiled_cta_shape_mnl(
to_gemm_coord(problem_shape_mnkl),
to_gemm_coord(cluster_shape),
cta_m, cta_n
);
}
// Given the inputs, computes the physical grid we should launch.
template<class ProblemShapeMNKL, class BlockShape, class ClusterShape>
CUTLASS_HOST_DEVICE static
dim3
get_grid_shape(
ProblemShapeMNKL problem_shape_mnk,
BlockShape cta_shape,
ClusterShape cluster_shape,
KernelHardwareInfo hw_info,
Arguments arguments,
bool truncate_by_problem_size=true) {
auto problem_shape_mnkl = cute::append<4>(problem_shape_mnk, cute::Int<1>{});
dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, cta_shape, cluster_shape);
return Params::get_grid_shape(
problem_blocks,
to_gemm_coord(cluster_shape),
hw_info,
arguments.max_swizzle_size,
arguments.raster_order,
/* truncate_by_problem_size = */true
);
}
// Returns whether the block assigned this work should compute the epilogue for the corresponding
// output tile. For the basic tile scheduler, this is always true.
CUTLASS_HOST_DEVICE
static bool
compute_epilogue(WorkTileInfo const&, Params const&) {
return true;
}
// Performs the reduction across splits for a given output tile. Since this scheduler does
// not split output tiles, no reduction is needed.
template <class FrgTensorC>
CUTLASS_DEVICE
static void
fixup(Params const&, WorkTileInfo const&, FrgTensorC&, uint32_t, uint32_t) {}
// Returns whether the current WorkTileInfo passed in should continue to be used. Since
// this scheduler only schedules work in units of single, full output tiles, the WorkTileInfo
// passed in should not be used after having been processed.
CUTLASS_DEVICE
static bool
continue_current_work(WorkTileInfo&) {
return false;
}
// The basic tile scheduler does not require any additional workspace
template <class ProblemShape, class ElementAccumulator>
static int
@ -355,74 +109,6 @@ public:
return Status::kSuccess;
}
template <class ProblemShape, class TileShape>
CUTLASS_HOST_DEVICE
static int
get_work_k_tile_count(WorkTileInfo const& work_tile_info, ProblemShape problem_shape, TileShape tile_shape) {
// All work units returned by this scheduler cover the entire K iteration
// space of the output tile assigned to the work unit.
return cute::size(cute::ceil_div(cute::get<2>(problem_shape), cute::get<2>(tile_shape)));
}
CUTLASS_HOST_DEVICE
static uint32_t
get_work_k_tile_start(WorkTileInfo const&) {
// All work units returned by this scheduler start from K tile 0
return 0u;
}
CUTLASS_DEVICE
static bool
need_separate_reduction(Params const& params) {
return false;
}
CUTLASS_DEVICE
bool
is_work_tile_for_reduction(WorkTileInfo const& work_tile_info, Params const& params) {
return false;
}
CUTLASS_DEVICE
uint32_t
epilgoue_subtile_idx(WorkTileInfo const& work_tile_info, Params const& params) const {
return 0;
}
template <class FrgTensorC>
CUTLASS_DEVICE
void
separate_reduction(
Params const& params,
WorkTileInfo const& work_tile_info,
FrgTensorC& accumulators,
uint32_t num_barriers,
uint32_t barrier_idx) {
}
// Shares the accumulator set with peers in the global workspace
template <class FrgTensorC>
CUTLASS_DEVICE
static void
share(
Params const& params,
WorkTileInfo const& work_tile_info,
FrgTensorC& accumulators,
uint32_t num_barriers,
uint32_t barrier_idx) {
}
CUTLASS_DEVICE
static bool
valid_warpgroup_in_work_tile(WorkTileInfo const& work_tile_info) {
return true;
}
CUTLASS_DEVICE
static bool
requires_separate_reduction(Params const& params) {
return false;
}
};
} // namespace cutlass::gemm::kernel::detail
}

View File

@ -94,6 +94,7 @@ struct SparseGemm {
//
// Data members
//
typename Epilogue::OutputTileIterator::Params params_C;
typename Epilogue::OutputTileIterator::TensorRef ref_C;
typename Epilogue::OutputTileIterator::Params params_D;
@ -125,8 +126,8 @@ struct SparseGemm {
ref_C(ref_C),
params_D(ref_D.layout()),
ref_D(ref_D),
output_op(output_op),
semaphore(workspace) {
output_op(output_op) {
semaphore = workspace;
}
};

View File

@ -1,3 +1,4 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause

View File

@ -0,0 +1,453 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/fast_math.h"
#include "cutlass/gemm_coord.hpp"
#include "cutlass/kernel_hardware_info.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cute/layout.hpp"
#include "cute/tensor.hpp"
#include "cute/arch/cluster_sm90.hpp"
#include "cutlass/pipeline/pipeline.hpp"
namespace cutlass::gemm::kernel::detail {
///////////////////////////////////////////////////////////////////////////////
// Users are not supposed to use this class directly.
// This is a CRTP base class for the actual tile schedulers.
template<class Subclass>
class StaticPersistentTileScheduler {
//
// Data members
//
private:
uint64_t current_work_linear_idx_;
uint64_t total_grid_size_;
public:
struct WorkTileInfo {
int32_t M_idx = 0;
int32_t N_idx = 0;
int32_t L_idx = 0;
bool is_valid_tile = false;
CUTLASS_HOST_DEVICE
bool
is_valid() const {
return is_valid_tile;
}
CUTLASS_HOST_DEVICE
static WorkTileInfo
invalid_work_tile() {
return {-1, -1, -1, false};
}
CUTLASS_HOST_DEVICE
bool
is_final_split(uint32_t k_tiles_per_output_tile) const {
return true;
}
CUTLASS_HOST_DEVICE
int32_t
reduction_subtile_idx() const {
return -1;
}
};
using Params = PersistentTileSchedulerSm90Params;
using RasterOrder = typename Params::RasterOrder;
using RasterOrderOptions = typename Params::RasterOrderOptions;
public:
struct Arguments {
int max_swizzle_size = 1;
RasterOrderOptions raster_order = RasterOrderOptions::Heuristic;
};
template <class ProblemShapeMNKL, class TileShape, class ClusterShape>
static Params
to_underlying_arguments(
ProblemShapeMNKL problem_shape_mnkl,
TileShape tile_shape,
ClusterShape cluster_shape,
[[maybe_unused]] KernelHardwareInfo const& hw_info,
Arguments const& arguments,
[[maybe_unused]] void* workspace=nullptr,
[[maybe_unused]] const uint32_t epilogue_subtile = 1) {
// We only need the tile and cluster shape during scheduler setup, so let FTAD do the magic
static_assert(cute::is_static<TileShape>::value);
static_assert(cute::is_static<ClusterShape>::value);
dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cluster_shape);
Params params;
params.initialize(
problem_blocks,
to_gemm_coord(cluster_shape),
hw_info,
arguments.max_swizzle_size,
arguments.raster_order
);
return params;
}
CUTLASS_HOST_DEVICE
static bool
can_implement(Arguments const& args) {
return true;
}
CUTLASS_HOST_DEVICE
StaticPersistentTileScheduler() { }
CUTLASS_DEVICE explicit StaticPersistentTileScheduler(Params const& params_) : scheduler_params(params_) {
// MSVC requires protecting use of CUDA-specific nonstandard syntax,
// like blockIdx and gridDim, with __CUDA_ARCH__.
#if defined(__CUDA_ARCH__)
if (params_.raster_order_ == RasterOrder::AlongN) {
current_work_linear_idx_ = uint64_t(blockIdx.x) + uint64_t(blockIdx.y) * uint64_t(gridDim.x);
}
else {
current_work_linear_idx_ = uint64_t(blockIdx.x) * uint64_t(gridDim.y) + uint64_t(blockIdx.y);
}
total_grid_size_ = uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z);
#else
CUTLASS_ASSERT(false && "This line should never be reached");
#endif
}
// Returns the initial work tile info that will be computed over
template <class ClusterShape>
CUTLASS_DEVICE
WorkTileInfo
initial_work_tile_info(ClusterShape cluster_shape) {
return get_current_work();
}
CUTLASS_DEVICE
WorkTileInfo
get_current_work() const {
return get_current_work_for_linear_idx(current_work_linear_idx_);
}
CUTLASS_DEVICE
WorkTileInfo
get_current_work_for_linear_idx(uint64_t linear_idx) const {
if (linear_idx >= scheduler_params.blocks_per_problem_) {
return WorkTileInfo::invalid_work_tile();
}
// Map worker's linear index into the CTA tiled problem shape to the corresponding MNL indices
uint64_t work_idx_l, remainder;
scheduler_params.divmod_batch_(work_idx_l, remainder, linear_idx);
uint64_t blk_per_grid_dim = scheduler_params.divmod_cluster_shape_minor_.divide(remainder);
auto [work_idx_m, work_idx_n] = Subclass::get_work_idx_m_and_n(blk_per_grid_dim,
scheduler_params.divmod_cluster_shape_major_,
scheduler_params.divmod_cluster_shape_minor_,
scheduler_params.divmod_cluster_blk_major_,
scheduler_params.log_swizzle_size_,
scheduler_params.raster_order_);
return {work_idx_m, work_idx_n, static_cast<int32_t>(work_idx_l), true};
}
CUTLASS_DEVICE
void
advance_to_next_work(uint32_t advance_count = 1) {
current_work_linear_idx_ += total_grid_size_ * uint64_t(advance_count);
}
// Computes the linear index within a batch given M and N tile offsets within the batch.
// This essentially inverts the mapping performed in get_work_idx_m_and_n
static CUTLASS_DEVICE
uint64_t
get_linear_idx_from_m_and_n(
int32_t tile_m,
int32_t tile_n,
FastDivmodU64Pow2 const& divmod_cluster_shape_major,
FastDivmodU64Pow2 const& divmod_cluster_shape_minor,
FastDivmodU64 const& divmod_cluster_blk_major,
int32_t log_swizzle_size,
RasterOrder raster_order) {
auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster();
uint64_t minor_work_idx, major_work_idx, cluster_minor_offset;
if (raster_order == RasterOrder::AlongN) {
minor_work_idx = static_cast<uint64_t>(tile_m);
major_work_idx = static_cast<uint64_t>(tile_n);
cluster_minor_offset = cta_m_in_cluster;
}
else {
major_work_idx = static_cast<uint64_t>(tile_m);
minor_work_idx = static_cast<uint64_t>(tile_n);
cluster_minor_offset = cta_n_in_cluster;
}
uint64_t cluster_idx_minor, cluster_idx_major, cluster_major_offset;
cluster_idx_minor = divmod_cluster_shape_minor.divide(minor_work_idx - cluster_minor_offset);
divmod_cluster_shape_major(cluster_idx_major, cluster_major_offset, major_work_idx);
uint64_t cluster_idx_minor_div_swizzle = cluster_idx_minor >> log_swizzle_size;
uint64_t offset = cluster_idx_minor & ((1 << log_swizzle_size) - 1);
uint64_t extra = cluster_idx_minor_div_swizzle * divmod_cluster_blk_major.divisor + cluster_idx_major;
uint64_t cluster_id = (extra << log_swizzle_size) | offset;
return (cluster_id * divmod_cluster_shape_major.divisor + cluster_major_offset) * divmod_cluster_shape_minor.divisor + cluster_minor_offset;
}
// Given the inputs, computes the total number of output blocks over which this problem will compute.
// Note that this is only the logical size of our grid, not the physical grid we will actually launch.
template<class ProblemShapeMNKL, class BlockShape, class ClusterShape>
CUTLASS_HOST_DEVICE static
dim3
get_tiled_cta_shape_mnl(ProblemShapeMNKL problem_shape_mnkl, BlockShape cta_shape, ClusterShape cluster_shape) {
auto cta_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shape_mnkl), cute::shape<0>(cta_shape)));
auto cta_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shape_mnkl), cute::shape<1>(cta_shape)));
return Params::get_tiled_cta_shape_mnl(
to_gemm_coord(problem_shape_mnkl),
to_gemm_coord(cluster_shape),
cta_m, cta_n
);
}
// Kernel helper function to get next work ID
template <class WorkIdPipeline, class WorkIdPipelineState>
CUTLASS_DEVICE
auto
fetch_next_work(
WorkTileInfo work_tile_info,
WorkIdPipeline& work_id_pipeline,
WorkIdPipelineState work_id_pipe_consumer_state) {
WorkTileInfo new_work_tile_info;
advance_to_next_work();
new_work_tile_info = get_current_work();
// Return true to indicate that the WorkID pipeline state should be advanced
return cute::make_tuple(new_work_tile_info, true);
}
CUTLASS_DEVICE
static auto
work_tile_to_cta_coord(WorkTileInfo work_tile_info) {
// Get every cta coord in three dimensions of the cluster
auto [cta_m_in_cluster, cta_n_in_cluster, cta_l_in_cluster] = cute::block_id_in_cluster();
return make_coord(
work_tile_info.M_idx + static_cast<int32_t>(cta_m_in_cluster),
work_tile_info.N_idx + static_cast<int32_t>(cta_n_in_cluster),
_,
work_tile_info.L_idx + static_cast<int32_t>(cta_l_in_cluster)
);
}
// Given the inputs, computes the physical grid we should launch.
template<class ProblemShapeMNKL, class BlockShape, class ClusterShape>
CUTLASS_HOST_DEVICE static
dim3
get_grid_shape(
ProblemShapeMNKL problem_shape_mnk,
BlockShape cta_shape,
ClusterShape cluster_shape,
KernelHardwareInfo hw_info,
Arguments arguments,
bool truncate_by_problem_size=true) {
auto problem_shape_mnkl = cute::append<4>(problem_shape_mnk, cute::Int<1>{});
dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, cta_shape, cluster_shape);
return Params::get_grid_shape(
problem_blocks,
to_gemm_coord(cluster_shape),
hw_info,
arguments.max_swizzle_size,
arguments.raster_order,
/* truncate_by_problem_size = */true
);
}
// Given the inputs, computes the physical grid we should launch.
template<class ProblemShapeMNKL, class BlockShape, class ClusterShape>
CUTLASS_HOST_DEVICE static
dim3
get_grid_shape(
Params const& params,
ProblemShapeMNKL problem_shape_mnk,
BlockShape cta_shape,
ClusterShape cluster_shape,
KernelHardwareInfo hw_info) {
auto problem_shape_mnkl = cute::append<4>(problem_shape_mnk, cute::Int<1>{});
dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, cta_shape, cluster_shape);
Arguments args{};
if constexpr (!std::is_const_v<decltype(args.max_swizzle_size)>) {
args.max_swizzle_size = 1 << params.log_swizzle_size_;
}
args.raster_order = params.raster_order_ == RasterOrder::AlongN ? RasterOrderOptions::AlongN : RasterOrderOptions::AlongM;
return Params::get_grid_shape(
problem_blocks,
to_gemm_coord(cluster_shape),
hw_info,
args.max_swizzle_size,
args.raster_order,
/* truncate_by_problem_size = */true
);
}
// Convert CTA-level work tile info to cluster-level tile coord
CUTLASS_DEVICE
cute::Coord<int,int,int,int>
tile_info_to_coord_mnkl(WorkTileInfo work_tile_info) const {
// TileScheduler works at CTA-level, kernel works at cluster-level
int m_coord = idx2crd(work_tile_info.M_idx / scheduler_params.cluster_shape_m_,
scheduler_params.problem_tiles_m_);
int n_coord = idx2crd(work_tile_info.N_idx / scheduler_params.cluster_shape_n_,
scheduler_params.problem_tiles_n_);
int l_coord = idx2crd(work_tile_info.L_idx,
scheduler_params.problem_tiles_l_);
return make_coord(m_coord, n_coord, _, l_coord);
}
// Returns whether the block assigned this work should compute the epilogue for the corresponding
// output tile. For the basic tile scheduler, this is always true.
CUTLASS_HOST_DEVICE
static bool
compute_epilogue(WorkTileInfo const&, Params const&) {
return true;
}
CUTLASS_HOST_DEVICE
static bool
compute_epilogue(WorkTileInfo const&) {
return true;
}
// Performs the reduction across splits for a given output tile. Since this scheduler does
// not split output tiles, no reduction is needed.
template <class FrgTensorC>
CUTLASS_DEVICE
static void
fixup(Params const&, WorkTileInfo const&, FrgTensorC&, uint32_t, uint32_t) {}
// Performs the reduction across splits for a given output tile. No fixup is required for
// work units returned by this scheduler.
template <class FrgTensorC>
CUTLASS_DEVICE
void
fixup(WorkTileInfo const&, FrgTensorC&, uint32_t, uint32_t) const { }
// Returns whether the current WorkTileInfo passed in should continue to be used. Since
// this scheduler only schedules work in units of single, full output tiles, the WorkTileInfo
// passed in should not be used after having been processed.
CUTLASS_DEVICE
static bool
continue_current_work(WorkTileInfo&) {
return false;
}
template <class ProblemShape, class TileShape>
CUTLASS_HOST_DEVICE
static int
get_work_k_tile_count(WorkTileInfo const& work_tile_info, ProblemShape problem_shape, TileShape tile_shape) {
// All work units returned by this scheduler cover the entire K iteration
// space of the output tile assigned to the work unit.
return cute::size(cute::ceil_div(cute::get<2>(problem_shape), cute::get<2>(tile_shape)));
}
CUTLASS_HOST_DEVICE
static uint32_t
get_work_k_tile_start(WorkTileInfo const&) {
// All work units returned by this scheduler start from K tile 0
return 0u;
}
CUTLASS_DEVICE
static bool
need_separate_reduction(Params const& params) {
return false;
}
CUTLASS_DEVICE
bool
is_work_tile_for_reduction(WorkTileInfo const& work_tile_info, Params const& params) {
return false;
}
template <class FrgTensorC>
CUTLASS_DEVICE
void
separate_reduction(
Params const& params,
WorkTileInfo const& work_tile_info,
FrgTensorC& accumulators,
uint32_t num_barriers,
uint32_t barrier_idx) {
}
// Shares the accumulator set with peers in the global workspace
template <class FrgTensorC>
CUTLASS_DEVICE
static void
share(
Params const& params,
WorkTileInfo const& work_tile_info,
FrgTensorC& accumulators,
uint32_t num_barriers,
uint32_t barrier_idx) {
}
CUTLASS_DEVICE
static bool
valid_warpgroup_in_work_tile(WorkTileInfo const& work_tile_info) {
return true;
}
CUTLASS_DEVICE
static bool
requires_separate_reduction(Params const& params) {
return false;
}
public:
// Sink scheduler params as a member
Params scheduler_params;
};
} // namespace cutlass::gemm::kernel::detail

View File

@ -87,6 +87,12 @@ struct PersistentTileSchedulerSm90Params {
int32_t log_swizzle_size_ = 0;
RasterOrder raster_order_ = RasterOrder::AlongN;
uint32_t problem_tiles_m_ = 0;
uint32_t problem_tiles_n_ = 0;
uint32_t problem_tiles_l_ = 0;
uint32_t cluster_shape_m_ = 0;
uint32_t cluster_shape_n_ = 0;
// Initializes members. This variant of the method should only be used when
// problem_shape and tile_shape contain modes of only rank 1.
void
@ -127,6 +133,12 @@ struct PersistentTileSchedulerSm90Params {
auto problem_blocks_m = round_up(problem_blocks.x, (1 << log_swizzle_size) * cluster_shape.m());
auto problem_blocks_n = round_up(problem_blocks.y, (1 << log_swizzle_size) * cluster_shape.n());
problem_tiles_m_ = problem_blocks_m / cluster_shape.m();
problem_tiles_n_ = problem_blocks_n / cluster_shape.n();
problem_tiles_l_ = problem_blocks.z;
cluster_shape_m_ = cluster_shape.m();
cluster_shape_n_ = cluster_shape.n();
RasterOrder raster_order = get_rasterization_order(
problem_blocks_m,
problem_blocks_n,

View File

@ -1,3 +1,34 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief This defines a "fragment" iterator for visiting the fragments of a warp tile
that participate in one warp-level mma operation.