|
|
|
|
@ -59,7 +59,6 @@ namespace UMMA {
|
|
|
|
|
// Common layouts for UMMA Shared Memory //
|
|
|
|
|
//////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
// TODO: Extend for remaining sm100 new layouts
|
|
|
|
|
using cute::GMMA::Layout_MN_INTER_Atom;
|
|
|
|
|
using cute::GMMA::Layout_MN_SW32_Atom;
|
|
|
|
|
using cute::GMMA::Layout_MN_SW64_Atom;
|
|
|
|
|
@ -275,19 +274,6 @@ make_umma_desc(Tensor<TEngine,TLayout> const& tensor)
|
|
|
|
|
} else {
|
|
|
|
|
static_assert(MajorMode != UMMA::Major::MN && MajorMode != UMMA::Major::K, "Unrecognized MajorMode!");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#if 0
|
|
|
|
|
// DEBUG and SANITY
|
|
|
|
|
assert((start_address & 0b0000001111) == 0); // Must be 16B aligned (4LSB are 0) no negotiation
|
|
|
|
|
assert((start_address & 0b1110000000) == 0); // Assert base_offset is 0, generalize later
|
|
|
|
|
if (thread0()) {
|
|
|
|
|
print("smem_desc input tensor: "); print(tensor.data()); print(" o "); print(tensor.layout()); print("\n");
|
|
|
|
|
print("smem_desc uint128_t tensor: "); print(u128_tensor.data()); print(" o "); print(u128_tensor.layout()); print("\n");
|
|
|
|
|
//print(" desc canonical layout: "); print(canonical_layout); print("\n");
|
|
|
|
|
print(desc);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
return desc;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -514,7 +500,7 @@ struct tmem_frg : tmem_frg_base
|
|
|
|
|
"UMMA_2SM only accepts Interleaved or Duplicated");
|
|
|
|
|
static_assert(M_MMA == 32 || M_MMA == 64 || M_MMA == 128, "UMMA_2SM M-mode size should be 32 or 64 or 128.");
|
|
|
|
|
|
|
|
|
|
if constexpr (M_MMA == 32) // TODO: Implement Duplicated mode for M_MMA = 32
|
|
|
|
|
if constexpr (M_MMA == 32)
|
|
|
|
|
{
|
|
|
|
|
static_assert(TmemAlloc == UMMA::TmemAllocMode::Interleaved, "Only TmemAllocMode::Interleaved is supported for UMMA_2SM M_MMA=32");
|
|
|
|
|
// The "1x4" layout atom: (M,N) -> tmem_addr
|
|
|
|
|
@ -1013,7 +999,7 @@ struct MMA_Traits<SM100_MMA_TF32_SS<a_type, b_type, c_type,
|
|
|
|
|
using ValTypeA = a_type;
|
|
|
|
|
using ValTypeB = b_type;
|
|
|
|
|
using ValTypeC = c_type;
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 32, "SM100_MMA_TF32 supports 32bit types");
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 32, "SM100_MMA_TF32_SS supports 32bit types");
|
|
|
|
|
|
|
|
|
|
using FrgTypeA = UMMA::smem_desc<a_major>;
|
|
|
|
|
using FrgTypeB = UMMA::smem_desc<b_major>;
|
|
|
|
|
@ -1077,7 +1063,7 @@ struct MMA_Traits<SM100_MMA_F16BF16_SS<a_type, b_type, c_type,
|
|
|
|
|
using ValTypeB = b_type;
|
|
|
|
|
using ValTypeC = c_type;
|
|
|
|
|
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16 supports 16bit types");
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_SS supports 16bit types");
|
|
|
|
|
|
|
|
|
|
using FrgTypeA = UMMA::smem_desc<a_major>;
|
|
|
|
|
using FrgTypeB = UMMA::smem_desc<b_major>;
|
|
|
|
|
@ -1142,7 +1128,7 @@ struct MMA_Traits<SM100_MMA_TF32_TS<a_type, b_type, c_type,
|
|
|
|
|
using ValTypeA = a_type;
|
|
|
|
|
using ValTypeB = b_type;
|
|
|
|
|
using ValTypeC = c_type;
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 32, "SM100_MMA_TF32 supports 32bit types");
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 32, "SM100_MMA_TF32_TS supports 32bit types");
|
|
|
|
|
|
|
|
|
|
using FrgTypeA = UMMA::tmem_frg_1sm<a_type, a_type, UMMA::TmemAllocMode::NonInterleaved>;
|
|
|
|
|
using FrgTypeB = UMMA::smem_desc<b_major>;
|
|
|
|
|
@ -1208,7 +1194,7 @@ struct MMA_Traits<SM100_MMA_F16BF16_TS<a_type, b_type, c_type,
|
|
|
|
|
using ValTypeA = a_type;
|
|
|
|
|
using ValTypeB = b_type;
|
|
|
|
|
using ValTypeC = c_type;
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16 supports 16bit types");
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_TS supports 16bit types");
|
|
|
|
|
|
|
|
|
|
using FrgTypeA = UMMA::tmem_frg_1sm<a_type, a_type, UMMA::TmemAllocMode::NonInterleaved>;
|
|
|
|
|
using FrgTypeB = UMMA::smem_desc<b_major>;
|
|
|
|
|
@ -1261,6 +1247,155 @@ struct MMA_Traits<SM100_MMA_F16BF16_TS<a_type, b_type, c_type,
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <class a_type, class b_type, class c_type,
|
|
|
|
|
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
|
|
|
|
|
uint32_t ScaleC, UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
|
|
|
|
|
struct MMA_Traits<SM100_MMA_F16BF16_SS_SCALED<a_type, b_type, c_type,
|
|
|
|
|
M, N, a_major, b_major,
|
|
|
|
|
ScaleC, a_neg, b_neg>>
|
|
|
|
|
{
|
|
|
|
|
using ValTypeD = c_type;
|
|
|
|
|
using ValTypeA = a_type;
|
|
|
|
|
using ValTypeB = b_type;
|
|
|
|
|
using ValTypeC = c_type;
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_SS_SCALED supports 16bit types");
|
|
|
|
|
|
|
|
|
|
using FrgTypeA = UMMA::smem_desc<a_major>;
|
|
|
|
|
using FrgTypeB = UMMA::smem_desc<b_major>;
|
|
|
|
|
using FrgTypeC = UMMA::tmem_frg_1sm<c_type>;
|
|
|
|
|
|
|
|
|
|
// Logical shape-K is always 256bits, transform to units of elements
|
|
|
|
|
static constexpr int K = 256 / cute::sizeof_bits<ValTypeA>::value;
|
|
|
|
|
|
|
|
|
|
static constexpr uint32_t ScalingFactor = ScaleC;
|
|
|
|
|
|
|
|
|
|
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
|
|
|
|
|
using ThrID = Layout<_1>;
|
|
|
|
|
using ALayout = Layout<Shape <_1,Shape <Int<M>,Int<K>>>,
|
|
|
|
|
Stride<_0,Stride< _1,Int<M>>>>;
|
|
|
|
|
using BLayout = Layout<Shape <_1,Shape <Int<N>,Int<K>>>,
|
|
|
|
|
Stride<_0,Stride< _1,Int<N>>>>;
|
|
|
|
|
using CLayout = Layout<Shape <_1,Shape <Int<M>,Int<N>>>,
|
|
|
|
|
Stride<_0,Stride< _1,Int<M>>>>;
|
|
|
|
|
|
|
|
|
|
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
|
|
|
|
|
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
|
|
|
|
|
|
|
|
|
|
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
|
|
|
|
|
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>();
|
|
|
|
|
|
|
|
|
|
template <class TD, class DLayout,
|
|
|
|
|
class TA, class ALayout,
|
|
|
|
|
class TB, class BLayout,
|
|
|
|
|
class TC, class CLayout>
|
|
|
|
|
CUTE_HOST_DEVICE constexpr friend
|
|
|
|
|
void
|
|
|
|
|
mma_unpack(MMA_Traits const& traits,
|
|
|
|
|
Tensor<TD, DLayout> & D,
|
|
|
|
|
Tensor<TA, ALayout> const& A,
|
|
|
|
|
Tensor<TB, BLayout> const& B,
|
|
|
|
|
Tensor<TC, CLayout> const& C)
|
|
|
|
|
{
|
|
|
|
|
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
|
|
|
|
|
static_assert(is_rmem<TA>::value, "Expected desc registers in MMA_Atom::call");
|
|
|
|
|
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
|
|
|
|
|
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
|
|
|
|
|
|
|
|
|
|
uint64_t desc_a = A[0];
|
|
|
|
|
uint64_t desc_b = B[0];
|
|
|
|
|
uint32_t tmem_c = raw_pointer_cast(D.data());
|
|
|
|
|
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
|
|
|
|
|
|
|
|
|
|
SM100_MMA_F16BF16_SS_SCALED<a_type, b_type, c_type,
|
|
|
|
|
M, N, a_major, b_major,
|
|
|
|
|
ScaleC, a_neg, b_neg>::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <uint32_t NewScaleC>
|
|
|
|
|
CUTE_HOST_DEVICE constexpr
|
|
|
|
|
MMA_Traits<SM100_MMA_F16BF16_SS_SCALED<a_type, b_type, c_type,
|
|
|
|
|
M, N, a_major, b_major,
|
|
|
|
|
NewScaleC, a_neg, b_neg>>
|
|
|
|
|
with(UMMA::ScaleOut accumulate, cute::integral_constant<uint32_t, NewScaleC> scaleC) const {
|
|
|
|
|
return {accumulate, idesc_};
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <class a_type, class b_type, class c_type,
|
|
|
|
|
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
|
|
|
|
|
uint32_t ScaleC, UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg, UMMA::Saturate c_sat>
|
|
|
|
|
struct MMA_Traits<SM100_MMA_F16BF16_TS_SCALED<a_type, b_type, c_type,
|
|
|
|
|
M, N, a_major, b_major,
|
|
|
|
|
ScaleC, a_neg, b_neg, c_sat>>
|
|
|
|
|
{
|
|
|
|
|
using ValTypeD = c_type;
|
|
|
|
|
using ValTypeA = a_type;
|
|
|
|
|
using ValTypeB = b_type;
|
|
|
|
|
using ValTypeC = c_type;
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_TS_SCALED supports 16bit types");
|
|
|
|
|
|
|
|
|
|
using FrgTypeA = UMMA::tmem_frg_1sm<a_type, a_type, UMMA::TmemAllocMode::NonInterleaved>;
|
|
|
|
|
using FrgTypeB = UMMA::smem_desc<b_major>;
|
|
|
|
|
using FrgTypeC = UMMA::tmem_frg_1sm<c_type, int32_t, UMMA::TmemAllocMode::NonInterleaved>;
|
|
|
|
|
|
|
|
|
|
// Logical shape-K is always 256 bits; transform to units of elements
|
|
|
|
|
static constexpr int K = 256 / cute::sizeof_bits<ValTypeA>::value;
|
|
|
|
|
|
|
|
|
|
static constexpr uint32_t ScalingFactor = ScaleC;
|
|
|
|
|
|
|
|
|
|
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
|
|
|
|
|
using ThrID = Layout<_1>;
|
|
|
|
|
using ALayout = Layout<Shape <_1,Shape <Int<M>,Int<K>>>,
|
|
|
|
|
Stride<_0,Stride< _1,Int<M>>>>;
|
|
|
|
|
using BLayout = Layout<Shape <_1,Shape <Int<N>,Int<K>>>,
|
|
|
|
|
Stride<_0,Stride< _1,Int<N>>>>;
|
|
|
|
|
using CLayout = Layout<Shape <_1,Shape <Int<M>,Int<N>>>,
|
|
|
|
|
Stride<_0,Stride< _1,Int<M>>>>;
|
|
|
|
|
|
|
|
|
|
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
|
|
|
|
|
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
|
|
|
|
|
|
|
|
|
|
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
|
|
|
|
|
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>();
|
|
|
|
|
|
|
|
|
|
template <class TD, class DLayout,
|
|
|
|
|
class TA, class ALayout,
|
|
|
|
|
class TB, class BLayout,
|
|
|
|
|
class TC, class CLayout>
|
|
|
|
|
CUTE_HOST_DEVICE constexpr friend
|
|
|
|
|
void
|
|
|
|
|
mma_unpack(MMA_Traits const& traits,
|
|
|
|
|
Tensor<TD, DLayout> & D,
|
|
|
|
|
Tensor<TA, ALayout> const& A,
|
|
|
|
|
Tensor<TB, BLayout> const& B,
|
|
|
|
|
Tensor<TC, CLayout> const& C)
|
|
|
|
|
{
|
|
|
|
|
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
|
|
|
|
|
static_assert(is_tmem<TA>::value, "Expected tmem in MMA_Atom::call");
|
|
|
|
|
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
|
|
|
|
|
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
|
|
|
|
|
|
|
|
|
|
uint32_t tmem_a = raw_pointer_cast(A.data());
|
|
|
|
|
uint64_t desc_b = B[0];
|
|
|
|
|
uint32_t tmem_c = raw_pointer_cast(D.data());
|
|
|
|
|
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
|
|
|
|
|
|
|
|
|
|
SM100_MMA_F16BF16_TS_SCALED<a_type, b_type, c_type,
|
|
|
|
|
M, N, a_major, b_major,
|
|
|
|
|
ScaleC, a_neg, b_neg>::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <uint32_t NewScaleC>
|
|
|
|
|
CUTE_HOST_DEVICE constexpr
|
|
|
|
|
MMA_Traits<SM100_MMA_F16BF16_TS_SCALED<a_type, b_type, c_type,
|
|
|
|
|
M, N, a_major, b_major,
|
|
|
|
|
NewScaleC, a_neg, b_neg, c_sat>>
|
|
|
|
|
with(UMMA::ScaleOut accumulate, cute::integral_constant<uint32_t, NewScaleC> scaleC) const {
|
|
|
|
|
return {accumulate, idesc_};
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <class a_type, class b_type, class c_type,
|
|
|
|
|
int M, int N,
|
|
|
|
|
UMMA::Major a_major, UMMA::Major b_major,
|
|
|
|
|
@ -1273,7 +1408,7 @@ struct MMA_Traits<SM100_MMA_TF32_2x1SM_SS<a_type, b_type, c_type,
|
|
|
|
|
using ValTypeA = a_type;
|
|
|
|
|
using ValTypeB = b_type;
|
|
|
|
|
using ValTypeC = c_type;
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 32, "SM100_MMA_TF32 supports 32bit types");
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 32, "SM100_MMA_TF32_2x1SM_SS supports 32bit types");
|
|
|
|
|
|
|
|
|
|
using FrgTypeA = UMMA::smem_desc<a_major>;
|
|
|
|
|
using FrgTypeB = UMMA::smem_desc<b_major>;
|
|
|
|
|
@ -1338,7 +1473,7 @@ struct MMA_Traits<SM100_MMA_F16BF16_2x1SM_SS<a_type, b_type, c_type,
|
|
|
|
|
using ValTypeA = a_type;
|
|
|
|
|
using ValTypeB = b_type;
|
|
|
|
|
using ValTypeC = c_type;
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16 supports 16bit types");
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_2x1SM_SS supports 16bit types");
|
|
|
|
|
|
|
|
|
|
using FrgTypeA = UMMA::smem_desc<a_major>;
|
|
|
|
|
using FrgTypeB = UMMA::smem_desc<b_major>;
|
|
|
|
|
@ -1404,7 +1539,7 @@ struct MMA_Traits<SM100_MMA_TF32_2x1SM_TS<a_type, b_type, c_type,
|
|
|
|
|
using ValTypeA = a_type;
|
|
|
|
|
using ValTypeB = b_type;
|
|
|
|
|
using ValTypeC = c_type;
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 32, "SM100_MMA_TF32 supports 32bit types");
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 32, "SM100_MMA_TF32_2x1SM_TS supports 32bit types");
|
|
|
|
|
|
|
|
|
|
using FrgTypeA = UMMA::tmem_frg_2sm<a_type, a_type, UMMA::TmemAllocMode::Duplicated>;
|
|
|
|
|
using FrgTypeB = UMMA::smem_desc<b_major>;
|
|
|
|
|
@ -1470,7 +1605,7 @@ struct MMA_Traits<SM100_MMA_F16BF16_2x1SM_TS<a_type, b_type, c_type,
|
|
|
|
|
using ValTypeA = a_type;
|
|
|
|
|
using ValTypeB = b_type;
|
|
|
|
|
using ValTypeC = c_type;
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16 supports 16bit types");
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_2x1SM_TS supports 16bit types");
|
|
|
|
|
|
|
|
|
|
using FrgTypeA = UMMA::tmem_frg_2sm<a_type, a_type, UMMA::TmemAllocMode::Duplicated>;
|
|
|
|
|
using FrgTypeB = UMMA::smem_desc<b_major>;
|
|
|
|
|
@ -1523,6 +1658,152 @@ struct MMA_Traits<SM100_MMA_F16BF16_2x1SM_TS<a_type, b_type, c_type,
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <class a_type, class b_type, class c_type,
|
|
|
|
|
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
|
|
|
|
|
uint32_t ScaleC, UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
|
|
|
|
|
struct MMA_Traits<SM100_MMA_F16BF16_2x1SM_SS_SCALED<a_type, b_type, c_type,
|
|
|
|
|
M, N, a_major, b_major,
|
|
|
|
|
ScaleC, a_neg, b_neg>>
|
|
|
|
|
{
|
|
|
|
|
using ValTypeD = c_type;
|
|
|
|
|
using ValTypeA = a_type;
|
|
|
|
|
using ValTypeB = b_type;
|
|
|
|
|
using ValTypeC = c_type;
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_2x1SM_SS_SCALED supports 16bit types");
|
|
|
|
|
|
|
|
|
|
using FrgTypeA = UMMA::smem_desc<a_major>;
|
|
|
|
|
using FrgTypeB = UMMA::smem_desc<b_major>;
|
|
|
|
|
using FrgTypeC = UMMA::tmem_frg_2sm<c_type>;
|
|
|
|
|
|
|
|
|
|
// Size of instructions's K extent is always 256bits, convert to units of element
|
|
|
|
|
constexpr static int K = 256 / cute::sizeof_bits<ValTypeA>::value;
|
|
|
|
|
constexpr static uint32_t ScalingFactor = ScaleC;
|
|
|
|
|
|
|
|
|
|
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
|
|
|
|
|
using ThrID = Layout<_2>;
|
|
|
|
|
using ALayout = Layout<Shape < _2,Shape <Int<M/2>,Int<K>>>,
|
|
|
|
|
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
|
|
|
|
|
using BLayout = Layout<Shape < _2,Shape <Int<N/2>,Int<K>>>,
|
|
|
|
|
Stride<Int<N/2>,Stride< _1,Int<N>>>>;
|
|
|
|
|
using CLayout = Layout<Shape < _2,Shape <Int<M/2>,Int<N>>>,
|
|
|
|
|
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
|
|
|
|
|
|
|
|
|
|
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
|
|
|
|
|
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
|
|
|
|
|
|
|
|
|
|
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
|
|
|
|
|
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>();
|
|
|
|
|
|
|
|
|
|
template <class TD, class DLayout,
|
|
|
|
|
class TA, class ALayout,
|
|
|
|
|
class TB, class BLayout,
|
|
|
|
|
class TC, class CLayout>
|
|
|
|
|
CUTE_HOST_DEVICE constexpr friend
|
|
|
|
|
void
|
|
|
|
|
mma_unpack(MMA_Traits const& traits,
|
|
|
|
|
Tensor<TD, DLayout> & D,
|
|
|
|
|
Tensor<TA, ALayout> const& A,
|
|
|
|
|
Tensor<TB, BLayout> const& B,
|
|
|
|
|
Tensor<TC, CLayout> const& C)
|
|
|
|
|
{
|
|
|
|
|
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
|
|
|
|
|
static_assert(is_rmem<TA>::value, "Expected desc registers in MMA_Atom::call");
|
|
|
|
|
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
|
|
|
|
|
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
|
|
|
|
|
|
|
|
|
|
uint64_t desc_a = A[0];
|
|
|
|
|
uint64_t desc_b = B[0];
|
|
|
|
|
uint32_t tmem_c = raw_pointer_cast(D.data());
|
|
|
|
|
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
|
|
|
|
|
|
|
|
|
|
SM100_MMA_F16BF16_2x1SM_SS_SCALED<a_type, b_type, c_type,
|
|
|
|
|
M, N, a_major, b_major,
|
|
|
|
|
ScaleC, a_neg, b_neg>::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <uint32_t NewScaleC>
|
|
|
|
|
CUTE_HOST_DEVICE constexpr
|
|
|
|
|
MMA_Traits<SM100_MMA_F16BF16_2x1SM_SS_SCALED<a_type, b_type, c_type,
|
|
|
|
|
M, N, a_major, b_major,
|
|
|
|
|
NewScaleC, a_neg, b_neg>>
|
|
|
|
|
with(UMMA::ScaleOut accumulate, cute::integral_constant<uint32_t, NewScaleC> scaleC) const {
|
|
|
|
|
return {accumulate, idesc_};
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <class a_type, class b_type, class c_type,
|
|
|
|
|
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
|
|
|
|
|
uint32_t ScaleC, UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg, UMMA::Saturate c_sat>
|
|
|
|
|
struct MMA_Traits<SM100_MMA_F16BF16_2x1SM_TS_SCALED<a_type, b_type, c_type,
|
|
|
|
|
M, N, a_major, b_major,
|
|
|
|
|
ScaleC, a_neg, b_neg, c_sat>>
|
|
|
|
|
{
|
|
|
|
|
using ValTypeD = c_type;
|
|
|
|
|
using ValTypeA = a_type;
|
|
|
|
|
using ValTypeB = b_type;
|
|
|
|
|
using ValTypeC = c_type;
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_2x1SM_TS_SCALED supports 16bit types");
|
|
|
|
|
|
|
|
|
|
using FrgTypeA = UMMA::tmem_frg_2sm<a_type, a_type, UMMA::TmemAllocMode::Duplicated>;
|
|
|
|
|
using FrgTypeB = UMMA::smem_desc<b_major>;
|
|
|
|
|
using FrgTypeC = UMMA::tmem_frg_2sm<c_type>;
|
|
|
|
|
|
|
|
|
|
// Size of instructions' K extent is always 256 bits; convert to units of element
|
|
|
|
|
constexpr static int K = 256 / cute::sizeof_bits<ValTypeA>::value;
|
|
|
|
|
constexpr static uint32_t ScalingFactor = ScaleC;
|
|
|
|
|
|
|
|
|
|
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
|
|
|
|
|
using ThrID = Layout<_2>;
|
|
|
|
|
using ALayout = Layout<Shape < _2,Shape <Int<M/2>,Int<K>>>,
|
|
|
|
|
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
|
|
|
|
|
using BLayout = Layout<Shape < _2,Shape <Int<N/2>,Int<K>>>,
|
|
|
|
|
Stride<Int<N/2>,Stride< _1,Int<N>>>>;
|
|
|
|
|
using CLayout = Layout<Shape < _2,Shape <Int<M/2>,Int<N>>>,
|
|
|
|
|
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
|
|
|
|
|
|
|
|
|
|
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
|
|
|
|
|
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
|
|
|
|
|
|
|
|
|
|
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
|
|
|
|
|
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>();
|
|
|
|
|
|
|
|
|
|
template <class TD, class DLayout,
|
|
|
|
|
class TA, class ALayout,
|
|
|
|
|
class TB, class BLayout,
|
|
|
|
|
class TC, class CLayout>
|
|
|
|
|
CUTE_HOST_DEVICE constexpr friend
|
|
|
|
|
void
|
|
|
|
|
mma_unpack(MMA_Traits const& traits,
|
|
|
|
|
Tensor<TD, DLayout> & D,
|
|
|
|
|
Tensor<TA, ALayout> const& A,
|
|
|
|
|
Tensor<TB, BLayout> const& B,
|
|
|
|
|
Tensor<TC, CLayout> const& C)
|
|
|
|
|
{
|
|
|
|
|
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
|
|
|
|
|
static_assert(is_tmem<TA>::value, "Expected desc registers in MMA_Atom::call");
|
|
|
|
|
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
|
|
|
|
|
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
|
|
|
|
|
|
|
|
|
|
uint64_t tmem_a = raw_pointer_cast(A.data());
|
|
|
|
|
uint64_t desc_b = B[0];
|
|
|
|
|
uint32_t tmem_c = raw_pointer_cast(D.data());
|
|
|
|
|
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
|
|
|
|
|
|
|
|
|
|
SM100_MMA_F16BF16_2x1SM_TS_SCALED<a_type, b_type, c_type,
|
|
|
|
|
M, N, a_major, b_major,
|
|
|
|
|
ScaleC, a_neg, b_neg, c_sat>::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <uint32_t NewScaleC>
|
|
|
|
|
CUTE_HOST_DEVICE constexpr
|
|
|
|
|
MMA_Traits<SM100_MMA_F16BF16_2x1SM_TS_SCALED<a_type, b_type, c_type,
|
|
|
|
|
M, N, a_major, b_major,
|
|
|
|
|
NewScaleC, a_neg, b_neg, c_sat>>
|
|
|
|
|
with(UMMA::ScaleOut accumulate, cute::integral_constant<uint32_t, NewScaleC> scaleC) const {
|
|
|
|
|
return {accumulate, idesc_};
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <class a_type, class b_type, class c_type,
|
|
|
|
|
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
|
|
|
|
|
UMMA::Saturate c_sat>
|
|
|
|
|
@ -1534,7 +1815,7 @@ struct MMA_Traits<SM100_MMA_S8_SS<a_type, b_type, c_type,
|
|
|
|
|
using ValTypeA = a_type;
|
|
|
|
|
using ValTypeB = b_type;
|
|
|
|
|
using ValTypeC = c_type;
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 8, "SM100_MMA_S8 supports 8bit types");
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 8, "SM100_MMA_S8_SS supports 8bit types");
|
|
|
|
|
|
|
|
|
|
using FrgTypeA = UMMA::smem_desc<a_major>;
|
|
|
|
|
using FrgTypeB = UMMA::smem_desc<b_major>;
|
|
|
|
|
@ -1599,7 +1880,7 @@ struct MMA_Traits<SM100_MMA_S8_TS<a_type, b_type, c_type,
|
|
|
|
|
using ValTypeA = a_type;
|
|
|
|
|
using ValTypeB = b_type;
|
|
|
|
|
using ValTypeC = c_type;
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 8, "SM100_MMA_S8 supports 8bit types");
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 8, "SM100_MMA_S8_TS supports 8bit types");
|
|
|
|
|
|
|
|
|
|
using FrgTypeA = UMMA::tmem_frg_1sm<a_type, a_type, UMMA::TmemAllocMode::NonInterleaved>;
|
|
|
|
|
using FrgTypeB = UMMA::smem_desc<b_major>;
|
|
|
|
|
@ -1663,7 +1944,7 @@ struct MMA_Traits<SM100_MMA_S8_2x1SM_SS<a_type, b_type, c_type,
|
|
|
|
|
using ValTypeA = a_type;
|
|
|
|
|
using ValTypeB = b_type;
|
|
|
|
|
using ValTypeC = c_type;
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 8, "SM100_MMA_S8 supports 8bit types");
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 8, "SM100_MMA_S8_2x1SM_SS supports 8bit types");
|
|
|
|
|
|
|
|
|
|
using FrgTypeA = UMMA::smem_desc<a_major>;
|
|
|
|
|
using FrgTypeB = UMMA::smem_desc<b_major>;
|
|
|
|
|
@ -1728,7 +2009,7 @@ struct MMA_Traits<SM100_MMA_S8_2x1SM_TS<a_type, b_type, c_type,
|
|
|
|
|
using ValTypeA = a_type;
|
|
|
|
|
using ValTypeB = b_type;
|
|
|
|
|
using ValTypeC = c_type;
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 8, "SM100_MMA_S8 supports 8bit types");
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 8, "SM100_MMA_S8_2x1SM_TS supports 8bit types");
|
|
|
|
|
|
|
|
|
|
using FrgTypeA = UMMA::tmem_frg_2sm<a_type, a_type, UMMA::TmemAllocMode::Duplicated>;
|
|
|
|
|
using FrgTypeB = UMMA::smem_desc<b_major>;
|
|
|
|
|
@ -1795,16 +2076,18 @@ struct MMA_Traits<SM100_MMA_F8F6F4_SS, a_type, b_type, c_type,
|
|
|
|
|
using ValTypeA = a_type;
|
|
|
|
|
using ValTypeB = b_type;
|
|
|
|
|
using ValTypeC = c_type;
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> <= 8 && cute::sizeof_bits_v<b_type> <= 8, "SM100_MMA_F8F6F4 supports types with leq 8bit types");
|
|
|
|
|
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> <= 8 && cute::sizeof_bits_v<b_type> <= 8, "SM100_MMA_F8F6F4_SS supports types with leq 8bit types");
|
|
|
|
|
static_assert(M == 64 || M == 128, "SM100_MMA_F8F6F4_SS M-mode size should be 64 or 128 for 1 CTA cluster MMA.");
|
|
|
|
|
static_assert((M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) ||
|
|
|
|
|
(M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)),
|
|
|
|
|
"SM100_MMA_F8F6F4_SS N-mode size should be a multiple of 8 between 8 and 256 for M=64,\
|
|
|
|
|
or a multiple of 16 between 16 and 256 for M=128.");
|
|
|
|
|
using FrgTypeA = UMMA::smem_desc<a_major>;
|
|
|
|
|
using FrgTypeB = UMMA::smem_desc<b_major>;
|
|
|
|
|
using FrgTypeC = UMMA::tmem_frg_1sm<c_type>;
|
|
|
|
|
|
|
|
|
|
static_assert(sizeof_bits_v<ValTypeA> <= sizeof_bits_v<uint8_t> &&
|
|
|
|
|
sizeof_bits_v<ValTypeB> <= sizeof_bits_v<uint8_t>);
|
|
|
|
|
static_assert(M == 64 || M == 128, "SM100_MMA_F8F6F4 M-mode size should be 64 or 128 for 1 CTA cluster MMA.");
|
|
|
|
|
static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_F8F6F4 N-mode size should be a multiple of 8 between 8 and 256.");
|
|
|
|
|
|
|
|
|
|
// Logical shape-K is always 256bits, transform to units of elements
|
|
|
|
|
constexpr static int K = 32;
|
|
|
|
|
@ -1863,7 +2146,7 @@ struct MMA_Traits<SM100_MMA_MXF8F6F4_SS<a_type, b_type, c_type, sf_type,
|
|
|
|
|
using ValTypeC = c_type;
|
|
|
|
|
using ValTypeSFA = sf_type;
|
|
|
|
|
using ValTypeSFB = sf_type;
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> <= 8 && cute::sizeof_bits_v<b_type> <= 8, "SM100_MMA_MXF8F6F4 supports types with leq 8bit types");
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> <= 8 && cute::sizeof_bits_v<b_type> <= 8, "SM100_MMA_MXF8F6F4_SS supports types with leq 8bit types");
|
|
|
|
|
|
|
|
|
|
// Logical shape-K is always 256bits, transform to units of elements
|
|
|
|
|
constexpr static int K = 32;
|
|
|
|
|
@ -1953,7 +2236,7 @@ struct MMA_Traits<SM100_MMA_F8F6F4_TS<a_type, b_type, c_type,
|
|
|
|
|
using ValTypeA = a_type;
|
|
|
|
|
using ValTypeB = b_type;
|
|
|
|
|
using ValTypeC = c_type;
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> <= 8 && cute::sizeof_bits_v<b_type> <= 8, "SM100_MMA_F8F6F4 supports types with leq 8bit types");
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> <= 8 && cute::sizeof_bits_v<b_type> <= 8, "SM100_MMA_F8F6F4_TS supports types with leq 8bit types");
|
|
|
|
|
|
|
|
|
|
using FrgTypeA = UMMA::tmem_frg_1sm<a_type, a_type, UMMA::TmemAllocMode::NonInterleaved>;
|
|
|
|
|
using FrgTypeB = UMMA::smem_desc<b_major>;
|
|
|
|
|
@ -2023,8 +2306,10 @@ struct MMA_Traits<SM100_MMA_F8F6F4_2x1SM_SS, a_type, b_type, c_type,
|
|
|
|
|
using ValTypeA = a_type;
|
|
|
|
|
using ValTypeB = b_type;
|
|
|
|
|
using ValTypeC = c_type;
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> <= 8 && cute::sizeof_bits_v<b_type> <= 8, "SM100_MMA_F8F6F4 supports types with leq 8bit types");
|
|
|
|
|
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> <= 8 && cute::sizeof_bits_v<b_type> <= 8, "SM100_MMA_F8F6F4_2x1SM_SS supports types with leq 8bit types");
|
|
|
|
|
static_assert(M == 128 || M == 256, "SM100_MMA_F8F6F4_2x1SM_SS M-mode size should be 64 or 128 for 1 CTA cluster MMA.");
|
|
|
|
|
static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F8F6F4_2x1SM_SS N-mode size should be a multiple of 32 between 32 and 256.");
|
|
|
|
|
|
|
|
|
|
using FrgTypeA = UMMA::smem_desc<a_major>;
|
|
|
|
|
using FrgTypeB = UMMA::smem_desc<b_major>;
|
|
|
|
|
using FrgTypeC = UMMA::tmem_frg_2sm<c_type>;
|
|
|
|
|
@ -2034,9 +2319,6 @@ struct MMA_Traits<SM100_MMA_F8F6F4_2x1SM_SS, a_type, b_type, c_type,
|
|
|
|
|
// Size of instructions's K extent is always 256bits, convert to units of element
|
|
|
|
|
constexpr static int K = 32;
|
|
|
|
|
|
|
|
|
|
static_assert(M == 128 || M == 256, "MMA_F8F6F4 M-mode size should be 128 or 256 for 2 CTA cluster MMA.");
|
|
|
|
|
static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "MMA_F8F6F4 N-mode size should be a multiple of 16 between 16 and 256.");
|
|
|
|
|
|
|
|
|
|
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
|
|
|
|
|
using ThrID = Layout<_2>;
|
|
|
|
|
using ALayout = Layout<Shape < _2,Shape <Int<M/2>,Int<K>>>,
|
|
|
|
|
@ -2090,7 +2372,7 @@ struct MMA_Traits<SM100_MMA_F8F6F4_2x1SM_TS<a_type, b_type, c_type,
|
|
|
|
|
using ValTypeA = a_type;
|
|
|
|
|
using ValTypeB = b_type;
|
|
|
|
|
using ValTypeC = c_type;
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> <= 8 && cute::sizeof_bits_v<b_type> <= 8, "SM100_MMA_F8F6F4 supports types with leq 8bit types");
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> <= 8 && cute::sizeof_bits_v<b_type> <= 8, "SM100_MMA_F8F6F4_2x1SM_TS supports types with leq 8bit types");
|
|
|
|
|
|
|
|
|
|
using FrgTypeA = UMMA::tmem_frg_2sm<a_type, a_type, UMMA::TmemAllocMode::Duplicated>;
|
|
|
|
|
using FrgTypeB = UMMA::smem_desc<b_major>;
|
|
|
|
|
@ -2159,7 +2441,7 @@ struct MMA_Traits<SM100_MMA_MXF8F6F4_2x1SM_SS<a_type, b_type, c_type, sf_type,
|
|
|
|
|
using ValTypeC = c_type;
|
|
|
|
|
using ValTypeSFA = sf_type;
|
|
|
|
|
using ValTypeSFB = sf_type;
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> <= 8 && cute::sizeof_bits_v<b_type> <= 8, "SM100_MMA_F8F6F4 supports types with leq 8bit types");
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> <= 8 && cute::sizeof_bits_v<b_type> <= 8, "SM100_MMA_MXF8F6F4_2x1SM_SS supports types with leq 8bit types");
|
|
|
|
|
|
|
|
|
|
using FrgTypeA = UMMA::smem_desc<a_major>;
|
|
|
|
|
using FrgTypeB = UMMA::smem_desc<b_major>;
|
|
|
|
|
@ -2252,7 +2534,7 @@ struct MMA_Traits<SM100_MMA_MXF4_SS<a_type, b_type, c_type, sf_type,
|
|
|
|
|
using ValTypeC = c_type;
|
|
|
|
|
using ValTypeSFA = sf_type;
|
|
|
|
|
using ValTypeSFB = sf_type;
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == 4 && cute::sizeof_bits_v<b_type> == 4, "SM100_MMA_MXF4 supports 4bit types");
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == 4 && cute::sizeof_bits_v<b_type> == 4, "SM100_MMA_MXF4_SS supports 4bit types");
|
|
|
|
|
|
|
|
|
|
// Logical shape-K is always 256bits, transform to units of elements
|
|
|
|
|
constexpr static int K = 64;
|
|
|
|
|
@ -2345,7 +2627,7 @@ struct MMA_Traits<SM100_MMA_MXF4_2x1SM_SS<a_type, b_type, c_type, sf_type,
|
|
|
|
|
using ValTypeC = c_type;
|
|
|
|
|
using ValTypeSFA = sf_type;
|
|
|
|
|
using ValTypeSFB = sf_type;
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == 4 && cute::sizeof_bits_v<b_type> == 4, "SM100_MMA_MXF4 supports 4bit types");
|
|
|
|
|
static_assert(cute::sizeof_bits_v<a_type> == 4 && cute::sizeof_bits_v<b_type> == 4, "SM100_MMA_MXF4_2x1SM_SS supports 4bit types");
|
|
|
|
|
|
|
|
|
|
// Logical shape-K is always 256bits, transform to units of elements
|
|
|
|
|
constexpr static int K = 64;
|
|
|
|
|
|