releaase 2.11 (#703)
This commit is contained in:
@ -67,6 +67,7 @@ Conv2dOperationProfiler::Conv2dOperationProfiler(Options const &options):
|
||||
{ArgumentTypeID::kInteger, {"s", "filter_s"}, "Filter S dimension of the Conv2d problem space"},
|
||||
{ArgumentTypeID::kInteger, {"p", "output_p"}, "Output P dimension of the Conv2d problem space"},
|
||||
{ArgumentTypeID::kInteger, {"q", "output_q"}, "Output Q dimension of the Conv2d problem space"},
|
||||
{ArgumentTypeID::kInteger, {"g", "groups"}, "Number of convolution groups"},
|
||||
{ArgumentTypeID::kInteger, {"pad_h"}, "Padding in H direction"},
|
||||
{ArgumentTypeID::kInteger, {"pad_w"}, "Padding in W direction"},
|
||||
{ArgumentTypeID::kInteger, {"stride_h"}, "Stride in H direction"},
|
||||
@ -233,6 +234,11 @@ Status Conv2dOperationProfiler::initialize_configuration(
|
||||
problem_.s = 3;
|
||||
}
|
||||
|
||||
if (!arg_as_int(problem_.groups, "g", problem_space, problem)) {
|
||||
// default value
|
||||
problem_.groups = 1;
|
||||
}
|
||||
|
||||
if (!arg_as_int(problem_.pad_h, "pad_h", problem_space, problem)) {
|
||||
// default value
|
||||
problem_.pad_h = 1;
|
||||
@ -382,7 +388,7 @@ Status Conv2dOperationProfiler::initialize_configuration(
|
||||
int(problem_.dilation_w),
|
||||
static_cast<conv::Mode>(static_cast<int>(problem_.conv_mode)),
|
||||
int(problem_.split_k_slices),
|
||||
1 // groups
|
||||
int(problem_.groups)
|
||||
);
|
||||
|
||||
conv_workspace_.configuration.split_k_mode = static_cast<conv::SplitKMode>(static_cast<int>(problem_.split_k_mode));
|
||||
@ -454,6 +460,8 @@ void Conv2dOperationProfiler::initialize_result_(
|
||||
set_argument(result, "p", problem_space, problem_.p);
|
||||
set_argument(result, "q", problem_space, problem_.q);
|
||||
|
||||
set_argument(result, "g", problem_space, problem_.groups);
|
||||
|
||||
set_argument(result, "pad_h", problem_space, problem_.pad_h);
|
||||
set_argument(result, "pad_w", problem_space, problem_.pad_w);
|
||||
|
||||
@ -624,6 +632,19 @@ Status Conv2dOperationProfiler::initialize_workspace(
|
||||
conv_workspace_.problem_count
|
||||
);
|
||||
|
||||
if(problem_.groups == problem_.c && problem_.groups == problem_.k){
|
||||
// Depthwise direct conv kernel needs reorder the filter.
|
||||
conv_workspace_.reordered_B = device_context.allocate_tensor(
|
||||
options,
|
||||
"B",
|
||||
operation_desc.B.element,
|
||||
operation_desc.B.layout,
|
||||
problem_.extent_b(operation_desc.conv_kind),
|
||||
conv_workspace_.configuration.stride_b,
|
||||
conv_workspace_.problem_count
|
||||
);
|
||||
}
|
||||
|
||||
conv_workspace_.C = device_context.allocate_tensor(
|
||||
options,
|
||||
"C",
|
||||
@ -738,6 +759,12 @@ bool Conv2dOperationProfiler::verify_cutlass(
|
||||
conv_workspace_.arguments.beta = problem_.beta.data();
|
||||
conv_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost;
|
||||
|
||||
if (conv_workspace_.reordered_B != nullptr){
|
||||
conv_workspace_.arguments.reordered_B = conv_workspace_.reordered_B->data();
|
||||
}else{
|
||||
conv_workspace_.arguments.reordered_B = nullptr;
|
||||
}
|
||||
|
||||
conv_workspace_.Computed->copy_from_device(conv_workspace_.C->data());
|
||||
|
||||
if (conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) {
|
||||
|
||||
@ -75,6 +75,7 @@ public:
|
||||
struct Conv2dProblem {
|
||||
|
||||
int64_t n, h, w, c, p, q, k, r, s;
|
||||
int64_t groups;
|
||||
int64_t pad_h, pad_w;
|
||||
int64_t stride_h, stride_w;
|
||||
int64_t dilation_h, dilation_w;
|
||||
@ -114,7 +115,7 @@ public:
|
||||
cutlass::gemm::GemmCoord eq_gemm_size(library::ConvKind const &conv_kind) const {
|
||||
|
||||
switch (conv_kind) {
|
||||
case library::ConvKind::kFprop: return cutlass::gemm::GemmCoord(int(n * p * q), int(k), int(r * s * c));
|
||||
case library::ConvKind::kFprop: return cutlass::gemm::GemmCoord(int(n * p * q), int(k), int(r * s * c / groups));
|
||||
case library::ConvKind::kDgrad: return cutlass::gemm::GemmCoord(int(n * h * w), int(c), int(k * r * s));
|
||||
case library::ConvKind::kWgrad: return cutlass::gemm::GemmCoord(int(k), int(r * s * c), int(n * p * q));
|
||||
default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)");
|
||||
@ -136,7 +137,7 @@ public:
|
||||
std::vector<int> extent_b(library::ConvKind const &conv_kind) const {
|
||||
|
||||
switch (conv_kind) {
|
||||
case library::ConvKind::kFprop: return {int(k), int(r), int(s), int(c)};
|
||||
case library::ConvKind::kFprop: return {int(k), int(r), int(s), int(c / groups)};
|
||||
case library::ConvKind::kDgrad: return {int(k), int(r), int(s), int(c)};
|
||||
case library::ConvKind::kWgrad: return {int(n), int(h), int(w), int(c)};
|
||||
default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)");
|
||||
@ -228,6 +229,7 @@ public:
|
||||
/// Conv device allocations
|
||||
DeviceAllocation *A;
|
||||
DeviceAllocation *B;
|
||||
DeviceAllocation *reordered_B;
|
||||
DeviceAllocation *C;
|
||||
DeviceAllocation *Computed;
|
||||
DeviceAllocation *Reference;
|
||||
@ -270,6 +272,7 @@ public:
|
||||
Conv2dWorkspace()
|
||||
: A(nullptr),
|
||||
B(nullptr),
|
||||
reordered_B(nullptr),
|
||||
C(nullptr),
|
||||
Computed(nullptr),
|
||||
Reference(nullptr) {}
|
||||
@ -317,10 +320,10 @@ public:
|
||||
stride_activations.push_back(int(problem.h) * int(problem.w) *
|
||||
int(problem.c));
|
||||
|
||||
stride_filters.push_back(int(problem.c));
|
||||
stride_filters.push_back(int(problem.s) * int(problem.c));
|
||||
stride_filters.push_back(int(problem.c / problem.groups));
|
||||
stride_filters.push_back(int(problem.s) * int(problem.c / problem.groups));
|
||||
stride_filters.push_back(int(problem.r) * int(problem.s) *
|
||||
int(problem.c));
|
||||
int(problem.c / problem.groups));
|
||||
|
||||
stride_output.push_back(int(problem.k));
|
||||
stride_output.push_back(int(problem.q) * int(problem.k));
|
||||
|
||||
@ -195,7 +195,12 @@ bool get_cudnn_mathtype(cudnnMathType_t &cudnn_math_type, library::ConvDescripti
|
||||
return true;
|
||||
}
|
||||
case library::OpcodeClassID::kSimt:
|
||||
return false;
|
||||
#if (defined(CUDNN_VERSION) && CUDNN_VERSION <= 8000)
|
||||
cudnn_math_type = CUDNN_DEFAULT_MATH;
|
||||
#else
|
||||
cudnn_math_type = CUDNN_FMA_MATH;
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
|
||||
@ -245,7 +245,7 @@ struct cudnnConvDispatcher {
|
||||
data_type_filter,
|
||||
layout_filter,
|
||||
configuration.problem_size.K,
|
||||
configuration.problem_size.C,
|
||||
configuration.problem_size.C / configuration.problem_size.groups,
|
||||
configuration.problem_size.R,
|
||||
configuration.problem_size.S
|
||||
));
|
||||
|
||||
Reference in New Issue
Block a user