Fix type bug in conv2d/gemm with broadcast (#796)
add ElementVector --------- Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
@ -107,7 +107,7 @@ struct DefaultConv2dFpropWithBroadcast {
|
||||
ImplicitGemmBase::Epilogue::kPartitionsK,
|
||||
ElementC,
|
||||
typename EpilogueOutputOp::ElementT,
|
||||
ElementC,
|
||||
typename EpilogueOutputOp::ElementVector,
|
||||
EpilogueOutputOp,
|
||||
ImplicitGemmBase::Epilogue::kElementsPerAccess
|
||||
>::Epilogue;
|
||||
|
||||
@ -61,7 +61,8 @@ template <
|
||||
typename ElementT_,
|
||||
int ElementsPerAccess,
|
||||
typename ElementwiseOp_ = Identity<ElementCompute_>,
|
||||
typename BinaryOp_ = plus<ElementCompute_>
|
||||
typename BinaryOp_ = plus<ElementCompute_>,
|
||||
typename ElementVector_ = ElementC_
|
||||
>
|
||||
class LinearCombinationBiasElementwise {
|
||||
public:
|
||||
@ -72,6 +73,7 @@ public:
|
||||
using ElementCompute = ElementCompute_;
|
||||
using ElementZ = ElementZ_;
|
||||
using ElementT = ElementT_;
|
||||
using ElementVector = ElementVector_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
static int const kCount = kElementsPerAccess;
|
||||
|
||||
|
||||
@ -204,7 +204,8 @@ template <
|
||||
typename ElementCompute_,
|
||||
typename ElementZ_,
|
||||
int ElementsPerAccess,
|
||||
bool StoreT = true
|
||||
bool StoreT = true,
|
||||
typename ElementVector_ = ElementC_
|
||||
>
|
||||
class LinearCombinationBiasRelu {
|
||||
public:
|
||||
@ -214,6 +215,7 @@ public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
using ElementZ = ElementZ_;
|
||||
using ElementVector = ElementVector_;
|
||||
|
||||
using ElementT = uint1b_t;
|
||||
|
||||
|
||||
@ -59,7 +59,8 @@ template <typename ElementOutput_, typename ElementAccumulator_,
|
||||
template <typename T> class ActivationOp_,
|
||||
template <typename T> class BinaryOp1_,
|
||||
template <typename T> class UnaryOp_,
|
||||
template <typename T> class BinaryOp2_ = detail::NoOp>
|
||||
template <typename T> class BinaryOp2_ = detail::NoOp,
|
||||
typename ElementVector_ = ElementC_>
|
||||
class LinearCombinationResidualBlock {
|
||||
public:
|
||||
static bool const kIsSingleSource = false;
|
||||
@ -68,6 +69,7 @@ public:
|
||||
using ElementC = ElementC_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
using ElementVector = ElementVector_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
static int const kCount = kElementsPerAccess;
|
||||
|
||||
@ -179,11 +181,12 @@ template <typename ElementOutput_, typename ElementAccumulator_,
|
||||
typename ElementCompute_, typename ElementC_, int ElementsPerAccess,
|
||||
template <typename T> class ActivationOp_,
|
||||
template <typename T> class BinaryOp1_,
|
||||
template <typename T> class UnaryOp_>
|
||||
template <typename T> class UnaryOp_,
|
||||
typename ElementVector_>
|
||||
class LinearCombinationResidualBlock<ElementOutput_, ElementAccumulator_,
|
||||
ElementCompute_, ElementC_, ElementsPerAccess,
|
||||
ActivationOp_, BinaryOp1_, UnaryOp_,
|
||||
detail::NoOp> {
|
||||
detail::NoOp, ElementVector_> {
|
||||
public:
|
||||
static bool const kIsSingleSource = true;
|
||||
|
||||
@ -191,6 +194,7 @@ public:
|
||||
using ElementC = ElementC_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
using ElementVector = ElementVector_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
static int const kCount = kElementsPerAccess;
|
||||
|
||||
|
||||
@ -121,7 +121,7 @@ struct DefaultGemmWithBroadcast {
|
||||
GemmBase::Epilogue::kPartitionsK,
|
||||
ElementC_,
|
||||
typename EpilogueOutputOp::ElementT,
|
||||
ElementC_,
|
||||
typename EpilogueOutputOp::ElementVector,
|
||||
EpilogueOutputOp,
|
||||
GemmBase::Epilogue::kElementsPerAccess
|
||||
>::Epilogue;
|
||||
@ -221,7 +221,7 @@ struct DefaultGemmWithBroadcast<
|
||||
GemmBase::Epilogue::kPartitionsK,
|
||||
ElementC_,
|
||||
typename EpilogueOutputOp::ElementT,
|
||||
ElementC_,
|
||||
typename EpilogueOutputOp::ElementVector,
|
||||
EpilogueOutputOp,
|
||||
GemmBase::Epilogue::kElementsPerAccess
|
||||
>::Epilogue;
|
||||
|
||||
Reference in New Issue
Block a user