support fp16 accmulator for sm89 fp8 mma (#2378)

* add support for sm89 in cute and the unit tests

* support fp16 accmulator for sm89 fp8 mma

* format code
This commit is contained in:
kf-zhang
2025-07-31 10:12:08 +08:00
committed by GitHub
parent a39cf6b511
commit 26b7450023
3 changed files with 230 additions and 6 deletions

View File

@ -177,4 +177,119 @@ struct SM89_16x8x32_F32E5M2E4M3F32_TN
}
};
struct SM89_16x8x32_F16E4M3E4M3F16_TN
{
using DRegisters = uint32_t[2];
using ARegisters = uint32_t[4];
using BRegisters = uint32_t[2];
using CRegisters = uint32_t[2];
CUTE_HOST_DEVICE static void
fma(uint32_t & d0, uint32_t & d1,
uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
uint32_t const& b0, uint32_t const& b1,
uint32_t const& c0, uint32_t const& c1)
{
#if defined(CUTE_ARCH_MMA_F16_SM89_ENABLED)
asm(
"mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e4m3.f16 "
"{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
: "=r"(d0), "=r"(d1)
:
"r"(a0), "r"(a1), "r"(a2), "r"(a3),
"r"(b0), "r"(b1),
"r"(c0), "r"(c1)
);
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F16E4M3E4M3F16_TN without CUTE_ARCH_MMA_F16_SM89_ENABLED");
#endif
}
};
struct SM89_16x8x32_F16E4M3E5M2F16_TN
{
using DRegisters = uint32_t[2];
using ARegisters = uint32_t[4];
using BRegisters = uint32_t[2];
using CRegisters = uint32_t[2];
CUTE_HOST_DEVICE static void
fma(uint32_t & d0, uint32_t & d1,
uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
uint32_t const& b0, uint32_t const& b1,
uint32_t const& c0, uint32_t const& c1)
{
#if defined(CUTE_ARCH_MMA_F16_SM89_ENABLED)
asm(
"mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e5m2.f16 "
"{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
: "=r"(d0), "=r"(d1)
:
"r"(a0), "r"(a1), "r"(a2), "r"(a3),
"r"(b0), "r"(b1),
"r"(c0), "r"(c1)
);
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F16E4M3E5M2F16_TN without CUTE_ARCH_MMA_F16_SM89_ENABLED");
#endif
}
};
struct SM89_16x8x32_F16E5M2E4M3F16_TN
{
using DRegisters = uint32_t[2];
using ARegisters = uint32_t[4];
using BRegisters = uint32_t[2];
using CRegisters = uint32_t[2];
CUTE_HOST_DEVICE static void
fma(uint32_t & d0, uint32_t & d1,
uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
uint32_t const& b0, uint32_t const& b1,
uint32_t const& c0, uint32_t const& c1)
{
#if defined(CUTE_ARCH_MMA_F16_SM89_ENABLED)
asm(
"mma.sync.aligned.m16n8k32.row.col.f16.e5m2.e4m3.f16 "
"{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
: "=r"(d0), "=r"(d1)
:
"r"(a0), "r"(a1), "r"(a2), "r"(a3),
"r"(b0), "r"(b1),
"r"(c0), "r"(c1)
);
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F16E5M2E4M3F16_TN without CUTE_ARCH_MMA_F16_SM89_ENABLED");
#endif
}
};
struct SM89_16x8x32_F16E5M2E5M2F16_TN
{
using DRegisters = uint32_t[2];
using ARegisters = uint32_t[4];
using BRegisters = uint32_t[2];
using CRegisters = uint32_t[2];
CUTE_HOST_DEVICE static void
fma(uint32_t & d0, uint32_t & d1,
uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
uint32_t const& b0, uint32_t const& b1,
uint32_t const& c0, uint32_t const& c1)
{
#if defined(CUTE_ARCH_MMA_F16_SM89_ENABLED)
asm(
"mma.sync.aligned.m16n8k32.row.col.f16.e5m2.e5m2.f16 "
"{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
: "=r"(d0), "=r"(d1)
:
"r"(a0), "r"(a1), "r"(a2), "r"(a3),
"r"(b0), "r"(b1),
"r"(c0), "r"(c1)
);
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F16E5M2E5M2F16_TN without CUTE_ARCH_MMA_F16_SM89_ENABLED");
#endif
}
};
} // namespace cute

View File

@ -67,8 +67,8 @@ struct MMA_Traits<SM89_16x8x32_F32E4M3E4M3F32_TN> {
};
template <>
struct MMA_Traits<SM89_16x8x32_F32E4M3E5M2F32_TN> :
MMA_Traits<SM89_16x8x32_F32E4M3E4M3F32_TN> {
struct MMA_Traits<SM89_16x8x32_F32E4M3E5M2F32_TN>
: MMA_Traits<SM89_16x8x32_F32E4M3E4M3F32_TN> {
using ValTypeD = float;
using ValTypeA = float_e4m3_t;
using ValTypeB = float_e5m2_t;
@ -76,8 +76,8 @@ MMA_Traits<SM89_16x8x32_F32E4M3E4M3F32_TN> {
};
template <>
struct MMA_Traits<SM89_16x8x32_F32E5M2E5M2F32_TN> :
MMA_Traits<SM89_16x8x32_F32E4M3E4M3F32_TN> {
struct MMA_Traits<SM89_16x8x32_F32E5M2E5M2F32_TN>
: MMA_Traits<SM89_16x8x32_F32E4M3E4M3F32_TN> {
using ValTypeD = float;
using ValTypeA = float_e5m2_t;
using ValTypeB = float_e5m2_t;
@ -85,12 +85,49 @@ MMA_Traits<SM89_16x8x32_F32E4M3E4M3F32_TN> {
};
template <>
struct MMA_Traits<SM89_16x8x32_F32E5M2E4M3F32_TN> :
MMA_Traits<SM89_16x8x32_F32E4M3E4M3F32_TN> {
struct MMA_Traits<SM89_16x8x32_F32E5M2E4M3F32_TN>
: MMA_Traits<SM89_16x8x32_F32E4M3E4M3F32_TN> {
using ValTypeD = float;
using ValTypeA = float_e5m2_t;
using ValTypeB = float_e4m3_t;
using ValTypeC = float;
};
template <>
struct MMA_Traits<SM89_16x8x32_F16E4M3E4M3F16_TN>
: MMA_Traits<SM89_16x8x32_F32E4M3E4M3F32_TN> {
using ValTypeD = cutlass::half_t;
using ValTypeA = cutlass::float_e4m3_t;
using ValTypeB = cutlass::float_e4m3_t;
using ValTypeC = cutlass::half_t;
};
template <>
struct MMA_Traits<SM89_16x8x32_F16E4M3E5M2F16_TN>
: MMA_Traits<SM89_16x8x32_F32E4M3E4M3F32_TN> {
using ValTypeD = cutlass::half_t;
using ValTypeA = cutlass::float_e4m3_t;
using ValTypeB = cutlass::float_e5m2_t;
using ValTypeC = cutlass::half_t;
};
template <>
struct MMA_Traits<SM89_16x8x32_F16E5M2E5M2F16_TN>
: MMA_Traits<SM89_16x8x32_F32E4M3E4M3F32_TN> {
using ValTypeD = cutlass::half_t;
using ValTypeA = cutlass::float_e5m2_t;
using ValTypeB = cutlass::float_e5m2_t;
using ValTypeC = cutlass::half_t;
};
template <>
struct MMA_Traits<SM89_16x8x32_F16E5M2E4M3F16_TN>
: MMA_Traits<SM89_16x8x32_F32E4M3E4M3F32_TN> {
using ValTypeD = cutlass::half_t;
using ValTypeA = cutlass::float_e5m2_t;
using ValTypeB = cutlass::float_e4m3_t;
using ValTypeC = cutlass::half_t;
};
} // end namespace cute

View File

@ -608,3 +608,75 @@ TEST(SM89_CuTe_Ada, CooperativeGemm_e5m2e5m2f32_MMA) {
test_cooperative_gemm_col_major_layout<thread_block_size, MaxVecBits, TA, TB, TC>(shape_mnk, tiled_mma);
}
TEST(SM89_CuTe_Ada, CooperativeGemm_e4m3e4m3f16_MMA) {
using TA = cutlass::float_e4m3_t;
using TB = cutlass::float_e4m3_t;
using TC = cute::half_t;
constexpr uint32_t thread_block_size = 128;
constexpr int MaxVecBits = 128;
auto shape_mnk = Shape<_64, _64, _64>{};
auto tiled_mma =
TiledMMA<
MMA_Atom<SM89_16x8x32_F16E4M3E4M3F16_TN>,
Layout<Shape<_2, _2, _1>>
>{};
test_cooperative_gemm_col_major_layout<thread_block_size, MaxVecBits, TA, TB, TC>(shape_mnk, tiled_mma);
}
TEST(SM89_CuTe_Ada, CooperativeGemm_e4m3e5m2f16_MMA) {
using TA = cutlass::float_e4m3_t;
using TB = cutlass::float_e5m2_t;
using TC = cute::half_t;
constexpr uint32_t thread_block_size = 128;
constexpr int MaxVecBits = 128;
auto shape_mnk = Shape<_64, _64, _64>{};
auto tiled_mma =
TiledMMA<
MMA_Atom<SM89_16x8x32_F16E4M3E5M2F16_TN>,
Layout<Shape<_2, _2, _1>>
>{};
test_cooperative_gemm_col_major_layout<thread_block_size, MaxVecBits, TA, TB, TC>(shape_mnk, tiled_mma);
}
TEST(SM89_CuTe_Ada, CooperativeGemm_e5m2e4m3f16_MMA) {
using TA = cutlass::float_e5m2_t;
using TB = cutlass::float_e4m3_t;
using TC = cute::half_t;
constexpr uint32_t thread_block_size = 128;
constexpr int MaxVecBits = 128;
auto shape_mnk = Shape<_64, _64, _64>{};
auto tiled_mma =
TiledMMA<
MMA_Atom<SM89_16x8x32_F16E5M2E4M3F16_TN>,
Layout<Shape<_2, _2, _1>>
>{};
test_cooperative_gemm_col_major_layout<thread_block_size, MaxVecBits, TA, TB, TC>(shape_mnk, tiled_mma);
}
TEST(SM89_CuTe_Ada, CooperativeGemm_e5m2e5m2f16_MMA) {
using TA = cutlass::float_e5m2_t;
using TB = cutlass::float_e5m2_t;
using TC = cute::half_t;
constexpr uint32_t thread_block_size = 128;
constexpr int MaxVecBits = 128;
auto shape_mnk = Shape<_64, _64, _64>{};
auto tiled_mma =
TiledMMA<
MMA_Atom<SM89_16x8x32_F16E5M2E5M2F16_TN>,
Layout<Shape<_2, _2, _1>>
>{};
test_cooperative_gemm_col_major_layout<thread_block_size, MaxVecBits, TA, TB, TC>(shape_mnk, tiled_mma);
}