add support for sm89 in cute and the unit tests (#2177)
* add support for sm89 in cute and the unit tests * rebase v3.9 and format code * minor fix --------- Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
@ -536,3 +536,75 @@ TEST(SM80_CuTe_Ampere, CooperativeGemmLDSMx2) {
|
||||
SM75_U32x4_LDSM_N{},
|
||||
SM75_U32x2_LDSM_N{});
|
||||
}
|
||||
|
||||
TEST(SM89_CuTe_Ada, CooperativeGemm_e4m3e4m3f32_MMA) {
|
||||
using TA = cutlass::float_e4m3_t;
|
||||
using TB = cutlass::float_e4m3_t;
|
||||
using TC = float;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
constexpr int MaxVecBits = 128;
|
||||
|
||||
auto shape_mnk = Shape<_64, _64, _64>{};
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<SM89_16x8x32_F32E4M3E4M3F32_TN>,
|
||||
Layout<Shape<_2, _2, _1>>
|
||||
>{};
|
||||
|
||||
test_cooperative_gemm_col_major_layout<thread_block_size, MaxVecBits, TA, TB, TC>(shape_mnk, tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM89_CuTe_Ada, CooperativeGemm_e4m3e5m2f32_MMA) {
|
||||
using TA = cutlass::float_e4m3_t;
|
||||
using TB = cutlass::float_e5m2_t;
|
||||
using TC = float;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
constexpr int MaxVecBits = 128;
|
||||
|
||||
auto shape_mnk = Shape<_64, _64, _64>{};
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<SM89_16x8x32_F32E4M3E5M2F32_TN>,
|
||||
Layout<Shape<_2, _2, _1>>
|
||||
>{};
|
||||
|
||||
test_cooperative_gemm_col_major_layout<thread_block_size, MaxVecBits, TA, TB, TC>(shape_mnk, tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM89_CuTe_Ada, CooperativeGemm_e5m2e4m3f32_MMA) {
|
||||
using TA = cutlass::float_e5m2_t;
|
||||
using TB = cutlass::float_e4m3_t;
|
||||
using TC = float;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
constexpr int MaxVecBits = 128;
|
||||
|
||||
auto shape_mnk = Shape<_64, _64, _64>{};
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<SM89_16x8x32_F32E5M2E4M3F32_TN>,
|
||||
Layout<Shape<_2, _2, _1>>
|
||||
>{};
|
||||
|
||||
test_cooperative_gemm_col_major_layout<thread_block_size, MaxVecBits, TA, TB, TC>(shape_mnk, tiled_mma);
|
||||
}
|
||||
|
||||
TEST(SM89_CuTe_Ada, CooperativeGemm_e5m2e5m2f32_MMA) {
|
||||
using TA = cutlass::float_e5m2_t;
|
||||
using TB = cutlass::float_e5m2_t;
|
||||
using TC = float;
|
||||
|
||||
constexpr uint32_t thread_block_size = 128;
|
||||
constexpr int MaxVecBits = 128;
|
||||
|
||||
auto shape_mnk = Shape<_64, _64, _64>{};
|
||||
auto tiled_mma =
|
||||
TiledMMA<
|
||||
MMA_Atom<SM89_16x8x32_F32E5M2E5M2F32_TN>,
|
||||
Layout<Shape<_2, _2, _1>>
|
||||
>{};
|
||||
|
||||
test_cooperative_gemm_col_major_layout<thread_block_size, MaxVecBits, TA, TB, TC>(shape_mnk, tiled_mma);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user