support fp16 accmulator for sm89 fp8 mma (#2378)
* add support for sm89 in cute and the unit tests * support fp16 accmulator for sm89 fp8 mma * format code
This commit is contained in:
@ -177,4 +177,119 @@ struct SM89_16x8x32_F32E5M2E4M3F32_TN
|
||||
}
|
||||
};
|
||||
|
||||
struct SM89_16x8x32_F16E4M3E4M3F16_TN
|
||||
{
|
||||
using DRegisters = uint32_t[2];
|
||||
using ARegisters = uint32_t[4];
|
||||
using BRegisters = uint32_t[2];
|
||||
using CRegisters = uint32_t[2];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(uint32_t & d0, uint32_t & d1,
|
||||
uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
||||
uint32_t const& b0, uint32_t const& b1,
|
||||
uint32_t const& c0, uint32_t const& c1)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_F16_SM89_ENABLED)
|
||||
asm(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e4m3.f16 "
|
||||
"{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
|
||||
: "=r"(d0), "=r"(d1)
|
||||
:
|
||||
"r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
||||
"r"(b0), "r"(b1),
|
||||
"r"(c0), "r"(c1)
|
||||
);
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F16E4M3E4M3F16_TN without CUTE_ARCH_MMA_F16_SM89_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM89_16x8x32_F16E4M3E5M2F16_TN
|
||||
{
|
||||
using DRegisters = uint32_t[2];
|
||||
using ARegisters = uint32_t[4];
|
||||
using BRegisters = uint32_t[2];
|
||||
using CRegisters = uint32_t[2];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(uint32_t & d0, uint32_t & d1,
|
||||
uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
||||
uint32_t const& b0, uint32_t const& b1,
|
||||
uint32_t const& c0, uint32_t const& c1)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_F16_SM89_ENABLED)
|
||||
asm(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e5m2.f16 "
|
||||
"{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
|
||||
: "=r"(d0), "=r"(d1)
|
||||
:
|
||||
"r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
||||
"r"(b0), "r"(b1),
|
||||
"r"(c0), "r"(c1)
|
||||
);
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F16E4M3E5M2F16_TN without CUTE_ARCH_MMA_F16_SM89_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM89_16x8x32_F16E5M2E4M3F16_TN
|
||||
{
|
||||
using DRegisters = uint32_t[2];
|
||||
using ARegisters = uint32_t[4];
|
||||
using BRegisters = uint32_t[2];
|
||||
using CRegisters = uint32_t[2];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(uint32_t & d0, uint32_t & d1,
|
||||
uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
||||
uint32_t const& b0, uint32_t const& b1,
|
||||
uint32_t const& c0, uint32_t const& c1)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_F16_SM89_ENABLED)
|
||||
asm(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f16.e5m2.e4m3.f16 "
|
||||
"{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
|
||||
: "=r"(d0), "=r"(d1)
|
||||
:
|
||||
"r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
||||
"r"(b0), "r"(b1),
|
||||
"r"(c0), "r"(c1)
|
||||
);
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F16E5M2E4M3F16_TN without CUTE_ARCH_MMA_F16_SM89_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM89_16x8x32_F16E5M2E5M2F16_TN
|
||||
{
|
||||
using DRegisters = uint32_t[2];
|
||||
using ARegisters = uint32_t[4];
|
||||
using BRegisters = uint32_t[2];
|
||||
using CRegisters = uint32_t[2];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(uint32_t & d0, uint32_t & d1,
|
||||
uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
||||
uint32_t const& b0, uint32_t const& b1,
|
||||
uint32_t const& c0, uint32_t const& c1)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_F16_SM89_ENABLED)
|
||||
asm(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f16.e5m2.e5m2.f16 "
|
||||
"{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
|
||||
: "=r"(d0), "=r"(d1)
|
||||
:
|
||||
"r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
||||
"r"(b0), "r"(b1),
|
||||
"r"(c0), "r"(c1)
|
||||
);
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F16E5M2E5M2F16_TN without CUTE_ARCH_MMA_F16_SM89_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
} // namespace cute
|
||||
|
||||
@ -67,8 +67,8 @@ struct MMA_Traits<SM89_16x8x32_F32E4M3E4M3F32_TN> {
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM89_16x8x32_F32E4M3E5M2F32_TN> :
|
||||
MMA_Traits<SM89_16x8x32_F32E4M3E4M3F32_TN> {
|
||||
struct MMA_Traits<SM89_16x8x32_F32E4M3E5M2F32_TN>
|
||||
: MMA_Traits<SM89_16x8x32_F32E4M3E4M3F32_TN> {
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = float_e4m3_t;
|
||||
using ValTypeB = float_e5m2_t;
|
||||
@ -76,8 +76,8 @@ MMA_Traits<SM89_16x8x32_F32E4M3E4M3F32_TN> {
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM89_16x8x32_F32E5M2E5M2F32_TN> :
|
||||
MMA_Traits<SM89_16x8x32_F32E4M3E4M3F32_TN> {
|
||||
struct MMA_Traits<SM89_16x8x32_F32E5M2E5M2F32_TN>
|
||||
: MMA_Traits<SM89_16x8x32_F32E4M3E4M3F32_TN> {
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = float_e5m2_t;
|
||||
using ValTypeB = float_e5m2_t;
|
||||
@ -85,12 +85,49 @@ MMA_Traits<SM89_16x8x32_F32E4M3E4M3F32_TN> {
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM89_16x8x32_F32E5M2E4M3F32_TN> :
|
||||
MMA_Traits<SM89_16x8x32_F32E4M3E4M3F32_TN> {
|
||||
struct MMA_Traits<SM89_16x8x32_F32E5M2E4M3F32_TN>
|
||||
: MMA_Traits<SM89_16x8x32_F32E4M3E4M3F32_TN> {
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = float_e5m2_t;
|
||||
using ValTypeB = float_e4m3_t;
|
||||
using ValTypeC = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM89_16x8x32_F16E4M3E4M3F16_TN>
|
||||
: MMA_Traits<SM89_16x8x32_F32E4M3E4M3F32_TN> {
|
||||
using ValTypeD = cutlass::half_t;
|
||||
using ValTypeA = cutlass::float_e4m3_t;
|
||||
using ValTypeB = cutlass::float_e4m3_t;
|
||||
using ValTypeC = cutlass::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM89_16x8x32_F16E4M3E5M2F16_TN>
|
||||
: MMA_Traits<SM89_16x8x32_F32E4M3E4M3F32_TN> {
|
||||
using ValTypeD = cutlass::half_t;
|
||||
using ValTypeA = cutlass::float_e4m3_t;
|
||||
using ValTypeB = cutlass::float_e5m2_t;
|
||||
using ValTypeC = cutlass::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM89_16x8x32_F16E5M2E5M2F16_TN>
|
||||
: MMA_Traits<SM89_16x8x32_F32E4M3E4M3F32_TN> {
|
||||
using ValTypeD = cutlass::half_t;
|
||||
using ValTypeA = cutlass::float_e5m2_t;
|
||||
using ValTypeB = cutlass::float_e5m2_t;
|
||||
using ValTypeC = cutlass::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM89_16x8x32_F16E5M2E4M3F16_TN>
|
||||
: MMA_Traits<SM89_16x8x32_F32E4M3E4M3F32_TN> {
|
||||
using ValTypeD = cutlass::half_t;
|
||||
using ValTypeA = cutlass::float_e5m2_t;
|
||||
using ValTypeB = cutlass::float_e4m3_t;
|
||||
using ValTypeC = cutlass::half_t;
|
||||
};
|
||||
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
@ -608,3 +608,75 @@ TEST(SM89_CuTe_Ada, CooperativeGemm_e5m2e5m2f32_MMA) {
|
||||
|
||||
test_cooperative_gemm_col_major_layout<thread_block_size, MaxVecBits, TA, TB, TC>(shape_mnk, tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM89_CuTe_Ada, CooperativeGemm_e4m3e4m3f16_MMA) {
|
||||
using TA = cutlass::float_e4m3_t;
|
||||
using TB = cutlass::float_e4m3_t;
|
||||
using TC = cute::half_t;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
constexpr int MaxVecBits = 128;
|
||||
|
||||
auto shape_mnk = Shape<_64, _64, _64>{};
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<SM89_16x8x32_F16E4M3E4M3F16_TN>,
|
||||
Layout<Shape<_2, _2, _1>>
|
||||
>{};
|
||||
|
||||
test_cooperative_gemm_col_major_layout<thread_block_size, MaxVecBits, TA, TB, TC>(shape_mnk, tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM89_CuTe_Ada, CooperativeGemm_e4m3e5m2f16_MMA) {
|
||||
using TA = cutlass::float_e4m3_t;
|
||||
using TB = cutlass::float_e5m2_t;
|
||||
using TC = cute::half_t;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
constexpr int MaxVecBits = 128;
|
||||
|
||||
auto shape_mnk = Shape<_64, _64, _64>{};
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<SM89_16x8x32_F16E4M3E5M2F16_TN>,
|
||||
Layout<Shape<_2, _2, _1>>
|
||||
>{};
|
||||
|
||||
test_cooperative_gemm_col_major_layout<thread_block_size, MaxVecBits, TA, TB, TC>(shape_mnk, tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM89_CuTe_Ada, CooperativeGemm_e5m2e4m3f16_MMA) {
|
||||
using TA = cutlass::float_e5m2_t;
|
||||
using TB = cutlass::float_e4m3_t;
|
||||
using TC = cute::half_t;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
constexpr int MaxVecBits = 128;
|
||||
|
||||
auto shape_mnk = Shape<_64, _64, _64>{};
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<SM89_16x8x32_F16E5M2E4M3F16_TN>,
|
||||
Layout<Shape<_2, _2, _1>>
|
||||
>{};
|
||||
|
||||
test_cooperative_gemm_col_major_layout<thread_block_size, MaxVecBits, TA, TB, TC>(shape_mnk, tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM89_CuTe_Ada, CooperativeGemm_e5m2e5m2f16_MMA) {
|
||||
using TA = cutlass::float_e5m2_t;
|
||||
using TB = cutlass::float_e5m2_t;
|
||||
using TC = cute::half_t;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
constexpr int MaxVecBits = 128;
|
||||
|
||||
auto shape_mnk = Shape<_64, _64, _64>{};
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<SM89_16x8x32_F16E5M2E5M2F16_TN>,
|
||||
Layout<Shape<_2, _2, _1>>
|
||||
>{};
|
||||
|
||||
test_cooperative_gemm_col_major_layout<thread_block_size, MaxVecBits, TA, TB, TC>(shape_mnk, tiled_mma);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user