refine the implementation
This commit is contained in:
@ -39,6 +39,7 @@
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16,
|
||||
128x128_64x3_64x64x64) {
|
||||
|
||||
@ -80,6 +81,7 @@ TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tens
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16,
|
||||
128x128_64x3_64x64x64) {
|
||||
|
||||
@ -121,4 +123,115 @@ TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_ten
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align4,
|
||||
128x128_64x3_64x64x64) {
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
/// Device-level Conv2d instance
|
||||
using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
4,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic,
|
||||
cutlass::conv::StrideSupport::kUnity,
|
||||
4,
|
||||
4
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution<Conv2dDgradKernel>;
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dDgrad>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align4,
|
||||
128x128_64x3_64x64x64) {
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
/// Device-level Conv2d instance
|
||||
using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
4,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
cutlass::conv::StrideSupport::kUnity,
|
||||
4,
|
||||
4
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution<Conv2dDgradKernel>;
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dDgrad>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED
|
||||
|
||||
@ -38,46 +38,7 @@
|
||||
#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST(SM75_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32,
|
||||
128x128_32x2_64x64x32) {
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = float;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
|
||||
/// Device-level Conv2d instance
|
||||
using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
cutlass::gemm::GemmShape<128, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic,
|
||||
cutlass::conv::StrideSupport::kUnity
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution<Conv2dDgradKernel>;
|
||||
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dDgrad>());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST(SM75_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride,
|
||||
128x128_32x2_64x64x32) {
|
||||
|
||||
@ -118,6 +79,7 @@ TEST(SM75_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tens
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM75_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride,
|
||||
128x128_32x2_64x64x32) {
|
||||
|
||||
@ -158,4 +120,113 @@ TEST(SM75_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_ten
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM75_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride_align2,
|
||||
128x128_32x2_64x64x32) {
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = float;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
|
||||
/// Device-level Conv2d instance
|
||||
using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
cutlass::gemm::GemmShape<128, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
4,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic,
|
||||
cutlass::conv::StrideSupport::kUnity,
|
||||
2,
|
||||
2
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution<Conv2dDgradKernel>;
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dDgrad>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM75_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride_align2,
|
||||
128x128_32x2_64x64x32) {
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = float;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
|
||||
/// Device-level Conv2d instance
|
||||
using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
cutlass::gemm::GemmShape<128, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
4,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
cutlass::conv::StrideSupport::kUnity,
|
||||
2,
|
||||
2
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution<Conv2dDgradKernel>;
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dDgrad>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED
|
||||
|
||||
@ -38,6 +38,7 @@
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16,
|
||||
128x128_64x3_64x64x64) {
|
||||
|
||||
@ -78,6 +79,7 @@ TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tens
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16,
|
||||
128x128_64x3_64x64x64) {
|
||||
|
||||
@ -118,9 +120,10 @@ TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_ten
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align2,
|
||||
|
||||
TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align2,
|
||||
128x128_64x3_64x64x64) {
|
||||
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
@ -141,28 +144,41 @@ TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_ten
|
||||
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
8,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic,
|
||||
cutlass::conv::StrideSupport::kStrided,
|
||||
2,
|
||||
2
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
|
||||
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dFprop>());
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dFprop>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align4,
|
||||
|
||||
TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align2,
|
||||
128x128_64x3_64x64x64) {
|
||||
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
@ -183,7 +199,7 @@ TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_ten
|
||||
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
8,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
@ -191,14 +207,81 @@ TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_ten
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
cutlass::conv::StrideSupport::kStrided,
|
||||
2,
|
||||
2
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dFprop>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align4,
|
||||
128x128_64x3_64x64x64) {
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
/// Device-level Conv2d instance
|
||||
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
8,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
cutlass::conv::StrideSupport::kStrided,
|
||||
4,
|
||||
4
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
|
||||
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dFprop>());
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dFprop>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -119,49 +119,6 @@ TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_ten
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align1,
|
||||
128x128_32x2_64x64x32) {
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = float;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
|
||||
/// Device-level Conv2d instance
|
||||
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
cutlass::gemm::GemmShape<128, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
1,
|
||||
1
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dFprop>());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align2,
|
||||
128x128_32x2_64x64x32) {
|
||||
|
||||
@ -185,7 +142,7 @@ TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_ten
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
8,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
@ -193,14 +150,26 @@ TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_ten
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
cutlass::conv::StrideSupport::kStrided,
|
||||
2,
|
||||
2
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dFprop>());
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dFprop>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -228,7 +197,7 @@ TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_ten
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
8,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
@ -236,15 +205,83 @@ TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_ten
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
cutlass::conv::StrideSupport::kStrided,
|
||||
4,
|
||||
4
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dFprop>());
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dFprop>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4,
|
||||
128x128_32x2_64x64x32) {
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = float;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
|
||||
/// Device-level Conv2d instance
|
||||
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
cutlass::gemm::GemmShape<128, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
8,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic,
|
||||
cutlass::conv::StrideSupport::kStrided,
|
||||
4,
|
||||
4
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dFprop>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED
|
||||
|
||||
@ -60,7 +60,7 @@ TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_te
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
8,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
@ -79,53 +79,9 @@ TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_te
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align1,
|
||||
128x128_32x3_64x64x32) {
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::tfloat32_t;
|
||||
using ElementB = cutlass::tfloat32_t;
|
||||
using ElementC = float;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
|
||||
/// Device-level Conv2d instance
|
||||
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 16>,
|
||||
cutlass::gemm::GemmShape<64, 64, 16>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
1,
|
||||
1
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
|
||||
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dFprop>());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align2,
|
||||
128x128_32x3_64x64x32) {
|
||||
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::tfloat32_t;
|
||||
using ElementB = cutlass::tfloat32_t;
|
||||
@ -146,7 +102,7 @@ TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_t
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
8,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
@ -154,15 +110,26 @@ TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_t
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
cutlass::conv::StrideSupport::kStrided,
|
||||
2,
|
||||
2
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dFprop>());
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dFprop>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -174,6 +174,61 @@ TEST(SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32n
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dDgrad>());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4,
|
||||
128x128_32x3_64x64x32) {
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = float;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
|
||||
/// Device-level Conv2d instance
|
||||
using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
4,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic,
|
||||
cutlass::conv::StrideSupport::kStrided,
|
||||
4,
|
||||
4
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution<Conv2dDgradKernel>;
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dDgrad>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED
|
||||
|
||||
@ -116,6 +116,7 @@ public:
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
//cutlass::reference::host::TensorFill(view, Element(1.0f));
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
int scope;
|
||||
|
||||
@ -36,6 +36,8 @@
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM75_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32,
|
||||
128x128_32x2_64x64x32) {
|
||||
|
||||
@ -74,5 +76,114 @@ TEST(SM75_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tens
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM75_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4,
|
||||
128x128_32x2_64x64x32) {
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = float;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
|
||||
using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
cutlass::gemm::GemmShape<128, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
4,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic,
|
||||
cutlass::conv::StrideSupport::kStrided,
|
||||
4,
|
||||
4
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution<Conv2dWgradKernel>;
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dWgrad>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM75_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4,
|
||||
128x128_32x2_64x64x32) {
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = float;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
|
||||
using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
cutlass::gemm::GemmShape<128, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
4,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
cutlass::conv::StrideSupport::kStrided,
|
||||
4,
|
||||
4
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution<Conv2dWgradKernel>;
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dWgrad>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED
|
||||
|
||||
|
||||
@ -146,8 +146,7 @@ TEST(SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_ten
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
4,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
cutlass::conv::StrideSupport::kStrided
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution<Conv2dWgradKernel>;
|
||||
@ -157,5 +156,113 @@ TEST(SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_ten
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED
|
||||
|
||||
TEST(SM80_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4,
|
||||
128x128_32x3_64x64x32) {
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = float;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
|
||||
using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 16>,
|
||||
cutlass::gemm::GemmShape<64, 64, 16>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
4,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic,
|
||||
cutlass::conv::StrideSupport::kStrided,
|
||||
4,
|
||||
4
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution<Conv2dWgradKernel>;
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dWgrad>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4,
|
||||
128x128_32x3_64x64x32) {
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = float;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
|
||||
using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 16>,
|
||||
cutlass::gemm::GemmShape<64, 64, 16>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
4,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
cutlass::conv::StrideSupport::kStrided,
|
||||
4,
|
||||
4
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution<Conv2dWgradKernel>;
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 4, 4, 12}, // input size (NHWC)
|
||||
{8, 3, 3, 12}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dWgrad>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED
|
||||
|
||||
Reference in New Issue
Block a user