v4.2.1 update. (#2666)

This commit is contained in:
Junkai-Wu
2025-09-24 01:25:43 +08:00
committed by GitHub
parent 2b8dff1f90
commit 7a6d4ee099
12 changed files with 163 additions and 65 deletions

View File

@ -69,6 +69,7 @@ template<cute::UMMA::Major SFAMajor,
int ScaleGranularityN,
int ScaleGranularityK,
bool Is2SM,
bool NoSmemEpilogue,
class LayoutA,
class LayoutB,
class LayoutCD,
@ -77,8 +78,10 @@ template<cute::UMMA::Major SFAMajor,
bool groupwise_test(
Int<ScaleGranularityM>, Int<ScaleGranularityN>, Int<ScaleGranularityK>, C<Is2SM>,
LayoutA, LayoutB, LayoutCD,
MmaTileShape, ClusterShape) {
MmaTileShape, ClusterShape,
C<NoSmemEpilogue>) {
using Epilogue1SM = conditional_t<NoSmemEpilogue, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm, cutlass::epilogue::TmaWarpSpecialized1Sm>;
using Epilogue2SM = conditional_t<NoSmemEpilogue, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized2Sm, cutlass::epilogue::TmaWarpSpecialized2Sm>;
using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, SFAMajor, SFBMajor>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
@ -90,7 +93,7 @@ bool groupwise_test(
float, float,
cutlass::float_e4m3_t, LayoutCD, 16,
cutlass::float_e4m3_t, LayoutCD, 16,
conditional_t<Is2SM, cutlass::epilogue::TmaWarpSpecialized2Sm, cutlass::epilogue::TmaWarpSpecialized1Sm>
conditional_t<Is2SM, Epilogue2SM, Epilogue1SM>
>::CollectiveOp;
using CollectiveMainloop =
@ -259,11 +262,26 @@ TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_1sm_f32_align16_blockwise, 128
cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{},
cutlass::layout::RowMajor{},
Shape<_128,_128,_128>{},
Shape<_1,_1,_1>{});
Shape<_1,_1,_1>{},
false_type{});
EXPECT_TRUE(passed);
}
TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_1sm_f32_align16_blockwise, 128x128x128_1x1x1_2x2x32_scale_direct_store) {
bool passed = groupwise_test<UMMA::Major::MN, UMMA::Major::K>(
Int<2>{}, Int<2>{}, Int<32>{}, false_type{},
cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{},
cutlass::layout::RowMajor{},
Shape<_128,_128,_128>{},
Shape<_1,_1,_1>{},
true_type{});
EXPECT_TRUE(passed);
}
TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_2sm_f32_align16_blockwise, 256x128x128_2x1x1_64x4x32_scale) {
bool passed = groupwise_test<UMMA::Major::MN, UMMA::Major::MN>(
@ -271,7 +289,8 @@ TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_2sm_f32_align16_blockwise, 256
cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{},
cutlass::layout::RowMajor{},
Shape<_256,_128,_128>{},
Shape<_2,_1,_1>{});
Shape<_2,_1,_1>{},
false_type{});
EXPECT_TRUE(passed);
@ -284,7 +303,8 @@ TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_1sm_f32_align16_blockwise, 128
cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{},
cutlass::layout::RowMajor{},
Shape<_128,_128,_128>{},
Shape<_1,_1,_1>{});
Shape<_1,_1,_1>{},
false_type{});
EXPECT_TRUE(passed);
@ -297,7 +317,8 @@ TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_2sm_f32_align16_blockwise, 256
cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{},
cutlass::layout::RowMajor{},
Shape<_256,_128,_128>{},
Shape<_2,_1,_1>{});
Shape<_2,_1,_1>{},
false_type{});
EXPECT_TRUE(passed);
@ -311,7 +332,8 @@ TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_2sm_f32_align16_blockwise, 256
cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{},
cutlass::layout::RowMajor{},
Shape<_256,_128,_128>{},
Shape<_2,_1,_1>{});
Shape<_2,_1,_1>{},
false_type{});
EXPECT_TRUE(passed);