refine the implementation

This commit is contained in:
Haicheng Wu
2021-09-08 13:14:08 +00:00
parent 4e8af93da1
commit 59e2aa505a
45 changed files with 1593 additions and 706 deletions

View File

@ -105,6 +105,18 @@ public:
return status;
}
static int const kAlignmentC = ImplicitGemmKernel::Epilogue::OutputTileIterator::kElementsPerAccess;
if (kConvolutionalOperator == conv::Operator::kFprop) {
if (args.problem_size.K % kAlignmentC)
return Status::kErrorMisalignedOperand;
} else if (kConvolutionalOperator == conv::Operator::kDgrad) {
if (args.problem_size.C % kAlignmentC)
return Status::kErrorMisalignedOperand;
} else if (kConvolutionalOperator == conv::Operator::kWgrad) {
if (args.problem_size.C % kAlignmentC)
return Status::kErrorMisalignedOperand;
}
// check for unsupported problem sizes for strided dgrad implementation
if (kConvolutionalOperator == conv::Operator::kDgrad &&
kStrideSupport == conv::StrideSupport::kStrided) {

View File

@ -66,7 +66,11 @@ template <
int Stages,
typename MathOperatorTag,
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
conv::StrideSupport StrideSupport = StrideSupport::kStrided
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
/// Access granularity of A matrix in units of elements
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
/// Access granularity of B matrix in units of elements
int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value
> struct DefaultConv2dDgrad;
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -90,7 +94,9 @@ template <
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag
typename MathOperatorTag,
int AlignmentA,
int AlignmentB
>
struct DefaultConv2dDgrad <
ElementA,
@ -110,7 +116,9 @@ struct DefaultConv2dDgrad <
Stages,
MathOperatorTag,
IteratorAlgorithm::kAnalytic,
StrideSupport::kStrided
StrideSupport::kStrided,
AlignmentA,
AlignmentB
> {
// Define the core components from GEMM
@ -121,24 +129,28 @@ struct DefaultConv2dDgrad <
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
using IteratorA =
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA,
ThreadMapA,
StrideSupport::kStrided
StrideSupport::kStrided,
AccessTypeA
>;
using SmemIteratorA = typename MmaCore::SmemIteratorA;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
using IteratorB =
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB,
ThreadMapB,
StrideSupport::kStrided
StrideSupport::kStrided,
AccessTypeB
>;
using SmemIteratorB = typename MmaCore::SmemIteratorB;
@ -147,6 +159,11 @@ struct DefaultConv2dDgrad <
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
using MmaPolicy = typename MmaCore::MmaPolicy;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the Mma
using Mma = threadblock::ImplicitGemmMultistage<
ThreadblockShape,
@ -155,7 +172,7 @@ struct DefaultConv2dDgrad <
arch::CacheOperation::Always,
IteratorB,
SmemIteratorB,
arch::CacheOperation::Global,
CacheOpB,
MmaPolicy,
Stages
>;
@ -196,7 +213,9 @@ template <
typename InstructionShape,
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
typename MathOperatorTag
typename MathOperatorTag,
int AlignmentA,
int AlignmentB
>
struct DefaultConv2dDgrad <
ElementA,
@ -216,7 +235,9 @@ struct DefaultConv2dDgrad <
2,
MathOperatorTag,
IteratorAlgorithm::kAnalytic,
StrideSupport::kStrided
StrideSupport::kStrided,
AlignmentA,
AlignmentB
> {
// Define the core components from GEMM
@ -227,13 +248,15 @@ struct DefaultConv2dDgrad <
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
using IteratorA =
cutlass::conv::threadblock::TileIteratorStridedDgrad<
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA,
ThreadMapA,
StrideSupport::kStrided
StrideSupport::kStrided,
AccessTypeA
>
>;
@ -241,13 +264,15 @@ struct DefaultConv2dDgrad <
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
using IteratorB =
cutlass::conv::threadblock::TileIteratorStridedDgrad<
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB,
ThreadMapB,
StrideSupport::kStrided
StrideSupport::kStrided,
AccessTypeB
>
>;
@ -308,7 +333,9 @@ template <
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag
typename MathOperatorTag,
int AlignmentA,
int AlignmentB
>
struct DefaultConv2dDgrad <
ElementA,
@ -328,7 +355,9 @@ struct DefaultConv2dDgrad <
Stages,
MathOperatorTag,
IteratorAlgorithm::kAnalytic,
StrideSupport::kUnity
StrideSupport::kUnity,
AlignmentA,
AlignmentB
> {
// Define the core components from GEMM
@ -339,24 +368,28 @@ struct DefaultConv2dDgrad <
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
using IteratorA =
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA,
ThreadMapA,
StrideSupport::kUnity
StrideSupport::kUnity,
AccessTypeA
>;
using SmemIteratorA = typename MmaCore::SmemIteratorA;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
using IteratorB =
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB,
ThreadMapB,
StrideSupport::kUnity
StrideSupport::kUnity,
AccessTypeB
>;
using SmemIteratorB = typename MmaCore::SmemIteratorB;
@ -365,6 +398,11 @@ struct DefaultConv2dDgrad <
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
using MmaPolicy = typename MmaCore::MmaPolicy;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the Mma
using Mma = threadblock::ImplicitGemmMultistage<
ThreadblockShape,
@ -373,7 +411,7 @@ struct DefaultConv2dDgrad <
arch::CacheOperation::Always,
IteratorB,
SmemIteratorB,
arch::CacheOperation::Global,
CacheOpB,
MmaPolicy,
Stages
>;
@ -414,7 +452,9 @@ template <
typename InstructionShape,
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
typename MathOperatorTag
typename MathOperatorTag,
int AlignmentA,
int AlignmentB
>
struct DefaultConv2dDgrad <
ElementA,
@ -434,7 +474,9 @@ struct DefaultConv2dDgrad <
2,
MathOperatorTag,
IteratorAlgorithm::kAnalytic,
StrideSupport::kUnity
StrideSupport::kUnity,
AlignmentA,
AlignmentB
> {
// Define the core components from GEMM
@ -445,13 +487,15 @@ struct DefaultConv2dDgrad <
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
using IteratorA =
cutlass::conv::threadblock::TileIterator<
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA,
ThreadMapA,
StrideSupport::kUnity
StrideSupport::kUnity,
AccessTypeA
>
>;
@ -459,13 +503,15 @@ struct DefaultConv2dDgrad <
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
using IteratorB =
cutlass::conv::threadblock::TileIterator<
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB,
ThreadMapB,
StrideSupport::kUnity
StrideSupport::kUnity,
AccessTypeB
>
>;
@ -526,7 +572,9 @@ template <
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag
typename MathOperatorTag,
int AlignmentA,
int AlignmentB
>
struct DefaultConv2dDgrad <
ElementA,
@ -546,7 +594,9 @@ struct DefaultConv2dDgrad <
Stages,
MathOperatorTag,
IteratorAlgorithm::kOptimized,
StrideSupport::kUnity
StrideSupport::kUnity,
AlignmentA,
AlignmentB
> {
// Define the core components from GEMM
@ -557,23 +607,28 @@ struct DefaultConv2dDgrad <
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
using IteratorA =
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA,
ThreadMapA,
StrideSupport::kUnity
StrideSupport::kUnity,
AccessTypeA
>;
using SmemIteratorA = typename MmaCore::SmemIteratorA;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
using IteratorB =
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB,
ThreadMapB
ThreadMapB,
StrideSupport::kUnity,
AccessTypeB
>;
using SmemIteratorB = typename MmaCore::SmemIteratorB;
@ -582,6 +637,11 @@ struct DefaultConv2dDgrad <
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
using MmaPolicy = typename MmaCore::MmaPolicy;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the Mma
using Mma = threadblock::ImplicitGemmMultistage<
ThreadblockShape,
@ -590,7 +650,7 @@ struct DefaultConv2dDgrad <
arch::CacheOperation::Always,
IteratorB,
SmemIteratorB,
arch::CacheOperation::Global,
CacheOpB,
MmaPolicy,
Stages
>;
@ -631,7 +691,9 @@ template <
typename InstructionShape,
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
typename MathOperatorTag
typename MathOperatorTag,
int AlignmentA,
int AlignmentB
>
struct DefaultConv2dDgrad <
ElementA,
@ -651,7 +713,9 @@ struct DefaultConv2dDgrad <
2,
MathOperatorTag,
IteratorAlgorithm::kOptimized,
StrideSupport::kUnity
StrideSupport::kUnity,
AlignmentA,
AlignmentB
> {
// Define the core components from GEMM
@ -662,13 +726,15 @@ struct DefaultConv2dDgrad <
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
using IteratorA =
cutlass::conv::threadblock::TileIterator<
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA,
ThreadMapA,
StrideSupport::kUnity
StrideSupport::kUnity,
AccessTypeA
>
>;
@ -676,12 +742,15 @@ struct DefaultConv2dDgrad <
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
using IteratorB =
cutlass::conv::threadblock::TileIterator<
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB,
ThreadMapB
ThreadMapB,
StrideSupport::kUnity,
AccessTypeB
>
>;
@ -744,7 +813,10 @@ template <
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag>
typename MathOperatorTag,
int AlignmentA,
int AlignmentB
>
struct DefaultConv2dDgrad <
ElementA,
LayoutA,
@ -763,7 +835,9 @@ struct DefaultConv2dDgrad <
Stages,
MathOperatorTag,
IteratorAlgorithm::kAnalytic,
conv::StrideSupport::kUnity
conv::StrideSupport::kUnity,
AlignmentA,
AlignmentB
> {
// Define the core components from GEMM
@ -848,7 +922,10 @@ template <
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag>
typename MathOperatorTag,
int AlignmentA,
int AlignmentB
>
struct DefaultConv2dDgrad <
ElementA,
LayoutA,
@ -867,7 +944,9 @@ struct DefaultConv2dDgrad <
Stages,
MathOperatorTag,
IteratorAlgorithm::kAnalytic,
conv::StrideSupport::kStrided
conv::StrideSupport::kStrided,
AlignmentA,
AlignmentB
> {
// Define the core components from GEMM
@ -955,7 +1034,9 @@ template <
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag
typename MathOperatorTag,
int AlignmentA,
int AlignmentB
>
struct DefaultConv2dDgrad <
ElementA,
@ -975,7 +1056,9 @@ struct DefaultConv2dDgrad <
Stages,
MathOperatorTag,
IteratorAlgorithm::kOptimized,
StrideSupport::kUnity
StrideSupport::kUnity,
AlignmentA,
AlignmentB
> {
// Define the core components from GEMM
@ -1040,10 +1123,8 @@ struct DefaultConv2dDgrad <
ThreadblockSwizzle,
conv::Operator::kDgrad
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -1063,7 +1144,9 @@ template <
typename InstructionShape,
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
typename MathOperatorTag
typename MathOperatorTag,
int AlignmentA,
int AlignmentB
>
struct DefaultConv2dDgrad <
ElementA,
@ -1083,7 +1166,9 @@ struct DefaultConv2dDgrad <
2,
MathOperatorTag,
IteratorAlgorithm::kAnalytic,
conv::StrideSupport::kUnity
conv::StrideSupport::kUnity,
AlignmentA,
AlignmentB
> {
// Define the core components from GEMM
@ -1169,7 +1254,9 @@ template <
typename InstructionShape,
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
typename MathOperatorTag
typename MathOperatorTag,
int AlignmentA,
int AlignmentB
>
struct DefaultConv2dDgrad <
ElementA,
@ -1189,7 +1276,9 @@ struct DefaultConv2dDgrad <
2,
MathOperatorTag,
IteratorAlgorithm::kAnalytic,
conv::StrideSupport::kStrided
conv::StrideSupport::kStrided,
AlignmentA,
AlignmentB
> {
// Define the core components from GEMM
@ -1257,7 +1346,6 @@ struct DefaultConv2dDgrad <
ThreadblockSwizzle,
conv::Operator::kDgrad
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -1278,7 +1366,9 @@ template <
typename InstructionShape,
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
typename MathOperatorTag
typename MathOperatorTag,
int AlignmentA,
int AlignmentB
>
struct DefaultConv2dDgrad <
ElementA,
@ -1298,7 +1388,9 @@ struct DefaultConv2dDgrad <
2,
MathOperatorTag,
IteratorAlgorithm::kOptimized,
StrideSupport::kUnity
StrideSupport::kUnity,
AlignmentA,
AlignmentB
> {
// Define the core components from GEMM
@ -1368,10 +1460,10 @@ struct DefaultConv2dDgrad <
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace conv
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -66,11 +66,11 @@ template <
int Stages,
typename MathOperatorTag,
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
/// Access granularity of A matrix in units of elements
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
/// Access granularity of B matrix in units of elements
int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value,
conv::StrideSupport StrideSupport = StrideSupport::kStrided
int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value
> struct DefaultConv2dFprop;
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -95,6 +95,7 @@ template <
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag,
conv::StrideSupport StrideSupport,
int AlignmentA,
int AlignmentB
>
@ -116,6 +117,7 @@ struct DefaultConv2dFprop <
Stages,
MathOperatorTag,
IteratorAlgorithm::kAnalytic,
StrideSupport,
AlignmentA,
AlignmentB
> {
@ -128,22 +130,26 @@ struct DefaultConv2dFprop <
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
using IteratorA =
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA, LayoutA,
ThreadMapA
ThreadMapA,
AccessTypeA
>;
using SmemIteratorA = typename MmaCore::SmemIteratorA;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
using IteratorB =
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB, LayoutB,
ThreadMapB
ThreadMapB,
AccessTypeB
>;
using SmemIteratorB = typename MmaCore::SmemIteratorB;
@ -152,6 +158,11 @@ struct DefaultConv2dFprop <
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
using MmaPolicy = typename MmaCore::MmaPolicy;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the Mma
using Mma = threadblock::ImplicitGemmMultistage<
ThreadblockShape,
@ -160,7 +171,7 @@ struct DefaultConv2dFprop <
arch::CacheOperation::Always,
IteratorB,
SmemIteratorB,
arch::CacheOperation::Global,
CacheOpB,
MmaPolicy,
Stages
>;
@ -203,9 +214,10 @@ template <
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag,
int InterleavedK,
conv::StrideSupport StrideSupport,
int AlignmentA,
int AlignmentB
int AlignmentB,
int InterleavedK
>
struct DefaultConv2dFprop <
ElementA,
@ -225,6 +237,7 @@ struct DefaultConv2dFprop <
Stages,
MathOperatorTag,
IteratorAlgorithm::kAnalytic,
StrideSupport,
AlignmentA,
AlignmentB
> {
@ -325,6 +338,7 @@ template <
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
typename MathOperatorTag,
conv::StrideSupport StrideSupport,
int AlignmentA,
int AlignmentB
>
@ -346,6 +360,7 @@ struct DefaultConv2dFprop <
2,
MathOperatorTag,
IteratorAlgorithm::kAnalytic,
StrideSupport,
AlignmentA,
AlignmentB
> {
@ -358,12 +373,14 @@ struct DefaultConv2dFprop <
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
using IteratorA =
cutlass::conv::threadblock::TileIterator<
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA, LayoutA,
ThreadMapA
ThreadMapA,
AccessTypeA
>
>;
@ -371,12 +388,14 @@ struct DefaultConv2dFprop <
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
using IteratorB =
cutlass::conv::threadblock::TileIterator<
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB, LayoutB,
ThreadMapB
ThreadMapB,
AccessTypeB
>
>;
@ -435,9 +454,10 @@ template <
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
typename MathOperatorTag,
int InterleavedK,
conv::StrideSupport StrideSupport,
int AlignmentA,
int AlignmentB
int AlignmentB,
int InterleavedK
>
struct DefaultConv2dFprop <
ElementA,
@ -457,6 +477,7 @@ struct DefaultConv2dFprop <
2,
MathOperatorTag,
IteratorAlgorithm::kAnalytic,
StrideSupport,
AlignmentA,
AlignmentB
> {
@ -561,6 +582,7 @@ template <
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag,
conv::StrideSupport StrideSupport,
int AlignmentA,
int AlignmentB
>
@ -582,6 +604,7 @@ struct DefaultConv2dFprop <
Stages,
MathOperatorTag,
IteratorAlgorithm::kOptimized,
StrideSupport,
AlignmentA,
AlignmentB
> {
@ -595,26 +618,28 @@ struct DefaultConv2dFprop <
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
using IteratorA =
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA,
LayoutA,
ThreadMapA,
AlignmentA
AccessTypeA
>;
using SmemIteratorA = typename MmaCore::SmemIteratorA;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
using IteratorB =
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB,
LayoutB,
ThreadMapB,
AlignmentB
AccessTypeB
>;
using SmemIteratorB = typename MmaCore::SmemIteratorB;
@ -624,7 +649,7 @@ struct DefaultConv2dFprop <
using MmaPolicy = typename MmaCore::MmaPolicy;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<ElementA>::value * AlignmentB) == 128)
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
@ -679,9 +704,10 @@ template <
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag,
int InterleavedK,
conv::StrideSupport StrideSupport,
int AlignmentA,
int AlignmentB
int AlignmentB,
int InterleavedK
>
struct DefaultConv2dFprop <
ElementA,
@ -701,6 +727,7 @@ struct DefaultConv2dFprop <
Stages,
MathOperatorTag,
IteratorAlgorithm::kOptimized,
StrideSupport,
AlignmentA,
AlignmentB
> {
@ -774,6 +801,8 @@ struct DefaultConv2dFprop <
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for Conv2dFprop specialzation for Optimized IteratorAlgorithm
/// and 2 stage pipeline.
template <
@ -791,6 +820,7 @@ template <
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
typename MathOperatorTag,
conv::StrideSupport StrideSupport,
int AlignmentA,
int AlignmentB
>
@ -812,6 +842,7 @@ struct DefaultConv2dFprop <
2,
MathOperatorTag,
IteratorAlgorithm::kOptimized,
StrideSupport,
AlignmentA,
AlignmentB
> {
@ -824,6 +855,7 @@ struct DefaultConv2dFprop <
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
using IteratorA =
cutlass::conv::threadblock::TileIterator<
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
@ -831,7 +863,7 @@ struct DefaultConv2dFprop <
ElementA,
LayoutA,
ThreadMapA,
AlignmentA
AccessTypeA
>
>;
@ -839,6 +871,7 @@ struct DefaultConv2dFprop <
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
using IteratorB =
cutlass::conv::threadblock::TileIterator<
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
@ -846,7 +879,7 @@ struct DefaultConv2dFprop <
ElementB,
LayoutB,
ThreadMapB,
AlignmentB
AccessTypeB
>
>;
@ -905,9 +938,10 @@ template <
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
typename MathOperatorTag,
int InterleavedK,
conv::StrideSupport StrideSupport,
int AlignmentA,
int AlignmentB
int AlignmentB,
int InterleavedK
>
struct DefaultConv2dFprop <
ElementA,
@ -927,6 +961,7 @@ struct DefaultConv2dFprop <
2,
MathOperatorTag,
IteratorAlgorithm::kOptimized,
StrideSupport,
AlignmentA,
AlignmentB
> {
@ -1023,6 +1058,7 @@ template <
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag,
conv::StrideSupport StrideSupport,
int AlignmentA,
int AlignmentB
>
@ -1044,6 +1080,7 @@ struct DefaultConv2dFprop <
Stages,
MathOperatorTag,
IteratorAlgorithm::kAnalytic,
StrideSupport,
AlignmentA,
AlignmentB
> {
@ -1132,6 +1169,7 @@ template <
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag,
conv::StrideSupport StrideSupport,
int AlignmentA,
int AlignmentB
>
@ -1153,6 +1191,7 @@ struct DefaultConv2dFprop <
Stages,
MathOperatorTag,
IteratorAlgorithm::kOptimized,
StrideSupport,
AlignmentA,
AlignmentB
> {
@ -1241,6 +1280,7 @@ template <
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
typename MathOperatorTag,
conv::StrideSupport StrideSupport,
int AlignmentA,
int AlignmentB
>
@ -1262,6 +1302,7 @@ struct DefaultConv2dFprop <
2,
MathOperatorTag,
IteratorAlgorithm::kAnalytic,
StrideSupport,
AlignmentA,
AlignmentB
> {
@ -1351,6 +1392,7 @@ template <
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
typename MathOperatorTag,
conv::StrideSupport StrideSupport,
int AlignmentA,
int AlignmentB
>
@ -1372,6 +1414,7 @@ struct DefaultConv2dFprop <
2,
MathOperatorTag,
IteratorAlgorithm::kOptimized,
StrideSupport,
AlignmentA,
AlignmentB
> {

View File

@ -65,7 +65,11 @@ template <
int Stages,
typename MathOperatorTag,
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
conv::StrideSupport StrideSupport = StrideSupport::kStrided
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
/// Access granularity of A matrix in units of elements
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
/// Access granularity of B matrix in units of elements
int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value
>
struct DefaultConv2dFpropWithBroadcast {
@ -84,7 +88,9 @@ struct DefaultConv2dFpropWithBroadcast {
Stages,
MathOperatorTag,
IteratorAlgorithm,
StrideSupport
StrideSupport,
AlignmentA,
AlignmentB
>::Kernel;
// Replace epilogue

View File

@ -66,7 +66,11 @@ template <
int Stages,
typename MathOperatorTag,
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
conv::StrideSupport StrideSupport = StrideSupport::kStrided
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
/// Access granularity of A matrix in units of elements
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
/// Access granularity of B matrix in units of elements
int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value
>
struct DefaultConv2dFpropWithReduction {
@ -85,7 +89,9 @@ struct DefaultConv2dFpropWithReduction {
Stages,
MathOperatorTag,
IteratorAlgorithm,
StrideSupport
StrideSupport,
AlignmentA,
AlignmentB
>::Kernel;
// Replace epilogue

View File

@ -67,8 +67,13 @@ template <
int Stages,
typename MathOperatorTag,
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
conv::StrideSupport StrideSupport = StrideSupport::kStrided
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
/// Access granularity of A matrix in units of elements
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
/// Access granularity of B matrix in units of elements
int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value
> struct DefaultConv2dWgrad;
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -93,7 +98,10 @@ template <
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag
typename MathOperatorTag,
conv::StrideSupport StrideSupport,
int AlignmentA,
int AlignmentB
>
struct DefaultConv2dWgrad <
ElementA,
@ -112,7 +120,10 @@ struct DefaultConv2dWgrad <
ThreadblockSwizzle,
Stages,
MathOperatorTag,
IteratorAlgorithm::kAnalytic
IteratorAlgorithm::kAnalytic,
StrideSupport,
AlignmentA,
AlignmentB
> {
// Define the core components from GEMM
@ -123,22 +134,26 @@ struct DefaultConv2dWgrad <
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
using IteratorA =
cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA,
ThreadMapA
ThreadMapA,
AccessTypeA
>;
using SmemIteratorA = typename MmaCore::SmemIteratorA;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
using IteratorB =
cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB,
ThreadMapB
ThreadMapB,
AccessTypeB
>;
using SmemIteratorB = typename MmaCore::SmemIteratorB;
@ -179,6 +194,7 @@ struct DefaultConv2dWgrad <
conv::Operator::kWgrad
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for Conv2dWgrad specialzation for Analytic IteratorAlgorithm and two
@ -198,7 +214,10 @@ template <
typename InstructionShape,
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
typename MathOperatorTag
typename MathOperatorTag,
conv::StrideSupport StrideSupport,
int AlignmentA,
int AlignmentB
>
struct DefaultConv2dWgrad <
ElementA,
@ -217,7 +236,10 @@ struct DefaultConv2dWgrad <
ThreadblockSwizzle,
2,
MathOperatorTag,
IteratorAlgorithm::kAnalytic
IteratorAlgorithm::kAnalytic,
StrideSupport,
AlignmentA,
AlignmentB
> {
// Define the core components from GEMM
@ -228,12 +250,14 @@ struct DefaultConv2dWgrad <
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
using IteratorA =
cutlass::conv::threadblock::TileIterator<
cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA,
ThreadMapA
ThreadMapA,
AccessTypeA
>
>;
@ -241,12 +265,14 @@ struct DefaultConv2dWgrad <
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
using IteratorB =
cutlass::conv::threadblock::TileIterator<
cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB,
ThreadMapB
ThreadMapB,
AccessTypeB
>
>;
@ -308,7 +334,10 @@ template <
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag
typename MathOperatorTag,
conv::StrideSupport StrideSupport,
int AlignmentA,
int AlignmentB
>
struct DefaultConv2dWgrad <
ElementA,
@ -327,7 +356,10 @@ struct DefaultConv2dWgrad <
ThreadblockSwizzle,
Stages,
MathOperatorTag,
IteratorAlgorithm::kOptimized
IteratorAlgorithm::kOptimized,
StrideSupport,
AlignmentA,
AlignmentB
> {
// Define the core components from GEMM
@ -338,22 +370,26 @@ struct DefaultConv2dWgrad <
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
using IteratorA =
cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA,
ThreadMapA
ThreadMapA,
AccessTypeA
>;
using SmemIteratorA = typename MmaCore::SmemIteratorA;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
using IteratorB =
cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB,
ThreadMapB
ThreadMapB,
AccessTypeB
>;
using SmemIteratorB = typename MmaCore::SmemIteratorB;
@ -394,6 +430,7 @@ struct DefaultConv2dWgrad <
conv::Operator::kWgrad
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for Conv2dWgrad specialzation for Optimized IteratorAlgorithm and two
@ -413,7 +450,10 @@ template <
typename InstructionShape,
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
typename MathOperatorTag
typename MathOperatorTag,
conv::StrideSupport StrideSupport,
int AlignmentA,
int AlignmentB
>
struct DefaultConv2dWgrad <
ElementA,
@ -432,7 +472,10 @@ struct DefaultConv2dWgrad <
ThreadblockSwizzle,
2,
MathOperatorTag,
IteratorAlgorithm::kOptimized
IteratorAlgorithm::kOptimized,
StrideSupport,
AlignmentA,
AlignmentB
> {
// Define the core components from GEMM
@ -443,12 +486,14 @@ struct DefaultConv2dWgrad <
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
using IteratorA =
cutlass::conv::threadblock::TileIterator<
cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA,
ThreadMapA
ThreadMapA,
AccessTypeA
>
>;
@ -456,12 +501,14 @@ struct DefaultConv2dWgrad <
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
using IteratorB =
cutlass::conv::threadblock::TileIterator<
cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB,
ThreadMapB
ThreadMapB,
AccessTypeB
>
>;
@ -524,7 +571,10 @@ template <
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag
typename MathOperatorTag,
conv::StrideSupport StrideSupport,
int AccessTypeA,
int AccessTypeB
>
struct DefaultConv2dWgrad <
ElementA,
@ -543,7 +593,10 @@ struct DefaultConv2dWgrad <
ThreadblockSwizzle,
Stages,
MathOperatorTag,
IteratorAlgorithm::kAnalytic
IteratorAlgorithm::kAnalytic,
StrideSupport,
AccessTypeA,
AccessTypeB
> {
// Define the core components from GEMM
@ -629,7 +682,10 @@ template <
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag
typename MathOperatorTag,
conv::StrideSupport StrideSupport,
int AccessTypeA,
int AccessTypeB
>
struct DefaultConv2dWgrad <
ElementA,
@ -648,7 +704,10 @@ struct DefaultConv2dWgrad <
ThreadblockSwizzle,
Stages,
MathOperatorTag,
IteratorAlgorithm::kOptimized
IteratorAlgorithm::kOptimized,
StrideSupport,
AccessTypeA,
AccessTypeB
> {
// Define the core components from GEMM
@ -732,7 +791,10 @@ template <
typename InstructionShape,
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
typename MathOperatorTag
typename MathOperatorTag,
conv::StrideSupport StrideSupport,
int AccessTypeA,
int AccessTypeB
>
struct DefaultConv2dWgrad <
ElementA,
@ -751,7 +813,10 @@ struct DefaultConv2dWgrad <
ThreadblockSwizzle,
2,
MathOperatorTag,
IteratorAlgorithm::kAnalytic
IteratorAlgorithm::kAnalytic,
StrideSupport,
AccessTypeA,
AccessTypeB
> {
// Define the core components from GEMM
@ -817,7 +882,6 @@ struct DefaultConv2dWgrad <
ThreadblockSwizzle,
conv::Operator::kWgrad
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -838,7 +902,10 @@ template <
typename InstructionShape,
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
typename MathOperatorTag
typename MathOperatorTag,
conv::StrideSupport StrideSupport,
int AccessTypeA,
int AccessTypeB
>
struct DefaultConv2dWgrad <
ElementA,
@ -857,7 +924,10 @@ struct DefaultConv2dWgrad <
ThreadblockSwizzle,
2,
MathOperatorTag,
IteratorAlgorithm::kOptimized
IteratorAlgorithm::kOptimized,
StrideSupport,
AccessTypeA,
AccessTypeB
> {
// Define the core components from GEMM
@ -925,12 +995,11 @@ struct DefaultConv2dWgrad <
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace conv
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -228,7 +228,6 @@ struct DefaultConv3dDgrad <
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using IteratorA =
cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,

View File

@ -501,4 +501,3 @@ struct DefaultConv3dWgrad <
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -59,7 +59,8 @@ template <
typename Shape_,
typename Element_,
typename ThreadMap_,
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity,
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
>
class Conv2dDgradFilterTileAccessIteratorAnalytic;
@ -70,13 +71,15 @@ class Conv2dDgradFilterTileAccessIteratorAnalytic;
template <
typename Shape_,
typename Element_,
typename ThreadMap_
typename ThreadMap_,
typename AccessType_
>
class Conv2dDgradFilterTileAccessIteratorAnalytic <
Shape_,
Element_,
ThreadMap_,
conv::StrideSupport::kStrided
conv::StrideSupport::kStrided,
AccessType_
> {
public:
@ -88,7 +91,7 @@ public:
using Element = Element_;
using Layout = layout::TensorNHWC;
using ThreadMap = ThreadMap_;
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
using AccessType = AccessType_;
using TensorRef = cutlass::TensorRef<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Index = typename Layout::Index;
@ -97,7 +100,12 @@ public:
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
static int const kConvDim = 2;
using ConvProblemSize = typename conv::Conv2dProblemSize;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
"Vectors implied by the thread map must be divisible by the access type.");
static_assert(sizeof_bits<Element>::value >= 8,
"DGRAD requires elements of size 8b or larger.");
@ -107,14 +115,13 @@ public:
using Params = Conv2dAnalyticParams<Layout>;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
private:
Params const &params_;
Conv2dProblemSize const &problem_size_;
LongIndex iteration_contiguous_;
LongIndex iteration_strided_;
LongIndex iteration_vector_;
char const *pointer_;
// For a fixed filter position (r,s) find and fill offset_k_, offset_c_ in strided and contiguous dimension
@ -162,8 +169,10 @@ public:
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(Index index) {
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
iteration_vector_ = index % kAccessesPerVector;
int residual_access = index / kAccessesPerVector;
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
}
/// Adds a pointer offset in units of Element
@ -213,7 +222,7 @@ public:
TensorCoord coord = at();
return coord.n() < problem_size_.K && coord.c() < problem_size_.C;
return coord.n() < problem_size_.K && (coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.C;
}
/// Returns a pointer to the vector starting at the current coordinate
@ -223,13 +232,19 @@ public:
TensorCoord coord = at();
LongIndex offset = params_.layout(coord);
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8) + iteration_vector_;
}
/// Increments to the next memory access
CUTLASS_HOST_DEVICE
Conv2dDgradFilterTileAccessIteratorAnalytic &operator++() {
++iteration_vector_;
if (iteration_vector_ < kAccessesPerVector) {
return *this;
}
iteration_vector_ = 0;
++iteration_contiguous_;
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
return *this;
@ -249,7 +264,7 @@ public:
static Status can_implement(Conv2dProblemSize const &problem_size) {
// check alignment constraint on iterator's contiguous dimension
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
if (problem_size.C % AccessType::kElements) {
return Status::kErrorInvalidProblem;
}
@ -263,13 +278,15 @@ public:
template <
typename Shape_,
typename Element_,
typename ThreadMap_
typename ThreadMap_,
typename AccessType_
>
class Conv2dDgradFilterTileAccessIteratorAnalytic <
Shape_,
Element_,
ThreadMap_,
conv::StrideSupport::kUnity
conv::StrideSupport::kUnity,
AccessType_
>{
public:
@ -281,7 +298,7 @@ public:
using Element = Element_;
using Layout = layout::TensorNHWC;
using ThreadMap = ThreadMap_;
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
using AccessType = AccessType_;
using TensorRef = cutlass::TensorRef<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Index = typename Layout::Index;
@ -290,7 +307,12 @@ public:
static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity;
static int const kConvDim = 2;
using ConvProblemSize = typename conv::Conv2dProblemSize;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
"Vectors implied by the thread map must be divisible by the access type.");
static_assert(sizeof_bits<Element>::value >= 8,
"DGRAD requires elements of size 8b or larger.");
@ -306,6 +328,7 @@ private:
Conv2dProblemSize const &problem_size_;
LongIndex iteration_contiguous_;
LongIndex iteration_strided_;
LongIndex iteration_vector_;
char const *pointer_;
// For a fixed filter position (r,s) find and fill offset_k_, offset_c_ in strided and contiguous dimension
@ -348,8 +371,10 @@ public:
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(Index index) {
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
iteration_vector_ = index % kAccessesPerVector;
int residual_access = index / kAccessesPerVector;
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
}
/// Adds a pointer offset in units of Element
@ -395,7 +420,7 @@ public:
TensorCoord coord = at();
return coord.n() < problem_size_.K && coord.c() < problem_size_.C;
return coord.n() < problem_size_.K && (coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.C;
}
/// Returns a pointer to the vector starting at the current coordinate
@ -405,13 +430,18 @@ public:
TensorCoord coord = at();
LongIndex offset = params_.layout(coord);
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8) + iteration_vector_;
}
/// Increments to the next memory access
CUTLASS_HOST_DEVICE
Conv2dDgradFilterTileAccessIteratorAnalytic &operator++() {
++iteration_vector_;
if (iteration_vector_ < kAccessesPerVector) {
return *this;
}
iteration_vector_ = 0;
++iteration_contiguous_;
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
return *this;
@ -431,7 +461,7 @@ public:
static Status can_implement(Conv2dProblemSize const &problem_size) {
// check alignment constraint on iterator's contiguous dimension
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
if (problem_size.C % AccessType::kElements) {
return Status::kErrorInvalidProblem;
}
@ -446,5 +476,3 @@ public:
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -60,7 +60,8 @@ template <
typename Shape_,
typename Element_,
typename ThreadMap_,
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity,
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
>
class Conv2dDgradFilterTileAccessIteratorOptimized;
@ -71,13 +72,15 @@ class Conv2dDgradFilterTileAccessIteratorOptimized;
template <
typename Shape_,
typename Element_,
typename ThreadMap_
typename ThreadMap_,
typename AccessType_
>
class Conv2dDgradFilterTileAccessIteratorOptimized <
Shape_,
Element_,
ThreadMap_,
conv::StrideSupport::kUnity
conv::StrideSupport::kUnity,
AccessType_
> {
public:
@ -89,7 +92,7 @@ public:
using Element = Element_;
using Layout = layout::TensorNHWC;
using ThreadMap = ThreadMap_;
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
using AccessType = AccessType_;
using TensorRef = cutlass::TensorRef<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Index = typename Layout::Index;
@ -98,9 +101,12 @@ public:
static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity;
static int const kConvDim = 2;
using ConvProblemSize = typename conv::Conv2dProblemSize;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
"Vectors implied by the thread map must be divisible by the access type.");
//
// Parameters structure
//
@ -141,9 +147,10 @@ private:
Conv2dProblemSize const &problem_size_;
LongIndex iteration_contiguous_;
LongIndex iteration_strided_;
LongIndex iteration_vector_;
char const *pointer_;
uint32_t predicates_;
uint32_t predicates_[kAccessesPerVector];
int filter_rs_;
int filter_k_;
@ -169,7 +176,7 @@ public:
params_(params),
problem_size_(problem_size),
pointer_(reinterpret_cast<char const *>(ptr)),
predicates_(0),
predicates_{0},
filter_rs_(0),
filter_k_(0) {
@ -186,11 +193,15 @@ public:
int filter_k = filter_k_ + s * ThreadMap::Delta::kStrided;
int filter_c = column + c * ThreadMap::Delta::kContiguous;
uint32_t pred = ((filter_k < problem_size_.K && filter_c < problem_size_.C) ? 1u : 0);
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < kAccessesPerVector; ++v) {
int pred_idx = c + s * ThreadMap::Iterations::kContiguous;
predicates_ |= (pred << pred_idx);
uint32_t pred = ((filter_k < problem_size_.K && (filter_c + v * AccessType::kElements) < problem_size_.C) ? 1u : 0);
int pred_idx = c + s * ThreadMap::Iterations::kContiguous;
predicates_[v] |= (pred << pred_idx);
}
}
}
@ -204,8 +215,10 @@ public:
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(Index index) {
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
iteration_vector_ = index % kAccessesPerVector;
int residual_access = index / kAccessesPerVector;
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
}
/// Adds a pointer offset in units of Element
@ -234,7 +247,11 @@ public:
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
if (filter_k_ + s * ThreadMap::Delta::kStrided >= problem_size_.K) {
uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous);
predicates_ = (predicates_ & (~kClearMask));
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < kAccessesPerVector; ++v) {
predicates_[v] = (predicates_[v] & (~kClearMask));
}
}
}
@ -245,19 +262,25 @@ public:
CUTLASS_HOST_DEVICE
bool valid() {
LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous;
return (predicates_ & (1u << pred_idx));
return (predicates_[iteration_vector_] & (1u << pred_idx));
}
/// Returns a pointer to the vector starting at the current coordinate
CUTLASS_HOST_DEVICE
AccessType const *get() const {
return reinterpret_cast<AccessType const *>(pointer_ +
iteration_contiguous_ * ThreadMap::Delta::kContiguous * sizeof_bits<Element>::value / 8);
iteration_contiguous_ * ThreadMap::Delta::kContiguous * sizeof_bits<Element>::value / 8) + iteration_vector_;
}
/// Increments to the next memory access
CUTLASS_HOST_DEVICE
Conv2dDgradFilterTileAccessIteratorOptimized &operator++() {
++iteration_vector_;
if (iteration_vector_ < kAccessesPerVector) {
return *this;
}
iteration_vector_ = 0;
++iteration_contiguous_;
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
return *this;
@ -282,7 +305,7 @@ public:
static Status can_implement(Conv2dProblemSize const &problem_size) {
// check alignment constraint on iterator's contiguous dimension
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
if (problem_size.C % AccessType::kElements) {
return Status::kErrorInvalidProblem;
}
@ -297,5 +320,3 @@ public:
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -59,7 +59,8 @@ template <
typename Shape_,
typename Element_,
typename ThreadMap_,
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kStrided
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kStrided,
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
>
class Conv2dDgradOutputGradientTileAccessIteratorAnalytic;
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -69,13 +70,15 @@ class Conv2dDgradOutputGradientTileAccessIteratorAnalytic;
template <
typename Shape_,
typename Element_,
typename ThreadMap_
typename ThreadMap_,
typename AccessType_
>
class Conv2dDgradOutputGradientTileAccessIteratorAnalytic <
Shape_,
Element_,
ThreadMap_,
conv::StrideSupport::kStrided
conv::StrideSupport::kStrided,
AccessType_
> {
public:
@ -86,7 +89,7 @@ public:
using Element = Element_;
using Layout = layout::TensorNHWC;
using ThreadMap = ThreadMap_;
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
using AccessType = AccessType_;
using TensorRef = cutlass::TensorRef<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Index = typename Layout::Index;
@ -95,7 +98,12 @@ public:
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
static int const kConvDim = 2;
using ConvProblemSize = typename conv::Conv2dProblemSize;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
"Vectors implied by the thread map must be divisible by the access type.");
static_assert(sizeof_bits<Element>::value >= 8,
"DGRAD requires elements of size 8b or greater.");
@ -112,14 +120,13 @@ public:
using Params = Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
private:
Params const &params_;
Conv2dProblemSize const &problem_size_;
LongIndex iteration_contiguous_;
LongIndex iteration_strided_;
LongIndex iteration_vector_;
char const *pointer_;
int filter_k_;
@ -211,8 +218,10 @@ public:
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(Index index) {
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
iteration_vector_ = index % kAccessesPerVector;
int residual_access = index / kAccessesPerVector;
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
}
/// Adds a pointer offset in units of Element
@ -277,7 +286,7 @@ public:
coord.n() < problem_size_.N &&
coord.h() >= 0 && coord.h() < problem_size_.P &&
coord.w() >= 0 && coord.w() < problem_size_.Q &&
coord.c() < problem_size_.K;
(coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.K;
}
/// Returns a pointer to the vector starting at the current coordinate
@ -287,12 +296,18 @@ public:
TensorCoord coord = at();
LongIndex offset = params_.layout(coord);
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8) + iteration_vector_;
}
/// Increments to the next memory access
CUTLASS_HOST_DEVICE
Conv2dDgradOutputGradientTileAccessIteratorAnalytic &operator++() {
++iteration_vector_;
if (iteration_vector_ < kAccessesPerVector) {
return *this;
}
iteration_vector_ = 0;
++iteration_contiguous_;
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
return *this;
@ -312,14 +327,14 @@ public:
static Status can_implement(Conv2dProblemSize const &problem_size) {
// check alignment constraint on iterator's contiguous dimension
if (problem_size.K % (128/sizeof_bits<Element>::value)) {
if (problem_size.K % AccessType::kElements) {
return Status::kErrorInvalidProblem;
}
return Status::kSuccess;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Conv2dDgradOutputGradientTileAccessIteratorAnalytic for unity strides can be optimized by
@ -327,13 +342,15 @@ public:
template <
typename Shape_,
typename Element_,
typename ThreadMap_
typename ThreadMap_,
typename AccessType_
>
class Conv2dDgradOutputGradientTileAccessIteratorAnalytic <
Shape_,
Element_,
ThreadMap_,
conv::StrideSupport::kUnity
conv::StrideSupport::kUnity,
AccessType_
> {
public:
@ -344,7 +361,7 @@ public:
using Element = Element_;
using Layout = layout::TensorNHWC;
using ThreadMap = ThreadMap_;
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
using AccessType = AccessType_;
using TensorRef = cutlass::TensorRef<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Index = typename Layout::Index;
@ -353,7 +370,12 @@ public:
static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity;
static int const kConvDim = 2;
using ConvProblemSize = typename conv::Conv2dProblemSize;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
"Vectors implied by the thread map must be divisible by the access type.");
static_assert(sizeof_bits<Element>::value >= 8,
"DGRAD requires elements of size 8b or greater.");
@ -368,8 +390,6 @@ public:
// Parameters structure
//
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
struct Params {
Layout layout;
@ -395,6 +415,7 @@ private:
Conv2dProblemSize const &problem_size_;
LongIndex iteration_contiguous_;
LongIndex iteration_strided_;
LongIndex iteration_vector_;
char const *pointer_;
int filter_k_;
@ -446,8 +467,10 @@ public:
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(Index index) {
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
iteration_vector_ = index % kAccessesPerVector;
int residual_access = index / kAccessesPerVector;
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
}
/// Adds a pointer offset in units of Element
@ -497,7 +520,6 @@ public:
}
/// Returns true if the current coordinate is within the output tensor Dy
CUTLASS_HOST_DEVICE
bool valid() const {
@ -507,7 +529,7 @@ public:
return coord.n() < problem_size_.N &&
coord.h() >= 0 && coord.h() < problem_size_.P &&
coord.w() >= 0 && coord.w() < problem_size_.Q &&
coord.c() < problem_size_.K;
(coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.K;
}
/// Returns a pointer to the vector starting at the current coordinate
@ -517,12 +539,18 @@ public:
TensorCoord coord = at();
LongIndex offset = params_.layout(coord);
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8) + iteration_vector_;
}
/// Increments to the next memory access
CUTLASS_HOST_DEVICE
Conv2dDgradOutputGradientTileAccessIteratorAnalytic &operator++() {
++iteration_vector_;
if (iteration_vector_ < kAccessesPerVector) {
return *this;
}
iteration_vector_ = 0;
++iteration_contiguous_;
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
return *this;
@ -548,7 +576,7 @@ public:
}
// check alignment constraint on iterator's contiguous dimension
if (problem_size.K % (128/sizeof_bits<Element>::value)) {
if (problem_size.K % AccessType::kElements) {
return Status::kErrorInvalidProblem;
}
@ -556,7 +584,9 @@ public:
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace conv
} // namespace cutlass

View File

@ -61,7 +61,8 @@ template <
typename Shape_,
typename Element_,
typename ThreadMap_,
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity,
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
>
class Conv2dDgradOutputGradientTileAccessIteratorOptimized;
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -74,14 +75,16 @@ class Conv2dDgradOutputGradientTileAccessIteratorOptimized;
template <
typename Shape_,
typename Element_,
typename ThreadMap_
typename ThreadMap_,
typename AccessType_
>
class Conv2dDgradOutputGradientTileAccessIteratorOptimized <
Shape_,
Element_,
ThreadMap_,
conv::StrideSupport::kUnity
> {
conv::StrideSupport::kUnity,
AccessType_
> {
public:
//
@ -93,7 +96,7 @@ public:
using Layout = layout::TensorNHWC;
using TensorCoord = typename Layout::TensorCoord;
using ThreadMap = ThreadMap_;
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
using AccessType = AccessType_;
using TensorRef = cutlass::TensorRef<Element, Layout>;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
@ -101,7 +104,12 @@ public:
static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity;
static int const kConvDim = 2;
using ConvProblemSize = typename conv::Conv2dProblemSize;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
"Vectors implied by the thread map must be divisible by the access type.");
using Mask = uint64_t;
//
@ -116,14 +124,13 @@ public:
using Params = Conv2dDgradOutputGradientIteratorOptimizedParams;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
private:
Conv2dDgradOutputGradientIteratorOptimizedParams const &params_;
Conv2dProblemSize const &problem_size_;
LongIndex iteration_contiguous_;
LongIndex iteration_strided_;
LongIndex iteration_vector_;
// One pointer per access
char const *pointer_[ThreadMap::Iterations::kStrided];
@ -133,7 +140,7 @@ private:
int filter_s_;
int filter_k_;
Index masks_[ThreadMap::Iterations::kStrided][2];
Index masks_[ThreadMap::Iterations::kStrided][kAccessesPerVector][2];
public:
@ -201,7 +208,11 @@ public:
int p = offset_h[s_idx] + problem_size_.pad_h - r_ * problem_size_.dilation_h;
bool pred = (offset_n[s_idx] < problem_size_.N && p >= 0 && p < problem_size_.P);
masks_[s_idx][0] |= (pred << r);
CUTLASS_PRAGMA_UNROLL
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
masks_[s_idx][v_idx][0] |= (pred << r);
}
}
}
@ -218,12 +229,17 @@ public:
int q = offset_w[s_idx] + problem_size_.pad_w - s_ * problem_size_.dilation_w;
bool pred = (q >= 0 && q < problem_size_.Q);
masks_[s_idx][1] |= (pred << s);
CUTLASS_PRAGMA_UNROLL
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
masks_[s_idx][v_idx][1] |= (pred << s);
}
}
}
if (filter_k_ >= problem_size.K) {
clear_mask();
CUTLASS_PRAGMA_UNROLL
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
clear_mask(v_idx, filter_k_ >= problem_size.K);
}
set_iteration_index(0);
@ -269,62 +285,15 @@ private:
}
}
/// Clears the predicates
CUTLASS_HOST_DEVICE
void clear_mask_(bool clear) {
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
// We are using inline PTX assembly here to avoid an CUDA C++ compilation
// artifact in which control flow instructions are generated. Instead, our
// intent is to predicate the mov instructions.
#if defined(__CUDA_ARCH__)
asm volatile(
"{\n"
" .reg .pred p;\n"
" .reg .u32 m;"
" mov.u32 m, %2;"
" setp.ne.b32 p, %1, 0;\n"
" @p mov.u32 m, 0;\n"
" mov.u32 %0, m;\n"
"}\n"
:
"=r"(masks_[s][0])
:
"r"((int)clear),
"r"(masks_[s][0])
);
asm volatile(
"{\n"
" .reg .pred p;\n"
" .reg .u32 m;"
" mov.u32 m, %2;"
" setp.ne.b32 p, %1, 0;\n"
" @p mov.u32 m, 0;\n"
" mov.u32 %0, m;\n"
"}\n"
:
"=r"(masks_[s][1])
:
"r"((int)clear),
"r"(masks_[s][1])
);
#else
if (clear) {
masks_[s][0] = 0;
masks_[s][1] = 0;
}
#endif
}
}
public:
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(Index index) {
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
iteration_vector_ = index % kAccessesPerVector;
int residual_access = index / kAccessesPerVector;
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
}
/// Adds a pointer offset in units of element
@ -359,16 +328,32 @@ public:
filter_k_ += params_.filter_k_delta;
}
clear_mask_(filter_k_ >= problem_size_.K);
CUTLASS_PRAGMA_UNROLL
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
clear_mask(v_idx, (filter_k_ + v_idx * AccessType::kElements) >= problem_size_.K);
}
}
/// Clears the predicates
CUTLASS_HOST_DEVICE
void clear_mask() {
void clear_mask(bool clear = true) {
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
masks_[s][0] = Mask(0);
masks_[s][1] = Mask(0);
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < kAccessesPerVector; ++v) {
masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0];
masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1];
}
}
}
/// Clears the predicates
CUTLASS_HOST_DEVICE
void clear_mask(int v, bool clear = true) {
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0];
masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1];
}
}
@ -376,20 +361,25 @@ public:
bool valid() {
return
(masks_[iteration_strided_][0] & (Index(1) << filter_r_)) &&
(masks_[iteration_strided_][1] & (Index(1) << filter_s_));
(masks_[iteration_strided_][iteration_vector_][0] & (Index(1) << filter_r_)) &&
(masks_[iteration_strided_][iteration_vector_][1] & (Index(1) << filter_s_));
}
/// Returns a pointer to the vector starting at the current coordinate
CUTLASS_HOST_DEVICE
AccessType const *get() const {
return reinterpret_cast<AccessType const *>(pointer_[iteration_strided_]);
return reinterpret_cast<AccessType const *>(pointer_[iteration_strided_]) + iteration_vector_;
}
/// Increments to the next memory access
CUTLASS_HOST_DEVICE
Conv2dDgradOutputGradientTileAccessIteratorOptimized &operator++() {
++iteration_vector_;
if (iteration_vector_ < kAccessesPerVector) {
return *this;
}
iteration_vector_ = 0;
++iteration_contiguous_;
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
@ -416,7 +406,7 @@ public:
}
// check alignment constraint on iterator's contiguous dimension
if (problem_size.K % (128/sizeof_bits<Element>::value)) {
if (problem_size.K % AccessType::kElements) {
return Status::kErrorNotSupported;
}

View File

@ -60,7 +60,8 @@ template <
typename Shape_,
typename Element_,
typename Layout_,
typename ThreadMap_
typename ThreadMap_,
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
>
class Conv2dFpropActivationTileAccessIteratorAnalytic {
public:
@ -74,7 +75,7 @@ public:
using Layout = Layout_;
using TensorCoord = typename Layout::TensorCoord;
using ThreadMap = ThreadMap_;
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
using AccessType = AccessType_;
using TensorRef = cutlass::TensorRef<Element, Layout>;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
@ -82,7 +83,12 @@ public:
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
static int const kConvDim = 2;
using ConvProblemSize = typename conv::Conv2dProblemSize;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
"Vectors implied by the thread map must be divisible by the access type.");
//
// Simplifying assertions
//
@ -95,14 +101,13 @@ public:
using Params = Conv2dAnalyticParams<Layout>;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
private:
Params const &params_;
Conv2dProblemSize const &problem_size_;
LongIndex iteration_contiguous_;
LongIndex iteration_strided_;
LongIndex iteration_vector_;
char const *pointer_;
int filter_c_;
@ -156,8 +161,10 @@ public:
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(Index index) {
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
iteration_vector_ = index % kAccessesPerVector;
int residual_access = index / kAccessesPerVector;
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
}
/// Adds a pointer offset in units of Element
@ -214,7 +221,7 @@ public:
return coord.n() < problem_size_.N &&
coord.h() >= 0 && coord.h() < problem_size_.H &&
coord.w() >= 0 && coord.w() < problem_size_.W &&
coord.c() < problem_size_.C;
(coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.C;
}
/// Returns a pointer to the vector starting at the current coordinate
@ -224,7 +231,7 @@ public:
TensorCoord coord = at();
LongIndex offset = params_.layout(coord);
AccessType const *ptr = reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
AccessType const *ptr = reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8) + iteration_vector_;
return ptr;
}
@ -232,6 +239,12 @@ public:
/// Increments to the next memory access
CUTLASS_HOST_DEVICE
Conv2dFpropActivationTileAccessIteratorAnalytic &operator++() {
++iteration_vector_;
if (iteration_vector_ < kAccessesPerVector) {
return *this;
}
iteration_vector_ = 0;
++iteration_contiguous_;
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
return *this;
@ -252,7 +265,7 @@ public:
static Status can_implement(Conv2dProblemSize const &problem_size) {
// check alignment constraint on iterator's contiguous dimension
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
if (problem_size.C % AccessType::kElements) {
return Status::kErrorInvalidProblem;
}

View File

@ -61,7 +61,7 @@ template <
typename Element_,
typename Layout_,
typename ThreadMap_,
int AccessSize = ThreadMap_::kElementsPerAccess
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
>
class Conv2dFpropActivationTileAccessIteratorOptimized {
public:
@ -75,7 +75,7 @@ public:
using Layout = Layout_;
using TensorCoord = typename Layout::TensorCoord;
using ThreadMap = ThreadMap_;
using AccessType = AlignedArray<Element, AccessSize>;
using AccessType = AccessType_;
using TensorRef = cutlass::TensorRef<Element, Layout>;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
@ -86,6 +86,11 @@ public:
using Mask = uint64_t;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
"Vectors implied by the thread map must be divisible by the access type.");
//
// Simplifying assertions
//
@ -98,8 +103,6 @@ public:
using Params = Conv2dFpropActivationIteratorOptimizedParams<Layout>;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
private:
Params const &params_;
@ -213,10 +216,10 @@ public:
}
}
CUTLASS_PRAGMA_UNROLL
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
clear_mask_(filter_c_ + v_idx * AccessSize >= problem_size_.C, v_idx);
}
CUTLASS_PRAGMA_UNROLL
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C);
}
set_iteration_index(0);
}
@ -260,56 +263,7 @@ private:
pointer_[s] += byte_offset;
}
}
/// Clears the predicates
CUTLASS_HOST_DEVICE
void clear_mask_(bool clear, int index) {
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
// We are using inline PTX assembly here to avoid an CUDA C++ compilation
// artifact in which control flow instructions are generated. Instead, our
// intent is to predicate the mov instructions.
#if defined(__CUDA_ARCH__)
asm volatile(
"{\n"
" .reg .pred p;\n"
" .reg .u32 m;"
" mov.u32 m, %2;"
" setp.ne.b32 p, %1, 0;\n"
" @p mov.u32 m, 0;\n"
" mov.u32 %0, m;\n"
"}\n"
:
"=r"(masks_[s][index][0])
:
"r"((int)clear),
"r"(masks_[s][index][0])
);
asm volatile(
"{\n"
" .reg .pred p;\n"
" .reg .u32 m;"
" mov.u32 m, %2;"
" setp.ne.b32 p, %1, 0;\n"
" @p mov.u32 m, 0;\n"
" mov.u32 %0, m;\n"
"}\n"
:
"=r"(masks_[s][index][1])
:
"r"((int)clear),
"r"(masks_[s][index][1])
);
#else
if (clear) {
masks_[s][index][0] = 0;
masks_[s][index][1] = 0;
}
#endif
}
}
public:
/// Overrides the internal iteration index
@ -354,23 +308,33 @@ public:
filter_c_ += params_.filter_c_delta;
}
CUTLASS_PRAGMA_UNROLL
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
clear_mask_(filter_c_ + v_idx * AccessSize >= problem_size_.C, v_idx);
}
CUTLASS_PRAGMA_UNROLL
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C);
}
}
/// Clears the predicates
CUTLASS_HOST_DEVICE
void clear_mask() {
void clear_mask(bool clear = true) {
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < kAccessesPerVector; ++v) {
masks_[s][v][0] = Mask(0);
masks_[s][v][1] = Mask(0);
masks_[s][v][0] = clear ? 0 : masks_[s][v][0];
masks_[s][v][1] = clear ? 0 : masks_[s][v][1];
}
}
}
/// Clears the predicates
CUTLASS_HOST_DEVICE
void clear_mask(int v, bool clear = true) {
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
masks_[s][v][0] = clear ? 0 : masks_[s][v][0];
masks_[s][v][1] = clear ? 0 : masks_[s][v][1];
}
}
CUTLASS_HOST_DEVICE
@ -396,7 +360,6 @@ public:
if (iteration_vector_ < kAccessesPerVector) {
return *this;
}
iteration_vector_ = 0;
++iteration_contiguous_;
@ -419,7 +382,7 @@ public:
static Status can_implement(Conv2dProblemSize const &problem_size) {
// check alignment constraint on iterator's contiguous dimension
if (problem_size.C % AccessSize) {
if (problem_size.C % AccessType::kElements) {
return Status::kErrorInvalidProblem;
}

View File

@ -59,7 +59,8 @@ template <
typename Shape_,
typename Element_,
typename Layout_,
typename ThreadMap_
typename ThreadMap_,
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
>
class Conv2dFpropFilterTileAccessIteratorAnalytic {
public:
@ -72,7 +73,7 @@ public:
using Element = Element_;
using Layout = Layout_;
using ThreadMap = ThreadMap_;
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
using AccessType = AccessType_;
using TensorRef = cutlass::TensorRef<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Index = typename Layout::Index;
@ -81,7 +82,12 @@ public:
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
static int const kConvDim = 2;
using ConvProblemSize = typename conv::Conv2dProblemSize;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
"Vectors implied by the thread map must be divisible by the access type.");
//
// Simplifying assertions
//
@ -94,14 +100,13 @@ public:
using Params = Conv2dAnalyticParams<Layout>;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
private:
Params const &params_;
Conv2dProblemSize const &problem_size_;
LongIndex iteration_contiguous_;
LongIndex iteration_strided_;
LongIndex iteration_vector_;
char const *pointer_;
int filter_r_;
@ -142,8 +147,10 @@ public:
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(Index index) {
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
iteration_vector_ = index % kAccessesPerVector;
int residual_access = index / kAccessesPerVector;
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
}
/// Adds a pointer offset in units of Element
@ -187,7 +194,7 @@ public:
TensorCoord coord = at();
return coord.n() < problem_size_.K &&
coord.c() < problem_size_.C;
(coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.C;
}
/// Returns a pointer to the vector starting at the current coordinate
@ -197,12 +204,18 @@ public:
TensorCoord coord = at();
LongIndex offset = params_.layout(coord);
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8) + iteration_vector_;
}
/// Increments to the next memory access
CUTLASS_HOST_DEVICE
Conv2dFpropFilterTileAccessIteratorAnalytic &operator++() {
++iteration_vector_;
if (iteration_vector_ < kAccessesPerVector) {
return *this;
}
iteration_vector_ = 0;
++iteration_contiguous_;
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
return *this;
@ -223,7 +236,7 @@ public:
static Status can_implement(Conv2dProblemSize const &problem_size) {
// check alignment constraint on iterator's contiguous dimension
if (problem_size.K % (128/sizeof_bits<Element>::value)) {
if (problem_size.K % AccessType::kElements) {
return Status::kErrorInvalidProblem;
}
@ -250,5 +263,3 @@ public:
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -61,7 +61,7 @@ template <
typename Element_,
typename Layout_,
typename ThreadMap_,
int AccessSize = ThreadMap_::kElementsPerAccess
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
>
class Conv2dFpropFilterTileAccessIteratorOptimized{
public:
@ -74,7 +74,7 @@ public:
using Element = Element_;
using Layout = Layout_;
using ThreadMap = ThreadMap_;
using AccessType = AlignedArray<Element, AccessSize>;
using AccessType = AccessType_;
using TensorRef = cutlass::TensorRef<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Index = typename Layout::Index;
@ -83,15 +83,18 @@ public:
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
static int const kConvDim = 2;
using ConvProblemSize = typename conv::Conv2dProblemSize;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
"Vectors implied by the thread map must be divisible by the access type.");
//
// Simplifying assertions
//
static_assert(ThreadMap::Iterations::kContiguous == 1,
"Require Iterations::kContiguous == 1");
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
//
// Parameters structure
//
@ -170,6 +173,7 @@ public:
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < problem_size_.K) ? 1u : 0);
CUTLASS_PRAGMA_UNROLL
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
predicates_[v_idx] |= (pred << s);
@ -178,7 +182,7 @@ public:
CUTLASS_PRAGMA_UNROLL
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
clear_mask_(filter_c_ + v_idx * AccessSize >= problem_size_.C, v_idx);
clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C);
}
pointer_ += (
@ -188,41 +192,11 @@ public:
set_iteration_index(0);
}
/// Clears the predicates
CUTLASS_HOST_DEVICE
void clear_mask_(bool clear, int index) {
// We are using inline PTX assembly here to avoid an CUDA C++ compilation
// artifact in which control flow instructions are generated. Instead, our
// intent is to predicate the mov instructions.
#if defined(__CUDA_ARCH__)
asm volatile(
"{\n"
" .reg .pred p;\n"
" .reg .u32 m;"
" mov.u32 m, %2;"
" setp.ne.b32 p, %1, 0;\n"
" @p mov.u32 m, 0;\n"
" mov.u32 %0, m;\n"
"}\n"
:
"=r"(predicates_[index])
:
"r"((int)clear),
"r"(predicates_[index])
);
#else
if (clear) {
predicates_[index] = 0;
}
#endif
}
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(Index index) {
iteration_vector_ = index % kAccessesPerVector;
int residual_access = index / kAccessesPerVector;
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
}
@ -246,15 +220,21 @@ public:
next = params_.inc_next_c;
filter_c_ += params_.filter_c_delta;
}
CUTLASS_PRAGMA_UNROLL
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
clear_mask_(filter_c_ + v_idx * AccessSize >= problem_size_.C, v_idx);
clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C);
}
pointer_ += next;
}
/// Clears the predicates
CUTLASS_HOST_DEVICE
void clear_mask(int v, bool clear = true) {
predicates_[v] = clear ? 0u : predicates_[v];
}
/// Returns true if the current coordinate is within the filter tensor W
CUTLASS_HOST_DEVICE
bool valid() {
@ -274,7 +254,6 @@ public:
if (iteration_vector_ < kAccessesPerVector) {
return *this;
}
iteration_vector_ = 0;
++iteration_contiguous_;
@ -301,7 +280,7 @@ public:
static Status can_implement(Conv2dProblemSize const &problem_size) {
// check alignment constraint on iterator's contiguous dimension
if (problem_size.C % AccessSize) {
if (problem_size.C % AccessType::kElements) {
return Status::kErrorInvalidProblem;
}

View File

@ -68,6 +68,7 @@ public:
using Params = typename TileAccessIterator::Params;
static int const kConvDim = TileAccessIterator::kConvDim;
using ConvProblemSize = typename TileAccessIterator::ConvProblemSize;
static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector;
/// Fragment object to be loaded or stored
using Fragment = cutlass::Array<
@ -130,18 +131,20 @@ public:
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < tile_access_iterator_.kAccessesPerVector; ++v) {
for (int v = 0; v < kAccessesPerVector; ++v) {
int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
cutlass::arch::global_load<
AccessType,
sizeof(AccessType)
>(
frag_ptr[(c + s * ThreadMap::Iterations::kContiguous) * tile_access_iterator_.kAccessesPerVector + v],
frag_ptr[idx],
tile_access_iterator_.get() + pointer_offset,
tile_access_iterator_.valid()
);
++tile_access_iterator_;
}
}

View File

@ -58,7 +58,8 @@ namespace threadblock {
template <
typename Shape_,
typename Element_,
typename ThreadMap_
typename ThreadMap_,
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
>
class Conv2dWgradActivationTileAccessIteratorAnalytic {
public:
@ -70,7 +71,7 @@ public:
using Element = Element_;
using Layout = layout::TensorNHWC;
using ThreadMap = ThreadMap_;
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
using AccessType = AccessType_;
using TensorRef = cutlass::TensorRef<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Index = typename Layout::Index;
@ -79,7 +80,12 @@ public:
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
static int const kConvDim = 2;
using ConvProblemSize = typename conv::Conv2dProblemSize;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
"Vectors implied by the thread map must be divisible by the access type.");
static_assert(sizeof_bits<Element>::value >= 8,
"WGRAD requires elements of size 8b or greater.");
@ -89,14 +95,13 @@ public:
using Params = Conv2dAnalyticParams<Layout>;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
private:
Params const &params_;
Conv2dProblemSize const &problem_size_;
LongIndex iteration_contiguous_;
LongIndex iteration_strided_;
LongIndex iteration_vector_;
char const *pointer_;
// Filter postion (r,s,c) in contiguous dimension stays constant for each gemm_iteration_k
@ -149,8 +154,10 @@ public:
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(Index index) {
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
iteration_vector_ = index % kAccessesPerVector;
int residual_access = index / kAccessesPerVector;
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
}
/// Adds a pointer offset in units of Element
@ -173,9 +180,19 @@ public:
/// by the iterator.
CUTLASS_HOST_DEVICE
TensorCoord at() const {
int r, s, c;
int r = filter_r_[iteration_contiguous_];
int s = filter_s_[iteration_contiguous_];
if (kAccessesPerVector == 1) {
r = filter_r_[iteration_contiguous_];
s = filter_s_[iteration_contiguous_];
c = filter_c_[iteration_contiguous_];
} else {
c = (filter_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements) % problem_size_.C;
int wrap_c = (filter_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements) / problem_size_.C;
s = (filter_s_[iteration_contiguous_] + wrap_c) % problem_size_.S;
int wrap_s = (filter_s_[iteration_contiguous_] + wrap_c) / problem_size_.S;
r = filter_r_[iteration_contiguous_] + wrap_s;
}
if (problem_size_.mode == Mode::kConvolution) {
r = (problem_size_.R - 1 - r);
@ -184,14 +201,14 @@ public:
int n = offset_npq_[iteration_strided_] / (problem_size_.P * problem_size_.Q);
int residual = offset_npq_[iteration_strided_] % (problem_size_.P * problem_size_.Q);
int p = residual / problem_size_.Q;
int q = residual % problem_size_.Q;
int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h;
int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w;
return TensorCoord(n, h, w, filter_c_[iteration_contiguous_]);
return TensorCoord(n, h, w, c);
}
/// Returns true if the current coordinate is within the activation tensor x
@ -201,8 +218,7 @@ public:
return coord.n() < problem_size_.N &&
coord.h() >= 0 && coord.h() < problem_size_.H &&
coord.w() >= 0 && coord.w() < problem_size_.W &&
coord.c() < problem_size_.C;
coord.w() >= 0 && coord.w() < problem_size_.W;
}
/// Returns a pointer to the vector starting at the current coordinate
@ -218,6 +234,12 @@ public:
/// Increments to the next memory access
CUTLASS_HOST_DEVICE
Conv2dWgradActivationTileAccessIteratorAnalytic &operator++() {
++iteration_vector_;
if (iteration_vector_ < kAccessesPerVector) {
return *this;
}
iteration_vector_ = 0;
++iteration_contiguous_;
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
return *this;
@ -237,13 +259,12 @@ public:
static Status can_implement(Conv2dProblemSize const &problem_size) {
// check alignment constraint on iterator's contiguous dimension
if (problem_size.K % (128/sizeof_bits<Element>::value)) {
if (problem_size.K % AccessType::kElements) {
return Status::kErrorInvalidProblem;
}
return Status::kSuccess;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -57,7 +57,8 @@ namespace threadblock {
template <
typename Shape_,
typename Element_,
typename ThreadMap_
typename ThreadMap_,
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
>
class Conv2dWgradActivationTileAccessIteratorOptimized {
public:
@ -69,7 +70,7 @@ public:
using Element = Element_;
using Layout = layout::TensorNHWC;
using ThreadMap = ThreadMap_;
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
using AccessType = AccessType_;
using TensorRef = cutlass::TensorRef<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Index = typename Layout::Index;
@ -78,7 +79,12 @@ public:
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
static int const kConvDim = 2;
using ConvProblemSize = typename conv::Conv2dProblemSize;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
"Vectors implied by the thread map must be divisible by the access type.");
static_assert(sizeof_bits<Element>::value >= 8,
"WGRAD requires elements of size 8b or greater.");
@ -88,14 +94,13 @@ public:
using Params = Conv2dWgradActivationIteratorOptimizedParams;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
private:
Conv2dWgradActivationIteratorOptimizedParams const &params_;
Conv2dProblemSize const &problem_size_;
LongIndex iteration_contiguous_;
LongIndex iteration_strided_;
LongIndex iteration_vector_;
char const *pointer_;
// Precomputed effective filter postion (r,s) in contiguous dimension stays constant for each gemm_iteration_k
@ -153,9 +158,8 @@ public:
s = (problem_size_.S - 1 - s);
}
precomputed_filter_r_[c] = - problem_size_.pad_h + r * problem_size_.dilation_h;
precomputed_filter_s_[c] = - problem_size_.pad_w + s * problem_size_.dilation_w;
precomputed_filter_r_[c] = -problem_size_.pad_h + r * problem_size_.dilation_h;
precomputed_filter_s_[c] = -problem_size_.pad_w + s * problem_size_.dilation_w;
}
// initialize n, p, q offset for every strided iteration
@ -170,8 +174,10 @@ public:
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(Index index) {
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
iteration_vector_ = index % kAccessesPerVector;
int residual_access = index / kAccessesPerVector;
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
}
/// Adds a pointer offset in units of Element
@ -194,6 +200,31 @@ public:
/// by the iterator.
CUTLASS_HOST_DEVICE
TensorCoord at() const {
int r = precomputed_filter_r_[iteration_contiguous_];
int s = precomputed_filter_s_[iteration_contiguous_];
int c = filter_c_[iteration_contiguous_];
if (kAccessesPerVector > 1) {
int wrap_c;
params_.c_divmod(wrap_c, c, c + iteration_vector_ * AccessType::kElements);
if (problem_size_.mode == Mode::kConvolution) {
s -= (problem_size_.dilation_w * wrap_c);
int wrap_s = (s == -problem_size_.pad_w - problem_size_.dilation_w);
s = wrap_s ? (-problem_size_.pad_w + (problem_size_.S - 1) * problem_size_.dilation_w): s;
r -= (problem_size_.dilation_h * wrap_s);
} else {
s += (problem_size_.dilation_w * wrap_c);
int wrap_s = (s == (-problem_size_.pad_w + problem_size_.S * problem_size_.dilation_w));
s = wrap_s ? -problem_size_.pad_w : s;
r += (problem_size_.dilation_h * wrap_s);
}
}
// The subseqnet fast_divmod() operations are equivalent to the following logical computation:
//
@ -209,10 +240,10 @@ public:
params_.pq_divmod(n, residual, offset_npq_[iteration_strided_]);
params_.q_divmod(p, q, residual);
int h = p * problem_size_.stride_h + precomputed_filter_r_[iteration_contiguous_];
int w = q * problem_size_.stride_w + precomputed_filter_s_[iteration_contiguous_];
int h = p * problem_size_.stride_h + r;
int w = q * problem_size_.stride_w + s;
return TensorCoord(n, h, w, filter_c_[iteration_contiguous_]);
return TensorCoord(n, h, w, c);
}
/// Returns true if the current coordinate is within the activation tensor x
@ -222,8 +253,7 @@ public:
return coord.n() < problem_size_.N &&
coord.h() >= 0 && coord.h() < problem_size_.H &&
coord.w() >= 0 && coord.w() < problem_size_.W &&
coord.c() < problem_size_.C;
coord.w() >= 0 && coord.w() < problem_size_.W;
}
/// Returns a pointer to the vector starting at the current coordinate
@ -239,6 +269,12 @@ public:
/// Increments to the next memory access
CUTLASS_HOST_DEVICE
Conv2dWgradActivationTileAccessIteratorOptimized &operator++() {
++iteration_vector_;
if (iteration_vector_ < kAccessesPerVector) {
return *this;
}
iteration_vector_ = 0;
++iteration_contiguous_;
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
return *this;
@ -258,14 +294,14 @@ public:
static Status can_implement(Conv2dProblemSize const &problem_size) {
// check alignment constraint on iterator's contiguous dimension
if (problem_size.K % (128/sizeof_bits<Element>::value)) {
if (problem_size.K % AccessType::kElements) {
return Status::kErrorInvalidProblem;
}
return Status::kSuccess;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock

View File

@ -58,7 +58,8 @@ namespace threadblock {
template <
typename Shape_,
typename Element_,
typename ThreadMap_
typename ThreadMap_,
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
>
class Conv2dWgradOutputGradientTileAccessIteratorAnalytic {
public:
@ -70,7 +71,7 @@ public:
using Element = Element_;
using Layout = layout::TensorNHWC;
using ThreadMap = ThreadMap_;
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
using AccessType = AccessType_;
using TensorRef = cutlass::TensorRef<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Index = typename Layout::Index;
@ -80,6 +81,11 @@ public:
static int const kConvDim = 2;
using ConvProblemSize = typename conv::Conv2dProblemSize;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
"Vectors implied by the thread map must be divisible by the access type.");
static_assert(sizeof_bits<Element>::value >= 8,
"WGRAD requires elements of size 8b or greater.");
@ -89,14 +95,13 @@ public:
using Params = Conv2dAnalyticParams<Layout>;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
private:
Params const &params_;
Conv2dProblemSize const &problem_size_;
LongIndex iteration_contiguous_;
LongIndex iteration_strided_;
LongIndex iteration_vector_;
char const *pointer_;
int filter_k_[ThreadMap::Iterations::kContiguous];
@ -143,8 +148,10 @@ public:
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(Index index) {
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
iteration_vector_ = index % kAccessesPerVector;
int residual_access = index / kAccessesPerVector;
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
}
/// Adds a pointer offset in units of Element
@ -187,7 +194,7 @@ public:
return coord.n() < problem_size_.N &&
coord.h() < problem_size_.P &&
coord.w() < problem_size_.Q &&
coord.c() < problem_size_.K;
(coord.c() + iteration_vector_ * AccessType::kElements) < problem_size_.K;
}
/// Returns a pointer to the vector starting at the current coordinate
@ -197,12 +204,18 @@ public:
TensorCoord coord = at();
LongIndex offset = params_.layout(coord);
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8) + iteration_vector_;
}
/// Increments to the next memory access
CUTLASS_HOST_DEVICE
Conv2dWgradOutputGradientTileAccessIteratorAnalytic &operator++() {
++iteration_vector_;
if (iteration_vector_ < kAccessesPerVector) {
return *this;
}
iteration_vector_ = 0;
++iteration_contiguous_;
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
return *this;
@ -222,14 +235,14 @@ public:
static Status can_implement(Conv2dProblemSize const &problem_size) {
// check alignment constraint on iterator's contiguous dimension
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
if (problem_size.C % AccessType::kElements) {
return Status::kErrorInvalidProblem;
}
return Status::kSuccess;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
@ -237,5 +250,3 @@ public:
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -57,7 +57,8 @@ namespace threadblock {
template <
typename Shape_,
typename Element_,
typename ThreadMap_
typename ThreadMap_,
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
>
class Conv2dWgradOutputGradientTileAccessIteratorOptimized {
public:
@ -69,7 +70,7 @@ public:
using Element = Element_;
using Layout = layout::TensorNHWC;
using ThreadMap = ThreadMap_;
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
using AccessType = AccessType_;
using TensorRef = cutlass::TensorRef<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Index = typename Layout::Index;
@ -79,6 +80,11 @@ public:
static int const kConvDim = 2;
using ConvProblemSize = typename conv::Conv2dProblemSize;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
"Vectors implied by the thread map must be divisible by the access type.");
static_assert(sizeof_bits<Element>::value >= 8,
"WGRAD requires elements of size 8b or greater.");
@ -88,17 +94,16 @@ public:
using Params = Conv2dWgradOutputGradientIteratorOptimizedParams;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
private:
Conv2dWgradOutputGradientIteratorOptimizedParams const &params_;
Conv2dProblemSize const &problem_size_;
LongIndex iteration_contiguous_;
LongIndex iteration_strided_;
LongIndex iteration_vector_;
char const *pointer_;
uint32_t predicates_;
uint32_t predicates_[kAccessesPerVector];
int filter_k_;
int offset_npq_;
@ -115,7 +120,7 @@ public:
params_(params),
problem_size_(problem_size),
pointer_(reinterpret_cast<char const *>(ptr)),
predicates_(0),
predicates_{0},
filter_k_(0),
offset_npq_(0) {
@ -132,13 +137,16 @@ public:
int filter_k = filter_k_ + c * ThreadMap::Delta::kContiguous;
int offset_npq = offset_npq_ + s * ThreadMap::Delta::kStrided;
bool predicate = valid_(at_(offset_npq, filter_k));
uint32_t pred = (predicate ? 1u : 0);
int pred_idx = c + s * ThreadMap::Iterations::kContiguous;
predicates_ |= (pred << pred_idx);
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < kAccessesPerVector; ++v) {
bool predicate = valid_(at_(offset_npq, filter_k + v * AccessType::kElements));
uint32_t pred = (predicate ? 1u : 0);
int pred_idx = c + s * ThreadMap::Iterations::kContiguous;
predicates_[v] |= (pred << pred_idx);
}
}
}
@ -165,8 +173,10 @@ public:
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(Index index) {
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
iteration_vector_ = index % kAccessesPerVector;
int residual_access = index / kAccessesPerVector;
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
}
/// Adds a pointer offset in units of Element
@ -185,7 +195,11 @@ public:
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
if (offset_npq_ + s * ThreadMap::Delta::kStrided >= params_.NPQ) {
uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous);
predicates_ = (predicates_ & (~kClearMask));
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < kAccessesPerVector; ++v) {
predicates_[v] = (predicates_[v] & (~kClearMask));
}
}
}
@ -231,7 +245,7 @@ public:
bool valid() const {
LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous;
return (predicates_ & (1u << pred_idx));
return (predicates_[iteration_vector_] & (1u << pred_idx));
}
/// Returns a pointer to the vector starting at the current coordinate
@ -242,12 +256,18 @@ public:
pointer_ +
iteration_strided_ * params_.offset_next_strided +
iteration_contiguous_ * params_.offset_next_contiguous
);
) + iteration_vector_;
}
/// Increments to the next memory access
CUTLASS_HOST_DEVICE
Conv2dWgradOutputGradientTileAccessIteratorOptimized &operator++() {
++iteration_vector_;
if (iteration_vector_ < kAccessesPerVector) {
return *this;
}
iteration_vector_ = 0;
++iteration_contiguous_;
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
return *this;
@ -267,14 +287,14 @@ public:
static Status can_implement(Conv2dProblemSize const &problem_size) {
// check alignment constraint on iterator's contiguous dimension
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
if (problem_size.C % AccessType::kElements) {
return Status::kErrorInvalidProblem;
}
return Status::kSuccess;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
@ -282,5 +302,3 @@ public:
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -79,11 +79,10 @@ public:
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
static int const kConvDim = 3;
using ConvProblemSize = typename conv::Conv3dProblemSize;
static int const kAccessesPerVector = 1;
static_assert(sizeof_bits<Element>::value >= 8,
"DGRAD requires elements of size 8b or larger.");
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
//
// Parameters structure
@ -261,5 +260,3 @@ public:
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -82,8 +82,7 @@ public:
static StrideSupport const kStrideSupport = StrideSupport_;
static int const kConvDim = 3;
using ConvProblemSize = typename conv::Conv3dProblemSize;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
static int const kAccessesPerVector = 1;
//
// Parameters structure
@ -217,7 +216,8 @@ public:
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
if (filter_k_ + s * ThreadMap::Delta::kStrided >= problem_size_.K) {
uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous);
uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous);
predicates_ = (predicates_ & (~kClearMask));
}
}
@ -281,5 +281,3 @@ public:
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -93,11 +93,10 @@ public:
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
static int const kConvDim = 3;
using ConvProblemSize = typename conv::Conv3dProblemSize;
static int const kAccessesPerVector = 1;
static_assert(sizeof_bits<Element>::value >= 8,
"DGRAD requires elements of size 8b or greater.");
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
//
// Simpligying assertions
@ -328,11 +327,11 @@ public:
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace conv
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -86,7 +86,7 @@ public:
static int const kConvDim = 3;
using ConvProblemSize = typename conv::Conv3dProblemSize;
using Coord3D = Coord<3>;
static int const kAccessesPerVector = 1;
using Mask = uint64_t;
//
@ -101,8 +101,6 @@ public:
using Params = Conv3dDgradOutputGradientIteratorOptimizedParams;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
private:
Params const &params_;
@ -403,7 +401,6 @@ public:
}
clear_mask_(filter_k_ >= problem_size_.K);
}

View File

@ -81,6 +81,7 @@ public:
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
static int const kConvDim = 3;
using ConvProblemSize = typename conv::Conv3dProblemSize;
static int const kAccessesPerVector = 1;
//
// Simplifying assertions
@ -94,8 +95,6 @@ public:
using Params = Conv3dAnalyticParams<Layout>;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
private:
Params const &params_;

View File

@ -82,7 +82,7 @@ public:
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
static int const kConvDim = 3;
using ConvProblemSize = typename conv::Conv3dProblemSize;
static int const kAccessesPerVector = 1;
using Mask = uint64_t;
//
@ -97,8 +97,6 @@ public:
using Params = Conv3dFpropActivationIteratorOptimizedParams<Layout>;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
private:
Conv3dFpropActivationIteratorOptimizedParams<Layout> const &params_;

View File

@ -80,6 +80,7 @@ public:
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
static int const kConvDim = 3;
using ConvProblemSize = typename conv::Conv3dProblemSize;
static int const kAccessesPerVector = 1;
//
// Simplifying assertions
@ -93,8 +94,6 @@ public:
using Params = Conv3dAnalyticParams<Layout>;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
private:
Params const &params_;

View File

@ -82,6 +82,7 @@ public:
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
static int const kConvDim = 3;
using ConvProblemSize = typename conv::Conv3dProblemSize;
static int const kAccessesPerVector = 1;
//
// Simplifying assertions
@ -89,8 +90,6 @@ public:
static_assert(ThreadMap::Iterations::kContiguous == 1,
"Require Iterations::kContiguous == 1");
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
//
// Parameters structure
//
@ -156,7 +155,7 @@ public:
params_(params),
problem_size_(problem_size),
pointer_(reinterpret_cast<char const *>(ptr)),
predicates_(0),
predicates_{0},
filter_trs_(0),
filter_c_(0) {

View File

@ -79,11 +79,11 @@ public:
static int const kConvDim = 3;
using ConvProblemSize = typename conv::Conv3dProblemSize;
static int const kAccessesPerVector = 1;
static_assert(sizeof_bits<Element>::value >= 8,
"WGRAD requires elements of size 8b or greater.");
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
//
// Parameters structure
//

View File

@ -79,12 +79,10 @@ public:
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
static int const kConvDim = 3;
using ConvProblemSize = typename conv::Conv3dProblemSize;
static int const kAccessesPerVector = 1;
static_assert(sizeof_bits<Element>::value >= 8,
"WGRAD requires elements of size 8b or greater.");
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
//
// Parameters structure
//

View File

@ -78,12 +78,10 @@ public:
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
static int const kConvDim = 3;
using ConvProblemSize = typename conv::Conv3dProblemSize;
static int const kAccessesPerVector = 1;
static_assert(sizeof_bits<Element>::value >= 8,
"WGRAD requires elements of size 8b or greater.");
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
//
// Parameters structure
//

View File

@ -79,12 +79,10 @@ public:
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
static int const kConvDim = 3;
using ConvProblemSize = typename conv::Conv3dProblemSize;
static int const kAccessesPerVector = 1;
static_assert(sizeof_bits<Element>::value >= 8,
"WGRAD requires elements of size 8b or greater.");
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
//
// Parameters structure
//

View File

@ -216,6 +216,7 @@ public:
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, iterator_A.get(), iterator_A.valid());
++iterator_A;
}
@ -244,6 +245,7 @@ public:
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, iterator_B.get(), iterator_B.valid());
++iterator_B;
}
++this->smem_iterator_B_;
@ -289,16 +291,17 @@ public:
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
int const kSrcBytes =
sizeof_bits<typename IteratorA::Element>::value *
IteratorA::ThreadMap::kElementsPerAccess /
IteratorA::kAccessesPerVector / 8;
int const kSrcBytes =
sizeof_bits<typename IteratorA::Element>::value *
IteratorA::ThreadMap::kElementsPerAccess /
IteratorA::kAccessesPerVector / 8;
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, iterator_A.get(), iterator_A.valid());
++iterator_A;
}
++this->smem_iterator_A_;
}
@ -313,17 +316,18 @@ public:
this->smem_iterator_B_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
int const kSrcBytes =
sizeof_bits<typename IteratorB::Element>::value *
IteratorB::ThreadMap::kElementsPerAccess /
IteratorB::kAccessesPerVector / 8;
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, iterator_B.get(), iterator_B.valid());
++iterator_B;
}
++this->smem_iterator_B_;
}