Fix type bug in conv2d/gemm with broadcast (#796)

add ElementVector

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Shuai Shao
2023-02-09 17:53:25 -08:00
committed by GitHub
parent 2e10404d26
commit ce8597dc14
7 changed files with 21 additions and 11 deletions

View File

@ -107,7 +107,7 @@ struct DefaultConv2dFpropWithBroadcast {
ImplicitGemmBase::Epilogue::kPartitionsK,
ElementC,
typename EpilogueOutputOp::ElementT,
ElementC,
typename EpilogueOutputOp::ElementVector,
EpilogueOutputOp,
ImplicitGemmBase::Epilogue::kElementsPerAccess
>::Epilogue;

View File

@ -61,7 +61,8 @@ template <
typename ElementT_,
int ElementsPerAccess,
typename ElementwiseOp_ = Identity<ElementCompute_>,
typename BinaryOp_ = plus<ElementCompute_>
typename BinaryOp_ = plus<ElementCompute_>,
typename ElementVector_ = ElementC_
>
class LinearCombinationBiasElementwise {
public:
@ -72,6 +73,7 @@ public:
using ElementCompute = ElementCompute_;
using ElementZ = ElementZ_;
using ElementT = ElementT_;
using ElementVector = ElementVector_;
static int const kElementsPerAccess = ElementsPerAccess;
static int const kCount = kElementsPerAccess;

View File

@ -204,7 +204,8 @@ template <
typename ElementCompute_,
typename ElementZ_,
int ElementsPerAccess,
bool StoreT = true
bool StoreT = true,
typename ElementVector_ = ElementC_
>
class LinearCombinationBiasRelu {
public:
@ -214,6 +215,7 @@ public:
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
using ElementZ = ElementZ_;
using ElementVector = ElementVector_;
using ElementT = uint1b_t;

View File

@ -59,7 +59,8 @@ template <typename ElementOutput_, typename ElementAccumulator_,
template <typename T> class ActivationOp_,
template <typename T> class BinaryOp1_,
template <typename T> class UnaryOp_,
template <typename T> class BinaryOp2_ = detail::NoOp>
template <typename T> class BinaryOp2_ = detail::NoOp,
typename ElementVector_ = ElementC_>
class LinearCombinationResidualBlock {
public:
static bool const kIsSingleSource = false;
@ -68,6 +69,7 @@ public:
using ElementC = ElementC_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
using ElementVector = ElementVector_;
static int const kElementsPerAccess = ElementsPerAccess;
static int const kCount = kElementsPerAccess;
@ -179,11 +181,12 @@ template <typename ElementOutput_, typename ElementAccumulator_,
typename ElementCompute_, typename ElementC_, int ElementsPerAccess,
template <typename T> class ActivationOp_,
template <typename T> class BinaryOp1_,
template <typename T> class UnaryOp_>
template <typename T> class UnaryOp_,
typename ElementVector_>
class LinearCombinationResidualBlock<ElementOutput_, ElementAccumulator_,
ElementCompute_, ElementC_, ElementsPerAccess,
ActivationOp_, BinaryOp1_, UnaryOp_,
detail::NoOp> {
detail::NoOp, ElementVector_> {
public:
static bool const kIsSingleSource = true;
@ -191,6 +194,7 @@ public:
using ElementC = ElementC_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
using ElementVector = ElementVector_;
static int const kElementsPerAccess = ElementsPerAccess;
static int const kCount = kElementsPerAccess;

View File

@ -121,7 +121,7 @@ struct DefaultGemmWithBroadcast {
GemmBase::Epilogue::kPartitionsK,
ElementC_,
typename EpilogueOutputOp::ElementT,
ElementC_,
typename EpilogueOutputOp::ElementVector,
EpilogueOutputOp,
GemmBase::Epilogue::kElementsPerAccess
>::Epilogue;
@ -221,7 +221,7 @@ struct DefaultGemmWithBroadcast<
GemmBase::Epilogue::kPartitionsK,
ElementC_,
typename EpilogueOutputOp::ElementT,
ElementC_,
typename EpilogueOutputOp::ElementVector,
EpilogueOutputOp,
GemmBase::Epilogue::kElementsPerAccess
>::Epilogue;