CUTLASS 3.1 (#915)

Co-authored-by: Aniket Shivam <ashivam@nvidia.com>
This commit is contained in:
ANIKET SHIVAM
2023-04-14 20:19:34 -07:00
committed by GitHub
parent 9b8166e3f0
commit d572cc1aab
482 changed files with 37184 additions and 16419 deletions

View File

@ -233,6 +233,17 @@ struct Options {
return false;
}
// Filter size passed through command line does not match filter size template parameter
if (filter_size.h() != FilterShape::kRow || filter_size.w() != FilterShape::kColumn) {
std::cerr << "Filter size passed in (" << filter_size.h() << "x" << filter_size.w() << ") "
<< "must match the FilterShape template parameter of the convolution "
<< "(" << FilterShape::kRow << "x" << FilterShape::kColumn << "). "
<< "To use the filter shape passed in, change the FilterShape template "
<< "parameter and recompile this example."
<< std::endl;
return false;
}
return true;
}
@ -319,9 +330,9 @@ struct Options {
"table\n";
out << "\n\nExamples:\n\n"
<< "$ ./examples/45_depthwise_simt_conv2dfprop/45_depthwise_simt_conv2dfprop --n=32 "
<< "$ ./examples/46_depthwise_simt_conv2dfprop/46_depthwise_simt_conv2dfprop --n=32 "
"--h=224 --w=224 --c=128 --k=128 --g=128 --r=3 --s=3\n\n"
<< "$ ./examples/45_depthwise_simt_conv2dfprop/45_depthwise_simt_conv2dfprop --n=1 "
<< "$ ./examples/46_depthwise_simt_conv2dfprop/46_depthwise_simt_conv2dfprop --n=1 "
"--h=224 --w=224 --c=32 --k=32 --g=32 --r=3 --s=3 --splitk=10 --ref-check\n\n";
return out;
@ -515,14 +526,13 @@ Result profile_convolution(Options const &options) {
ElementOutput,
LayoutOutput,
ElementComputeEpilogue,
ElementAccumulator,
cutlass::NumericConverter<ElementOutput, ElementComputeEpilogue> >(problem_size,
tensor_a.host_ref(),
tensor_b.host_ref(),
tensor_c.host_ref(),
tensor_ref_d.host_ref(),
options.alpha,
options.beta);
ElementAccumulator >(problem_size,
tensor_a.host_ref(),
tensor_b.host_ref(),
tensor_c.host_ref(),
tensor_ref_d.host_ref(),
options.alpha,
options.beta);
// Check if output from CUTLASS kernel and reference kernel are equal or not
tensor_d.sync_host();