v3.8.0 update (#2082)

* 3.8 update

* fix Markus' name

---------

Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
Yujia Zhai
2025-02-06 18:33:40 -08:00
committed by GitHub
parent affd1b693d
commit 833f6990e0
168 changed files with 24945 additions and 3436 deletions

View File

@ -521,15 +521,6 @@ make_cotiled_copy(Copy_Atom<Args...> const& copy_atom,
// Check validity
CUTE_STATIC_ASSERT_V(coalesce(composition(data_layout, layout<1>(layout_tv_data))) == coalesce(layout<1>(atom_tv_layout)),
"The memory pointed to by AtomTVLayout does not exist in the DataLayout.");
#if 0
if (thread0()) {
print("data_layout : "); print(data_layout); print("\n");
print("atom_tv_layout : "); print(atom_tv_layout); print("\n");
print("layout_tv_data : "); print(layout_tv_data); print("\n");
}
#endif
//
// Tiler -- Find the active elements in the DATA tensor and generate a tiler to extract them
//
@ -552,15 +543,6 @@ make_cotiled_copy(Copy_Atom<Args...> const& copy_atom,
// (tid,vid) -> tile_coord
auto layout_tv = composition(left_inverse(tile2data), layout_tv_data);
#if 0
if (thread0()) {
print("tiler : "); print(tiler); print("\n");
print("tile2data : "); print(tile2data); print("\n");
print("layout_tv : "); print(layout_tv); print("\n");
}
#endif
return make_tiled_copy_impl(copy_atom, layout_tv, tiler);
}

View File

@ -394,15 +394,6 @@ make_tmem_warp_partitioner(Tensor<TEngine,TLayout> const& tmem)
// wid -> tmem_coord
auto layout_t_tmem = composition(inv_tmem_layout, atom_t_layout);
#if 0
if (thread0()) {
print("input : "); print(tmem.data()); print(" o "); print(tmem_layout); print("\n");
print("atom_t_layout : "); print(atom_t_layout); print("\n");
print("layout_tv_tmem : "); print(layout_tv_tmem); print("\n");
}
#endif
//
// Tiler -- Find the active elements in the TMEM tensor and generate a tiler to extract them
//
@ -425,15 +416,6 @@ make_tmem_warp_partitioner(Tensor<TEngine,TLayout> const& tmem)
// wid -> tile_coord
auto layout_tv = composition(left_inverse(tile2tmem), layout_t_tmem);
#if 0
if (thread0()) {
print("tiler : "); print(tiler); print("\n");
print("tile2tmem : "); print(tile2tmem); print("\n");
print("layout_tv : "); print(layout_tv); print("\n");
}
#endif
return make_tiler_impl(layout_tv, tiler);
}

View File

@ -1374,19 +1374,6 @@ tma_partition(Copy_Atom<Args...> const& copy_atom,
// Transform tile mode and coalesce
Tensor gtensor_v = coalesce(gtensor.compose(glayout_V), Shape<Shape<_1,_1>>{}); // ((TMA,TMA_Iter), Rest...)
Tensor stensor_v = coalesce(stensor.compose(slayout_V), Shape<Shape<_1,_1>>{}); // ((TMA,TMA_Iter), Rest...)
#if 0
if (thread0()) {
print("cta_coord : "); print(cta_coord); print("\n");
print("cta_layout : "); print(cta_layout); print("\n");
print("gtensor : "); print(gtensor); print("\n");
print("stensor : "); print(stensor); print("\n");
print("layout_V : "); print(layout_V); print("\n");
print("gtensor_v : "); print(gtensor_v); print("\n");
print("stensor_v : "); print(stensor_v); print("\n");
}
#endif
// Offset inside the TMA-mode for the multicast
auto multicast_offset = cta_layout(cta_coord) * (size(tma_layout_v) / cosize(cta_layout));
auto multicast_coord = make_coord(make_coord(multicast_offset, Int<0>{}));

View File

@ -157,7 +157,6 @@ struct MMA_Atom<MMA_Traits<MMAOperation, Args...>>
|| (sizeof_bits_v<typename remove_cvref_t<ATensor>::value_type> == 8 &&
(sizeof_bits_v<ValTypeA> == 8 || sizeof_bits_v<ValTypeA> == 6 || sizeof_bits_v<ValTypeA> == 4))
, "Expecting ValTypeA type");
return make_tensor<FrgTypeA>(static_cast<ATensor&&>(atensor));
} else {

View File

@ -59,7 +59,6 @@ namespace UMMA {
// Common layouts for UMMA Shared Memory //
//////////////////////////////////////////////////
// TODO: Extend for remaining sm100 new layouts
using cute::GMMA::Layout_MN_INTER_Atom;
using cute::GMMA::Layout_MN_SW32_Atom;
using cute::GMMA::Layout_MN_SW64_Atom;
@ -275,19 +274,6 @@ make_umma_desc(Tensor<TEngine,TLayout> const& tensor)
} else {
static_assert(MajorMode != UMMA::Major::MN && MajorMode != UMMA::Major::K, "Unrecognized MajorMode!");
}
#if 0
// DEBUG and SANITY
assert((start_address & 0b0000001111) == 0); // Must be 16B aligned (4LSB are 0) no negotiation
assert((start_address & 0b1110000000) == 0); // Assert base_offset is 0, generalize later
if (thread0()) {
print("smem_desc input tensor: "); print(tensor.data()); print(" o "); print(tensor.layout()); print("\n");
print("smem_desc uint128_t tensor: "); print(u128_tensor.data()); print(" o "); print(u128_tensor.layout()); print("\n");
//print(" desc canonical layout: "); print(canonical_layout); print("\n");
print(desc);
}
#endif
return desc;
}
@ -514,7 +500,7 @@ struct tmem_frg : tmem_frg_base
"UMMA_2SM only accepts Interleaved or Duplicated");
static_assert(M_MMA == 32 || M_MMA == 64 || M_MMA == 128, "UMMA_2SM M-mode size should be 32 or 64 or 128.");
if constexpr (M_MMA == 32) // TODO: Implement Duplicated mode for M_MMA = 32
if constexpr (M_MMA == 32)
{
static_assert(TmemAlloc == UMMA::TmemAllocMode::Interleaved, "Only TmemAllocMode::Interleaved is supported for UMMA_2SM M_MMA=32");
// The "1x4" layout atom: (M,N) -> tmem_addr
@ -1013,7 +999,7 @@ struct MMA_Traits<SM100_MMA_TF32_SS<a_type, b_type, c_type,
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 32, "SM100_MMA_TF32 supports 32bit types");
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 32, "SM100_MMA_TF32_SS supports 32bit types");
using FrgTypeA = UMMA::smem_desc<a_major>;
using FrgTypeB = UMMA::smem_desc<b_major>;
@ -1077,7 +1063,7 @@ struct MMA_Traits<SM100_MMA_F16BF16_SS<a_type, b_type, c_type,
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16 supports 16bit types");
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_SS supports 16bit types");
using FrgTypeA = UMMA::smem_desc<a_major>;
using FrgTypeB = UMMA::smem_desc<b_major>;
@ -1142,7 +1128,7 @@ struct MMA_Traits<SM100_MMA_TF32_TS<a_type, b_type, c_type,
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 32, "SM100_MMA_TF32 supports 32bit types");
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 32, "SM100_MMA_TF32_TS supports 32bit types");
using FrgTypeA = UMMA::tmem_frg_1sm<a_type, a_type, UMMA::TmemAllocMode::NonInterleaved>;
using FrgTypeB = UMMA::smem_desc<b_major>;
@ -1208,7 +1194,7 @@ struct MMA_Traits<SM100_MMA_F16BF16_TS<a_type, b_type, c_type,
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16 supports 16bit types");
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_TS supports 16bit types");
using FrgTypeA = UMMA::tmem_frg_1sm<a_type, a_type, UMMA::TmemAllocMode::NonInterleaved>;
using FrgTypeB = UMMA::smem_desc<b_major>;
@ -1261,6 +1247,155 @@ struct MMA_Traits<SM100_MMA_F16BF16_TS<a_type, b_type, c_type,
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
uint32_t ScaleC, UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
struct MMA_Traits<SM100_MMA_F16BF16_SS_SCALED<a_type, b_type, c_type,
M, N, a_major, b_major,
ScaleC, a_neg, b_neg>>
{
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_SS_SCALED supports 16bit types");
using FrgTypeA = UMMA::smem_desc<a_major>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_1sm<c_type>;
// Logical shape-K is always 256bits, transform to units of elements
static constexpr int K = 256 / cute::sizeof_bits<ValTypeA>::value;
static constexpr uint32_t ScalingFactor = ScaleC;
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
using ThrID = Layout<_1>;
using ALayout = Layout<Shape <_1,Shape <Int<M>,Int<K>>>,
Stride<_0,Stride< _1,Int<M>>>>;
using BLayout = Layout<Shape <_1,Shape <Int<N>,Int<K>>>,
Stride<_0,Stride< _1,Int<N>>>>;
using CLayout = Layout<Shape <_1,Shape <Int<M>,Int<N>>>,
Stride<_0,Stride< _1,Int<M>>>>;
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>();
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE constexpr friend
void
mma_unpack(MMA_Traits const& traits,
Tensor<TD, DLayout> & D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_rmem<TA>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
uint64_t desc_a = A[0];
uint64_t desc_b = B[0];
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F16BF16_SS_SCALED<a_type, b_type, c_type,
M, N, a_major, b_major,
ScaleC, a_neg, b_neg>::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
}
template <uint32_t NewScaleC>
CUTE_HOST_DEVICE constexpr
MMA_Traits<SM100_MMA_F16BF16_SS_SCALED<a_type, b_type, c_type,
M, N, a_major, b_major,
NewScaleC, a_neg, b_neg>>
with(UMMA::ScaleOut accumulate, cute::integral_constant<uint32_t, NewScaleC> scaleC) const {
return {accumulate, idesc_};
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
uint32_t ScaleC, UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg, UMMA::Saturate c_sat>
struct MMA_Traits<SM100_MMA_F16BF16_TS_SCALED<a_type, b_type, c_type,
M, N, a_major, b_major,
ScaleC, a_neg, b_neg, c_sat>>
{
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_TS_SCALED supports 16bit types");
using FrgTypeA = UMMA::tmem_frg_1sm<a_type, a_type, UMMA::TmemAllocMode::NonInterleaved>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_1sm<c_type, int32_t, UMMA::TmemAllocMode::NonInterleaved>;
// Logical shape-K is always 256 bits; transform to units of elements
static constexpr int K = 256 / cute::sizeof_bits<ValTypeA>::value;
static constexpr uint32_t ScalingFactor = ScaleC;
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
using ThrID = Layout<_1>;
using ALayout = Layout<Shape <_1,Shape <Int<M>,Int<K>>>,
Stride<_0,Stride< _1,Int<M>>>>;
using BLayout = Layout<Shape <_1,Shape <Int<N>,Int<K>>>,
Stride<_0,Stride< _1,Int<N>>>>;
using CLayout = Layout<Shape <_1,Shape <Int<M>,Int<N>>>,
Stride<_0,Stride< _1,Int<M>>>>;
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>();
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE constexpr friend
void
mma_unpack(MMA_Traits const& traits,
Tensor<TD, DLayout> & D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_tmem<TA>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
uint32_t tmem_a = raw_pointer_cast(A.data());
uint64_t desc_b = B[0];
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F16BF16_TS_SCALED<a_type, b_type, c_type,
M, N, a_major, b_major,
ScaleC, a_neg, b_neg>::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
}
template <uint32_t NewScaleC>
CUTE_HOST_DEVICE constexpr
MMA_Traits<SM100_MMA_F16BF16_TS_SCALED<a_type, b_type, c_type,
M, N, a_major, b_major,
NewScaleC, a_neg, b_neg, c_sat>>
with(UMMA::ScaleOut accumulate, cute::integral_constant<uint32_t, NewScaleC> scaleC) const {
return {accumulate, idesc_};
}
};
template <class a_type, class b_type, class c_type,
int M, int N,
UMMA::Major a_major, UMMA::Major b_major,
@ -1273,7 +1408,7 @@ struct MMA_Traits<SM100_MMA_TF32_2x1SM_SS<a_type, b_type, c_type,
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 32, "SM100_MMA_TF32 supports 32bit types");
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 32, "SM100_MMA_TF32_2x1SM_SS supports 32bit types");
using FrgTypeA = UMMA::smem_desc<a_major>;
using FrgTypeB = UMMA::smem_desc<b_major>;
@ -1338,7 +1473,7 @@ struct MMA_Traits<SM100_MMA_F16BF16_2x1SM_SS<a_type, b_type, c_type,
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16 supports 16bit types");
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_2x1SM_SS supports 16bit types");
using FrgTypeA = UMMA::smem_desc<a_major>;
using FrgTypeB = UMMA::smem_desc<b_major>;
@ -1404,7 +1539,7 @@ struct MMA_Traits<SM100_MMA_TF32_2x1SM_TS<a_type, b_type, c_type,
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 32, "SM100_MMA_TF32 supports 32bit types");
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 32, "SM100_MMA_TF32_2x1SM_TS supports 32bit types");
using FrgTypeA = UMMA::tmem_frg_2sm<a_type, a_type, UMMA::TmemAllocMode::Duplicated>;
using FrgTypeB = UMMA::smem_desc<b_major>;
@ -1470,7 +1605,7 @@ struct MMA_Traits<SM100_MMA_F16BF16_2x1SM_TS<a_type, b_type, c_type,
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16 supports 16bit types");
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_2x1SM_TS supports 16bit types");
using FrgTypeA = UMMA::tmem_frg_2sm<a_type, a_type, UMMA::TmemAllocMode::Duplicated>;
using FrgTypeB = UMMA::smem_desc<b_major>;
@ -1523,6 +1658,152 @@ struct MMA_Traits<SM100_MMA_F16BF16_2x1SM_TS<a_type, b_type, c_type,
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
uint32_t ScaleC, UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
struct MMA_Traits<SM100_MMA_F16BF16_2x1SM_SS_SCALED<a_type, b_type, c_type,
M, N, a_major, b_major,
ScaleC, a_neg, b_neg>>
{
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_2x1SM_SS_SCALED supports 16bit types");
using FrgTypeA = UMMA::smem_desc<a_major>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_2sm<c_type>;
// Size of instructions's K extent is always 256bits, convert to units of element
constexpr static int K = 256 / cute::sizeof_bits<ValTypeA>::value;
constexpr static uint32_t ScalingFactor = ScaleC;
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
using ThrID = Layout<_2>;
using ALayout = Layout<Shape < _2,Shape <Int<M/2>,Int<K>>>,
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
using BLayout = Layout<Shape < _2,Shape <Int<N/2>,Int<K>>>,
Stride<Int<N/2>,Stride< _1,Int<N>>>>;
using CLayout = Layout<Shape < _2,Shape <Int<M/2>,Int<N>>>,
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>();
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE constexpr friend
void
mma_unpack(MMA_Traits const& traits,
Tensor<TD, DLayout> & D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_rmem<TA>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
uint64_t desc_a = A[0];
uint64_t desc_b = B[0];
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F16BF16_2x1SM_SS_SCALED<a_type, b_type, c_type,
M, N, a_major, b_major,
ScaleC, a_neg, b_neg>::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
}
template <uint32_t NewScaleC>
CUTE_HOST_DEVICE constexpr
MMA_Traits<SM100_MMA_F16BF16_2x1SM_SS_SCALED<a_type, b_type, c_type,
M, N, a_major, b_major,
NewScaleC, a_neg, b_neg>>
with(UMMA::ScaleOut accumulate, cute::integral_constant<uint32_t, NewScaleC> scaleC) const {
return {accumulate, idesc_};
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
uint32_t ScaleC, UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg, UMMA::Saturate c_sat>
struct MMA_Traits<SM100_MMA_F16BF16_2x1SM_TS_SCALED<a_type, b_type, c_type,
M, N, a_major, b_major,
ScaleC, a_neg, b_neg, c_sat>>
{
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_2x1SM_TS_SCALED supports 16bit types");
using FrgTypeA = UMMA::tmem_frg_2sm<a_type, a_type, UMMA::TmemAllocMode::Duplicated>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_2sm<c_type>;
// Size of instructions' K extent is always 256 bits; convert to units of element
constexpr static int K = 256 / cute::sizeof_bits<ValTypeA>::value;
constexpr static uint32_t ScalingFactor = ScaleC;
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
using ThrID = Layout<_2>;
using ALayout = Layout<Shape < _2,Shape <Int<M/2>,Int<K>>>,
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
using BLayout = Layout<Shape < _2,Shape <Int<N/2>,Int<K>>>,
Stride<Int<N/2>,Stride< _1,Int<N>>>>;
using CLayout = Layout<Shape < _2,Shape <Int<M/2>,Int<N>>>,
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>();
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE constexpr friend
void
mma_unpack(MMA_Traits const& traits,
Tensor<TD, DLayout> & D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_tmem<TA>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
uint64_t tmem_a = raw_pointer_cast(A.data());
uint64_t desc_b = B[0];
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F16BF16_2x1SM_TS_SCALED<a_type, b_type, c_type,
M, N, a_major, b_major,
ScaleC, a_neg, b_neg, c_sat>::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
}
template <uint32_t NewScaleC>
CUTE_HOST_DEVICE constexpr
MMA_Traits<SM100_MMA_F16BF16_2x1SM_TS_SCALED<a_type, b_type, c_type,
M, N, a_major, b_major,
NewScaleC, a_neg, b_neg, c_sat>>
with(UMMA::ScaleOut accumulate, cute::integral_constant<uint32_t, NewScaleC> scaleC) const {
return {accumulate, idesc_};
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::Saturate c_sat>
@ -1534,7 +1815,7 @@ struct MMA_Traits<SM100_MMA_S8_SS<a_type, b_type, c_type,
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 8, "SM100_MMA_S8 supports 8bit types");
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 8, "SM100_MMA_S8_SS supports 8bit types");
using FrgTypeA = UMMA::smem_desc<a_major>;
using FrgTypeB = UMMA::smem_desc<b_major>;
@ -1599,7 +1880,7 @@ struct MMA_Traits<SM100_MMA_S8_TS<a_type, b_type, c_type,
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 8, "SM100_MMA_S8 supports 8bit types");
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 8, "SM100_MMA_S8_TS supports 8bit types");
using FrgTypeA = UMMA::tmem_frg_1sm<a_type, a_type, UMMA::TmemAllocMode::NonInterleaved>;
using FrgTypeB = UMMA::smem_desc<b_major>;
@ -1663,7 +1944,7 @@ struct MMA_Traits<SM100_MMA_S8_2x1SM_SS<a_type, b_type, c_type,
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 8, "SM100_MMA_S8 supports 8bit types");
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 8, "SM100_MMA_S8_2x1SM_SS supports 8bit types");
using FrgTypeA = UMMA::smem_desc<a_major>;
using FrgTypeB = UMMA::smem_desc<b_major>;
@ -1728,7 +2009,7 @@ struct MMA_Traits<SM100_MMA_S8_2x1SM_TS<a_type, b_type, c_type,
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 8, "SM100_MMA_S8 supports 8bit types");
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 8, "SM100_MMA_S8_2x1SM_TS supports 8bit types");
using FrgTypeA = UMMA::tmem_frg_2sm<a_type, a_type, UMMA::TmemAllocMode::Duplicated>;
using FrgTypeB = UMMA::smem_desc<b_major>;
@ -1795,16 +2076,18 @@ struct MMA_Traits<SM100_MMA_F8F6F4_SS, a_type, b_type, c_type,
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> <= 8 && cute::sizeof_bits_v<b_type> <= 8, "SM100_MMA_F8F6F4 supports types with leq 8bit types");
static_assert(cute::sizeof_bits_v<a_type> <= 8 && cute::sizeof_bits_v<b_type> <= 8, "SM100_MMA_F8F6F4_SS supports types with leq 8bit types");
static_assert(M == 64 || M == 128, "SM100_MMA_F8F6F4_SS M-mode size should be 64 or 128 for 1 CTA cluster MMA.");
static_assert((M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) ||
(M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)),
"SM100_MMA_F8F6F4_SS N-mode size should be a multiple of 8 between 8 and 256 for M=64,\
or a multiple of 16 between 16 and 256 for M=128.");
using FrgTypeA = UMMA::smem_desc<a_major>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_1sm<c_type>;
static_assert(sizeof_bits_v<ValTypeA> <= sizeof_bits_v<uint8_t> &&
sizeof_bits_v<ValTypeB> <= sizeof_bits_v<uint8_t>);
static_assert(M == 64 || M == 128, "SM100_MMA_F8F6F4 M-mode size should be 64 or 128 for 1 CTA cluster MMA.");
static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_F8F6F4 N-mode size should be a multiple of 8 between 8 and 256.");
// Logical shape-K is always 256bits, transform to units of elements
constexpr static int K = 32;
@ -1863,7 +2146,7 @@ struct MMA_Traits<SM100_MMA_MXF8F6F4_SS<a_type, b_type, c_type, sf_type,
using ValTypeC = c_type;
using ValTypeSFA = sf_type;
using ValTypeSFB = sf_type;
static_assert(cute::sizeof_bits_v<a_type> <= 8 && cute::sizeof_bits_v<b_type> <= 8, "SM100_MMA_MXF8F6F4 supports types with leq 8bit types");
static_assert(cute::sizeof_bits_v<a_type> <= 8 && cute::sizeof_bits_v<b_type> <= 8, "SM100_MMA_MXF8F6F4_SS supports types with leq 8bit types");
// Logical shape-K is always 256bits, transform to units of elements
constexpr static int K = 32;
@ -1953,7 +2236,7 @@ struct MMA_Traits<SM100_MMA_F8F6F4_TS<a_type, b_type, c_type,
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> <= 8 && cute::sizeof_bits_v<b_type> <= 8, "SM100_MMA_F8F6F4 supports types with leq 8bit types");
static_assert(cute::sizeof_bits_v<a_type> <= 8 && cute::sizeof_bits_v<b_type> <= 8, "SM100_MMA_F8F6F4_TS supports types with leq 8bit types");
using FrgTypeA = UMMA::tmem_frg_1sm<a_type, a_type, UMMA::TmemAllocMode::NonInterleaved>;
using FrgTypeB = UMMA::smem_desc<b_major>;
@ -2023,8 +2306,10 @@ struct MMA_Traits<SM100_MMA_F8F6F4_2x1SM_SS, a_type, b_type, c_type,
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> <= 8 && cute::sizeof_bits_v<b_type> <= 8, "SM100_MMA_F8F6F4 supports types with leq 8bit types");
static_assert(cute::sizeof_bits_v<a_type> <= 8 && cute::sizeof_bits_v<b_type> <= 8, "SM100_MMA_F8F6F4_2x1SM_SS supports types with leq 8bit types");
static_assert(M == 128 || M == 256, "SM100_MMA_F8F6F4_2x1SM_SS M-mode size should be 64 or 128 for 1 CTA cluster MMA.");
static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F8F6F4_2x1SM_SS N-mode size should be a multiple of 32 between 32 and 256.");
using FrgTypeA = UMMA::smem_desc<a_major>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_2sm<c_type>;
@ -2034,9 +2319,6 @@ struct MMA_Traits<SM100_MMA_F8F6F4_2x1SM_SS, a_type, b_type, c_type,
// Size of instructions's K extent is always 256bits, convert to units of element
constexpr static int K = 32;
static_assert(M == 128 || M == 256, "MMA_F8F6F4 M-mode size should be 128 or 256 for 2 CTA cluster MMA.");
static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "MMA_F8F6F4 N-mode size should be a multiple of 16 between 16 and 256.");
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
using ThrID = Layout<_2>;
using ALayout = Layout<Shape < _2,Shape <Int<M/2>,Int<K>>>,
@ -2090,7 +2372,7 @@ struct MMA_Traits<SM100_MMA_F8F6F4_2x1SM_TS<a_type, b_type, c_type,
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> <= 8 && cute::sizeof_bits_v<b_type> <= 8, "SM100_MMA_F8F6F4 supports types with leq 8bit types");
static_assert(cute::sizeof_bits_v<a_type> <= 8 && cute::sizeof_bits_v<b_type> <= 8, "SM100_MMA_F8F6F4_2x1SM_TS supports types with leq 8bit types");
using FrgTypeA = UMMA::tmem_frg_2sm<a_type, a_type, UMMA::TmemAllocMode::Duplicated>;
using FrgTypeB = UMMA::smem_desc<b_major>;
@ -2159,7 +2441,7 @@ struct MMA_Traits<SM100_MMA_MXF8F6F4_2x1SM_SS<a_type, b_type, c_type, sf_type,
using ValTypeC = c_type;
using ValTypeSFA = sf_type;
using ValTypeSFB = sf_type;
static_assert(cute::sizeof_bits_v<a_type> <= 8 && cute::sizeof_bits_v<b_type> <= 8, "SM100_MMA_F8F6F4 supports types with leq 8bit types");
static_assert(cute::sizeof_bits_v<a_type> <= 8 && cute::sizeof_bits_v<b_type> <= 8, "SM100_MMA_MXF8F6F4_2x1SM_SS supports types with leq 8bit types");
using FrgTypeA = UMMA::smem_desc<a_major>;
using FrgTypeB = UMMA::smem_desc<b_major>;
@ -2252,7 +2534,7 @@ struct MMA_Traits<SM100_MMA_MXF4_SS<a_type, b_type, c_type, sf_type,
using ValTypeC = c_type;
using ValTypeSFA = sf_type;
using ValTypeSFB = sf_type;
static_assert(cute::sizeof_bits_v<a_type> == 4 && cute::sizeof_bits_v<b_type> == 4, "SM100_MMA_MXF4 supports 4bit types");
static_assert(cute::sizeof_bits_v<a_type> == 4 && cute::sizeof_bits_v<b_type> == 4, "SM100_MMA_MXF4_SS supports 4bit types");
// Logical shape-K is always 256bits, transform to units of elements
constexpr static int K = 64;
@ -2345,7 +2627,7 @@ struct MMA_Traits<SM100_MMA_MXF4_2x1SM_SS<a_type, b_type, c_type, sf_type,
using ValTypeC = c_type;
using ValTypeSFA = sf_type;
using ValTypeSFB = sf_type;
static_assert(cute::sizeof_bits_v<a_type> == 4 && cute::sizeof_bits_v<b_type> == 4, "SM100_MMA_MXF4 supports 4bit types");
static_assert(cute::sizeof_bits_v<a_type> == 4 && cute::sizeof_bits_v<b_type> == 4, "SM100_MMA_MXF4_2x1SM_SS supports 4bit types");
// Logical shape-K is always 256bits, transform to units of elements
constexpr static int K = 64;

View File

@ -295,19 +295,6 @@ make_gmma_desc(Tensor<TEngine,TLayout> const& tensor)
} else {
static_assert(MajorMode != Major::MN && MajorMode != Major::K, "Unrecognized MajorMode!");
}
#if 0
// DEBUG and SANITY
assert((start_address & 0b0000001111) == 0); // Must be 16B aligned (4LSB are 0) no negotiation
assert((start_address & 0b1110000000) == 0); // Assert base_offset is 0, generalize later
if (thread0()) {
print("smem_desc input tensor: "); print(tensor.data()); print(" o "); print(tensor.layout()); print("\n");
print("smem_desc uint128_t tensor: "); print(u128_tensor.data()); print(" o "); print(u128_tensor.layout()); print("\n");
//print(" desc canonical layout: "); print(canonical_layout); print("\n");
print(desc);
}
#endif
return desc;
}