Updates for 3.4 release. (#1305)
This commit is contained in:
@ -108,6 +108,28 @@ CUTE_NAMED_UNARY_OP(conjugate, cute::conj);
|
||||
#undef CUTE_RIGHT_UNARY_OP
|
||||
#undef CUTE_NAMED_UNARY_OP
|
||||
|
||||
template <int Shift_>
|
||||
struct shift_right_const {
|
||||
static constexpr int Shift = Shift_;
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto) operator()(T&& arg) const {
|
||||
return std::forward<T>(arg) >> Shift;
|
||||
}
|
||||
};
|
||||
|
||||
template <int Shift_>
|
||||
struct shift_left_const {
|
||||
static constexpr int Shift = Shift_;
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto) operator()(T&& arg) const {
|
||||
return std::forward<T>(arg) << Shift;
|
||||
}
|
||||
};
|
||||
|
||||
/************/
|
||||
/** Binary **/
|
||||
/************/
|
||||
|
||||
@ -604,8 +604,7 @@ unwrap(T const& t)
|
||||
}
|
||||
|
||||
//
|
||||
// Flatten a hierarchical tuple to a tuple of depth one.
|
||||
//
|
||||
// Flatten and Unflatten
|
||||
//
|
||||
|
||||
template <class T>
|
||||
@ -614,13 +613,15 @@ struct is_flat : true_type {};
|
||||
template <class... Ts>
|
||||
struct is_flat<tuple<Ts...>> : bool_constant<(true && ... && (not is_tuple<Ts>::value))> {};
|
||||
|
||||
// Flatten a hierarchical tuple to a tuple of depth one
|
||||
// and wrap non-tuples into a rank-1 tuple.
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
flatten_to_tuple(T const& t)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
if constexpr (is_flat<T>::value) {
|
||||
if constexpr (is_flat<T>::value) { // Shortcut for perf
|
||||
return t;
|
||||
} else {
|
||||
return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); });
|
||||
@ -632,13 +633,15 @@ flatten_to_tuple(T const& t)
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
// Flatten a hierarchical tuple to a tuple of depth one
|
||||
// and leave non-tuple untouched.
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
flatten(T const& t)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
if constexpr (is_flat<T>::value) {
|
||||
if constexpr (is_flat<T>::value) { // Shortcut for perf
|
||||
return t;
|
||||
} else {
|
||||
return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); });
|
||||
@ -650,6 +653,43 @@ flatten(T const& t)
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
template<class FlatTuple, class TargetProfile>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
unflatten_impl(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
|
||||
{
|
||||
if constexpr (is_tuple<TargetProfile>::value) {
|
||||
return fold(target_profile, cute::make_tuple(cute::make_tuple(), flat_tuple), [](auto const& v, auto const& t) {
|
||||
auto [result, remaining_tuple] = v;
|
||||
auto [sub_result, sub_tuple] = unflatten_impl(remaining_tuple, t);
|
||||
return cute::make_tuple(append(result, sub_result), sub_tuple);
|
||||
});
|
||||
} else {
|
||||
return cute::make_tuple(get<0>(flat_tuple), take<1, decltype(rank(flat_tuple))::value>(flat_tuple));
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
// Unflatten a flat tuple into a hierarchical tuple
|
||||
// @pre flatten(@a flat_tuple) == @a flat_tuple
|
||||
// @pre rank(flatten(@a target_profile)) == rank(@a flat_tuple)
|
||||
// @post congruent(@a result, @a target_profile)
|
||||
// @post flatten(@a result) == @a flat_tuple
|
||||
template<class FlatTuple, class TargetProfile>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
unflatten(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
|
||||
{
|
||||
auto [unflatten_tuple, flat_remainder] = detail::unflatten_impl(flat_tuple, target_profile);
|
||||
CUTE_STATIC_ASSERT_V(rank(flat_remainder) == Int<0>{});
|
||||
return unflatten_tuple;
|
||||
}
|
||||
|
||||
//
|
||||
// insert and remove and replace
|
||||
//
|
||||
@ -728,6 +768,18 @@ replace_back(T const& t, X const& x)
|
||||
// Make a tuple of Xs of tuple_size N
|
||||
//
|
||||
|
||||
template <int N, class X>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tuple_repeat(X const& x)
|
||||
{
|
||||
return detail::construct(0, x, seq<>{}, make_seq<N>{}, seq<>{});
|
||||
}
|
||||
|
||||
//
|
||||
// Make repeated Xs of rank N
|
||||
//
|
||||
|
||||
template <int N, class X>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
@ -743,7 +795,7 @@ repeat(X const& x)
|
||||
}
|
||||
|
||||
//
|
||||
// Make a tuple of Xs the same profile as tuple
|
||||
// Make a tuple of Xs the same profile as tuple T
|
||||
//
|
||||
|
||||
template <class T, class X>
|
||||
@ -864,48 +916,6 @@ prepend(T const& a, X const& x)
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// Unflatten a flat tuple into a hierarchical one
|
||||
// unflatten(x, flatten(x)) == x
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
template<class FlatTuple, class TargetProfile>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
unflatten_impl(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
|
||||
{
|
||||
if constexpr (is_tuple<TargetProfile>::value) {
|
||||
return fold(target_profile, cute::make_tuple(cute::make_tuple(), flat_tuple), [](auto const& v, auto const& t) {
|
||||
auto [result, remaining_tuple] = v;
|
||||
auto [sub_result, sub_tuple] = unflatten_impl(remaining_tuple, t);
|
||||
return cute::make_tuple(append(result, sub_result), sub_tuple);
|
||||
});
|
||||
} else {
|
||||
return cute::make_tuple(get<0>(flat_tuple), take<1, decltype(rank(flat_tuple))::value>(flat_tuple));
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
// @pre flatten(@a flat_tuple) == @a flat_tuple
|
||||
// @pre rank(flatten(@a target_profile)) == rank(@a flat_tuple)
|
||||
// @post congruent(@a result, @a target_profile)
|
||||
// @post flatten(@a result) == @a flat_tuple
|
||||
template<class FlatTuple, class TargetProfile>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
unflatten(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
|
||||
{
|
||||
auto [unflatten_tuple, flat_remainder] = detail::unflatten_impl(flat_tuple, target_profile);
|
||||
CUTE_STATIC_ASSERT_V(rank(flat_remainder) == Int<0>{});
|
||||
return unflatten_tuple;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Inclusive scan (prefix sum)
|
||||
//
|
||||
|
||||
@ -63,7 +63,7 @@ initialize_barrier(uint64_t& smem_barrier, // 64 bits user-mange
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier);
|
||||
asm volatile ("mbarrier.init.shared.b64 [%0], %1;\n"
|
||||
asm volatile ("mbarrier.init.shared::cta.b64 [%0], %1;\n"
|
||||
:: "r"(smem_int_ptr),
|
||||
"r"(thread_count));
|
||||
#endif
|
||||
@ -77,7 +77,7 @@ set_barrier_transaction_bytes(uint64_t& smem_barrier, // 64 bits user-mange
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier);
|
||||
asm volatile ("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;\n"
|
||||
asm volatile ("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%0], %1;\n"
|
||||
:: "r"(smem_int_ptr),
|
||||
"r"(bytes));
|
||||
#endif
|
||||
@ -95,7 +95,7 @@ wait_barrier(uint64_t& smem_barrier, // 64 bits user-mange
|
||||
"{\n"
|
||||
".reg .pred P1;\n"
|
||||
"LAB_WAIT:\n"
|
||||
"mbarrier.try_wait.parity.shared.b64 P1, [%0], %1;\n"
|
||||
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n"
|
||||
"@P1 bra.uni DONE;\n"
|
||||
"bra.uni LAB_WAIT;\n"
|
||||
"DONE:\n"
|
||||
@ -116,7 +116,7 @@ arrive_barrier(uint64_t& smem_barrier) // 64 bits user-mang
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .b64 state; \n"
|
||||
"mbarrier.arrive.shared.b64 state, [%0];\n"
|
||||
"mbarrier.arrive.shared::cta.b64 state, [%0];\n"
|
||||
"}\n"
|
||||
:: "r"(smem_int_ptr));
|
||||
#endif
|
||||
|
||||
@ -854,11 +854,12 @@ rs_op_selector()
|
||||
|
||||
// FP32 accumulator
|
||||
else if constexpr (is_same_v<ElementC, float>) {
|
||||
static_assert(is_same_v<ElementA, ElementB>, "ElementA and ElementB must be the same type for this config.");
|
||||
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
|
||||
|
||||
// FP16 inputs
|
||||
if constexpr (is_same_v<ElementA, half_t>) {
|
||||
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
|
||||
static_assert(is_same_v<ElementA, ElementB>, "ElementA and ElementB must be the same type for this config.");
|
||||
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
@ -891,6 +892,7 @@ rs_op_selector()
|
||||
// BF16 inputs
|
||||
else if constexpr (is_same_v<ElementA, bfloat16_t>) {
|
||||
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
|
||||
static_assert(is_same_v<ElementA, ElementB>, "ElementA and ElementB must be the same type for this config.");
|
||||
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
@ -925,6 +927,7 @@ rs_op_selector()
|
||||
else if constexpr (is_same_v<ElementA, tfloat32_t>) {
|
||||
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
|
||||
static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8.");
|
||||
static_assert(is_same_v<ElementA, ElementB>, "ElementA and ElementB must be the same type for this config.");
|
||||
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x8_F32TF32TF32_RS_TN<Args...>{};
|
||||
@ -1023,7 +1026,7 @@ rs_op_selector()
|
||||
return SM90_64x8x32_F32E4M3E5M2_RS_TN<Args...>{};
|
||||
}
|
||||
else {
|
||||
static_aRSert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -65,9 +65,9 @@ struct Copy_Atom<Copy_Traits<Args...>, CopyInternalType>
|
||||
|
||||
using ValType = CopyInternalType;
|
||||
|
||||
using ValLayoutSrc = decltype(upcast<sizeof_bits<ValType>::value>(BitLayoutSrc{}));
|
||||
using ValLayoutDst = decltype(upcast<sizeof_bits<ValType>::value>(BitLayoutDst{}));
|
||||
using ValLayoutRef = decltype(upcast<sizeof_bits<ValType>::value>(BitLayoutRef{}));
|
||||
using ValLayoutSrc = decltype(recast_layout<uint1_t, ValType>(BitLayoutSrc{}));
|
||||
using ValLayoutDst = decltype(recast_layout<uint1_t, ValType>(BitLayoutDst{}));
|
||||
using ValLayoutRef = decltype(recast_layout<uint1_t, ValType>(BitLayoutRef{}));
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<0>(ValLayoutSrc{}) == size(ThrID{}), "CopyOperation is not valid for Src of ValType.");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(ValLayoutDst{}) == size(ThrID{}), "CopyOperation is not valid for Dst of ValType.");
|
||||
@ -479,20 +479,24 @@ make_tiled_copy(Copy_Atom<Args...> const& copy_atom,
|
||||
ThrLayout const& thr_layout = {}, // (m,n) -> thr_idx
|
||||
ValLayout const& val_layout = {}) // (m,n) -> val_idx
|
||||
{
|
||||
constexpr int R = cute::max(rank_v<ThrLayout>, rank_v<ValLayout>);
|
||||
|
||||
auto thr_layout_mn = append<R>(thr_layout, Layout<_1>{});
|
||||
auto val_layout_mn = append<R>(val_layout, Layout<_1>{});
|
||||
|
||||
// Take the raked_products to compute the Layout_MN
|
||||
auto layout_mn = raked_product(thr_layout_mn, val_layout_mn);
|
||||
// (M,N) -> (thr_idx, val_idx)
|
||||
auto layout_mn = raked_product(thr_layout, val_layout);
|
||||
// (thr_idx, val_idx) -> (M,N)
|
||||
auto layout_tv = right_inverse(layout_mn).with_shape(make_shape(size(thr_layout), size(val_layout)));
|
||||
// print("thr_layout: "); print(thr_layout_mn); print("\n");
|
||||
// print("val_layout: "); print(val_layout_mn); print("\n");
|
||||
// print("layout_mn : "); print(layout_mn); print("\n");
|
||||
// print("layout_tv : "); print(layout_tv); print("\n");
|
||||
// Tiler for extracting relevant elements
|
||||
// (M,N) -> tensor coord
|
||||
auto tiler = product_each(shape(layout_mn));
|
||||
|
||||
return make_tiled_copy_impl(copy_atom, layout_tv, product_each(shape(layout_mn)));
|
||||
#if 0
|
||||
print("thr_layout: "); print(thr_layout); print("\n");
|
||||
print("val_layout: "); print(val_layout); print("\n");
|
||||
print("layout_mn : "); print(layout_mn); print("\n");
|
||||
print("layout_tv : "); print(layout_tv); print("\n");
|
||||
print("tiler : "); print(tiler); print("\n");
|
||||
#endif
|
||||
|
||||
return make_tiled_copy_impl(copy_atom, layout_tv, tiler);
|
||||
}
|
||||
|
||||
/** Produce a TiledCopy from thread and value offset maps.
|
||||
@ -622,7 +626,7 @@ print(Copy_Atom<Copy_Traits<Args...>, T> const&)
|
||||
print(" ValLayoutSrc: "); print(typename Atom::ValLayoutSrc{}); print("\n");
|
||||
print(" ValLayoutDst: "); print(typename Atom::ValLayoutDst{}); print("\n");
|
||||
print(" ValLayoutRef: "); print(typename Atom::ValLayoutRef{}); print("\n");
|
||||
print(" ValueType: %db\n", int(sizeof_bits<typename Atom::ValType>::value));
|
||||
print(" ValueType: "); print(sizeof_bits<typename Atom::ValType>::value); print("b\n");
|
||||
}
|
||||
|
||||
template <class Atom, class... Args>
|
||||
@ -755,6 +759,7 @@ print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and
|
||||
#include <cute/atom/copy_traits_sm75.hpp>
|
||||
#include <cute/atom/copy_traits_sm80.hpp>
|
||||
#include <cute/atom/copy_traits_sm90.hpp>
|
||||
|
||||
// Config
|
||||
#if (__CUDACC_VER_MAJOR__ >= 12)
|
||||
# define CUTE_COPY_ATOM_TMA_SM90_ENABLED
|
||||
|
||||
@ -673,15 +673,14 @@ fill_tma_gmem_shape_stride(Tensor<GEngine,GLayout> const& gtensor, /
|
||||
// Trivial contribution of this gmem mode to this tma mode
|
||||
auto ej = unwrap(get<i>(tma_gbasis_stride));
|
||||
gmem_prob_shape[i] = basis_get(ej, gmem_shape);
|
||||
gmem_prob_stride[i] = basis_get(ej, gmem_stride) * sizeof_bits_v<TmaInternalType> / 8;
|
||||
gmem_prob_stride[i] = basis_get(ej, gmem_stride);
|
||||
} else {
|
||||
// Apply a recurrence to each gmem mode that contributes to this tma mode
|
||||
for_each(get<i>(tma_gbasis_stride), [&](auto ej) {
|
||||
// Problem shape
|
||||
uint64_t shape_j = basis_get(ej, gmem_shape);
|
||||
// Problem stride (in bytes)
|
||||
uint64_t stride_j = basis_get(ej, gmem_stride) * sizeof_bits_v<TmaInternalType> / 8;
|
||||
|
||||
uint64_t stride_j = basis_get(ej, gmem_stride);
|
||||
uint64_t old_stride = gmem_prob_stride[i];
|
||||
gmem_prob_stride[i] = gcd(gmem_prob_stride[i], stride_j);
|
||||
|
||||
@ -764,8 +763,14 @@ make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The origin
|
||||
assert(gmem_prob_shape[4] <= (uint64_t(1) << 32)); // Size must be max 2^32
|
||||
|
||||
// TMA descriptor does not store the zeroth stride and assumes it is 1 (TmaInternalType element).
|
||||
assert(gmem_prob_stride[0] == sizeof(TmaInternalType) && "Majorness of smem doesn't match majorness of gmem");
|
||||
assert(gmem_prob_stride[0] == 1 && "Majorness of smem doesn't match majorness of gmem");
|
||||
|
||||
// convert strides to byte strides
|
||||
for(uint64_t& stride : gmem_prob_stride) {
|
||||
stride = (stride * sizeof_bits_v<TmaInternalType>) / 8;
|
||||
}
|
||||
|
||||
// Assert the byte strides. Tma Descriptor uses byte strides
|
||||
assert((gmem_prob_stride[1]) < (uint64_t(1) << 40)); // Stride must be max 2^40
|
||||
assert((gmem_prob_stride[1] & 0b1111) == 0); // Stride must be multiple of 16B (128b)
|
||||
assert((gmem_prob_stride[2]) < (uint64_t(1) << 40)); // Stride must be max 2^40
|
||||
@ -866,8 +871,8 @@ make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The origin
|
||||
}
|
||||
|
||||
#endif // (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__)
|
||||
auto recast_ratio = cute::ratio(Int<sizeof_bits<typename GEngine::value_type>::value>{},
|
||||
Int<sizeof_bits< TmaInternalType>::value>{});
|
||||
auto recast_ratio = cute::trait_ratio(sizeof_bits<typename GEngine::value_type>{},
|
||||
sizeof_bits< TmaInternalType>{});
|
||||
|
||||
auto gbasis = make_basis_like(shape(gtensor));
|
||||
|
||||
@ -943,7 +948,7 @@ make_tma_copy_atom(CopyOp,
|
||||
// Construct the Copy_Traits
|
||||
//
|
||||
|
||||
constexpr int num_bits_per_tma = decltype(size(tma_gbasis))::value * sizeof_bits_v<TmaInternalType>;
|
||||
constexpr int num_bits_per_tma = size(tma_gbasis) * sizeof_bits<TmaInternalType>::value;
|
||||
using Traits = Copy_Traits<CopyOp, cute::C<num_bits_per_tma>, decltype(aux_params)>;
|
||||
using Atom = Copy_Atom<Traits, typename GEngine::value_type>;
|
||||
|
||||
@ -985,7 +990,7 @@ make_tma_copy_tiled(CopyOp const& copy_op,
|
||||
|
||||
[[maybe_unused]] auto cta_tiler = product_each(shape(cta_v_map));
|
||||
|
||||
auto num_elems_per_tma = size<1>(typename decltype(atom)::RefLayout{}) / Int<sizeof_bits_v<typename GEngine::value_type>>{};
|
||||
auto num_elems_per_tma = size<1>(typename decltype(atom)::RefLayout{}) / static_value<sizeof_bits<typename GEngine::value_type>>();
|
||||
|
||||
// smem idx -> smem coord
|
||||
auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout));
|
||||
|
||||
@ -55,10 +55,10 @@ struct MMA_Atom<MMA_Traits<Args...>>
|
||||
using Traits = MMA_Traits<Args...>;
|
||||
|
||||
// Element value types from the MMA_Traits
|
||||
using ValTypeD = typename Traits::ElementDVal;
|
||||
using ValTypeA = typename Traits::ElementAVal;
|
||||
using ValTypeB = typename Traits::ElementBVal;
|
||||
using ValTypeC = typename Traits::ElementCVal;
|
||||
using ValTypeD = typename Traits::ValTypeD;
|
||||
using ValTypeA = typename Traits::ValTypeA;
|
||||
using ValTypeB = typename Traits::ValTypeB;
|
||||
using ValTypeC = typename Traits::ValTypeC;
|
||||
|
||||
// Thr-Val layouts from the MMA_Traits
|
||||
using Shape_MNK = typename Traits::Shape_MNK;
|
||||
|
||||
@ -50,14 +50,14 @@ struct supports_output_scaling<X, void_t<decltype(declval<X>().accumulate_)>> {
|
||||
/**
|
||||
* concept MMA_Traits
|
||||
* {
|
||||
* using ElementDVal = // Logical A-value type
|
||||
* using ElementAVal = // Logical B-value type
|
||||
* using ElementBVal = // Logical C-value type
|
||||
* using ElementCVal = // Logical D-value type (NOTE: Not used? Assumed == ElementDVal)
|
||||
* using ValTypeD = // Logical A-value type
|
||||
* using ValTypeA = // Logical B-value type
|
||||
* using ValTypeB = // Logical C-value type
|
||||
* using ValTypeC = // Logical D-value type (NOTE: Not used? Assumed == ValTypeD)
|
||||
*
|
||||
* using ElementAFrg = // A-type consumed by MMA (if ommitted, same as ElementAVal)
|
||||
* using ElementBFrg = // B_type consumed by MMA (if ommitted, same as ElementBVal)
|
||||
* using ElementCFrg = // C_type consumed by MMA (if ommitted, same as ElementCVal)
|
||||
* using FrgTypeA = // A-type consumed by MMA (if ommitted, same as ValTypeA)
|
||||
* using FrgTypeB = // B_type consumed by MMA (if ommitted, same as ValTypeB)
|
||||
* using FrgTypeC = // C_type consumed by MMA (if ommitted, same as ValTypeC)
|
||||
*
|
||||
* using Shape_MNK = // Logical MxNxK shape of the MMA
|
||||
*
|
||||
@ -78,10 +78,10 @@ struct MMA_Traits
|
||||
template <class D, class A, class B, class C>
|
||||
struct MMA_Traits<UniversalFMA<D,A,B,C>>
|
||||
{
|
||||
using ElementDVal = D;
|
||||
using ElementAVal = A;
|
||||
using ElementBVal = B;
|
||||
using ElementCVal = C;
|
||||
using ValTypeD = D;
|
||||
using ValTypeA = A;
|
||||
using ValTypeB = B;
|
||||
using ValTypeC = C;
|
||||
|
||||
// Logical shape of the MMA
|
||||
using Shape_MNK = Shape<_1,_1,_1>;
|
||||
@ -209,19 +209,19 @@ mma_unpack(MMA_Traits<MMA_Op, MMA_Args...> const& traits,
|
||||
namespace detail {
|
||||
|
||||
template <class X, class = void>
|
||||
struct FrgTypeA_or_Default { using type = typename X::ElementAVal; };
|
||||
struct FrgTypeA_or_Default { using type = typename X::ValTypeA; };
|
||||
template <class X>
|
||||
struct FrgTypeA_or_Default<X,void_t<typename X::ElementAFrg>> { using type = typename X::ElementAFrg; };
|
||||
struct FrgTypeA_or_Default<X,void_t<typename X::FrgTypeA>> { using type = typename X::FrgTypeA; };
|
||||
|
||||
template <class X, class = void>
|
||||
struct FrgTypeB_or_Default { using type = typename X::ElementBVal; };
|
||||
struct FrgTypeB_or_Default { using type = typename X::ValTypeB; };
|
||||
template <class X>
|
||||
struct FrgTypeB_or_Default<X,void_t<typename X::ElementBFrg>> { using type = typename X::ElementBFrg; };
|
||||
struct FrgTypeB_or_Default<X,void_t<typename X::FrgTypeB>> { using type = typename X::FrgTypeB; };
|
||||
|
||||
template <class X, class = void>
|
||||
struct FrgTypeC_or_Default { using type = typename X::ElementCVal; };
|
||||
struct FrgTypeC_or_Default { using type = typename X::ValTypeC; };
|
||||
template <class X>
|
||||
struct FrgTypeC_or_Default<X,void_t<typename X::ElementCFrg>> { using type = typename X::ElementCFrg; };
|
||||
struct FrgTypeC_or_Default<X,void_t<typename X::FrgTypeC>> { using type = typename X::FrgTypeC; };
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
|
||||
@ -41,10 +41,10 @@ namespace cute
|
||||
template <>
|
||||
struct MMA_Traits<SM61_DP4A>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = int8_t;
|
||||
using ElementBVal = int8_t;
|
||||
using ElementCVal = int32_t;
|
||||
using ValTypeD = int32_t;
|
||||
using ValTypeA = int8_t;
|
||||
using ValTypeB = int8_t;
|
||||
using ValTypeC = int32_t;
|
||||
|
||||
using Shape_MNK = Shape<_1,_1,_4>;
|
||||
using ThrID = Layout<_1>;
|
||||
@ -58,10 +58,10 @@ struct MMA_Traits<SM61_DP4A>
|
||||
template <>
|
||||
struct MMA_Traits<SM61_DP2A>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = int16_t;
|
||||
using ElementBVal = int16_t;
|
||||
using ElementCVal = int32_t;
|
||||
using ValTypeD = int32_t;
|
||||
using ValTypeA = int16_t;
|
||||
using ValTypeB = int16_t;
|
||||
using ValTypeC = int32_t;
|
||||
|
||||
using Shape_MNK = Shape<_1,_1,_2>;
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
@ -63,10 +63,10 @@ using SM70_8x8_32b = Layout<Shape <Shape <_2, _2,_2>,Shape <_2,_2, _2>>,
|
||||
template <>
|
||||
struct MMA_Traits<SM70_8x8x4_F16F16F16F16_TN>
|
||||
{
|
||||
using ElementDVal = half_t;
|
||||
using ElementAVal = half_t;
|
||||
using ElementBVal = half_t;
|
||||
using ElementCVal = half_t;
|
||||
using ValTypeD = half_t;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = half_t;
|
||||
|
||||
using Shape_MNK = Shape<_8,_8,_4>;
|
||||
using ThrID = SM70_QuadPair;
|
||||
@ -80,10 +80,10 @@ struct MMA_Traits<SM70_8x8x4_F16F16F16F16_TN>
|
||||
template <>
|
||||
struct MMA_Traits<SM70_8x8x4_F16F16F16F16_NT>
|
||||
{
|
||||
using ElementDVal = half_t;
|
||||
using ElementAVal = half_t;
|
||||
using ElementBVal = half_t;
|
||||
using ElementCVal = half_t;
|
||||
using ValTypeD = half_t;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = half_t;
|
||||
|
||||
using Shape_MNK = Shape<_8,_8,_4>;
|
||||
using ThrID = SM70_QuadPair;
|
||||
@ -97,10 +97,10 @@ struct MMA_Traits<SM70_8x8x4_F16F16F16F16_NT>
|
||||
template <>
|
||||
struct MMA_Traits<SM70_8x8x4_F16F16F16F16_NN>
|
||||
{
|
||||
using ElementDVal = half_t;
|
||||
using ElementAVal = half_t;
|
||||
using ElementBVal = half_t;
|
||||
using ElementCVal = half_t;
|
||||
using ValTypeD = half_t;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = half_t;
|
||||
|
||||
using Shape_MNK = Shape<_8,_8,_4>;
|
||||
using ThrID = SM70_QuadPair;
|
||||
@ -114,10 +114,10 @@ struct MMA_Traits<SM70_8x8x4_F16F16F16F16_NN>
|
||||
template <>
|
||||
struct MMA_Traits<SM70_8x8x4_F16F16F16F16_TT>
|
||||
{
|
||||
using ElementDVal = half_t;
|
||||
using ElementAVal = half_t;
|
||||
using ElementBVal = half_t;
|
||||
using ElementCVal = half_t;
|
||||
using ValTypeD = half_t;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = half_t;
|
||||
|
||||
using Shape_MNK = Shape<_8,_8,_4>;
|
||||
using ThrID = SM70_QuadPair;
|
||||
@ -131,10 +131,10 @@ struct MMA_Traits<SM70_8x8x4_F16F16F16F16_TT>
|
||||
template <>
|
||||
struct MMA_Traits<SM70_8x8x4_F32F16F16F32_TN>
|
||||
{
|
||||
using ElementDVal = float;
|
||||
using ElementAVal = half_t;
|
||||
using ElementBVal = half_t;
|
||||
using ElementCVal = float;
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using Shape_MNK = Shape<_8,_8,_4>;
|
||||
using ThrID = SM70_QuadPair;
|
||||
@ -148,10 +148,10 @@ struct MMA_Traits<SM70_8x8x4_F32F16F16F32_TN>
|
||||
template <>
|
||||
struct MMA_Traits<SM70_8x8x4_F32F16F16F32_NT>
|
||||
{
|
||||
using ElementDVal = float;
|
||||
using ElementAVal = half_t;
|
||||
using ElementBVal = half_t;
|
||||
using ElementCVal = float;
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using Shape_MNK = Shape<_8,_8,_4>;
|
||||
using ThrID = SM70_QuadPair;
|
||||
@ -165,10 +165,10 @@ struct MMA_Traits<SM70_8x8x4_F32F16F16F32_NT>
|
||||
template <>
|
||||
struct MMA_Traits<SM70_8x8x4_F32F16F16F32_NN>
|
||||
{
|
||||
using ElementDVal = float;
|
||||
using ElementAVal = half_t;
|
||||
using ElementBVal = half_t;
|
||||
using ElementCVal = float;
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using Shape_MNK = Shape<_8,_8,_4>;
|
||||
using ThrID = SM70_QuadPair;
|
||||
@ -182,10 +182,10 @@ struct MMA_Traits<SM70_8x8x4_F32F16F16F32_NN>
|
||||
template <>
|
||||
struct MMA_Traits<SM70_8x8x4_F32F16F16F32_TT>
|
||||
{
|
||||
using ElementDVal = float;
|
||||
using ElementAVal = half_t;
|
||||
using ElementBVal = half_t;
|
||||
using ElementCVal = float;
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using Shape_MNK = Shape<_8,_8,_4>;
|
||||
using ThrID = SM70_QuadPair;
|
||||
|
||||
@ -41,10 +41,10 @@ namespace cute
|
||||
template <>
|
||||
struct MMA_Traits<SM75_16x8x8_F32F16F16F32_TN>
|
||||
{
|
||||
using ElementDVal = float;
|
||||
using ElementAVal = half_t;
|
||||
using ElementBVal = half_t;
|
||||
using ElementCVal = float;
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using Shape_MNK = Shape<_16,_8,_8>;
|
||||
using ThrID = Layout<_32>;
|
||||
@ -61,10 +61,10 @@ struct MMA_Traits<SM75_16x8x8_F32F16F16F32_TN>
|
||||
template <>
|
||||
struct MMA_Traits<SM75_8x8x16_S32S8S8S32_TN>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = int8_t;
|
||||
using ElementBVal = int8_t;
|
||||
using ElementCVal = int32_t;
|
||||
using ValTypeD = int32_t;
|
||||
using ValTypeA = int8_t;
|
||||
using ValTypeB = int8_t;
|
||||
using ValTypeC = int32_t;
|
||||
|
||||
using Shape_MNK = Shape<_8,_8,_16>;
|
||||
using ThrID = Layout<_32>;
|
||||
|
||||
@ -66,10 +66,10 @@ using SM80_16x8_Row = Layout<Shape <Shape < _4,_8>,Shape < _2,_2>>,
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x8_F16F16F16F16_TN>
|
||||
{
|
||||
using ElementDVal = half_t;
|
||||
using ElementAVal = half_t;
|
||||
using ElementBVal = half_t;
|
||||
using ElementCVal = half_t;
|
||||
using ValTypeD = half_t;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = half_t;
|
||||
|
||||
using Shape_MNK = Shape<_16,_8,_8>;
|
||||
using ThrID = Layout<_32>;
|
||||
@ -81,10 +81,10 @@ struct MMA_Traits<SM80_16x8x8_F16F16F16F16_TN>
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x16_F16F16F16F16_TN>
|
||||
{
|
||||
using ElementDVal = half_t;
|
||||
using ElementAVal = half_t;
|
||||
using ElementBVal = half_t;
|
||||
using ElementCVal = half_t;
|
||||
using ValTypeD = half_t;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = half_t;
|
||||
|
||||
using Shape_MNK = Shape<_16,_8,_16>;
|
||||
using ThrID = Layout<_32>;
|
||||
@ -103,20 +103,20 @@ template <>
|
||||
struct MMA_Traits<SM80_16x8x8_F32F16F16F32_TN>
|
||||
: MMA_Traits<SM80_16x8x8_F16F16F16F16_TN>
|
||||
{
|
||||
using ElementDVal = float;
|
||||
using ElementAVal = half_t;
|
||||
using ElementBVal = half_t;
|
||||
using ElementCVal = float;
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x16_F32F16F16F32_TN>
|
||||
: MMA_Traits<SM80_16x8x16_F16F16F16F16_TN>
|
||||
{
|
||||
using ElementDVal = float;
|
||||
using ElementAVal = half_t;
|
||||
using ElementBVal = half_t;
|
||||
using ElementCVal = float;
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = float;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -127,20 +127,20 @@ template <>
|
||||
struct MMA_Traits<SM80_16x8x8_F32BF16BF16F32_TN>
|
||||
: MMA_Traits<SM80_16x8x8_F16F16F16F16_TN>
|
||||
{
|
||||
using ElementDVal = float;
|
||||
using ElementAVal = bfloat16_t;
|
||||
using ElementBVal = bfloat16_t;
|
||||
using ElementCVal = float;
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = bfloat16_t;
|
||||
using ValTypeB = bfloat16_t;
|
||||
using ValTypeC = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x16_F32BF16BF16F32_TN>
|
||||
: MMA_Traits<SM80_16x8x16_F16F16F16F16_TN>
|
||||
{
|
||||
using ElementDVal = float;
|
||||
using ElementAVal = bfloat16_t;
|
||||
using ElementBVal = bfloat16_t;
|
||||
using ElementCVal = float;
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = bfloat16_t;
|
||||
using ValTypeB = bfloat16_t;
|
||||
using ValTypeC = float;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -150,10 +150,10 @@ struct MMA_Traits<SM80_16x8x16_F32BF16BF16F32_TN>
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x4_F32TF32TF32F32_TN>
|
||||
{
|
||||
using ElementDVal = float;
|
||||
using ElementAVal = cutlass::tfloat32_t;
|
||||
using ElementBVal = cutlass::tfloat32_t;
|
||||
using ElementCVal = float;
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = cutlass::tfloat32_t;
|
||||
using ValTypeB = cutlass::tfloat32_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using Shape_MNK = Shape<_16,_8,_4>;
|
||||
using ThrID = Layout<_32>;
|
||||
@ -166,10 +166,10 @@ struct MMA_Traits<SM80_16x8x4_F32TF32TF32F32_TN>
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x8_F32TF32TF32F32_TN>
|
||||
{
|
||||
using ElementDVal = float;
|
||||
using ElementAVal = cutlass::tfloat32_t;
|
||||
using ElementBVal = cutlass::tfloat32_t;
|
||||
using ElementCVal = float;
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = cutlass::tfloat32_t;
|
||||
using ValTypeB = cutlass::tfloat32_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using Shape_MNK = Shape<_16,_8,_8>;
|
||||
using ThrID = Layout<_32>;
|
||||
@ -187,10 +187,10 @@ struct MMA_Traits<SM80_16x8x8_F32TF32TF32F32_TN>
|
||||
template <>
|
||||
struct MMA_Traits<SM80_8x8x4_F64F64F64F64_TN>
|
||||
{
|
||||
using ElementDVal = double;
|
||||
using ElementAVal = double;
|
||||
using ElementBVal = double;
|
||||
using ElementCVal = double;
|
||||
using ValTypeD = double;
|
||||
using ValTypeA = double;
|
||||
using ValTypeB = double;
|
||||
using ValTypeC = double;
|
||||
|
||||
using Shape_MNK = Shape<_8,_8,_4>;
|
||||
using ThrID = Layout<_32>;
|
||||
@ -204,10 +204,10 @@ template <>
|
||||
struct MMA_Traits<SM80_8x8x4_C64C64C64C64_TN>
|
||||
: MMA_Traits<SM80_8x8x4_F64F64F64F64_TN>
|
||||
{
|
||||
using ElementDVal = complex<double>;
|
||||
using ElementAVal = complex<double>;
|
||||
using ElementBVal = complex<double>;
|
||||
using ElementCVal = complex<double>;
|
||||
using ValTypeD = complex<double>;
|
||||
using ValTypeA = complex<double>;
|
||||
using ValTypeB = complex<double>;
|
||||
using ValTypeC = complex<double>;
|
||||
};
|
||||
|
||||
// Custom complex fp64 MMA composed of 3 fp64 MMAs -- same layouts
|
||||
@ -215,10 +215,10 @@ template <>
|
||||
struct MMA_Traits<SM80_8x8x4_GC64C64C64GC64_TN>
|
||||
: MMA_Traits<SM80_8x8x4_F64F64F64F64_TN>
|
||||
{
|
||||
using ElementDVal = typename SM80_8x8x4_GC64C64C64GC64_TN::GaussComplex;
|
||||
using ElementAVal = complex<double>;
|
||||
using ElementBVal = complex<double>;
|
||||
using ElementCVal = typename SM80_8x8x4_GC64C64C64GC64_TN::GaussComplex;
|
||||
using ValTypeD = typename SM80_8x8x4_GC64C64C64GC64_TN::GaussComplex;
|
||||
using ValTypeA = complex<double>;
|
||||
using ValTypeB = complex<double>;
|
||||
using ValTypeC = typename SM80_8x8x4_GC64C64C64GC64_TN::GaussComplex;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -228,10 +228,10 @@ struct MMA_Traits<SM80_8x8x4_GC64C64C64GC64_TN>
|
||||
template <>
|
||||
struct MMA_Traits<SM80_8x8x16_S32S8S8S32_TN>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = int8_t;
|
||||
using ElementBVal = int8_t;
|
||||
using ElementCVal = int32_t;
|
||||
using ValTypeD = int32_t;
|
||||
using ValTypeA = int8_t;
|
||||
using ValTypeB = int8_t;
|
||||
using ValTypeC = int32_t;
|
||||
|
||||
using Shape_MNK = Shape<_8,_8,_16>;
|
||||
using ThrID = Layout<_32>;
|
||||
@ -247,10 +247,10 @@ struct MMA_Traits<SM80_8x8x16_S32S8S8S32_TN_SATURATE>
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x16_S32S8S8S32_TN>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = int8_t;
|
||||
using ElementBVal = int8_t;
|
||||
using ElementCVal = int32_t;
|
||||
using ValTypeD = int32_t;
|
||||
using ValTypeA = int8_t;
|
||||
using ValTypeB = int8_t;
|
||||
using ValTypeC = int32_t;
|
||||
|
||||
using Shape_MNK = Shape<_16,_8,_16>;
|
||||
using ThrID = Layout<_32>;
|
||||
@ -267,10 +267,10 @@ struct MMA_Traits<SM80_16x8x16_S32S8S8S32_TN_SATURATE>
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x32_S32S8S8S32_TN>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = int8_t;
|
||||
using ElementBVal = int8_t;
|
||||
using ElementCVal = int32_t;
|
||||
using ValTypeD = int32_t;
|
||||
using ValTypeA = int8_t;
|
||||
using ValTypeB = int8_t;
|
||||
using ValTypeC = int32_t;
|
||||
|
||||
using Shape_MNK = Shape<_16,_8,_32>;
|
||||
using ThrID = Layout<_32>;
|
||||
@ -293,10 +293,10 @@ template <>
|
||||
struct MMA_Traits<SM80_8x8x16_S32S8U8S32_TN>
|
||||
: MMA_Traits<SM80_8x8x16_S32S8S8S32_TN>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = int8_t;
|
||||
using ElementBVal = uint8_t;
|
||||
using ElementCVal = int32_t;
|
||||
using ValTypeD = int32_t;
|
||||
using ValTypeA = int8_t;
|
||||
using ValTypeB = uint8_t;
|
||||
using ValTypeC = int32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -307,10 +307,10 @@ template <>
|
||||
struct MMA_Traits<SM80_16x8x16_S32S8U8S32_TN>
|
||||
: MMA_Traits<SM80_16x8x16_S32S8S8S32_TN>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = int8_t;
|
||||
using ElementBVal = uint8_t;
|
||||
using ElementCVal = int32_t;
|
||||
using ValTypeD = int32_t;
|
||||
using ValTypeA = int8_t;
|
||||
using ValTypeB = uint8_t;
|
||||
using ValTypeC = int32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -321,10 +321,10 @@ template <>
|
||||
struct MMA_Traits<SM80_16x8x32_S32S8U8S32_TN>
|
||||
: MMA_Traits<SM80_16x8x32_S32S8S8S32_TN>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = int8_t;
|
||||
using ElementBVal = uint8_t;
|
||||
using ElementCVal = int32_t;
|
||||
using ValTypeD = int32_t;
|
||||
using ValTypeA = int8_t;
|
||||
using ValTypeB = uint8_t;
|
||||
using ValTypeC = int32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -339,10 +339,10 @@ template <>
|
||||
struct MMA_Traits<SM80_8x8x16_S32U8S8S32_TN>
|
||||
: MMA_Traits<SM80_8x8x16_S32S8S8S32_TN>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = uint8_t;
|
||||
using ElementBVal = int8_t;
|
||||
using ElementCVal = int32_t;
|
||||
using ValTypeD = int32_t;
|
||||
using ValTypeA = uint8_t;
|
||||
using ValTypeB = int8_t;
|
||||
using ValTypeC = int32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -353,10 +353,10 @@ template <>
|
||||
struct MMA_Traits<SM80_16x8x16_S32U8S8S32_TN>
|
||||
: MMA_Traits<SM80_16x8x16_S32S8S8S32_TN>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = uint8_t;
|
||||
using ElementBVal = int8_t;
|
||||
using ElementCVal = int32_t;
|
||||
using ValTypeD = int32_t;
|
||||
using ValTypeA = uint8_t;
|
||||
using ValTypeB = int8_t;
|
||||
using ValTypeC = int32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -367,10 +367,10 @@ template <>
|
||||
struct MMA_Traits<SM80_16x8x32_S32U8S8S32_TN>
|
||||
: MMA_Traits<SM80_16x8x32_S32S8S8S32_TN>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = uint8_t;
|
||||
using ElementBVal = int8_t;
|
||||
using ElementCVal = int32_t;
|
||||
using ValTypeD = int32_t;
|
||||
using ValTypeA = uint8_t;
|
||||
using ValTypeB = int8_t;
|
||||
using ValTypeC = int32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -385,10 +385,10 @@ template <>
|
||||
struct MMA_Traits<SM80_8x8x16_S32U8U8S32_TN>
|
||||
: MMA_Traits<SM80_8x8x16_S32S8S8S32_TN>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = uint8_t;
|
||||
using ElementBVal = uint8_t;
|
||||
using ElementCVal = int32_t;
|
||||
using ValTypeD = int32_t;
|
||||
using ValTypeA = uint8_t;
|
||||
using ValTypeB = uint8_t;
|
||||
using ValTypeC = int32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -399,10 +399,10 @@ template <>
|
||||
struct MMA_Traits<SM80_16x8x16_S32U8U8S32_TN>
|
||||
: MMA_Traits<SM80_16x8x16_S32S8S8S32_TN>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = uint8_t;
|
||||
using ElementBVal = uint8_t;
|
||||
using ElementCVal = int32_t;
|
||||
using ValTypeD = int32_t;
|
||||
using ValTypeA = uint8_t;
|
||||
using ValTypeB = uint8_t;
|
||||
using ValTypeC = int32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -413,10 +413,10 @@ template <>
|
||||
struct MMA_Traits<SM80_16x8x32_S32U8U8S32_TN>
|
||||
: MMA_Traits<SM80_16x8x32_S32S8S8S32_TN>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = uint8_t;
|
||||
using ElementBVal = uint8_t;
|
||||
using ElementCVal = int32_t;
|
||||
using ValTypeD = int32_t;
|
||||
using ValTypeA = uint8_t;
|
||||
using ValTypeB = uint8_t;
|
||||
using ValTypeC = int32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -430,10 +430,10 @@ struct MMA_Traits<SM80_16x8x32_S32U8U8S32_TN_SATURATE>
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x256_S32U1U1S32_TN_XORPOPC>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = cute::uint1b_t;
|
||||
using ElementBVal = cute::uint1b_t;
|
||||
using ElementCVal = int32_t;
|
||||
using ValTypeD = int32_t;
|
||||
using ValTypeA = cute::uint1b_t;
|
||||
using ValTypeB = cute::uint1b_t;
|
||||
using ValTypeC = int32_t;
|
||||
|
||||
using Shape_MNK = Shape<_16,_8,_256>;
|
||||
using ThrID = Layout<_32>;
|
||||
|
||||
@ -44,10 +44,10 @@ namespace cute {
|
||||
template <>
|
||||
struct MMA_Traits<SM90_16x8x4_F64F64F64F64_TN>
|
||||
{
|
||||
using ElementDVal = double;
|
||||
using ElementAVal = double;
|
||||
using ElementBVal = double;
|
||||
using ElementCVal = double;
|
||||
using ValTypeD = double;
|
||||
using ValTypeA = double;
|
||||
using ValTypeB = double;
|
||||
using ValTypeC = double;
|
||||
|
||||
using Shape_MNK = Shape<_16,_8,_4>;
|
||||
using ThrID = Layout<_32>;
|
||||
@ -62,10 +62,10 @@ struct MMA_Traits<SM90_16x8x4_F64F64F64F64_TN>
|
||||
template <>
|
||||
struct MMA_Traits<SM90_16x8x8_F64F64F64F64_TN>
|
||||
{
|
||||
using ElementDVal = double;
|
||||
using ElementAVal = double;
|
||||
using ElementBVal = double;
|
||||
using ElementCVal = double;
|
||||
using ValTypeD = double;
|
||||
using ValTypeA = double;
|
||||
using ValTypeB = double;
|
||||
using ValTypeC = double;
|
||||
|
||||
using Shape_MNK = Shape<_16,_8,_8>;
|
||||
using ThrID = Layout<_32>;
|
||||
@ -80,10 +80,10 @@ struct MMA_Traits<SM90_16x8x8_F64F64F64F64_TN>
|
||||
template <>
|
||||
struct MMA_Traits<SM90_16x8x16_F64F64F64F64_TN>
|
||||
{
|
||||
using ElementDVal = double;
|
||||
using ElementAVal = double;
|
||||
using ElementBVal = double;
|
||||
using ElementCVal = double;
|
||||
using ValTypeD = double;
|
||||
using ValTypeA = double;
|
||||
using ValTypeB = double;
|
||||
using ValTypeC = double;
|
||||
|
||||
using Shape_MNK = Shape<_16,_8,_16>;
|
||||
using ThrID = Layout<_32>;
|
||||
@ -103,30 +103,30 @@ template <>
|
||||
struct MMA_Traits<SM90_16x8x4_C64C64C64C64_TN>
|
||||
: MMA_Traits<SM90_16x8x4_F64F64F64F64_TN>
|
||||
{
|
||||
using ElementDVal = complex<double>;
|
||||
using ElementAVal = complex<double>;
|
||||
using ElementBVal = complex<double>;
|
||||
using ElementCVal = complex<double>;
|
||||
using ValTypeD = complex<double>;
|
||||
using ValTypeA = complex<double>;
|
||||
using ValTypeB = complex<double>;
|
||||
using ValTypeC = complex<double>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM90_16x8x8_C64C64C64C64_TN>
|
||||
: MMA_Traits<SM90_16x8x8_F64F64F64F64_TN>
|
||||
{
|
||||
using ElementDVal = complex<double>;
|
||||
using ElementAVal = complex<double>;
|
||||
using ElementBVal = complex<double>;
|
||||
using ElementCVal = complex<double>;
|
||||
using ValTypeD = complex<double>;
|
||||
using ValTypeA = complex<double>;
|
||||
using ValTypeB = complex<double>;
|
||||
using ValTypeC = complex<double>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM90_16x8x16_C64C64C64C64_TN>
|
||||
: MMA_Traits<SM90_16x8x16_F64F64F64F64_TN>
|
||||
{
|
||||
using ElementDVal = complex<double>;
|
||||
using ElementAVal = complex<double>;
|
||||
using ElementBVal = complex<double>;
|
||||
using ElementCVal = complex<double>;
|
||||
using ValTypeD = complex<double>;
|
||||
using ValTypeA = complex<double>;
|
||||
using ValTypeB = complex<double>;
|
||||
using ValTypeC = complex<double>;
|
||||
};
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -479,8 +479,9 @@ weakly_congruent(IntTupleA const& a, IntTupleB const& b)
|
||||
template <class A, class B>
|
||||
using is_weakly_congruent = decltype(weakly_congruent(declval<A>(), declval<B>()));
|
||||
|
||||
/** Test if Shape B is compatible with Shape A:
|
||||
* Any coordinate into A can also be used as a coordinate into B
|
||||
/** Test if Shape A is compatible with Shape B:
|
||||
* the size of A and B are the same, and
|
||||
* any coordinate into A can also be used as a coordinate into B
|
||||
* compatible is a partial order on A and B: A <= B
|
||||
*/
|
||||
template <class IntTupleA, class IntTupleB>
|
||||
@ -509,8 +510,8 @@ compatible(IntTupleA const& a, IntTupleB const& b)
|
||||
template <class A, class B>
|
||||
using is_compatible = decltype(compatible(declval<A>(), declval<B>()));
|
||||
|
||||
/** Test if Shape B is weakly compatible with Shape A:
|
||||
* Shape B is a multiple of a shape that is compatible with Shape A
|
||||
/** Test if Shape A is weakly compatible with Shape B:
|
||||
* there exists a Shape C congruent to A such that compatible(elem_scale(A,C), B)
|
||||
* weakly_compatible is a partial order on A and B: A <= B
|
||||
*/
|
||||
template <class IntTupleA, class IntTupleB>
|
||||
|
||||
@ -36,6 +36,8 @@
|
||||
#include <cute/int_tuple.hpp>
|
||||
#include <cute/stride.hpp>
|
||||
#include <cute/numeric/arithmetic_tuple.hpp>
|
||||
#include <cute/numeric/integral_ratio.hpp>
|
||||
#include <cute/numeric/integral_constant.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
@ -167,16 +169,6 @@ struct Layout
|
||||
return operator()(make_coord(c0,c1,cs...));
|
||||
}
|
||||
|
||||
// Map a linear index to a hier ND logical coordinate
|
||||
// NOTE: Dangerous and error-prone
|
||||
template <class Int>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator[](Int const& linear_idx) const {
|
||||
static_assert(is_integral<Int>::value);
|
||||
return get_hier_coord(linear_idx);
|
||||
}
|
||||
|
||||
//
|
||||
// Compose
|
||||
//
|
||||
@ -305,11 +297,24 @@ struct Layout
|
||||
#endif
|
||||
};
|
||||
|
||||
// Equality, return a static or dynamic boolean
|
||||
template <class ShapeA, class StrideA,
|
||||
class ShapeB, class StrideB>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator==(Layout<ShapeA,StrideA> const& layoutA, Layout<ShapeB,StrideB> const& layoutB)
|
||||
{
|
||||
return layoutA.shape() == layoutB.shape() && layoutA.stride() == layoutB.stride();
|
||||
}
|
||||
|
||||
template <class Layout>
|
||||
struct is_layout : false_type {};
|
||||
template <class Shape, class Stride>
|
||||
struct is_layout<Layout<Shape,Stride>> : true_type {};
|
||||
|
||||
//
|
||||
// Layout construction
|
||||
//
|
||||
|
||||
template <class Shape, class Stride,
|
||||
__CUTE_REQUIRES((is_tuple<Shape >::value || is_integral<Shape >::value) &&
|
||||
@ -446,51 +451,59 @@ make_identity_layout(Shape const& shape)
|
||||
// Operations to manipulate Layouts like a tuple of pairs
|
||||
//
|
||||
|
||||
// Return the Is...th sublayout.
|
||||
// For Is... = <I0,I1,...,IN>, equivalent to get<IN>(...get<I1>(get<I0>(layout)))
|
||||
template <size_t... Is, class Shape, class Stride>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
get(Layout<Shape,Stride> const& layout)
|
||||
{
|
||||
// Let the static_asserts in get<I>(shape|stride) catch problems
|
||||
return make_layout(get<Is...>(layout.shape()), get<Is...>(layout.stride()));
|
||||
return make_layout(get<Is...>(layout.shape()),
|
||||
get<Is...>(layout.stride()));
|
||||
}
|
||||
|
||||
// Return a new layout with only the modes in the range [B,E)
|
||||
template <int B, int E, class Shape, class Stride>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
take(Layout<Shape,Stride> const& layout)
|
||||
{
|
||||
// Let the static_asserts in take<B,E>(shape|stride) catch problems
|
||||
return make_layout(take<B,E>(layout.shape()), take<B,E>(layout.stride()));
|
||||
static_assert(B < E, "take: empty range error");
|
||||
static_assert(0 <= B && E <= Layout<Shape,Stride>::rank, "take: range out of bounds");
|
||||
return make_layout(take<B,E>(layout.shape()),
|
||||
take<B,E>(layout.stride()));
|
||||
}
|
||||
|
||||
//
|
||||
// Select layout modes according to an index sequence.
|
||||
//
|
||||
|
||||
template <int... I, class Shape, class Stride>
|
||||
// Return a new layout with only the modes Is... = <I0,I1,...,IN>
|
||||
template <int... Is, class Shape, class Stride>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
select(Layout<Shape,Stride> const& layout)
|
||||
{
|
||||
return make_layout(select<I...>(layout.shape()),
|
||||
select<I...>(layout.stride()));
|
||||
return make_layout(select<Is...>(layout.shape()),
|
||||
select<Is...>(layout.stride()));
|
||||
}
|
||||
|
||||
// Return a layout with depth at most 1
|
||||
template <class Shape, class Stride>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
flatten(Layout<Shape,Stride> const& layout)
|
||||
{
|
||||
return make_layout(flatten(layout.shape()), flatten(layout.stride()));
|
||||
return make_layout(flatten(layout.shape()),
|
||||
flatten(layout.stride()));
|
||||
}
|
||||
|
||||
// Return a layout whose profile is congruent to TargetProfile
|
||||
// @pre Input layout is flat, flatten(@a layout) == @a layout
|
||||
// @pre Input layout can be folded to profile, rank(@a layout) == rank(flatten(@a target_profile))
|
||||
// @post congruent(@a result, @a target_profile)
|
||||
template <class Shape, class Stride, class TargetProfile>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
unflatten(Layout<Shape,Stride> const& layout, TargetProfile const& target_profile)
|
||||
{
|
||||
return make_layout(unflatten(layout.shape(), target_profile),
|
||||
return make_layout(unflatten(layout.shape(), target_profile),
|
||||
unflatten(layout.stride(), target_profile));
|
||||
}
|
||||
|
||||
@ -498,7 +511,7 @@ unflatten(Layout<Shape,Stride> const& layout, TargetProfile const& target_profil
|
||||
// Utilities
|
||||
//
|
||||
|
||||
// Return the layout of a mode
|
||||
// Return the sublayout of mode I...
|
||||
template <int... Is, class Shape, class Stride>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
@ -609,17 +622,6 @@ using cosize_t = decltype(cosize(declval<Layout>()));
|
||||
template <class Layout>
|
||||
static constexpr int cosize_v = cosize_t<Layout>::value;
|
||||
|
||||
// Equality
|
||||
// Return a static or dynamic boolean
|
||||
template <class ShapeA, class StrideA,
|
||||
class ShapeB, class StrideB>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator==(Layout<ShapeA,StrideA> const& layoutA, Layout<ShapeB,StrideB> const& layoutB)
|
||||
{
|
||||
return layoutA.shape() == layoutB.shape() && layoutA.stride() == layoutB.stride();
|
||||
}
|
||||
|
||||
// With crd2idx(coord, shape), makes sense to have crd2idx(coord, Layout) as well
|
||||
template <class Coord, class Shape, class Stride>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
@ -762,8 +764,11 @@ bw_coalesce(OldShape const& old_shape, OldStride const& old_stride,
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
// Combine all the modes that are possible to combine
|
||||
// Does not respect the profile of the layout, but does preserve total size
|
||||
// "Simplify" the layout by combining modes that are possible to combine
|
||||
// Does not respect the shape of the layout, but does preserve total size
|
||||
// @post size(@a result) == size(@a layout)
|
||||
// @post depth(@a result) <= 1
|
||||
// @post for all i, 0 <= i < size(@a layout), @a layout(i) == @a result(i)
|
||||
template <class Shape, class Stride>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
@ -894,7 +899,7 @@ group(Layout<Shape,Stride> const& layout)
|
||||
// Composition of two layouts: lhs o rhs
|
||||
// @post compatible(rhs, result)
|
||||
// @post result(c) = lhs(rhs(c))
|
||||
// for all c in the domain of result
|
||||
// for all c in the domain of rhs
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
@ -984,19 +989,19 @@ composition(Layout<LShape,LStride> const& lhs,
|
||||
return detail::composition_impl(lhs, rhs.shape(), rhs.stride());
|
||||
}
|
||||
|
||||
template <class LShape, class LStride, class IntTuple>
|
||||
template <class LShape, class LStride, class Tiler>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
composition(Layout<LShape,LStride> const& lhs,
|
||||
IntTuple const& rhs)
|
||||
Tiler const& rhs)
|
||||
{
|
||||
if constexpr (is_tuple<IntTuple>::value) {
|
||||
static_assert(tuple_size<IntTuple>::value <= Layout<LShape,LStride>::rank);
|
||||
if constexpr (is_tuple<Tiler>::value) {
|
||||
static_assert(tuple_size<Tiler>::value <= Layout<LShape,LStride>::rank);
|
||||
// Drop any modes of lhs that aren't hit by rhs
|
||||
return detail::transform_layout(lhs, rhs, [](auto const& l, auto const& r) { return composition(l,r); }, make_seq<tuple_size<IntTuple>::value>{}, seq<>{}, seq<>{});
|
||||
} else if constexpr (is_underscore<IntTuple>::value) {
|
||||
return detail::transform_layout(lhs, rhs, [](auto const& l, auto const& r) { return composition(l,r); }, make_seq<tuple_size<Tiler>::value>{}, seq<>{}, seq<>{});
|
||||
} else if constexpr (is_underscore<Tiler>::value) {
|
||||
return lhs;
|
||||
} else if constexpr (is_integral<IntTuple>::value) {
|
||||
} else if constexpr (is_integral<Tiler>::value) {
|
||||
return detail::composition_impl(lhs, rhs, Int<1>{});
|
||||
}
|
||||
|
||||
@ -1041,19 +1046,25 @@ complement(Shape const& shape, Stride const& stride, CoSizeHi const& cosize_hi)
|
||||
auto [shape, stride, result_shape, result_stride] = init;
|
||||
auto min_stride = cute::min(stride);
|
||||
auto min_idx = find(stride, min_stride);
|
||||
auto new_shape = min_stride / get<i>(result_stride);
|
||||
auto new_stride = get<min_idx>(shape) * min_stride;
|
||||
static_assert(not is_constant<0, decltype(new_shape)>::value, "Non-injective Layout detected in complement.");
|
||||
|
||||
return cute::make_tuple(remove<min_idx>(shape), // Remove the min_idx from shape
|
||||
remove<min_idx>(stride), // Remove the min_idx from stride
|
||||
append(result_shape , min_stride / get<i>(result_stride)), // new shape = min_stride / last_stride
|
||||
append(result_stride, get<min_idx>(shape) * min_stride)); // new stride = curr_shape * min_stride
|
||||
return cute::make_tuple(remove<min_idx>(shape), // Remove the min_idx from shape
|
||||
remove<min_idx>(stride), // Remove the min_idx from stride
|
||||
append(result_shape , new_shape ), // new shape = min_stride / last_stride
|
||||
append(result_stride, new_stride)); // new stride = curr_shape * min_stride
|
||||
});
|
||||
|
||||
// Append the last shape mode
|
||||
auto result_shape = append(result_shape_, get<0>(stride_) / get<R-1>(result_stride)); // new shape = min_stride / last_stride
|
||||
auto new_shape = get<0>(stride_) / get<R-1>(result_stride);
|
||||
static_assert(not is_constant<0, decltype(new_shape)>::value, "Non-injective Layout detected in complement.");
|
||||
auto result_shape = append(result_shape_, new_shape); // new shape = min_stride / last_stride
|
||||
|
||||
// Compute the rest_shape and rest_stride
|
||||
auto rest_stride = get<0>(shape_) * get<0>(stride_);
|
||||
auto rest_shape = ceil_div(cosize_hi, rest_stride);
|
||||
|
||||
// Jump into coalesce and append (rest_shape, rest_stride)
|
||||
return detail::bw_coalesce<R-1>(result_shape, result_stride, rest_shape, rest_stride);
|
||||
}
|
||||
@ -1323,14 +1334,14 @@ zip(Layout<TShape,TStride> const& layoutA,
|
||||
// their own mode.
|
||||
//
|
||||
|
||||
template <class LShape, class LStride, class IntTuple>
|
||||
template <class LShape, class LStride, class Tiler>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tile_unzip(Layout<LShape,LStride> const& layout,
|
||||
IntTuple const& tile)
|
||||
Tiler const& tiler)
|
||||
{
|
||||
return make_layout(zip2_by(layout.shape(), tile),
|
||||
zip2_by(layout.stride(), tile));
|
||||
return make_layout(zip2_by(layout.shape(), tiler),
|
||||
zip2_by(layout.stride(), tiler));
|
||||
}
|
||||
|
||||
//
|
||||
@ -1389,10 +1400,10 @@ auto
|
||||
tiled_divide(Layout<LShape,LStride> const& layout,
|
||||
Tiler const& tiler)
|
||||
{
|
||||
auto div = zipped_divide(layout, tiler);
|
||||
auto result = zipped_divide(layout, tiler);
|
||||
|
||||
auto R = rank<1>(div);
|
||||
return div(_, repeat<R>(_));
|
||||
auto R1 = rank<1>(result);
|
||||
return result(_, repeat<R1>(_));
|
||||
}
|
||||
|
||||
// Same as zipped_divide, but unpacks both modes: (BLK_A,BLK_B,...,a,b,...,x,y)
|
||||
@ -1403,40 +1414,41 @@ auto
|
||||
flat_divide(Layout<LShape,LStride> const& layout,
|
||||
Tiler const& tiler)
|
||||
{
|
||||
auto div = zipped_divide(layout, tiler);
|
||||
auto result = zipped_divide(layout, tiler);
|
||||
|
||||
auto R0 = rank<0>(div);
|
||||
auto R1 = rank<1>(div);
|
||||
return div(repeat<R0>(_), repeat<R1>(_));
|
||||
auto R0 = rank<0>(result);
|
||||
auto R1 = rank<1>(result);
|
||||
return result(repeat<R0>(_), repeat<R1>(_));
|
||||
}
|
||||
|
||||
//
|
||||
// Logical product
|
||||
//
|
||||
|
||||
// @post compatible()
|
||||
template <class LShape, class LStride,
|
||||
class TShape, class TStride>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
logical_product(Layout<LShape,LStride> const& layout,
|
||||
logical_product(Layout<LShape,LStride> const& block,
|
||||
Layout<TShape,TStride> const& tiler)
|
||||
{
|
||||
return make_layout(layout, composition(complement(layout, size(layout)*cosize(tiler)), tiler));
|
||||
return make_layout(block, composition(complement(block, size(block)*cosize(tiler)), tiler));
|
||||
}
|
||||
|
||||
template <class LShape, class LStride, class Tiler>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
logical_product(Layout<LShape,LStride> const& layout,
|
||||
logical_product(Layout<LShape,LStride> const& block,
|
||||
Tiler const& tiler)
|
||||
{
|
||||
if constexpr (is_tuple<Tiler>::value) {
|
||||
static_assert(tuple_size<Tiler>::value <= Layout<LShape,LStride>::rank, "logical_product: Too many modes in tiler.");
|
||||
return transform_layout(layout, tiler, [](auto const& l, auto const& t) { return logical_product(l,t); });
|
||||
return transform_layout(block, tiler, [](auto const& l, auto const& t) { return logical_product(l,t); });
|
||||
} else if constexpr (is_underscore<Tiler>::value) {
|
||||
return layout;
|
||||
return block;
|
||||
} else if constexpr (is_integral<Tiler>::value) {
|
||||
return logical_product(layout, make_layout(tiler));
|
||||
return logical_product(block, make_layout(tiler));
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
@ -1452,10 +1464,10 @@ template <class LShape, class LStride,
|
||||
class Tiler>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
zipped_product(Layout<LShape,LStride> const& layout,
|
||||
zipped_product(Layout<LShape,LStride> const& block,
|
||||
Tiler const& tiler)
|
||||
{
|
||||
return tile_unzip(logical_product(layout, tiler), tiler);
|
||||
return tile_unzip(logical_product(block, tiler), tiler);
|
||||
}
|
||||
|
||||
// Same as zipped_product, but unpacks the second mode: ((BLK_A,BLK_B,...),a,b,...,x,y)
|
||||
@ -1463,69 +1475,107 @@ template <class LShape, class LStride,
|
||||
class Tiler>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tiled_product(Layout<LShape,LStride> const& layout,
|
||||
tiled_product(Layout<LShape,LStride> const& block,
|
||||
Tiler const& tiler)
|
||||
{
|
||||
auto div = zipped_product(layout, tiler);
|
||||
auto result = zipped_product(block, tiler);
|
||||
|
||||
auto R = rank<1>(div);
|
||||
return div(_, repeat<R>(_));
|
||||
auto R1 = rank<1>(result);
|
||||
return result(_, repeat<R1>(_));
|
||||
}
|
||||
|
||||
// Attempts to reproduce a layout over a tiler
|
||||
// That is, think of every element of "tiler" as a "layout"
|
||||
// and return the layout of the resulting structure
|
||||
// Same as zipped_product, but unpacks both modes: (BLK_A,BLK_B,...,a,b,...,x,y)
|
||||
template <class LShape, class LStride,
|
||||
class Tiler>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
flat_product(Layout<LShape,LStride> const& block,
|
||||
Tiler const& tiler)
|
||||
{
|
||||
auto result = zipped_product(block, tiler);
|
||||
|
||||
auto R0 = rank<0>(result);
|
||||
auto R1 = rank<1>(result);
|
||||
return result(repeat<R0>(_), repeat<R1>(_));
|
||||
}
|
||||
|
||||
//
|
||||
// Rank-sensitive products
|
||||
//
|
||||
|
||||
// blocked_product -- Reproduce a block over a tiler.
|
||||
// Think of every element of "tiler" as a "block"
|
||||
// and return the layout of the resulting structure.
|
||||
// @post rank(@a result) == cute::max(rank(@a block), rank(@a tiler))
|
||||
template <class TShape, class TStride,
|
||||
class UShape, class UStride>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
blocked_product(Layout<TShape,TStride> const& layout,
|
||||
blocked_product(Layout<TShape,TStride> const& block,
|
||||
Layout<UShape,UStride> const& tiler)
|
||||
{
|
||||
constexpr int R = cute::max(rank_v<TShape>, rank_v<UShape>);
|
||||
|
||||
auto result = logical_product(append<R>(layout), append<R>(tiler));
|
||||
auto result = logical_product(append<R>(block), append<R>(tiler));
|
||||
|
||||
return coalesce(zip(get<0>(result), get<1>(result)), repeat<R>(Int<1>{}));
|
||||
return coalesce(zip(get<0>(result), get<1>(result)), tuple_repeat<R>(Int<1>{}));
|
||||
}
|
||||
|
||||
// raked_product -- Reproduce a block over a tiler with block-interleaving.
|
||||
// Think of every element of "tiler" as a "block", interleave those blocks,
|
||||
// and return the layout of the resulting structure.
|
||||
// @post rank(@a result) == cute::max(rank(@a block), rank(@a tiler))
|
||||
template <class TShape, class TStride,
|
||||
class UShape, class UStride>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
raked_product(Layout<TShape,TStride> const& layout,
|
||||
raked_product(Layout<TShape,TStride> const& block,
|
||||
Layout<UShape,UStride> const& tiler)
|
||||
{
|
||||
constexpr int R = cute::max(rank_v<TShape>, rank_v<UShape>);
|
||||
|
||||
auto result = logical_product(append<R>(layout), append<R>(tiler));
|
||||
auto result = logical_product(append<R>(block), append<R>(tiler));
|
||||
|
||||
return coalesce(zip(get<1>(result), get<0>(result)), repeat<R>(Int<1>{}));
|
||||
return coalesce(zip(get<1>(result), get<0>(result)), tuple_repeat<R>(Int<1>{}));
|
||||
}
|
||||
|
||||
// tile_to_shape -- Perform a product of a layout so that the result matches a target shape.
|
||||
// This is similar to blocked_product, but specifies the result shape instead of the
|
||||
// product shape, which is more convenient in certain circumstances.
|
||||
// @param block The layout to repeat
|
||||
// @param trg_shape The target shape of the result
|
||||
// @param ord_shape The order of the modes of @a trg_shape to tile @a layout with.
|
||||
// Defaults to GenColMajor, so @a layout will repeat
|
||||
// across the first mode first, the second mode second, etc
|
||||
// E.g. Step<_2,_1,_3> will cause @a layout to repeat
|
||||
// across the second mode first, the first mode second, and the third mode last.
|
||||
// @pre rank(@a block) <= rank(@a trg_shape)
|
||||
// @post compatible(@a trg_shape, shape(@a result))
|
||||
template <class Shape, class Stride,
|
||||
class TrgShape, class ModeOrder = GenColMajor>
|
||||
class TrgShape, class ModeOrder = LayoutLeft>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tile_to_shape(Layout<Shape,Stride> const& layout,
|
||||
tile_to_shape(Layout<Shape,Stride> const& block,
|
||||
TrgShape const& trg_shape,
|
||||
ModeOrder const& ord_shape = {})
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(rank(layout) <= rank(trg_shape), "Rank of layout must be <= rank of target shape.");
|
||||
CUTE_STATIC_ASSERT_V(rank(block) <= rank(trg_shape), "Rank of layout must be <= rank of target shape.");
|
||||
constexpr int R = rank_v<TrgShape>;
|
||||
|
||||
auto padded_layout = append<R>(layout);
|
||||
auto padded_block = append<R>(block);
|
||||
|
||||
auto layout_shape = product_each(padded_layout.shape());
|
||||
auto target_shape = product_each(trg_shape);
|
||||
auto block_shape = product_each(shape(padded_block));
|
||||
auto target_shape = product_each(shape(trg_shape));
|
||||
|
||||
// Assert proper division
|
||||
CUTE_STATIC_ASSERT_V(sum(transform(target_shape, layout_shape, modulus{})) == Int<0>{},
|
||||
"Layout shape does not divide the target shape.");
|
||||
if constexpr (is_static<decltype(target_shape)>::value) {
|
||||
CUTE_STATIC_ASSERT_V(weakly_compatible(block_shape, target_shape),
|
||||
"tile_to_shape: block shape does not divide the target shape.");
|
||||
}
|
||||
|
||||
auto product_shape = shape_div(target_shape, layout_shape);
|
||||
auto product_shape = ceil_div(target_shape, block_shape);
|
||||
|
||||
return coalesce(blocked_product(padded_layout, make_ordered_layout(product_shape, ord_shape)), product_shape);
|
||||
return coalesce(blocked_product(padded_block, make_ordered_layout(product_shape, ord_shape)), product_shape);
|
||||
}
|
||||
|
||||
//
|
||||
@ -1602,15 +1652,20 @@ CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
recast_layout(Layout<Shape,Stride> const& layout)
|
||||
{
|
||||
if constexpr (sizeof_bits<NewType>::value == sizeof_bits<OldType>::value) {
|
||||
using scale = decltype(trait_ratio(sizeof_bits<NewType>{}, sizeof_bits<OldType>{}));
|
||||
if constexpr (scale::num == 1 && scale::den == 1) {
|
||||
return layout;
|
||||
} else if constexpr (sizeof_bits<NewType>::value > sizeof_bits<OldType>::value) {
|
||||
static_assert(sizeof_bits<NewType>::value % sizeof_bits<OldType>::value == 0, "NewType must be a multiple of OldType");
|
||||
return upcast<sizeof_bits<NewType>::value/sizeof_bits<OldType>::value>(layout);
|
||||
} else if constexpr (sizeof_bits<NewType>::value < sizeof_bits<OldType>::value) {
|
||||
static_assert(sizeof_bits<OldType>::value % sizeof_bits<NewType>::value == 0, "NewType must be a divisor of OldType");
|
||||
return downcast<sizeof_bits<OldType>::value/sizeof_bits<NewType>::value>(layout);
|
||||
}
|
||||
else if constexpr (scale::num == 1) {
|
||||
return downcast<scale::den>(layout);
|
||||
}
|
||||
else if constexpr (scale::den == 1) {
|
||||
return upcast<scale::num>(layout);
|
||||
}
|
||||
else {
|
||||
static_assert(dependent_false<scale>, "Recast not supported.");
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
@ -1693,12 +1748,13 @@ print_layout(Layout const& layout, ThrID const& thrid) // (m,n) -> (tid,vid) a
|
||||
}
|
||||
|
||||
// Generic 2D Layout to Latex printer -- B&W 8-value color coding
|
||||
template <class Layout>
|
||||
template <class LayoutA>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_latex(Layout const& layout) // (m,n) -> idx
|
||||
print_latex(LayoutA const& layout_a)
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(layout_a) <= Int<2>{});
|
||||
auto layout = append<2>(layout_a, Layout<_1,_0>{});
|
||||
|
||||
char const* latex_header =
|
||||
"\\documentclass[convert]{standalone}\n"
|
||||
@ -1727,7 +1783,6 @@ print_latex(Layout const& layout) // (m,n) -> idx
|
||||
for (int i = 0; i < size<0>(layout); ++i) {
|
||||
for (int j = 0; j < size<1>(layout); ++j) {
|
||||
int idx = layout(i,j);
|
||||
|
||||
printf("\\node[box,fill=%s] at (%d,%d) {%d};\n",
|
||||
color_map[idx % 8],
|
||||
i, j,
|
||||
|
||||
@ -37,7 +37,7 @@
|
||||
/* This implements a ComposedLayout of the form
|
||||
* LayoutA o Offset o LayoutB
|
||||
* and is useful in cases where composition() does not or cannot apply to LayoutA and LayoutB.
|
||||
* For example, then the "divisibility condition" in shape_div is violated in composition(LayoutA, LayoutB).
|
||||
* For example, when the "divisibility condition" in shape_div is violated in composition(LayoutA, LayoutB).
|
||||
*
|
||||
* This ComposedLayout provides similar functionality to Layout including tiling, partitioning,
|
||||
* coordinate-to-index mapping and layout manipulations, but is not considered a "normal" layout.
|
||||
@ -357,12 +357,11 @@ composition(LayoutA const& layoutA,
|
||||
return ComposedLayout<LayoutA, Offset, LayoutB>{layoutA, offset, layoutB};
|
||||
}
|
||||
|
||||
template <class A, class O, class B,
|
||||
class LayoutOrTile>
|
||||
template <class A, class O, class B, class Tiler>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
composition(ComposedLayout<A,O,B> const& a,
|
||||
LayoutOrTile const& b)
|
||||
Tiler const& b)
|
||||
{
|
||||
return composition(a.layout_a(), a.offset(), composition(a.layout_b(), b));
|
||||
}
|
||||
@ -433,92 +432,101 @@ zip(ComposedLayout<A,O,B> const& a)
|
||||
|
||||
// Partitions
|
||||
|
||||
template <class A, class O, class B,
|
||||
class Tile>
|
||||
template <class A, class O, class B, class Tiler>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
logical_divide(ComposedLayout<A,O,B> const& a,
|
||||
Tile const& b)
|
||||
Tiler const& b)
|
||||
{
|
||||
return composition(a.layout_a(), a.offset(), logical_divide(a.layout_b(), b));
|
||||
}
|
||||
|
||||
template <class A, class O, class B,
|
||||
class Tile>
|
||||
template <class A, class O, class B, class Tiler>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tile_unzip(ComposedLayout<A,O,B> const& a,
|
||||
Tile const& b)
|
||||
Tiler const& b)
|
||||
{
|
||||
return composition(a.layout_a(), a.offset(), tile_unzip(a.layout_b(), b));
|
||||
}
|
||||
|
||||
template <class A, class O, class B,
|
||||
class Tile>
|
||||
template <class A, class O, class B, class Tiler>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tiled_divide(ComposedLayout<A,O,B> const& a,
|
||||
Tile const& b)
|
||||
Tiler const& b)
|
||||
{
|
||||
return composition(a.layout_a(), a.offset(), tiled_divide(a.layout_b(), b));
|
||||
}
|
||||
|
||||
template <class A, class O, class B,
|
||||
class Tile>
|
||||
template <class A, class O, class B, class Tiler>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
zipped_divide(ComposedLayout<A,O,B> const& a,
|
||||
Tile const& b)
|
||||
Tiler const& b)
|
||||
{
|
||||
return composition(a.layout_a(), a.offset(), zipped_divide(a.layout_b(), b));
|
||||
}
|
||||
|
||||
template <class A, class O, class B,
|
||||
class Tile>
|
||||
template <class A, class O, class B, class Tiler>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
flat_divide(ComposedLayout<A,O,B> const& a,
|
||||
Tile const& b)
|
||||
Tiler const& b)
|
||||
{
|
||||
return composition(a.layout_a(), a.offset(), flat_divide(a.layout_b(), b));
|
||||
}
|
||||
|
||||
template <class A, class O, class B,
|
||||
class Tile>
|
||||
template <class A, class O, class B, class Tiler>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
logical_product(ComposedLayout<A,O,B> const& a,
|
||||
Tile const& b)
|
||||
Tiler const& b)
|
||||
{
|
||||
return composition(a.layout_a(), a.offset(), logical_product(a.layout_b(), b));
|
||||
}
|
||||
|
||||
template <class A, class O, class B,
|
||||
class Tile>
|
||||
template <class A, class O, class B, class Tiler>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
zipped_product(ComposedLayout<A,O,B> const& a,
|
||||
Tiler const& b)
|
||||
{
|
||||
return composition(a.layout_a(), a.offset(), zipped_product(a.layout_b(), b));
|
||||
}
|
||||
|
||||
template <class A, class O, class B, class Tiler>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tiled_product(ComposedLayout<A,O,B> const& a,
|
||||
Tile const& b)
|
||||
Tiler const& b)
|
||||
{
|
||||
return composition(a.layout_a(), a.offset(), tiled_product(a.layout_b(), b));
|
||||
}
|
||||
|
||||
template <class A, class O, class B,
|
||||
class Tile>
|
||||
template <class A, class O, class B, class Tiler>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
flat_product(ComposedLayout<A,O,B> const& a,
|
||||
Tiler const& b)
|
||||
{
|
||||
return composition(a.layout_a(), a.offset(), flat_product(a.layout_b(), b));
|
||||
}
|
||||
|
||||
template <class A, class O, class B, class Tiler>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
blocked_product(ComposedLayout<A,O,B> const& a,
|
||||
Tile const& b)
|
||||
Tiler const& b)
|
||||
{
|
||||
return composition(a.layout_a(), a.offset(), blocked_product(a.layout_b(), b));
|
||||
}
|
||||
|
||||
template <class A, class O, class B,
|
||||
class Tile>
|
||||
template <class A, class O, class B, class Tiler>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
raked_product(ComposedLayout<A,O,B> const& a,
|
||||
Tile const& b)
|
||||
Tiler const& b)
|
||||
{
|
||||
return composition(a.layout_a(), a.offset(), raked_product(a.layout_b(), b));
|
||||
}
|
||||
@ -585,16 +593,19 @@ CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
recast_layout(ComposedLayout<A,O,B> const& layout)
|
||||
{
|
||||
if constexpr (sizeof(NewType) == sizeof(OldType)) {
|
||||
using scale = decltype(trait_ratio(sizeof_bits<NewType>{}, sizeof_bits<OldType>{}));
|
||||
if constexpr (scale::num == 1 && scale::den == 1) {
|
||||
return layout;
|
||||
} else if constexpr (sizeof(NewType) > sizeof(OldType)) {
|
||||
static_assert(sizeof(NewType) % sizeof(OldType) == 0, "NewType must be a multiple of OldType");
|
||||
return upcast<sizeof(NewType)/sizeof(OldType)>(layout);
|
||||
} else if constexpr (sizeof(NewType) < sizeof(OldType)) {
|
||||
static_assert(sizeof(OldType) % sizeof(NewType) == 0, "NewType must be a divisor of OldType");
|
||||
return downcast<sizeof(OldType)/sizeof(NewType)>(layout);
|
||||
}
|
||||
|
||||
else if constexpr (scale::num == 1) {
|
||||
return downcast<scale::den>(layout);
|
||||
}
|
||||
else if constexpr (scale::den == 1) {
|
||||
return upcast<scale::num>(layout);
|
||||
}
|
||||
else {
|
||||
static_assert(dependent_false<scale>, "Recast not supported.");
|
||||
}
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
|
||||
@ -413,6 +413,19 @@ conditional_return(TrueType const& t, FalseType const& f) {
|
||||
}
|
||||
}
|
||||
|
||||
template <class Trait>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
static_value()
|
||||
{
|
||||
if constexpr (is_std_integral<decltype(Trait::value)>::value) {
|
||||
return Int<Trait::value>{};
|
||||
} else {
|
||||
return Trait::value;
|
||||
}
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// Display utilities
|
||||
//
|
||||
|
||||
@ -65,6 +65,11 @@ class R {
|
||||
using type = typename conditional<num == 0 || den == 1, C<num>, R<num,den>>::type;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct is_ratio : false_type {};
|
||||
template <auto n, auto d>
|
||||
struct is_ratio<R<n,d>> : true_type {};
|
||||
|
||||
template <auto a, auto b>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
typename R<a,b>::type
|
||||
@ -72,6 +77,59 @@ ratio(C<a>, C<b>) {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <auto a, auto b, auto c>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
typename R<a*c,b>::type
|
||||
ratio(C<a>, R<b,c>) {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <auto a, auto b, auto c>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
typename R<b,a*c>::type
|
||||
ratio(R<b,c>, C<a>) {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <auto a, auto b, auto c, auto d>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
typename R<a*d,b*c>::type
|
||||
ratio(R<a,b>, R<c,d>) {
|
||||
return {};
|
||||
}
|
||||
|
||||
//
|
||||
// Non-reduced ratio implementations
|
||||
//
|
||||
|
||||
template <auto a, auto b>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
R<a,b>
|
||||
nratio(C<a>, C<b>) {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <auto a, auto b, auto c>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
R<a*c,b>
|
||||
nratio(C<a>, R<b,c>) {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <auto a, auto b, auto c>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
R<b,a*c>
|
||||
nratio(R<b,c>, C<a>) {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <auto a, auto b, auto c, auto d>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
R<a*d,b*c>
|
||||
nratio(R<a,b>, R<c,d>) {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <auto a, auto b, auto x, auto y>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
typename R<a*x,b*y>::type
|
||||
@ -93,6 +151,13 @@ operator*(C<c>, R<a,b>) {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <auto c, auto a, auto b>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
typename R<c*b,a>::type
|
||||
operator/(C<c>, R<a,b>) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Product with dynamic type needs to produce an integer...
|
||||
template <class C, auto a, auto b,
|
||||
__CUTE_REQUIRES(cute::is_std_integral<C>::value)>
|
||||
@ -160,6 +225,23 @@ abs(R<a,b>) {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <auto a, auto b>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
log_2(R<a,b>) {
|
||||
static_assert(R<a,b>::num > 0);
|
||||
static_assert(R<a,b>::den > 0);
|
||||
return log_2(static_cast<uint32_t>(R<a,b>::num)) - log_2(static_cast<uint32_t>(R<a,b>::den));
|
||||
}
|
||||
|
||||
|
||||
template <class Trait0, class Trait1>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
trait_ratio(Trait0, Trait1) {
|
||||
return nratio(static_value<Trait0>(), static_value<Trait1>());
|
||||
}
|
||||
|
||||
//
|
||||
// Display utilities
|
||||
//
|
||||
|
||||
@ -310,4 +310,17 @@ safe_div(T const& t, U const& u) {
|
||||
return t / u;
|
||||
}
|
||||
|
||||
/**
|
||||
* log2 computation
|
||||
*/
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
log_2(T x) {
|
||||
assert(x > 0);
|
||||
static_assert(is_unsigned<T>::value, "Only to be used for unsigned integral types.");
|
||||
return bit_width(x) - 1;
|
||||
}
|
||||
|
||||
} // namespace cute
|
||||
|
||||
@ -41,6 +41,7 @@
|
||||
|
||||
#include <cute/pointer_base.hpp>
|
||||
#include <cute/pointer_swizzle.hpp>
|
||||
#include <cute/layout.hpp>
|
||||
namespace cute
|
||||
{
|
||||
|
||||
|
||||
@ -227,7 +227,7 @@ raw_pointer_cast(counting_iterator<T> const& x) {
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE void print(T const* const ptr)
|
||||
{
|
||||
printf("ptr[%db](%p)", int(sizeof_bits<T>::value), ptr);
|
||||
printf("ptr["); print(sizeof_bits<T>::value); printf("b](%p)", ptr);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
|
||||
@ -37,7 +37,8 @@
|
||||
namespace cute
|
||||
{
|
||||
|
||||
/** crd2idx maps a coordinate within <Shape,Stride> to an index
|
||||
/** crd2idx(c,s,d) maps a coordinate within <Shape,Stride> to an index
|
||||
*
|
||||
* This is computed as follows:
|
||||
* [coord, shape, and stride are all integers => step forward by stride]
|
||||
* op(c, s, d) => c * d
|
||||
@ -46,7 +47,6 @@ namespace cute
|
||||
* [coord, shape, and stride are all tuples => consider each mode independently]
|
||||
* op((c,C), (s,S), (d,D)) => op(c, s, d) + op((C), (S), (D))
|
||||
*/
|
||||
|
||||
template <class Coord, class Shape, class Stride>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
@ -115,10 +115,6 @@ crd2idx(Coord const& coord,
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// If we know Stride is default [CompactColMajor], then we can take shortcuts
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class CTuple, class STuple, int I0, int... Is>
|
||||
@ -138,26 +134,31 @@ crd2idx_horner(CTuple const& coord,
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
/** crd2idx(c,s) maps a coordinate within Shape to an index
|
||||
* via a colexicographical enumeration of coordinates in Shape.
|
||||
* i = c0 + s0 * (c1 + s1 * (c2 + s2 * ...))
|
||||
*/
|
||||
template <class Coord, class Shape>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
crd2idx(Coord const& coord,
|
||||
Shape const& shape)
|
||||
{
|
||||
static_assert(decltype(congruent(coord,shape))::value, "Mismatched Ranks");
|
||||
if constexpr (is_tuple<Shape>::value) {
|
||||
// Flatten and apply Horner's method
|
||||
auto flat_coord = flatten(coord);
|
||||
auto flat_shape = flatten(shape);
|
||||
return detail::crd2idx_horner(flat_coord, flat_shape, tuple_seq<decltype(flat_shape)>{});
|
||||
} else {
|
||||
if constexpr (is_integral<Coord>::value) { // Coord is already an index
|
||||
return coord;
|
||||
} else if constexpr (is_integral<Shape>::value) {
|
||||
static_assert(dependent_false<Shape>, "Invalid parameters");
|
||||
} else { // Make congruent, flatten, and apply Horner's method
|
||||
static_assert(tuple_size<Coord>::value == tuple_size<Shape>::value, "Mismatched Ranks");
|
||||
auto flat_coord = flatten(coord);
|
||||
auto flat_shape = flatten(product_like(shape, coord));
|
||||
return detail::crd2idx_horner(flat_coord, flat_shape, tuple_seq<decltype(flat_shape)>{});
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
/** idx2crd splits an index to a coordinate within <Shape,Stride>.
|
||||
/** idx2crd(i,s,d) splits an index into a coordinate within <Shape,Stride>.
|
||||
*
|
||||
* This is computed as follows:
|
||||
* [index, shape, and stride are all integers => determine 1D coord]
|
||||
@ -170,7 +171,6 @@ crd2idx(Coord const& coord,
|
||||
* NOTE: This only works for compact shape+stride layouts. A more general version would
|
||||
* apply to all surjective layouts
|
||||
*/
|
||||
|
||||
template <class Index, class Shape, class Stride>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
@ -207,15 +207,13 @@ idx2crd(Index const& idx,
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// If we know Stride is default [CompactColMajor], then we can take shortcuts
|
||||
//
|
||||
|
||||
//(idx / 1) % s0
|
||||
//(idx / s0) % s1
|
||||
//(idx / (s0 * s1)) % s2
|
||||
//...
|
||||
|
||||
/** idx2crd(i,s) splits an index into a coordinate within Shape
|
||||
* via a colexicographical enumeration of coordinates in Shape.
|
||||
* c0 = (idx / 1) % s0
|
||||
* c1 = (idx / s0) % s1
|
||||
* c2 = (idx / (s0 * s1)) % s2
|
||||
* ...
|
||||
*/
|
||||
template <class Index, class Shape>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
|
||||
@ -434,15 +434,20 @@ CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
recast_layout(Swizzle<B,M,S> const& swizzle)
|
||||
{
|
||||
if constexpr (sizeof_bits<NewType>::value == sizeof_bits<OldType>::value) {
|
||||
using scale = decltype(trait_ratio(sizeof_bits<NewType>{}, sizeof_bits<OldType>{}));
|
||||
if constexpr (scale::num == 1 && scale::den == 1) {
|
||||
return swizzle;
|
||||
} else if constexpr (sizeof_bits<NewType>::value > sizeof_bits<OldType>::value) {
|
||||
static_assert(sizeof_bits<NewType>::value % sizeof_bits<OldType>::value == 0, "NewType must be a multiple of OldType");
|
||||
return upcast<sizeof_bits<NewType>::value/sizeof_bits<OldType>::value>(swizzle);
|
||||
} else if constexpr (sizeof_bits<NewType>::value < sizeof_bits<OldType>::value) {
|
||||
static_assert(sizeof_bits<OldType>::value % sizeof_bits<NewType>::value == 0, "NewType must be a divisor of OldType");
|
||||
return downcast<sizeof_bits<OldType>::value/sizeof_bits<NewType>::value>(swizzle);
|
||||
}
|
||||
else if constexpr (scale::num == 1) {
|
||||
return downcast<scale::den>(swizzle);
|
||||
}
|
||||
else if constexpr (scale::den == 1) {
|
||||
return upcast<scale::num>(swizzle);
|
||||
}
|
||||
else {
|
||||
static_assert(dependent_false<scale>, "Recast not supported.");
|
||||
}
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
@ -453,7 +458,7 @@ template <int B, int M, int S, class Offset, class LayoutB, class Shape, class S
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
max_common_layout(ComposedLayout<Swizzle<B,M,S>,Offset,LayoutB> const& a,
|
||||
Layout<Shape,Stride> const& b)
|
||||
Layout<Shape,Stride> const& b)
|
||||
{
|
||||
auto common = max_common_layout(a.layout_b(), b);
|
||||
auto base = Int<(1 << M)>{};
|
||||
@ -467,7 +472,7 @@ max_common_layout(ComposedLayout<Swizzle<B,M,S>,Offset,LayoutB> const& a,
|
||||
template <class Shape, class Stride, int B, int M, int S, class Offset, class LayoutB>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
max_common_layout(Layout<Shape,Stride> const& a,
|
||||
max_common_layout(Layout<Shape,Stride> const& a,
|
||||
ComposedLayout<Swizzle<B,M,S>,Offset,LayoutB> const& b)
|
||||
{
|
||||
return max_common_layout(b, a);
|
||||
@ -477,7 +482,7 @@ template <int B, int M, int S, class Offset, class LayoutB, class Shape, class S
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
max_common_vector(ComposedLayout<Swizzle<B,M,S>,Offset,LayoutB> const& a,
|
||||
Layout<Shape,Stride> const& b)
|
||||
Layout<Shape,Stride> const& b)
|
||||
{
|
||||
// This assumes that Offset is in the YZ domain of the Swizzle...
|
||||
return cute::min(Int<(1 << M)>{}, max_common_vector(a.layout_b(), b));
|
||||
@ -486,7 +491,7 @@ max_common_vector(ComposedLayout<Swizzle<B,M,S>,Offset,LayoutB> const& a,
|
||||
template <class Shape, class Stride, int B, int M, int S, class Offset, class LayoutB>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
max_common_vector(Layout<Shape,Stride> const& a,
|
||||
max_common_vector(Layout<Shape,Stride> const& a,
|
||||
ComposedLayout<Swizzle<B,M,S>,Offset,LayoutB> const& b)
|
||||
{
|
||||
return max_common_vector(b, a);
|
||||
@ -517,13 +522,13 @@ template <class Shape, class Stride,
|
||||
int B, int M, int S, class Offset, class LayoutT>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
logical_product(Layout<Shape,Stride> const& block,
|
||||
ComposedLayout<Swizzle<B,M,S>,Offset,LayoutT> const& tile)
|
||||
logical_product(Layout<Shape,Stride> const& layout,
|
||||
ComposedLayout<Swizzle<B,M,S>,Offset,LayoutT> const& tiler)
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(tile.offset() == Int<0>{}, "Require Swizzle offset == 0.");
|
||||
CUTE_STATIC_ASSERT_V(tiler.offset() == Int<0>{}, "Require Swizzle offset == 0.");
|
||||
// The new layout -- if swizzle wasn't an issue, this is the result
|
||||
// our goal is to determine a new swizzle for these strides
|
||||
auto new_layout = logical_product(block, tile.layout_b());
|
||||
auto new_layout = logical_product(layout, tiler.layout_b());
|
||||
|
||||
// This is accomplished by identifying
|
||||
// S o L :=: S? o L*
|
||||
@ -536,8 +541,8 @@ logical_product(Layout<Shape,Stride> const& block,
|
||||
auto swizzle_only_zy = make_layout(make_shape (Int<(1 << M)>{}, Int<(1 << B)>{}, Int<(1 << (abs(S)-B))>{}, Int<(1 << B )>{}, Int<1>{}),
|
||||
make_stride( Int<0>{}, Int<(1 << M)>{}, Int<0>{}, Int<(1 << (M+abs(S)))>{}, Int<0>{}));
|
||||
|
||||
// Compose with the tile to get the swizzle projection, P o L [The Z and Y contributing portions of L]
|
||||
auto layout_only_zy = composition(swizzle_only_zy, tile.layout_b());
|
||||
// Compose with the tiler to get the swizzle projection, P o L [The Z and Y contributing portions of L]
|
||||
auto layout_only_zy = composition(swizzle_only_zy, tiler.layout_b());
|
||||
// Transform the end coordinate to get the active bits of the swizzle, (P o L)(c*)
|
||||
auto swizzle_active_bits = layout_only_zy(size(layout_only_zy)-Int<1>{});
|
||||
// Get the Z bit and the Y bits
|
||||
@ -545,8 +550,8 @@ logical_product(Layout<Shape,Stride> const& block,
|
||||
auto active_Y = swizzle_active_bits & typename Swizzle<B,M,S>::yyy_msk{};
|
||||
|
||||
// Pass the identifiers through the old layout and new layout to make a new swizzle identifier, L*(L[(P o L)(c*)])
|
||||
auto new_active_Z = new_layout(Int<0>{}, tile.layout_b()[active_Z]);
|
||||
auto new_active_Y = new_layout(Int<0>{}, tile.layout_b()[active_Y]);
|
||||
auto new_active_Z = new_layout(Int<0>{}, tiler.layout_b()[active_Z]);
|
||||
auto new_active_Y = new_layout(Int<0>{}, tiler.layout_b()[active_Y]);
|
||||
|
||||
// Use this new swizzle identifier to construxt the new swizzle for new_layout
|
||||
// (this also makes sure it's a "valid" swizzle that Swizzle can represent)
|
||||
|
||||
@ -127,6 +127,18 @@ print(unsigned long long a) {
|
||||
printf("%llu", a);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print(float a) {
|
||||
printf("%f", a);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print(double a) {
|
||||
printf("%f", a);
|
||||
}
|
||||
|
||||
template <class... T>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
|
||||
@ -236,7 +236,7 @@ public:
|
||||
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
"mbarrier.init.shared.b64 [%1], %0; \n"
|
||||
"mbarrier.init.shared::cta.b64 [%1], %0; \n"
|
||||
"}"
|
||||
:
|
||||
: "r"(arrive_count), "r"(smem_addr));
|
||||
@ -256,7 +256,7 @@ public:
|
||||
"{\n\t"
|
||||
".reg .pred P1; \n\t"
|
||||
"LAB_WAIT: \n\t"
|
||||
"mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t"
|
||||
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1, %2; \n\t"
|
||||
"@P1 bra.uni DONE; \n\t"
|
||||
"bra.uni LAB_WAIT; \n\t"
|
||||
"DONE: \n\t"
|
||||
@ -280,7 +280,7 @@ public:
|
||||
".reg .pred P1; \n\t"
|
||||
".reg .pred P2; \n\t"
|
||||
"setp.eq.u32 P2, %3, 1;\n\t"
|
||||
"@P2 mbarrier.test_wait.parity.shared.b64 P1, [%1], %2; \n\t"
|
||||
"@P2 mbarrier.test_wait.parity.shared::cta.b64 P1, [%1], %2; \n\t"
|
||||
"selp.b32 %0, 1, 0, P1; \n\t"
|
||||
"}"
|
||||
: "=r"(waitComplete)
|
||||
@ -302,7 +302,7 @@ public:
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
".reg .pred P1; \n\t"
|
||||
"mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \n\t"
|
||||
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%1], %2; \n\t"
|
||||
"selp.b32 %0, 1, 0, P1; \n\t"
|
||||
"}"
|
||||
: "=r"(waitComplete)
|
||||
@ -342,7 +342,7 @@ public:
|
||||
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
"mbarrier.arrive.shared.b64 _, [%0];\n\t"
|
||||
"mbarrier.arrive.shared::cta.b64 _, [%0];\n\t"
|
||||
"}"
|
||||
:
|
||||
: "r"(smem_addr));
|
||||
@ -357,7 +357,7 @@ public:
|
||||
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
"mbarrier.ival.shared.b64 [%0]; \n\t"
|
||||
"mbarrier.ival.shared::cta.b64 [%0]; \n\t"
|
||||
"}"
|
||||
:
|
||||
: "r"(smem_addr));
|
||||
@ -418,7 +418,7 @@ struct ClusterTransactionBarrier : public ClusterBarrier {
|
||||
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
"mbarrier.arrive.expect_tx.shared.b64 _, [%1], %0; \n\t"
|
||||
"mbarrier.arrive.expect_tx.shared::cta.b64 _, [%1], %0; \n\t"
|
||||
"}"
|
||||
:
|
||||
: "r"(transaction_bytes), "r"(smem_addr));
|
||||
@ -455,7 +455,7 @@ struct ClusterTransactionBarrier : public ClusterBarrier {
|
||||
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
"mbarrier.expect_tx.shared.b64 [%1], %0; \n\t"
|
||||
"mbarrier.expect_tx.shared::cta.b64 [%1], %0; \n\t"
|
||||
"}"
|
||||
:
|
||||
: "r"(transaction_bytes), "r"(smem_addr));
|
||||
@ -563,7 +563,7 @@ void cpasync_barrier_arrive(uint64_t const* smem_ptr) {
|
||||
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
"cp.async.mbarrier.arrive.shared.b64 [%0];\n\t"
|
||||
"cp.async.mbarrier.arrive.shared::cta.b64 [%0];\n\t"
|
||||
"}"
|
||||
:
|
||||
: "r"(smem_addr));
|
||||
|
||||
@ -77,7 +77,7 @@ struct ClusterLauncher {
|
||||
constexpr static int MaxClusterSize = 32;
|
||||
|
||||
// Check for hardware compatibility
|
||||
static inline __host__
|
||||
static inline CUTLASS_HOST
|
||||
Status check_cluster_dims(dim3 grid, dim3 cluster) {
|
||||
if (((cluster.x * cluster.y * cluster.z) <= MaxClusterSize) &&
|
||||
(grid.x % cluster.x == 0) && (grid.y % cluster.y == 0) && (grid.z % cluster.z == 0)) {
|
||||
@ -89,7 +89,7 @@ struct ClusterLauncher {
|
||||
}
|
||||
}
|
||||
|
||||
static inline __host__
|
||||
static inline CUTLASS_HOST
|
||||
Status
|
||||
#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED)
|
||||
init(void const* kernel_function)
|
||||
@ -109,7 +109,7 @@ struct ClusterLauncher {
|
||||
}
|
||||
|
||||
// This is the method we expect to use going forward
|
||||
static inline __host__
|
||||
static inline CUTLASS_HOST
|
||||
Status launch(
|
||||
dim3 const grid_dims,
|
||||
dim3 const cluster_dims,
|
||||
@ -217,7 +217,7 @@ struct ClusterLaunchParams {
|
||||
/// kernel_ptr, x, y, z);
|
||||
/// @endcode
|
||||
template<class ... Args>
|
||||
__host__ cutlass::Status
|
||||
CUTLASS_HOST cutlass::Status
|
||||
launch_kernel_on_cluster(const ClusterLaunchParams& params,
|
||||
void const* kernel_ptr,
|
||||
Args&& ... args)
|
||||
|
||||
@ -81,23 +81,59 @@ struct CudaHostAdapter {
|
||||
void *kernel_handles[kMaximumKernelCount];
|
||||
int32_t kernel_count = 0;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor
|
||||
CudaHostAdapter() = default;
|
||||
|
||||
/// Dtor
|
||||
virtual ~CudaHostAdapter() {}
|
||||
|
||||
/// Copy Ctor deleted
|
||||
CudaHostAdapter(const CudaHostAdapter&) = delete;
|
||||
/// Copy Ctor
|
||||
inline CudaHostAdapter(const CudaHostAdapter & rhs):
|
||||
kernel_count(rhs.kernel_count)
|
||||
{
|
||||
CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount);
|
||||
for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) {
|
||||
kernel_handles[i] = rhs.kernel_handles[i];
|
||||
}
|
||||
}
|
||||
|
||||
/// Copy Assignment deleted
|
||||
CudaHostAdapter& operator=(const CudaHostAdapter&) = delete;
|
||||
/// Copy Assignment
|
||||
inline CudaHostAdapter& operator=(const CudaHostAdapter & rhs) {
|
||||
|
||||
/// Move ctor deleted
|
||||
CudaHostAdapter(CudaHostAdapter&&) = delete;
|
||||
CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount);
|
||||
for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) {
|
||||
kernel_handles[i] = rhs.kernel_handles[i];
|
||||
}
|
||||
kernel_count = rhs.kernel_count;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Move assignment deleted
|
||||
CudaHostAdapter& operator=(CudaHostAdapter&&) = delete;
|
||||
/// Move ctor
|
||||
inline CudaHostAdapter(CudaHostAdapter && rhs):
|
||||
kernel_count(rhs.kernel_count)
|
||||
{
|
||||
CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount);
|
||||
for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) {
|
||||
kernel_handles[i] = rhs.kernel_handles[i];
|
||||
}
|
||||
}
|
||||
|
||||
/// Move assignment
|
||||
inline CudaHostAdapter& operator=(CudaHostAdapter && rhs) {
|
||||
|
||||
CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount);
|
||||
for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) {
|
||||
kernel_handles[i] = rhs.kernel_handles[i];
|
||||
}
|
||||
|
||||
kernel_count = rhs.kernel_count;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Ctor
|
||||
inline CudaHostAdapter(
|
||||
@ -112,13 +148,19 @@ struct CudaHostAdapter {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the CudaHostAdapter is empty (kernel_count == 0)
|
||||
inline bool empty() const { return !kernel_count; }
|
||||
|
||||
/// Returns kernel_count
|
||||
inline size_t size() const { return static_cast<size_t>(kernel_count); }
|
||||
|
||||
/// Queries the occupancy of a kernel
|
||||
virtual Status query_occupancy(
|
||||
int32_t *device_sms,
|
||||
int32_t *sm_occupancy,
|
||||
int32_t kernel_index,
|
||||
int32_t thread_count,
|
||||
int32_t smem_size) = 0;
|
||||
int32_t smem_size) const = 0;
|
||||
|
||||
/// Launches a kernel without using Threadblock Clusters.
|
||||
virtual Status launch(
|
||||
@ -127,7 +169,7 @@ struct CudaHostAdapter {
|
||||
size_t const smem_size,
|
||||
cudaStream_t cuda_stream,
|
||||
void** kernel_params,
|
||||
int32_t kernel_index) = 0;
|
||||
int32_t kernel_index) const = 0;
|
||||
|
||||
/// Launches a kernel using the CUDA Extensible Launch API and Threadblock Clusters.
|
||||
virtual Status launch(
|
||||
@ -137,7 +179,7 @@ struct CudaHostAdapter {
|
||||
size_t const smem_size,
|
||||
cudaStream_t cuda_stream,
|
||||
void** kernel_params,
|
||||
int32_t kernel_index) = 0;
|
||||
int32_t kernel_index) const = 0;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -57,6 +57,9 @@
|
||||
#define CUTLASS_DEVICE inline
|
||||
#endif
|
||||
|
||||
#define CUTLASS_HOST __host__
|
||||
#define CUTLASS_GLOBAL __global__ static
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
|
||||
@ -60,7 +60,7 @@ namespace cutlass {
|
||||
|
||||
/// Generic CUTLASS kernel template.
|
||||
template <typename Operator>
|
||||
__global__
|
||||
CUTLASS_GLOBAL
|
||||
void Kernel(typename Operator::Params params) {
|
||||
// Dynamic shared memory base pointer
|
||||
extern __shared__ int SharedStorageBase[];
|
||||
@ -76,7 +76,7 @@ void Kernel(typename Operator::Params params) {
|
||||
|
||||
/// Generic CUTLASS kernel template.
|
||||
template <typename Operator>
|
||||
__global__
|
||||
CUTLASS_GLOBAL
|
||||
void Kernel2(typename Operator::Params params) {
|
||||
// Dynamic shared memory base pointer
|
||||
extern __shared__ int SharedStorageBase[];
|
||||
@ -96,7 +96,7 @@ void Kernel2(typename Operator::Params params) {
|
||||
|
||||
/// Generic CUTLASS kernel template.
|
||||
template <typename Operator>
|
||||
__global__
|
||||
CUTLASS_GLOBAL
|
||||
#ifdef __CUDACC__
|
||||
// Enclosing this in __CUDACC__ suppresses MSVC warnings.
|
||||
__launch_bounds__(Operator::MaxThreadsPerBlock, Operator::MinBlocksPerMultiprocessor)
|
||||
|
||||
@ -58,7 +58,6 @@ struct FusionOperation {
|
||||
static constexpr int AlignmentScalar = 0;
|
||||
static constexpr bool IsScaleFactorSupported = false;
|
||||
static constexpr bool IsPerRowScaleSupported = false;
|
||||
|
||||
using ElementBias = void;
|
||||
static constexpr int AlignmentBias = 0;
|
||||
static constexpr bool IsPerRowBiasSupported = false;
|
||||
|
||||
@ -240,8 +240,10 @@ public:
|
||||
NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
|
||||
frag_Z = convert_z(result_Z);
|
||||
|
||||
NumericArrayConverter<ElementT, ElementCompute, kElementsPerAccess> convert_t;
|
||||
frag_T = convert_t(result_T);
|
||||
if constexpr (kStoreT) {
|
||||
NumericArrayConverter<ElementT, ElementCompute, kElementsPerAccess> convert_t;
|
||||
frag_T = convert_t(result_T);
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies the operation when is_source_needed() is false
|
||||
@ -269,8 +271,10 @@ public:
|
||||
NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
|
||||
frag_Z = convert_z(result_Z);
|
||||
|
||||
NumericArrayConverter<ElementT, ElementCompute, kElementsPerAccess> convert_t;
|
||||
frag_T = convert_t(result_T);
|
||||
if constexpr (kStoreT) {
|
||||
NumericArrayConverter<ElementT, ElementCompute, kElementsPerAccess> convert_t;
|
||||
frag_T = convert_t(result_T);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -402,7 +402,7 @@ struct OutputTileThreadLayout: DefaultThreadMapTensorOp<
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static auto tid2coord(int thread_idx) {
|
||||
return make_layout(ThreadShape{})[thread_idx];
|
||||
return cute::idx2crd(thread_idx, ThreadShape{});
|
||||
}
|
||||
|
||||
template <class TensorInput>
|
||||
|
||||
@ -1,3 +1,34 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
@ -44,7 +44,6 @@
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
using namespace cute;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
@ -78,7 +77,8 @@ struct CollectiveMma<
|
||||
GmemTiledCopyB_,
|
||||
SmemLayoutAtomB_,
|
||||
SmemCopyAtomB_,
|
||||
TransformB_>
|
||||
TransformB_
|
||||
>
|
||||
{
|
||||
//
|
||||
// Type Aliases
|
||||
@ -286,7 +286,6 @@ struct CollectiveMma<
|
||||
copy(smem_tiled_copy_B, tCsB_p(_,_,Int<0>{}), tCrB_copy_view(_,_,Int<0>{}));
|
||||
}
|
||||
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for ( ; k_tile_count > -(DispatchPolicy::Stages-1); --k_tile_count)
|
||||
{
|
||||
@ -332,6 +331,7 @@ struct CollectiveMma<
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
@ -352,7 +352,8 @@ template <
|
||||
class GmemTiledCopyB_,
|
||||
class SmemLayoutAtomB_,
|
||||
class SmemCopyAtomB_,
|
||||
class TransformB_>
|
||||
class TransformB_
|
||||
>
|
||||
struct CollectiveMma<
|
||||
MainloopSm80CpAsync<Stages>,
|
||||
TileShape_,
|
||||
@ -368,7 +369,8 @@ struct CollectiveMma<
|
||||
GmemTiledCopyB_,
|
||||
SmemLayoutAtomB_,
|
||||
SmemCopyAtomB_,
|
||||
TransformB_>
|
||||
TransformB_
|
||||
>
|
||||
{
|
||||
//
|
||||
// Type Aliases
|
||||
@ -627,7 +629,6 @@ struct CollectiveMma<
|
||||
copy(smem_tiled_copy_B, tCsB_p(_,_,Int<0>{}), tCrB_copy_view(_,_,Int<0>{}));
|
||||
}
|
||||
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for ( ; k_tile_count > -(DispatchPolicy::Stages-1); --k_tile_count)
|
||||
{
|
||||
@ -678,6 +679,7 @@ struct CollectiveMma<
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -353,11 +353,9 @@ struct CollectiveMma<
|
||||
int thread_idx,
|
||||
uint32_t block_rank_in_cluster,
|
||||
TensorStorage& shared_tensors) {
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
int warp_idx_in_warp_group = warp_idx % 4;
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (warp_idx_in_warp_group == 0 and lane_predicate) {
|
||||
if (lane_predicate) {
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
|
||||
@ -433,12 +431,10 @@ struct CollectiveMma<
|
||||
// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
|
||||
CUTLASS_DEVICE void
|
||||
load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) {
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
int warp_idx_in_warp_group = warp_idx % 4;
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Issue the epilogue waits
|
||||
if (warp_idx_in_warp_group == 0 and lane_predicate) {
|
||||
if (lane_predicate) {
|
||||
// This helps avoid early exit of blocks in Cluster.
|
||||
// Waits for all stages to either be released (all
|
||||
// Consumer UNLOCKs), or if the stage was never used
|
||||
|
||||
@ -380,15 +380,10 @@ struct CollectiveMma<
|
||||
KTileIterator k_tile_iter, int k_tile_count,
|
||||
int thread_idx,
|
||||
uint32_t block_rank_in_cluster,
|
||||
TensorStorage& shared_tensors)
|
||||
{
|
||||
|
||||
using namespace cute;
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
int warp_idx_in_warp_group = warp_idx % 4;
|
||||
TensorStorage& shared_tensors) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (warp_idx_in_warp_group == 0 and lane_predicate) {
|
||||
if (lane_predicate) {
|
||||
Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE)
|
||||
@ -464,14 +459,11 @@ struct CollectiveMma<
|
||||
|
||||
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
|
||||
CUTLASS_DEVICE void
|
||||
load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write)
|
||||
{
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
int warp_idx_in_warp_group = warp_idx % 4;
|
||||
load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Issue the epilogue waits
|
||||
if (warp_idx_in_warp_group == 0 and lane_predicate) {
|
||||
if (lane_predicate) {
|
||||
/* This helps avoid early exit of blocks in Cluster
|
||||
* Waits for all stages to either be released (all
|
||||
* Consumer UNLOCKs), or if the stage was never used
|
||||
@ -494,9 +486,7 @@ struct CollectiveMma<
|
||||
int k_tile_count,
|
||||
int thread_idx,
|
||||
TensorStorage& shared_tensors,
|
||||
Params const& mainloop_params)
|
||||
{
|
||||
using namespace cute;
|
||||
Params const& mainloop_params) {
|
||||
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
|
||||
@ -680,9 +680,6 @@ public:
|
||||
int thread_idx,
|
||||
uint32_t block_rank_in_cluster,
|
||||
TensorStorage& shared_tensors) {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
||||
static_assert(sizeof... (Ts) == 2, "Direct convert needs two inputs");
|
||||
}
|
||||
@ -696,11 +693,9 @@ public:
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in TMA load.");
|
||||
}
|
||||
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
int warp_idx_in_warp_group = warp_idx % 4;
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (warp_idx_in_warp_group == 0 and lane_predicate) {
|
||||
if (lane_predicate) {
|
||||
Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE)
|
||||
@ -812,12 +807,10 @@ public:
|
||||
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
|
||||
CUTLASS_DEVICE void
|
||||
load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) {
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
int warp_idx_in_warp_group = warp_idx % 4;
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Issue the epilogue waits
|
||||
if (warp_idx_in_warp_group == 0 and lane_predicate) {
|
||||
if (lane_predicate) {
|
||||
/* This helps avoid early exit of blocks in Cluster
|
||||
* Waits for all stages to either be released (all
|
||||
* Consumer UNLOCKs), or if the stage was never used
|
||||
@ -841,7 +834,6 @@ public:
|
||||
int thread_idx,
|
||||
TensorStorage& shared_tensors,
|
||||
Params const& mainloop_params) {
|
||||
using namespace cute;
|
||||
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
|
||||
@ -111,6 +111,8 @@ struct CollectiveMma<
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
using PipelineState = typename cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||
|
||||
static constexpr int ThreadCount = CUTE_STATIC_V(size(TiledMma{}));
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
@ -300,12 +300,9 @@ struct CollectiveMma<
|
||||
int thread_idx,
|
||||
uint32_t block_rank_in_cluster,
|
||||
TensorStorage& shared_tensors) {
|
||||
using namespace cute;
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
int warp_idx_in_warp_group = warp_idx % 4;
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (warp_idx_in_warp_group == 0 and lane_predicate) {
|
||||
if (lane_predicate) {
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
|
||||
@ -381,12 +378,10 @@ struct CollectiveMma<
|
||||
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
|
||||
CUTLASS_DEVICE void
|
||||
load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) {
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
int warp_idx_in_warp_group = warp_idx % 4;
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Issue the epilogue waits
|
||||
if (warp_idx_in_warp_group == 0 and lane_predicate) {
|
||||
if (lane_predicate) {
|
||||
/* This helps avoid early exit of blocks in Cluster
|
||||
* Waits for all stages to either be released (all
|
||||
* Consumer UNLOCKs), or if the stage was never used
|
||||
@ -410,8 +405,6 @@ struct CollectiveMma<
|
||||
int thread_idx,
|
||||
TensorStorage& shared_tensors,
|
||||
Params const& mainloop_params) {
|
||||
using namespace cute;
|
||||
|
||||
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
|
||||
@ -297,15 +297,10 @@ struct CollectiveMma<
|
||||
KTileIterator k_tile_iter, int k_tile_count,
|
||||
int thread_idx,
|
||||
uint32_t block_rank_in_cluster,
|
||||
TensorStorage& shared_tensors)
|
||||
{
|
||||
|
||||
using namespace cute;
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
int warp_idx_in_warp_group = warp_idx % 4;
|
||||
TensorStorage& shared_tensors) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (warp_idx_in_warp_group == 0 and lane_predicate) {
|
||||
if (lane_predicate) {
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
|
||||
@ -382,14 +377,11 @@ struct CollectiveMma<
|
||||
CUTLASS_DEVICE void
|
||||
load_tail(
|
||||
MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_write)
|
||||
{
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
int warp_idx_in_warp_group = warp_idx % 4;
|
||||
PipelineState smem_pipe_write) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Issue the epilogue waits
|
||||
if (warp_idx_in_warp_group == 0 and lane_predicate) {
|
||||
if (lane_predicate) {
|
||||
/* This helps avoid early exit of blocks in Cluster
|
||||
* Waits for all stages to either be released (all
|
||||
* Consumer UNLOCKs), or if the stage was never used
|
||||
@ -412,9 +404,7 @@ struct CollectiveMma<
|
||||
int k_tile_count,
|
||||
int thread_idx,
|
||||
TensorStorage& shared_tensors,
|
||||
Params const& mainloop_params)
|
||||
{
|
||||
using namespace cute;
|
||||
Params const& mainloop_params) {
|
||||
|
||||
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
|
||||
@ -75,7 +75,7 @@ template <
|
||||
/// Operator class tag
|
||||
typename OperatorClass_ = arch::OpClassSimt,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag_ = arch::Sm70,
|
||||
typename ArchTag_ = arch::Sm80,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
@ -243,7 +243,7 @@ public:
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
|
||||
|
||||
size_t bytes = 0;
|
||||
|
||||
return bytes;
|
||||
@ -271,7 +271,7 @@ public:
|
||||
args.ref_E.non_const_ref(),
|
||||
args.epilogue
|
||||
};
|
||||
|
||||
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
if (smem_size >= (48 << 10)) {
|
||||
cudaError_t result = cudaFuncSetAttribute(Kernel<GemmKernel>,
|
||||
@ -324,9 +324,9 @@ public:
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
@ -339,7 +339,10 @@ public:
|
||||
/// Primary run() entry point API that is static allowing users to create and manage their own params.
|
||||
/// Supplied params struct must be construct by calling GemmKernel::to_underling_arguments()
|
||||
static Status
|
||||
run(Params& params, cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) {
|
||||
run(Params& params,
|
||||
cudaStream_t stream = nullptr,
|
||||
CudaHostAdapter *cuda_adapter = nullptr) {
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::run()");
|
||||
dim3 const block = GemmKernel::get_block_shape();
|
||||
dim3 const grid = get_grid_shape(params);
|
||||
@ -425,7 +428,9 @@ public:
|
||||
cudaStream_t stream = nullptr,
|
||||
CudaHostAdapter *cuda_adapter = nullptr
|
||||
) {
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
Status status = initialize(args, workspace, stream, cuda_adapter);
|
||||
|
||||
if (Status::kSuccess == status) {
|
||||
status = run(params_, stream, cuda_adapter);
|
||||
}
|
||||
@ -444,14 +449,14 @@ public:
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
run(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
run(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) {
|
||||
return run(params_, stream, cuda_adapter);
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
operator()(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) {
|
||||
return run(params_, stream, cuda_adapter);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -70,6 +70,8 @@ class GemmUniversalBase {
|
||||
public:
|
||||
|
||||
using GemmKernel = GemmKernel_;
|
||||
|
||||
/// Boolean indicating whether the CudaHostAdapter is enabled
|
||||
static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER;
|
||||
|
||||
using ThreadblockShape = typename GemmKernel::Mma::Shape;
|
||||
@ -99,6 +101,14 @@ public:
|
||||
/// Argument structure
|
||||
using Arguments = typename GemmKernel::Arguments;
|
||||
|
||||
|
||||
/// Index of the GEMM Kernel within the CudaHostAdapter
|
||||
static int32_t const kGemmKernelIndex = 0;
|
||||
|
||||
/// Kernel dynamic shared memory allocation requirement
|
||||
/// Update the kernel function's shared memory configuration for the current device
|
||||
static constexpr size_t kSharedStorageSize = sizeof(typename GemmKernel::SharedStorage);
|
||||
|
||||
protected:
|
||||
|
||||
//
|
||||
@ -114,9 +124,7 @@ protected:
|
||||
/// Kernel SM occupancy (in thread blocks)
|
||||
CUTLASS_THREAD_LOCAL static int sm_occupancy_;
|
||||
|
||||
/// Kernel dynamic shared memory allocation requirement
|
||||
/// Update the kernel function's shared memory configuration for the current device
|
||||
static constexpr size_t smem_size_ = sizeof(typename GemmKernel::SharedStorage);
|
||||
protected:
|
||||
|
||||
/// Initialize static thread-local members for the thread's current device,
|
||||
/// if necessary.
|
||||
@ -148,12 +156,12 @@ protected:
|
||||
}
|
||||
|
||||
// If requires more than 48KB: configure for extended, dynamic shared memory
|
||||
if constexpr (smem_size_ >= (48 << 10))
|
||||
if constexpr (kSharedStorageSize >= (48 << 10))
|
||||
{
|
||||
cudart_result = cudaFuncSetAttribute(
|
||||
Kernel2<GemmKernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size_);
|
||||
kSharedStorageSize);
|
||||
if (cudart_result != cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(cudart_result));
|
||||
return Status::kErrorInternal;
|
||||
@ -165,7 +173,7 @@ protected:
|
||||
&sm_occupancy_,
|
||||
Kernel2<GemmKernel>,
|
||||
GemmKernel::kThreadCount,
|
||||
smem_size_,
|
||||
kSharedStorageSize,
|
||||
cudaOccupancyDisableCachingOverride);
|
||||
if (cudart_result != cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags() returned error " << cudaGetErrorString(cudart_result));
|
||||
@ -179,7 +187,7 @@ protected:
|
||||
"device_ordinal: (" << device_ordinal_ << "), "
|
||||
"device_sms: (" << device_sms_ << "), "
|
||||
"sm_occupancy: (" << sm_occupancy_ << ") "
|
||||
"smem_size: (" << smem_size_ << ") "
|
||||
"smem_size: (" << kSharedStorageSize << ") "
|
||||
"GemmKernel::kThreadCount: (" << GemmKernel::kThreadCount << ")");
|
||||
|
||||
return Status::kSuccess;
|
||||
@ -197,16 +205,58 @@ protected:
|
||||
|
||||
|
||||
/// Initialize params member
|
||||
Status init_params(Arguments const &args)
|
||||
Status init_params(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr)
|
||||
{
|
||||
// Initialize static device properties, if necessary
|
||||
Status result = init_device_props();
|
||||
if (result != Status::kSuccess) {
|
||||
return result;
|
||||
int32_t device_sms = 0;
|
||||
int32_t sm_occupancy = 0;
|
||||
|
||||
if constexpr (kEnableCudaHostAdapter) {
|
||||
CUTLASS_ASSERT(cuda_adapter);
|
||||
|
||||
//
|
||||
// Occupancy query using CudaHostAdapter::query_occupancy().
|
||||
//
|
||||
|
||||
if (cuda_adapter) {
|
||||
|
||||
Status status = cuda_adapter->query_occupancy(
|
||||
&device_sms,
|
||||
&sm_occupancy,
|
||||
kGemmKernelIndex,
|
||||
GemmKernel::kThreadCount,
|
||||
kSharedStorageSize);
|
||||
|
||||
CUTLASS_ASSERT(status == Status::kSuccess);
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
}
|
||||
else {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
else {
|
||||
CUTLASS_ASSERT(cuda_adapter == nullptr);
|
||||
|
||||
// Initialize static device properties, if necessary
|
||||
Status result = init_device_props();
|
||||
|
||||
if (result != Status::kSuccess) {
|
||||
return result;
|
||||
}
|
||||
|
||||
//
|
||||
// Use thread-local static members for occupancy query initialized by call to
|
||||
// `init_device_props()`
|
||||
//
|
||||
|
||||
device_sms = device_sms_;
|
||||
sm_occupancy = sm_occupancy_;
|
||||
}
|
||||
|
||||
// Initialize params member
|
||||
params_ = typename GemmKernel::Params(args, device_sms_, sm_occupancy_);
|
||||
params_ = typename GemmKernel::Params(args, device_sms, sm_occupancy);
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
@ -217,11 +267,11 @@ public:
|
||||
//---------------------------------------------------------------------------------------------
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const &args)
|
||||
static Status can_implement(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase::can_implement()");
|
||||
|
||||
dim3 grid = get_grid_shape(args);
|
||||
dim3 grid = get_grid_shape(args, cuda_adapter);
|
||||
|
||||
if (!(grid.y <= std::numeric_limits<uint16_t>::max() &&
|
||||
grid.z <= std::numeric_limits<uint16_t>::max()))
|
||||
@ -235,13 +285,13 @@ public:
|
||||
|
||||
/// Returns the workspace size (in bytes) needed for the problem
|
||||
/// geometry expressed by these arguments
|
||||
static size_t get_workspace_size(Arguments const &args)
|
||||
static size_t get_workspace_size(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase::get_workspace_size()");
|
||||
|
||||
// Initialize parameters from args
|
||||
GemmUniversalBase base;
|
||||
if (base.init_params(args) != Status::kSuccess) {
|
||||
if (base.init_params(args, cuda_adapter) != Status::kSuccess) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -254,13 +304,13 @@ public:
|
||||
|
||||
|
||||
/// Returns the grid extents in thread blocks to launch
|
||||
static dim3 get_grid_shape(Arguments const &args)
|
||||
static dim3 get_grid_shape(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase::get_grid_shape()");
|
||||
|
||||
// Initialize parameters from args
|
||||
GemmUniversalBase base;
|
||||
if (base.init_params(args) != Status::kSuccess) {
|
||||
if (base.init_params(args, cuda_adapter) != Status::kSuccess) {
|
||||
return dim3(0,0,0);
|
||||
}
|
||||
|
||||
@ -276,17 +326,48 @@ public:
|
||||
|
||||
|
||||
/// Returns the maximum number of active thread blocks per multiprocessor
|
||||
static int maximum_active_blocks()
|
||||
static int maximum_active_blocks(CudaHostAdapter *cuda_adapter = nullptr)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase::maximum_active_blocks()");
|
||||
|
||||
// Initialize static device properties, if necessary
|
||||
if (init_device_props() != Status::kSuccess) {
|
||||
return -1;
|
||||
int32_t device_sms = 0;
|
||||
int32_t sm_occupancy = 0;
|
||||
|
||||
|
||||
if constexpr (kEnableCudaHostAdapter) {
|
||||
CUTLASS_ASSERT(cuda_adapter);
|
||||
|
||||
if (cuda_adapter) {
|
||||
|
||||
Status status = cuda_adapter->query_occupancy(
|
||||
&device_sms,
|
||||
&sm_occupancy,
|
||||
kGemmKernelIndex,
|
||||
GemmKernel::kThreadCount,
|
||||
kSharedStorageSize);
|
||||
|
||||
CUTLASS_ASSERT(status == Status::kSuccess);
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
else {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
else {
|
||||
CUTLASS_ASSERT(cuda_adapter == nullptr);
|
||||
// Initialize static device properties, if necessary
|
||||
if (init_device_props() != Status::kSuccess) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
sm_occupancy = sm_occupancy_;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << sm_occupancy_);
|
||||
return sm_occupancy_;
|
||||
return sm_occupancy;
|
||||
}
|
||||
|
||||
|
||||
@ -305,7 +386,7 @@ public:
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
// Initialize parameters from args
|
||||
Status result = init_params(args);
|
||||
Status result = init_params(args, cuda_adapter);
|
||||
if (result != Status::kSuccess) {
|
||||
return result;
|
||||
}
|
||||
@ -340,13 +421,13 @@ public:
|
||||
CUTLASS_TRACE_HOST(" "
|
||||
"grid: (" << grid << "), "
|
||||
"block: (" << block << "), "
|
||||
"SMEM: (" << smem_size_ << ")");
|
||||
"SMEM: (" << kSharedStorageSize << ")");
|
||||
|
||||
if constexpr (kEnableCudaHostAdapter) {
|
||||
CUTLASS_ASSERT(cuda_adapter);
|
||||
if (cuda_adapter) {
|
||||
void* kernel_params[] = {¶ms_};
|
||||
return cuda_adapter->launch(grid, block, smem_size_, stream, kernel_params, 0);
|
||||
return cuda_adapter->launch(grid, block, kSharedStorageSize, stream, kernel_params, 0);
|
||||
}
|
||||
else {
|
||||
return Status::kErrorInternal;
|
||||
@ -355,7 +436,7 @@ public:
|
||||
else {
|
||||
CUTLASS_ASSERT(cuda_adapter == nullptr);
|
||||
|
||||
Kernel2<GemmKernel><<<grid, block, smem_size_, stream>>>(params_);
|
||||
Kernel2<GemmKernel><<<grid, block, kSharedStorageSize, stream>>>(params_);
|
||||
|
||||
// Query for errors
|
||||
cudaError_t result = cudaGetLastError();
|
||||
@ -370,9 +451,9 @@ public:
|
||||
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr)
|
||||
Status operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr)
|
||||
{
|
||||
return run(stream);
|
||||
return run(stream, cuda_adapter);
|
||||
}
|
||||
|
||||
|
||||
@ -383,7 +464,7 @@ public:
|
||||
cudaStream_t stream = nullptr,
|
||||
CudaHostAdapter *cuda_adapter = nullptr)
|
||||
{
|
||||
Status status = initialize(args, workspace, stream);
|
||||
Status status = initialize(args, workspace, stream, cuda_adapter);
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream, cuda_adapter);
|
||||
|
||||
@ -195,4 +195,3 @@ struct DefaultSparseGemmWithVisitor<ElementA, LayoutA, kAlignmentA, ElementB, La
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
|
||||
@ -53,7 +53,7 @@ namespace kernel {
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Mma, typename Epilogue, typename ThreadblockSwizzle>
|
||||
__global__ void GemmPipelined(
|
||||
CUTLASS_GLOBAL void GemmPipelined(
|
||||
cutlass::gemm::GemmCoord problem_size,
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape,
|
||||
typename Mma::IteratorA::Params params_A,
|
||||
|
||||
@ -186,7 +186,7 @@ CUTLASS_DEVICE void GemvBatchedStridedDevice(
|
||||
}
|
||||
|
||||
template <typename GemvKernel, typename ElementAlphaBeta, bool BetaIsZero>
|
||||
__global__ void GemvBatchedStrided(
|
||||
CUTLASS_GLOBAL void GemvBatchedStrided(
|
||||
cutlass::gemm::BatchedGemmCoord problem_size,
|
||||
ElementAlphaBeta alpha,
|
||||
ElementAlphaBeta beta,
|
||||
@ -205,7 +205,7 @@ __global__ void GemvBatchedStrided(
|
||||
}
|
||||
|
||||
template <typename GemvKernel, typename ElementAlphaBeta>
|
||||
__global__ void GemvBatchedStrided(
|
||||
CUTLASS_GLOBAL void GemvBatchedStrided(
|
||||
cutlass::gemm::BatchedGemmCoord problem_size,
|
||||
ElementAlphaBeta alpha,
|
||||
typename GemvKernel::IteratorA::TensorRef ref_A,
|
||||
@ -221,7 +221,7 @@ __global__ void GemvBatchedStrided(
|
||||
}
|
||||
|
||||
template <typename GemvKernel>
|
||||
__global__ void GemvBatchedStrided(
|
||||
CUTLASS_GLOBAL void GemvBatchedStrided(
|
||||
cutlass::gemm::BatchedGemmCoord problem_size,
|
||||
typename GemvKernel::IteratorA::TensorRef ref_A,
|
||||
typename GemvKernel::IteratorA::TensorRef::LongIndex lda,
|
||||
|
||||
@ -59,7 +59,6 @@ public:
|
||||
// Type Aliases
|
||||
//
|
||||
using ProblemShape = ProblemShape_;
|
||||
|
||||
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
|
||||
@ -77,13 +76,14 @@ public:
|
||||
using MainloopArguments = typename CollectiveMainloop::Arguments;
|
||||
using MainloopParams = typename CollectiveMainloop::Params;
|
||||
|
||||
static_assert(cute::is_void_v<TileScheduler_> or cute::is_same_v<TileScheduler_, PersistentScheduler>,
|
||||
"SM70 kernel does not support specializing the tile scheduler.");
|
||||
using TileSchedulerTag = TileScheduler_;
|
||||
using TileScheduler = typename detail::TileSchedulerSelector<
|
||||
TileScheduler_, ArchTag, TileShape,
|
||||
cute::Shape<cute::Int<1>, cute::Int<1>, cute::Int<1>>>::Scheduler;
|
||||
using TileSchedulerArguments = typename TileScheduler::Arguments;
|
||||
static constexpr bool is_valid_tile_scheduler =
|
||||
cute::is_void_v<TileScheduler_> or cute::is_same_v<TileScheduler_, PersistentScheduler>;
|
||||
static_assert(is_valid_tile_scheduler, "SM70 kernel does not support specializing the tile scheduler.");
|
||||
|
||||
// Epilogue derived types
|
||||
using CollectiveEpilogue = CollectiveEpilogue_;
|
||||
@ -131,6 +131,10 @@ public:
|
||||
Params
|
||||
to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
(void) workspace;
|
||||
|
||||
KernelHardwareInfo hw_info{args.hw_info.device_id, args.hw_info.sm_count};
|
||||
auto problem_shape_MNKL = append<4>(args.problem_shape, Int<1>{});
|
||||
|
||||
return {
|
||||
args.mode,
|
||||
args.problem_shape,
|
||||
@ -148,13 +152,16 @@ public:
|
||||
|
||||
static int
|
||||
get_workspace_size(Arguments const& args) {
|
||||
return 0;
|
||||
int workspace_size = 0;
|
||||
return workspace_size;
|
||||
}
|
||||
|
||||
static
|
||||
cutlass::Status
|
||||
initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
return Status::kSuccess;
|
||||
cutlass::Status status = Status::kSuccess;
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
static dim3
|
||||
|
||||
@ -45,7 +45,6 @@
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::kernel {
|
||||
@ -74,7 +73,6 @@ public:
|
||||
using ProblemShape = ProblemShape_;
|
||||
static_assert(rank(typename ProblemShape::UnderlyingProblemShape{}) == 3 or rank(typename ProblemShape::UnderlyingProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
|
||||
// Mainloop derived types
|
||||
using CollectiveMainloop = CollectiveMainloop_;
|
||||
using TileShape = typename CollectiveMainloop::TileShape;
|
||||
|
||||
@ -40,7 +40,6 @@
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -82,7 +81,6 @@ public:
|
||||
using ProblemShape = ProblemShape_;
|
||||
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
|
||||
// Mainloop derived types
|
||||
using CollectiveMainloop = CollectiveMainloop_;
|
||||
using TileShape = typename CollectiveMainloop::TileShape;
|
||||
@ -121,7 +119,8 @@ public:
|
||||
sizeof(typename CollectiveMainloop::SharedStorage),
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage)));
|
||||
|
||||
static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{}));
|
||||
static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::ThreadCount;
|
||||
|
||||
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
|
||||
|
||||
// Device side arguments
|
||||
|
||||
@ -44,7 +44,6 @@
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::kernel {
|
||||
@ -71,7 +70,6 @@ public:
|
||||
using ProblemShape = ProblemShape_;
|
||||
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
|
||||
// Mainloop derived types
|
||||
using CollectiveMainloop = CollectiveMainloop_;
|
||||
using TileShape = typename CollectiveMainloop::TileShape;
|
||||
|
||||
@ -44,7 +44,6 @@
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::kernel {
|
||||
@ -71,7 +70,6 @@ public:
|
||||
using ProblemShape = ProblemShape_;
|
||||
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
|
||||
// Mainloop derived types
|
||||
using CollectiveMainloop = CollectiveMainloop_;
|
||||
using TileShape = typename CollectiveMainloop::TileShape;
|
||||
|
||||
@ -45,7 +45,6 @@
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::kernel {
|
||||
@ -72,7 +71,6 @@ public:
|
||||
using ProblemShape = ProblemShape_;
|
||||
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
|
||||
// Mainloop derived types
|
||||
using CollectiveMainloop = CollectiveMainloop_;
|
||||
using TileShape = typename CollectiveMainloop::TileShape;
|
||||
@ -521,10 +519,10 @@ public:
|
||||
shared_storage.tensors.epilogue
|
||||
);
|
||||
|
||||
// Get next work tile
|
||||
scheduler.advance_to_next_work();
|
||||
work_tile_info = scheduler.get_current_work();
|
||||
} // Scheduler work fetch loop
|
||||
// Get next work tile
|
||||
scheduler.advance_to_next_work();
|
||||
work_tile_info = scheduler.get_current_work();
|
||||
} // Scheduler work fetch loop
|
||||
|
||||
// Make sure all Consumer Warp Groups have been waited upon
|
||||
collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state);
|
||||
|
||||
@ -42,7 +42,6 @@
|
||||
#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::kernel {
|
||||
@ -69,7 +68,6 @@ public:
|
||||
using ProblemShape = ProblemShape_;
|
||||
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
|
||||
// Mainloop derived types
|
||||
using CollectiveMainloop = CollectiveMainloop_;
|
||||
using TileShape = typename CollectiveMainloop::TileShape;
|
||||
|
||||
@ -44,7 +44,6 @@
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::kernel {
|
||||
|
||||
@ -29,165 +29,24 @@
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
#include "cutlass/gemm/kernel/static_tile_scheduler.hpp"
|
||||
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm_coord.hpp"
|
||||
#include "cutlass/kernel_hardware_info.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
#include "cute/layout.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
|
||||
namespace cutlass::gemm::kernel::detail {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Persistent Thread Block (TB) scheduler
|
||||
class PersistentTileSchedulerSm90 {
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
private:
|
||||
uint64_t current_work_linear_idx_;
|
||||
uint64_t total_grid_size_;
|
||||
class PersistentTileSchedulerSm90:
|
||||
public StaticPersistentTileScheduler<PersistentTileSchedulerSm90> {
|
||||
|
||||
using BaseScheduler = StaticPersistentTileScheduler<PersistentTileSchedulerSm90>;
|
||||
public:
|
||||
struct WorkTileInfo {
|
||||
int32_t M_idx = 0;
|
||||
int32_t N_idx = 0;
|
||||
int32_t L_idx = 0;
|
||||
bool is_valid_tile = false;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool
|
||||
is_valid() const {
|
||||
return is_valid_tile;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static WorkTileInfo
|
||||
invalid_work_tile() {
|
||||
return {-1, -1, -1, false};
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool
|
||||
is_final_split(uint32_t k_tiles_per_output_tile) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
int32_t
|
||||
reduction_subtile_idx() const {
|
||||
return -1;
|
||||
}
|
||||
};
|
||||
|
||||
using StaticPersistentTileScheduler::StaticPersistentTileScheduler;
|
||||
using Params = PersistentTileSchedulerSm90Params;
|
||||
using RasterOrder = typename Params::RasterOrder;
|
||||
using RasterOrderOptions = typename Params::RasterOrderOptions;
|
||||
struct Arguments {
|
||||
int max_swizzle_size = 1;
|
||||
RasterOrderOptions raster_order = RasterOrderOptions::Heuristic;
|
||||
};
|
||||
|
||||
// Sink scheduler params as a member
|
||||
Params scheduler_params;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
template <class ProblemShapeMNKL, class TileShape, class ClusterShape>
|
||||
static Params
|
||||
to_underlying_arguments(
|
||||
ProblemShapeMNKL problem_shape_mnkl,
|
||||
TileShape tile_shape,
|
||||
ClusterShape cluster_shape,
|
||||
[[maybe_unused]] KernelHardwareInfo const& hw_info,
|
||||
Arguments const& arguments,
|
||||
[[maybe_unused]] void* workspace=nullptr,
|
||||
[[maybe_unused]] const uint32_t epilogue_subtile = 1) {
|
||||
|
||||
// We only need the tile and cluster shape during scheduler setup, so let FTAD do the magic
|
||||
static_assert(cute::is_static<TileShape>::value);
|
||||
static_assert(cute::is_static<ClusterShape>::value);
|
||||
|
||||
dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cluster_shape);
|
||||
|
||||
Params params;
|
||||
params.initialize(
|
||||
problem_blocks,
|
||||
to_gemm_coord(cluster_shape),
|
||||
hw_info,
|
||||
arguments.max_swizzle_size,
|
||||
arguments.raster_order
|
||||
);
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static bool
|
||||
can_implement(Arguments const& args) {
|
||||
return true;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
PersistentTileSchedulerSm90() { };
|
||||
|
||||
CUTLASS_DEVICE explicit PersistentTileSchedulerSm90(Params const& params_) : scheduler_params(params_) {
|
||||
// MSVC requires protecting use of CUDA-specific nonstandard syntax,
|
||||
// like blockIdx and gridDim, with __CUDA_ARCH__.
|
||||
#if defined(__CUDA_ARCH__)
|
||||
if (params_.raster_order_ == RasterOrder::AlongN) {
|
||||
current_work_linear_idx_ = uint64_t(blockIdx.x) + uint64_t(blockIdx.y) * uint64_t(gridDim.x);
|
||||
}
|
||||
else {
|
||||
current_work_linear_idx_ = uint64_t(blockIdx.x) * uint64_t(gridDim.y) + uint64_t(blockIdx.y);
|
||||
}
|
||||
|
||||
total_grid_size_ = uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z);
|
||||
#else
|
||||
CUTLASS_ASSERT(false && "This line should never be reached");
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
WorkTileInfo
|
||||
get_current_work() const {
|
||||
return get_current_work_for_linear_idx(current_work_linear_idx_);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
WorkTileInfo
|
||||
get_current_work_for_linear_idx(uint64_t linear_idx) const {
|
||||
if (linear_idx >= scheduler_params.blocks_per_problem_) {
|
||||
return WorkTileInfo::invalid_work_tile();
|
||||
}
|
||||
|
||||
// Map worker's linear index into the CTA tiled problem shape to the corresponding MNL indices
|
||||
uint64_t work_idx_l, remainder;
|
||||
scheduler_params.divmod_batch_(work_idx_l, remainder, linear_idx);
|
||||
|
||||
uint64_t blk_per_grid_dim = scheduler_params.divmod_cluster_shape_minor_.divide(remainder);
|
||||
|
||||
auto [work_idx_m, work_idx_n] = get_work_idx_m_and_n(blk_per_grid_dim,
|
||||
scheduler_params.divmod_cluster_shape_major_,
|
||||
scheduler_params.divmod_cluster_shape_minor_,
|
||||
scheduler_params.divmod_cluster_blk_major_,
|
||||
scheduler_params.log_swizzle_size_,
|
||||
scheduler_params.raster_order_);
|
||||
|
||||
return {work_idx_m, work_idx_n, static_cast<int32_t>(work_idx_l), true};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
advance_to_next_work(uint32_t advance_count = 1) {
|
||||
current_work_linear_idx_ += total_grid_size_ * uint64_t(advance_count);
|
||||
}
|
||||
using Arguments = BaseScheduler::Arguments;
|
||||
|
||||
// get work_idx_m, work_idx_n from blk_per_grid_dim while applying swizzle
|
||||
static CUTLASS_DEVICE
|
||||
@ -236,111 +95,6 @@ public:
|
||||
|
||||
}
|
||||
|
||||
// Computes the linear index within a batch given M and N tile offsets within the batch.
|
||||
// This essentially inverts the mapping performed in get_work_idx_m_and_n
|
||||
static CUTLASS_DEVICE
|
||||
uint64_t
|
||||
get_linear_idx_from_m_and_n(
|
||||
int32_t tile_m,
|
||||
int32_t tile_n,
|
||||
FastDivmodU64Pow2 const& divmod_cluster_shape_major,
|
||||
FastDivmodU64Pow2 const& divmod_cluster_shape_minor,
|
||||
FastDivmodU64 const& divmod_cluster_blk_major,
|
||||
int32_t log_swizzle_size,
|
||||
RasterOrder raster_order) {
|
||||
|
||||
auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster();
|
||||
|
||||
uint64_t minor_work_idx, major_work_idx, cluster_minor_offset;
|
||||
if (raster_order == RasterOrder::AlongN) {
|
||||
minor_work_idx = static_cast<uint64_t>(tile_m);
|
||||
major_work_idx = static_cast<uint64_t>(tile_n);
|
||||
cluster_minor_offset = cta_m_in_cluster;
|
||||
}
|
||||
else {
|
||||
major_work_idx = static_cast<uint64_t>(tile_m);
|
||||
minor_work_idx = static_cast<uint64_t>(tile_n);
|
||||
cluster_minor_offset = cta_n_in_cluster;
|
||||
}
|
||||
|
||||
uint64_t cluster_idx_minor, cluster_idx_major, cluster_major_offset;
|
||||
cluster_idx_minor = divmod_cluster_shape_minor.divide(minor_work_idx - cluster_minor_offset);
|
||||
divmod_cluster_shape_major(cluster_idx_major, cluster_major_offset, major_work_idx);
|
||||
|
||||
uint64_t cluster_idx_minor_div_swizzle = cluster_idx_minor >> log_swizzle_size;
|
||||
uint64_t offset = cluster_idx_minor & ((1 << log_swizzle_size) - 1);
|
||||
|
||||
uint64_t extra = cluster_idx_minor_div_swizzle * divmod_cluster_blk_major.divisor + cluster_idx_major;
|
||||
|
||||
uint64_t cluster_id = (extra << log_swizzle_size) | offset;
|
||||
return (cluster_id * divmod_cluster_shape_major.divisor + cluster_major_offset) * divmod_cluster_shape_minor.divisor + cluster_minor_offset;
|
||||
}
|
||||
|
||||
// Given the inputs, computes the total number of output blocks this problem will compute over
|
||||
// Note that this is only the logical size of our grid, not the physical grid we will actually launch.
|
||||
template<class ProblemShapeMNKL, class BlockShape, class ClusterShape>
|
||||
CUTLASS_HOST_DEVICE static
|
||||
dim3
|
||||
get_tiled_cta_shape_mnl(ProblemShapeMNKL problem_shape_mnkl, BlockShape cta_shape, ClusterShape cluster_shape) {
|
||||
auto cta_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shape_mnkl), cute::shape<0>(cta_shape)));
|
||||
auto cta_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shape_mnkl), cute::shape<1>(cta_shape)));
|
||||
|
||||
return Params::get_tiled_cta_shape_mnl(
|
||||
to_gemm_coord(problem_shape_mnkl),
|
||||
to_gemm_coord(cluster_shape),
|
||||
cta_m, cta_n
|
||||
);
|
||||
}
|
||||
|
||||
// Given the inputs, computes the physical grid we should launch.
|
||||
template<class ProblemShapeMNKL, class BlockShape, class ClusterShape>
|
||||
CUTLASS_HOST_DEVICE static
|
||||
dim3
|
||||
get_grid_shape(
|
||||
ProblemShapeMNKL problem_shape_mnk,
|
||||
BlockShape cta_shape,
|
||||
ClusterShape cluster_shape,
|
||||
KernelHardwareInfo hw_info,
|
||||
Arguments arguments,
|
||||
bool truncate_by_problem_size=true) {
|
||||
|
||||
auto problem_shape_mnkl = cute::append<4>(problem_shape_mnk, cute::Int<1>{});
|
||||
dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, cta_shape, cluster_shape);
|
||||
|
||||
return Params::get_grid_shape(
|
||||
problem_blocks,
|
||||
to_gemm_coord(cluster_shape),
|
||||
hw_info,
|
||||
arguments.max_swizzle_size,
|
||||
arguments.raster_order,
|
||||
/* truncate_by_problem_size = */true
|
||||
);
|
||||
}
|
||||
|
||||
// Returns whether the block assigned this work should compute the epilogue for the corresponding
|
||||
// output tile. For the basic tile scheduler, this is always true.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static bool
|
||||
compute_epilogue(WorkTileInfo const&, Params const&) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Performs the reduction across splits for a given output tile. Since this scheduler does
|
||||
// not split output tiles, no reduction is needed.
|
||||
template <class FrgTensorC>
|
||||
CUTLASS_DEVICE
|
||||
static void
|
||||
fixup(Params const&, WorkTileInfo const&, FrgTensorC&, uint32_t, uint32_t) {}
|
||||
|
||||
// Returns whether the current WorkTileInfo passed in should continue to be used. Since
|
||||
// this scheduler only schedules work in units of single, full output tiles, the WorkTileInfo
|
||||
// passed in should not be used after having been processed.
|
||||
CUTLASS_DEVICE
|
||||
static bool
|
||||
continue_current_work(WorkTileInfo&) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// The basic tile scheduler does not require any additional workspace
|
||||
template <class ProblemShape, class ElementAccumulator>
|
||||
static int
|
||||
@ -355,74 +109,6 @@ public:
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
template <class ProblemShape, class TileShape>
|
||||
CUTLASS_HOST_DEVICE
|
||||
static int
|
||||
get_work_k_tile_count(WorkTileInfo const& work_tile_info, ProblemShape problem_shape, TileShape tile_shape) {
|
||||
// All work units returned by this scheduler cover the entire K iteration
|
||||
// space of the output tile assigned to the work unit.
|
||||
return cute::size(cute::ceil_div(cute::get<2>(problem_shape), cute::get<2>(tile_shape)));
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static uint32_t
|
||||
get_work_k_tile_start(WorkTileInfo const&) {
|
||||
// All work units returned by this scheduler start from K tile 0
|
||||
return 0u;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static bool
|
||||
need_separate_reduction(Params const& params) {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool
|
||||
is_work_tile_for_reduction(WorkTileInfo const& work_tile_info, Params const& params) {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
uint32_t
|
||||
epilgoue_subtile_idx(WorkTileInfo const& work_tile_info, Params const& params) const {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class FrgTensorC>
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
separate_reduction(
|
||||
Params const& params,
|
||||
WorkTileInfo const& work_tile_info,
|
||||
FrgTensorC& accumulators,
|
||||
uint32_t num_barriers,
|
||||
uint32_t barrier_idx) {
|
||||
}
|
||||
|
||||
// Shares the accumulator set with peers in the global workspace
|
||||
template <class FrgTensorC>
|
||||
CUTLASS_DEVICE
|
||||
static void
|
||||
share(
|
||||
Params const& params,
|
||||
WorkTileInfo const& work_tile_info,
|
||||
FrgTensorC& accumulators,
|
||||
uint32_t num_barriers,
|
||||
uint32_t barrier_idx) {
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static bool
|
||||
valid_warpgroup_in_work_tile(WorkTileInfo const& work_tile_info) {
|
||||
return true;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static bool
|
||||
requires_separate_reduction(Params const& params) {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::gemm::kernel::detail
|
||||
}
|
||||
|
||||
@ -94,6 +94,7 @@ struct SparseGemm {
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
typename Epilogue::OutputTileIterator::Params params_C;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C;
|
||||
typename Epilogue::OutputTileIterator::Params params_D;
|
||||
@ -125,8 +126,8 @@ struct SparseGemm {
|
||||
ref_C(ref_C),
|
||||
params_D(ref_D.layout()),
|
||||
ref_D(ref_D),
|
||||
output_op(output_op),
|
||||
semaphore(workspace) {
|
||||
output_op(output_op) {
|
||||
semaphore = workspace;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
453
include/cutlass/gemm/kernel/static_tile_scheduler.hpp
Normal file
453
include/cutlass/gemm/kernel/static_tile_scheduler.hpp
Normal file
@ -0,0 +1,453 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm_coord.hpp"
|
||||
#include "cutlass/kernel_hardware_info.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
#include "cute/layout.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
namespace cutlass::gemm::kernel::detail {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Users are not supposed to use this class directly.
|
||||
// This is a CRTP base class for the actual tile schedulers.
|
||||
template<class Subclass>
|
||||
class StaticPersistentTileScheduler {
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
private:
|
||||
uint64_t current_work_linear_idx_;
|
||||
uint64_t total_grid_size_;
|
||||
|
||||
public:
|
||||
struct WorkTileInfo {
|
||||
int32_t M_idx = 0;
|
||||
int32_t N_idx = 0;
|
||||
int32_t L_idx = 0;
|
||||
bool is_valid_tile = false;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool
|
||||
is_valid() const {
|
||||
return is_valid_tile;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static WorkTileInfo
|
||||
invalid_work_tile() {
|
||||
return {-1, -1, -1, false};
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool
|
||||
is_final_split(uint32_t k_tiles_per_output_tile) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
int32_t
|
||||
reduction_subtile_idx() const {
|
||||
return -1;
|
||||
}
|
||||
};
|
||||
|
||||
using Params = PersistentTileSchedulerSm90Params;
|
||||
using RasterOrder = typename Params::RasterOrder;
|
||||
using RasterOrderOptions = typename Params::RasterOrderOptions;
|
||||
public:
|
||||
struct Arguments {
|
||||
int max_swizzle_size = 1;
|
||||
RasterOrderOptions raster_order = RasterOrderOptions::Heuristic;
|
||||
};
|
||||
|
||||
template <class ProblemShapeMNKL, class TileShape, class ClusterShape>
|
||||
static Params
|
||||
to_underlying_arguments(
|
||||
ProblemShapeMNKL problem_shape_mnkl,
|
||||
TileShape tile_shape,
|
||||
ClusterShape cluster_shape,
|
||||
[[maybe_unused]] KernelHardwareInfo const& hw_info,
|
||||
Arguments const& arguments,
|
||||
[[maybe_unused]] void* workspace=nullptr,
|
||||
[[maybe_unused]] const uint32_t epilogue_subtile = 1) {
|
||||
|
||||
// We only need the tile and cluster shape during scheduler setup, so let FTAD do the magic
|
||||
static_assert(cute::is_static<TileShape>::value);
|
||||
static_assert(cute::is_static<ClusterShape>::value);
|
||||
|
||||
dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cluster_shape);
|
||||
|
||||
Params params;
|
||||
params.initialize(
|
||||
problem_blocks,
|
||||
to_gemm_coord(cluster_shape),
|
||||
hw_info,
|
||||
arguments.max_swizzle_size,
|
||||
arguments.raster_order
|
||||
);
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static bool
|
||||
can_implement(Arguments const& args) {
|
||||
return true;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
StaticPersistentTileScheduler() { }
|
||||
|
||||
CUTLASS_DEVICE explicit StaticPersistentTileScheduler(Params const& params_) : scheduler_params(params_) {
|
||||
// MSVC requires protecting use of CUDA-specific nonstandard syntax,
|
||||
// like blockIdx and gridDim, with __CUDA_ARCH__.
|
||||
#if defined(__CUDA_ARCH__)
|
||||
if (params_.raster_order_ == RasterOrder::AlongN) {
|
||||
current_work_linear_idx_ = uint64_t(blockIdx.x) + uint64_t(blockIdx.y) * uint64_t(gridDim.x);
|
||||
}
|
||||
else {
|
||||
current_work_linear_idx_ = uint64_t(blockIdx.x) * uint64_t(gridDim.y) + uint64_t(blockIdx.y);
|
||||
}
|
||||
|
||||
total_grid_size_ = uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z);
|
||||
#else
|
||||
CUTLASS_ASSERT(false && "This line should never be reached");
|
||||
#endif
|
||||
}
|
||||
|
||||
// Returns the initial work tile info that will be computed over
|
||||
template <class ClusterShape>
|
||||
CUTLASS_DEVICE
|
||||
WorkTileInfo
|
||||
initial_work_tile_info(ClusterShape cluster_shape) {
|
||||
return get_current_work();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
WorkTileInfo
|
||||
get_current_work() const {
|
||||
return get_current_work_for_linear_idx(current_work_linear_idx_);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
WorkTileInfo
|
||||
get_current_work_for_linear_idx(uint64_t linear_idx) const {
|
||||
if (linear_idx >= scheduler_params.blocks_per_problem_) {
|
||||
return WorkTileInfo::invalid_work_tile();
|
||||
}
|
||||
|
||||
// Map worker's linear index into the CTA tiled problem shape to the corresponding MNL indices
|
||||
uint64_t work_idx_l, remainder;
|
||||
scheduler_params.divmod_batch_(work_idx_l, remainder, linear_idx);
|
||||
|
||||
uint64_t blk_per_grid_dim = scheduler_params.divmod_cluster_shape_minor_.divide(remainder);
|
||||
|
||||
auto [work_idx_m, work_idx_n] = Subclass::get_work_idx_m_and_n(blk_per_grid_dim,
|
||||
scheduler_params.divmod_cluster_shape_major_,
|
||||
scheduler_params.divmod_cluster_shape_minor_,
|
||||
scheduler_params.divmod_cluster_blk_major_,
|
||||
scheduler_params.log_swizzle_size_,
|
||||
scheduler_params.raster_order_);
|
||||
|
||||
return {work_idx_m, work_idx_n, static_cast<int32_t>(work_idx_l), true};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
advance_to_next_work(uint32_t advance_count = 1) {
|
||||
current_work_linear_idx_ += total_grid_size_ * uint64_t(advance_count);
|
||||
}
|
||||
|
||||
// Computes the linear index within a batch given M and N tile offsets within the batch.
|
||||
// This essentially inverts the mapping performed in get_work_idx_m_and_n
|
||||
static CUTLASS_DEVICE
|
||||
uint64_t
|
||||
get_linear_idx_from_m_and_n(
|
||||
int32_t tile_m,
|
||||
int32_t tile_n,
|
||||
FastDivmodU64Pow2 const& divmod_cluster_shape_major,
|
||||
FastDivmodU64Pow2 const& divmod_cluster_shape_minor,
|
||||
FastDivmodU64 const& divmod_cluster_blk_major,
|
||||
int32_t log_swizzle_size,
|
||||
RasterOrder raster_order) {
|
||||
|
||||
auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster();
|
||||
|
||||
uint64_t minor_work_idx, major_work_idx, cluster_minor_offset;
|
||||
if (raster_order == RasterOrder::AlongN) {
|
||||
minor_work_idx = static_cast<uint64_t>(tile_m);
|
||||
major_work_idx = static_cast<uint64_t>(tile_n);
|
||||
cluster_minor_offset = cta_m_in_cluster;
|
||||
}
|
||||
else {
|
||||
major_work_idx = static_cast<uint64_t>(tile_m);
|
||||
minor_work_idx = static_cast<uint64_t>(tile_n);
|
||||
cluster_minor_offset = cta_n_in_cluster;
|
||||
}
|
||||
|
||||
uint64_t cluster_idx_minor, cluster_idx_major, cluster_major_offset;
|
||||
cluster_idx_minor = divmod_cluster_shape_minor.divide(minor_work_idx - cluster_minor_offset);
|
||||
divmod_cluster_shape_major(cluster_idx_major, cluster_major_offset, major_work_idx);
|
||||
|
||||
uint64_t cluster_idx_minor_div_swizzle = cluster_idx_minor >> log_swizzle_size;
|
||||
uint64_t offset = cluster_idx_minor & ((1 << log_swizzle_size) - 1);
|
||||
|
||||
uint64_t extra = cluster_idx_minor_div_swizzle * divmod_cluster_blk_major.divisor + cluster_idx_major;
|
||||
|
||||
uint64_t cluster_id = (extra << log_swizzle_size) | offset;
|
||||
return (cluster_id * divmod_cluster_shape_major.divisor + cluster_major_offset) * divmod_cluster_shape_minor.divisor + cluster_minor_offset;
|
||||
}
|
||||
|
||||
// Given the inputs, computes the total number of output blocks over which this problem will compute.
|
||||
// Note that this is only the logical size of our grid, not the physical grid we will actually launch.
|
||||
template<class ProblemShapeMNKL, class BlockShape, class ClusterShape>
|
||||
CUTLASS_HOST_DEVICE static
|
||||
dim3
|
||||
get_tiled_cta_shape_mnl(ProblemShapeMNKL problem_shape_mnkl, BlockShape cta_shape, ClusterShape cluster_shape) {
|
||||
auto cta_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shape_mnkl), cute::shape<0>(cta_shape)));
|
||||
auto cta_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shape_mnkl), cute::shape<1>(cta_shape)));
|
||||
|
||||
return Params::get_tiled_cta_shape_mnl(
|
||||
to_gemm_coord(problem_shape_mnkl),
|
||||
to_gemm_coord(cluster_shape),
|
||||
cta_m, cta_n
|
||||
);
|
||||
}
|
||||
// Kernel helper function to get next work ID
|
||||
template <class WorkIdPipeline, class WorkIdPipelineState>
|
||||
CUTLASS_DEVICE
|
||||
auto
|
||||
fetch_next_work(
|
||||
WorkTileInfo work_tile_info,
|
||||
WorkIdPipeline& work_id_pipeline,
|
||||
WorkIdPipelineState work_id_pipe_consumer_state) {
|
||||
WorkTileInfo new_work_tile_info;
|
||||
advance_to_next_work();
|
||||
new_work_tile_info = get_current_work();
|
||||
|
||||
// Return true to indicate that the WorkID pipeline state should be advanced
|
||||
return cute::make_tuple(new_work_tile_info, true);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static auto
|
||||
work_tile_to_cta_coord(WorkTileInfo work_tile_info) {
|
||||
// Get every cta coord in three dimensions of the cluster
|
||||
auto [cta_m_in_cluster, cta_n_in_cluster, cta_l_in_cluster] = cute::block_id_in_cluster();
|
||||
return make_coord(
|
||||
work_tile_info.M_idx + static_cast<int32_t>(cta_m_in_cluster),
|
||||
work_tile_info.N_idx + static_cast<int32_t>(cta_n_in_cluster),
|
||||
_,
|
||||
work_tile_info.L_idx + static_cast<int32_t>(cta_l_in_cluster)
|
||||
);
|
||||
}
|
||||
|
||||
// Given the inputs, computes the physical grid we should launch.
|
||||
template<class ProblemShapeMNKL, class BlockShape, class ClusterShape>
|
||||
CUTLASS_HOST_DEVICE static
|
||||
dim3
|
||||
get_grid_shape(
|
||||
ProblemShapeMNKL problem_shape_mnk,
|
||||
BlockShape cta_shape,
|
||||
ClusterShape cluster_shape,
|
||||
KernelHardwareInfo hw_info,
|
||||
Arguments arguments,
|
||||
bool truncate_by_problem_size=true) {
|
||||
|
||||
auto problem_shape_mnkl = cute::append<4>(problem_shape_mnk, cute::Int<1>{});
|
||||
dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, cta_shape, cluster_shape);
|
||||
|
||||
return Params::get_grid_shape(
|
||||
problem_blocks,
|
||||
to_gemm_coord(cluster_shape),
|
||||
hw_info,
|
||||
arguments.max_swizzle_size,
|
||||
arguments.raster_order,
|
||||
/* truncate_by_problem_size = */true
|
||||
);
|
||||
}
|
||||
|
||||
// Given the inputs, computes the physical grid we should launch.
|
||||
template<class ProblemShapeMNKL, class BlockShape, class ClusterShape>
|
||||
CUTLASS_HOST_DEVICE static
|
||||
dim3
|
||||
get_grid_shape(
|
||||
Params const& params,
|
||||
ProblemShapeMNKL problem_shape_mnk,
|
||||
BlockShape cta_shape,
|
||||
ClusterShape cluster_shape,
|
||||
KernelHardwareInfo hw_info) {
|
||||
|
||||
auto problem_shape_mnkl = cute::append<4>(problem_shape_mnk, cute::Int<1>{});
|
||||
dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, cta_shape, cluster_shape);
|
||||
|
||||
Arguments args{};
|
||||
if constexpr (!std::is_const_v<decltype(args.max_swizzle_size)>) {
|
||||
args.max_swizzle_size = 1 << params.log_swizzle_size_;
|
||||
}
|
||||
args.raster_order = params.raster_order_ == RasterOrder::AlongN ? RasterOrderOptions::AlongN : RasterOrderOptions::AlongM;
|
||||
|
||||
return Params::get_grid_shape(
|
||||
problem_blocks,
|
||||
to_gemm_coord(cluster_shape),
|
||||
hw_info,
|
||||
args.max_swizzle_size,
|
||||
args.raster_order,
|
||||
/* truncate_by_problem_size = */true
|
||||
);
|
||||
}
|
||||
|
||||
// Convert CTA-level work tile info to cluster-level tile coord
|
||||
CUTLASS_DEVICE
|
||||
cute::Coord<int,int,int,int>
|
||||
tile_info_to_coord_mnkl(WorkTileInfo work_tile_info) const {
|
||||
// TileScheduler works at CTA-level, kernel works at cluster-level
|
||||
int m_coord = idx2crd(work_tile_info.M_idx / scheduler_params.cluster_shape_m_,
|
||||
scheduler_params.problem_tiles_m_);
|
||||
int n_coord = idx2crd(work_tile_info.N_idx / scheduler_params.cluster_shape_n_,
|
||||
scheduler_params.problem_tiles_n_);
|
||||
int l_coord = idx2crd(work_tile_info.L_idx,
|
||||
scheduler_params.problem_tiles_l_);
|
||||
return make_coord(m_coord, n_coord, _, l_coord);
|
||||
}
|
||||
|
||||
// Returns whether the block assigned this work should compute the epilogue for the corresponding
|
||||
// output tile. For the basic tile scheduler, this is always true.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static bool
|
||||
compute_epilogue(WorkTileInfo const&, Params const&) {
|
||||
return true;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static bool
|
||||
compute_epilogue(WorkTileInfo const&) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Performs the reduction across splits for a given output tile. Since this scheduler does
|
||||
// not split output tiles, no reduction is needed.
|
||||
template <class FrgTensorC>
|
||||
CUTLASS_DEVICE
|
||||
static void
|
||||
fixup(Params const&, WorkTileInfo const&, FrgTensorC&, uint32_t, uint32_t) {}
|
||||
|
||||
// Performs the reduction across splits for a given output tile. No fixup is required for
|
||||
// work units returned by this scheduler.
|
||||
template <class FrgTensorC>
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
fixup(WorkTileInfo const&, FrgTensorC&, uint32_t, uint32_t) const { }
|
||||
|
||||
// Returns whether the current WorkTileInfo passed in should continue to be used. Since
|
||||
// this scheduler only schedules work in units of single, full output tiles, the WorkTileInfo
|
||||
// passed in should not be used after having been processed.
|
||||
CUTLASS_DEVICE
|
||||
static bool
|
||||
continue_current_work(WorkTileInfo&) {
|
||||
return false;
|
||||
}
|
||||
|
||||
template <class ProblemShape, class TileShape>
|
||||
CUTLASS_HOST_DEVICE
|
||||
static int
|
||||
get_work_k_tile_count(WorkTileInfo const& work_tile_info, ProblemShape problem_shape, TileShape tile_shape) {
|
||||
// All work units returned by this scheduler cover the entire K iteration
|
||||
// space of the output tile assigned to the work unit.
|
||||
return cute::size(cute::ceil_div(cute::get<2>(problem_shape), cute::get<2>(tile_shape)));
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static uint32_t
|
||||
get_work_k_tile_start(WorkTileInfo const&) {
|
||||
// All work units returned by this scheduler start from K tile 0
|
||||
return 0u;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static bool
|
||||
need_separate_reduction(Params const& params) {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool
|
||||
is_work_tile_for_reduction(WorkTileInfo const& work_tile_info, Params const& params) {
|
||||
return false;
|
||||
}
|
||||
|
||||
template <class FrgTensorC>
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
separate_reduction(
|
||||
Params const& params,
|
||||
WorkTileInfo const& work_tile_info,
|
||||
FrgTensorC& accumulators,
|
||||
uint32_t num_barriers,
|
||||
uint32_t barrier_idx) {
|
||||
}
|
||||
|
||||
// Shares the accumulator set with peers in the global workspace
|
||||
template <class FrgTensorC>
|
||||
CUTLASS_DEVICE
|
||||
static void
|
||||
share(
|
||||
Params const& params,
|
||||
WorkTileInfo const& work_tile_info,
|
||||
FrgTensorC& accumulators,
|
||||
uint32_t num_barriers,
|
||||
uint32_t barrier_idx) {
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static bool
|
||||
valid_warpgroup_in_work_tile(WorkTileInfo const& work_tile_info) {
|
||||
return true;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static bool
|
||||
requires_separate_reduction(Params const& params) {
|
||||
return false;
|
||||
}
|
||||
public:
|
||||
// Sink scheduler params as a member
|
||||
Params scheduler_params;
|
||||
};
|
||||
|
||||
} // namespace cutlass::gemm::kernel::detail
|
||||
@ -87,6 +87,12 @@ struct PersistentTileSchedulerSm90Params {
|
||||
int32_t log_swizzle_size_ = 0;
|
||||
RasterOrder raster_order_ = RasterOrder::AlongN;
|
||||
|
||||
uint32_t problem_tiles_m_ = 0;
|
||||
uint32_t problem_tiles_n_ = 0;
|
||||
uint32_t problem_tiles_l_ = 0;
|
||||
uint32_t cluster_shape_m_ = 0;
|
||||
uint32_t cluster_shape_n_ = 0;
|
||||
|
||||
// Initializes members. This variant of the method should only be used when
|
||||
// problem_shape and tile_shape contain modes of only rank 1.
|
||||
void
|
||||
@ -127,6 +133,12 @@ struct PersistentTileSchedulerSm90Params {
|
||||
auto problem_blocks_m = round_up(problem_blocks.x, (1 << log_swizzle_size) * cluster_shape.m());
|
||||
auto problem_blocks_n = round_up(problem_blocks.y, (1 << log_swizzle_size) * cluster_shape.n());
|
||||
|
||||
problem_tiles_m_ = problem_blocks_m / cluster_shape.m();
|
||||
problem_tiles_n_ = problem_blocks_n / cluster_shape.n();
|
||||
problem_tiles_l_ = problem_blocks.z;
|
||||
cluster_shape_m_ = cluster_shape.m();
|
||||
cluster_shape_n_ = cluster_shape.n();
|
||||
|
||||
RasterOrder raster_order = get_rasterization_order(
|
||||
problem_blocks_m,
|
||||
problem_blocks_n,
|
||||
|
||||
@ -1,3 +1,34 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief This defines a "fragment" iterator for visiting the fragments of a warp tile
|
||||
that participate in one warp-level mma operation.
|
||||
|
||||
Reference in New Issue
Block a user