v4.2.1 update. (#2666)
This commit is contained in:
@ -69,6 +69,7 @@ template<cute::UMMA::Major SFAMajor,
|
||||
int ScaleGranularityN,
|
||||
int ScaleGranularityK,
|
||||
bool Is2SM,
|
||||
bool NoSmemEpilogue,
|
||||
class LayoutA,
|
||||
class LayoutB,
|
||||
class LayoutCD,
|
||||
@ -77,8 +78,10 @@ template<cute::UMMA::Major SFAMajor,
|
||||
bool groupwise_test(
|
||||
Int<ScaleGranularityM>, Int<ScaleGranularityN>, Int<ScaleGranularityK>, C<Is2SM>,
|
||||
LayoutA, LayoutB, LayoutCD,
|
||||
MmaTileShape, ClusterShape) {
|
||||
|
||||
MmaTileShape, ClusterShape,
|
||||
C<NoSmemEpilogue>) {
|
||||
using Epilogue1SM = conditional_t<NoSmemEpilogue, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm, cutlass::epilogue::TmaWarpSpecialized1Sm>;
|
||||
using Epilogue2SM = conditional_t<NoSmemEpilogue, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized2Sm, cutlass::epilogue::TmaWarpSpecialized2Sm>;
|
||||
using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, SFAMajor, SFBMajor>;
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
|
||||
@ -90,7 +93,7 @@ bool groupwise_test(
|
||||
float, float,
|
||||
cutlass::float_e4m3_t, LayoutCD, 16,
|
||||
cutlass::float_e4m3_t, LayoutCD, 16,
|
||||
conditional_t<Is2SM, cutlass::epilogue::TmaWarpSpecialized2Sm, cutlass::epilogue::TmaWarpSpecialized1Sm>
|
||||
conditional_t<Is2SM, Epilogue2SM, Epilogue1SM>
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
@ -259,11 +262,26 @@ TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_1sm_f32_align16_blockwise, 128
|
||||
cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{},
|
||||
cutlass::layout::RowMajor{},
|
||||
Shape<_128,_128,_128>{},
|
||||
Shape<_1,_1,_1>{});
|
||||
Shape<_1,_1,_1>{},
|
||||
false_type{});
|
||||
|
||||
EXPECT_TRUE(passed);
|
||||
}
|
||||
|
||||
TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_1sm_f32_align16_blockwise, 128x128x128_1x1x1_2x2x32_scale_direct_store) {
|
||||
|
||||
bool passed = groupwise_test<UMMA::Major::MN, UMMA::Major::K>(
|
||||
Int<2>{}, Int<2>{}, Int<32>{}, false_type{},
|
||||
cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{},
|
||||
cutlass::layout::RowMajor{},
|
||||
Shape<_128,_128,_128>{},
|
||||
Shape<_1,_1,_1>{},
|
||||
true_type{});
|
||||
|
||||
EXPECT_TRUE(passed);
|
||||
}
|
||||
|
||||
|
||||
TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_2sm_f32_align16_blockwise, 256x128x128_2x1x1_64x4x32_scale) {
|
||||
|
||||
bool passed = groupwise_test<UMMA::Major::MN, UMMA::Major::MN>(
|
||||
@ -271,7 +289,8 @@ TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_2sm_f32_align16_blockwise, 256
|
||||
cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{},
|
||||
cutlass::layout::RowMajor{},
|
||||
Shape<_256,_128,_128>{},
|
||||
Shape<_2,_1,_1>{});
|
||||
Shape<_2,_1,_1>{},
|
||||
false_type{});
|
||||
|
||||
EXPECT_TRUE(passed);
|
||||
|
||||
@ -284,7 +303,8 @@ TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_1sm_f32_align16_blockwise, 128
|
||||
cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{},
|
||||
cutlass::layout::RowMajor{},
|
||||
Shape<_128,_128,_128>{},
|
||||
Shape<_1,_1,_1>{});
|
||||
Shape<_1,_1,_1>{},
|
||||
false_type{});
|
||||
|
||||
EXPECT_TRUE(passed);
|
||||
|
||||
@ -297,7 +317,8 @@ TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_2sm_f32_align16_blockwise, 256
|
||||
cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{},
|
||||
cutlass::layout::RowMajor{},
|
||||
Shape<_256,_128,_128>{},
|
||||
Shape<_2,_1,_1>{});
|
||||
Shape<_2,_1,_1>{},
|
||||
false_type{});
|
||||
|
||||
EXPECT_TRUE(passed);
|
||||
|
||||
@ -311,7 +332,8 @@ TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_2sm_f32_align16_blockwise, 256
|
||||
cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{},
|
||||
cutlass::layout::RowMajor{},
|
||||
Shape<_256,_128,_128>{},
|
||||
Shape<_2,_1,_1>{});
|
||||
Shape<_2,_1,_1>{},
|
||||
false_type{});
|
||||
|
||||
EXPECT_TRUE(passed);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user