Fix type bug in conv2d/gemm with broadcast (#796)

add ElementVector

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Shuai Shao
2023-02-09 17:53:25 -08:00
committed by GitHub
parent 2e10404d26
commit ce8597dc14
7 changed files with 21 additions and 11 deletions

View File

@ -120,6 +120,7 @@ public:
using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp;
using ElementZ = typename EpilogueOutputOp::ElementZ;
using ElementT = typename EpilogueOutputOp::ElementT;
using ElementVector = typename EpilogueOutputOp::ElementVector;
static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator;
static const bool kAddBroadcastFirst = AddBroadcastFirst;
@ -142,7 +143,7 @@ public:
cutlass::HostTensor<ElementT, LayoutC> tensor_T_computed;
cutlass::HostTensor<ElementT, LayoutC> tensor_T_reference;
cutlass::HostTensor<ElementAccumulator, LayoutC> tensor_Y_reference;
cutlass::HostTensor<ElementC, LayoutC> tensor_Broadcast; // Input Broadcast
cutlass::HostTensor<ElementVector, LayoutC> tensor_Broadcast; // Input Broadcast
public: