diff --git a/include/cute/arch/mma_sm89.hpp b/include/cute/arch/mma_sm89.hpp index 85d7bb64..e810ce42 100644 --- a/include/cute/arch/mma_sm89.hpp +++ b/include/cute/arch/mma_sm89.hpp @@ -177,4 +177,119 @@ struct SM89_16x8x32_F32E5M2E4M3F32_TN } }; +struct SM89_16x8x32_F16E4M3E4M3F16_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_F16_SM89_ENABLED) + asm( + "mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e4m3.f16 " + "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" + : "=r"(d0), "=r"(d1) + : + "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1) + ); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F16E4M3E4M3F16_TN without CUTE_ARCH_MMA_F16_SM89_ENABLED"); +#endif + } +}; + +struct SM89_16x8x32_F16E4M3E5M2F16_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_F16_SM89_ENABLED) + asm( + "mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e5m2.f16 " + "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" + : "=r"(d0), "=r"(d1) + : + "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1) + ); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F16E4M3E5M2F16_TN without CUTE_ARCH_MMA_F16_SM89_ENABLED"); +#endif + } +}; + +struct SM89_16x8x32_F16E5M2E4M3F16_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_F16_SM89_ENABLED) + asm( + "mma.sync.aligned.m16n8k32.row.col.f16.e5m2.e4m3.f16 " + "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" + : "=r"(d0), "=r"(d1) + : + "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1) + ); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F16E5M2E4M3F16_TN without CUTE_ARCH_MMA_F16_SM89_ENABLED"); +#endif + } +}; + +struct SM89_16x8x32_F16E5M2E5M2F16_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_F16_SM89_ENABLED) + asm( + "mma.sync.aligned.m16n8k32.row.col.f16.e5m2.e5m2.f16 " + "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" + : "=r"(d0), "=r"(d1) + : + "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1) + ); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F16E5M2E5M2F16_TN without CUTE_ARCH_MMA_F16_SM89_ENABLED"); +#endif + } +}; } // namespace cute diff --git a/include/cute/atom/mma_traits_sm89.hpp b/include/cute/atom/mma_traits_sm89.hpp index 35ad436e..d438fd3c 100644 --- a/include/cute/atom/mma_traits_sm89.hpp +++ b/include/cute/atom/mma_traits_sm89.hpp @@ -67,8 +67,8 @@ struct MMA_Traits { }; template <> -struct MMA_Traits : -MMA_Traits { +struct MMA_Traits + : MMA_Traits { using ValTypeD = float; using ValTypeA = float_e4m3_t; using ValTypeB = float_e5m2_t; @@ -76,8 +76,8 @@ MMA_Traits { }; template <> -struct MMA_Traits : -MMA_Traits { +struct MMA_Traits + : MMA_Traits { using ValTypeD = float; using ValTypeA = float_e5m2_t; using ValTypeB = float_e5m2_t; @@ -85,12 +85,49 @@ MMA_Traits { }; template <> -struct MMA_Traits : -MMA_Traits { +struct MMA_Traits + : MMA_Traits { using ValTypeD = float; using ValTypeA = float_e5m2_t; using ValTypeB = float_e4m3_t; using ValTypeC = float; }; +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = cutlass::half_t; + using ValTypeA = cutlass::float_e4m3_t; + using ValTypeB = cutlass::float_e4m3_t; + using ValTypeC = cutlass::half_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = cutlass::half_t; + using ValTypeA = cutlass::float_e4m3_t; + using ValTypeB = cutlass::float_e5m2_t; + using ValTypeC = cutlass::half_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = cutlass::half_t; + using ValTypeA = cutlass::float_e5m2_t; + using ValTypeB = cutlass::float_e5m2_t; + using ValTypeC = cutlass::half_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = cutlass::half_t; + using ValTypeA = cutlass::float_e5m2_t; + using ValTypeB = cutlass::float_e4m3_t; + using ValTypeC = cutlass::half_t; +}; + + } // end namespace cute diff --git a/test/unit/cute/ampere/cooperative_gemm.cu b/test/unit/cute/ampere/cooperative_gemm.cu index 48d6e36b..a956500c 100644 --- a/test/unit/cute/ampere/cooperative_gemm.cu +++ b/test/unit/cute/ampere/cooperative_gemm.cu @@ -608,3 +608,75 @@ TEST(SM89_CuTe_Ada, CooperativeGemm_e5m2e5m2f32_MMA) { test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); } + +TEST(SM89_CuTe_Ada, CooperativeGemm_e4m3e4m3f16_MMA) { + using TA = cutlass::float_e4m3_t; + using TB = cutlass::float_e4m3_t; + using TC = cute::half_t; + + constexpr uint32_t thread_block_size = 128; + constexpr int MaxVecBits = 128; + + auto shape_mnk = Shape<_64, _64, _64>{}; + auto tiled_mma = + TiledMMA< + MMA_Atom, + Layout> + >{}; + + test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); +} + +TEST(SM89_CuTe_Ada, CooperativeGemm_e4m3e5m2f16_MMA) { + using TA = cutlass::float_e4m3_t; + using TB = cutlass::float_e5m2_t; + using TC = cute::half_t; + + constexpr uint32_t thread_block_size = 128; + constexpr int MaxVecBits = 128; + + auto shape_mnk = Shape<_64, _64, _64>{}; + auto tiled_mma = + TiledMMA< + MMA_Atom, + Layout> + >{}; + + test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); +} + +TEST(SM89_CuTe_Ada, CooperativeGemm_e5m2e4m3f16_MMA) { + using TA = cutlass::float_e5m2_t; + using TB = cutlass::float_e4m3_t; + using TC = cute::half_t; + + constexpr uint32_t thread_block_size = 128; + constexpr int MaxVecBits = 128; + + auto shape_mnk = Shape<_64, _64, _64>{}; + auto tiled_mma = + TiledMMA< + MMA_Atom, + Layout> + >{}; + + test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); +} + +TEST(SM89_CuTe_Ada, CooperativeGemm_e5m2e5m2f16_MMA) { + using TA = cutlass::float_e5m2_t; + using TB = cutlass::float_e5m2_t; + using TC = cute::half_t; + + constexpr uint32_t thread_block_size = 128; + constexpr int MaxVecBits = 128; + + auto shape_mnk = Shape<_64, _64, _64>{}; + auto tiled_mma = + TiledMMA< + MMA_Atom, + Layout> + >{}; + + test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); +}