Example 43 - DualGemm (#670)

* Ex50 wip

* IS_PROFILING mode

* MultiStage2 - but is slower

* Add SwiGLU

* Support SplitKSerial reduction
Support not storing D0/D1
Cleanup code

* Option to disable bias

* Renumber example

* Fix build

* Remove references to pb_size_0 / pb_size_1

* Add support for bf16 inputs with float accum

* small changes

Co-authored-by: danthe3rd <danthe3rd>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
dan_the_3rd
2022-10-26 20:04:42 +02:00
committed by GitHub
parent 8c1bf9b784
commit 1b4e24470a
12 changed files with 3728 additions and 0 deletions

View File

@ -0,0 +1,36 @@
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cutlass_example_add_executable(
43_dual_gemm
dual_gemm.cu
)

View File

@ -0,0 +1,457 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Performs a dual gemm in one fused kernel:
```
D0 = epilogue0(X @ B0, C0)
D1 = epilogue1(X @ B1, C1)
D2 = element_wise(D0, D1)
```
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/device_kernel.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass/epilogue/thread/linear_combination_relu.h"
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
#include "../kernel/dual_gemm.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace device {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator_,
/// Operator class tag
typename OperatorClass_,
/// Tag indicating architecture to tune for
typename ArchTag_,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_,
/// Epilogue output operator
typename EpilogueOutputOp0_,
typename EpilogueOutputOp1_,
typename EpilogueOutputOp2_,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>,
/// Number of stages used in the pipelined mainloop
int Stages =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kStages,
bool StoreD0 = true,
bool StoreD1 = true,
/// If true, kernel supports split-K with serial reduction
bool SplitKSerial = false,
/// Access granularity of A matrix in units of elements
int AlignmentA =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentA,
/// Access granularity of B matrix in units of elements
int AlignmentB =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentB,
/// Operation performed by GEMM
typename Operator_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::Operator>
class DualGemm {
public:
using ElementA = ElementA_;
using LayoutA = LayoutA_;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
using ElementB = ElementB_;
using LayoutB = LayoutB_;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
using ElementC = ElementC_;
using LayoutC = LayoutC_;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
using ElementAccumulator = ElementAccumulator_;
using OperatorClass = OperatorClass_;
using ArchTag = ArchTag_;
using ThreadblockShape = ThreadblockShape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using EpilogueOutputOp0 = EpilogueOutputOp0_;
using EpilogueOutputOp1 = EpilogueOutputOp1_;
using EpilogueOutputOp2 = EpilogueOutputOp2_;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using Operator = Operator_;
static int const kStages = Stages;
static int const kAlignmentA = AlignmentA;
static int const kAlignmentB = AlignmentB;
static int const kAlignmentC = EpilogueOutputOp1::kCount;
static bool const kSplitKSerial = SplitKSerial;
static bool constexpr kStoreD0 = StoreD0;
static bool constexpr kStoreD1 = StoreD1;
static ComplexTransform const kTransformA = ComplexTransform::kNone;
static ComplexTransform const kTransformB = ComplexTransform::kNone;
using LayoutScaleBias = layout::RowMajor;
/// Define the kernel
/// Define the threadblock-scoped matrix multiply-accumulate
static_assert(ArchTag::kMinComputeCapability >= 80, "Only multistage is implemented");
static_assert(kStages >= 3, "Only multistage is implemented");
using Mma = typename cutlass::gemm::threadblock::DefaultMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag,
ThreadblockShape, WarpShape,
InstructionShape, Stages, Operator>::ThreadblockMma;
using DualMma = threadblock::DualMmaMultistage<
typename Mma::Shape,
typename Mma::IteratorA,
typename Mma::SmemIteratorA,
Mma::kCacheOpA,
typename Mma::IteratorB,
typename Mma::SmemIteratorB,
Mma::kCacheOpB,
typename Mma::ElementC,
typename Mma::LayoutC,
typename Mma::Policy,
Mma::kStages,
SharedMemoryClearOption::kNone
>;
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
/// Define the epilogue
using Epilogue0 =
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
ThreadblockShape, typename DualMma::Operator, kPartitionsK, EpilogueOutputOp0,
EpilogueOutputOp0::kCount>::Epilogue;
using Epilogue1 =
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
ThreadblockShape, typename DualMma::Operator, kPartitionsK, EpilogueOutputOp1,
EpilogueOutputOp1::kCount>::Epilogue;
/// Define the kernel-level GEMM operator.
using DualGemmKernel = kernel::DualGemm<
DualMma,
Epilogue0, Epilogue1, EpilogueOutputOp2,
ThreadblockSwizzle, kSplitKSerial,
kStoreD0, kStoreD1>;
/// Argument structure
struct Arguments {
//
// Data members
//
GemmCoord problem_size;
TensorRef<ElementA const, LayoutA> ref_A0;
TensorRef<ElementB const, LayoutB> ref_B0;
TensorRef<ElementC const, LayoutC> ref_C0;
TensorRef<ElementC, LayoutC> ref_D0;
TensorRef<ElementB const, LayoutB> ref_B1;
TensorRef<ElementC const, LayoutC> ref_C1;
TensorRef<ElementC, LayoutC> ref_D1;
TensorRef<ElementC, LayoutC> ref_D2;
typename EpilogueOutputOp0::Params epilogue0;
typename EpilogueOutputOp1::Params epilogue1;
typename EpilogueOutputOp2::Params epilogue2;
int split_k_slices;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments(): problem_size(0, 0, 0), split_k_slices(1) {
}
/// Constructs an Arguments structure
CUTLASS_HOST_DEVICE
Arguments(
GemmCoord problem_size_,
TensorRef<ElementA const, LayoutA> ref_A0_,
TensorRef<ElementB const, LayoutB> ref_B0_,
TensorRef<ElementC const, LayoutC> ref_C0_,
TensorRef<ElementC, LayoutC> ref_D0_,
TensorRef<ElementB const, LayoutB> ref_B1_,
TensorRef<ElementC const, LayoutC> ref_C1_,
TensorRef<ElementC, LayoutC> ref_D1_,
TensorRef<ElementC, LayoutC> ref_D2_,
typename EpilogueOutputOp0::Params epilogue0_ =
typename EpilogueOutputOp0::Params(),
typename EpilogueOutputOp1::Params epilogue1_ =
typename EpilogueOutputOp1::Params(),
typename EpilogueOutputOp2::Params epilogue2_ =
typename EpilogueOutputOp2::Params(),
int split_k_slices_ = 1
):
problem_size(problem_size_),
ref_A0(ref_A0_),
ref_B0(ref_B0_),
ref_C0(ref_C0_),
ref_D0(ref_D0_),
ref_B1(ref_B1_),
ref_C1(ref_C1_),
ref_D1(ref_D1_),
ref_D2(ref_D2_),
epilogue0(epilogue0_),
epilogue1(epilogue1_),
epilogue2(epilogue2_),
split_k_slices(split_k_slices_) {
}
};
private:
/// Kernel parameters object
typename DualGemmKernel::Params params_;
public:
/// Constructs the GEMM.
DualGemm() = default;
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args) {
if (!kSplitKSerial && args.split_k_slices > 1) {
return Status::kErrorInvalidProblem;
}
if (kStoreD0 != (args.ref_D0.data() != nullptr)) {
return Status::kErrorInternal;
}
if (kStoreD1 != (args.ref_D1.data() != nullptr)) {
return Status::kErrorInternal;
}
Status status = DualGemmKernel::can_implement(
args.problem_size,
args.ref_A0.non_const_ref(),
args.ref_B0.non_const_ref(),
args.ref_C0.non_const_ref(),
args.ref_D0,
args.ref_B1.non_const_ref(),
args.ref_C1.non_const_ref(),
args.ref_D1,
args.ref_D2
);
if (status != Status::kSuccess) {
return status;
}
return Status::kSuccess;
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const &args) {
size_t bytes = 0;
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size,
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
args.split_k_slices);
if (kSplitKSerial && args.split_k_slices > 1) {
bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
}
return bytes;
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size,
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
args.split_k_slices);
if (kSplitKSerial) {
if (args.split_k_slices > 1) {
if (!workspace) {
return Status::kErrorWorkspaceNull;
}
size_t bytes = get_workspace_size(args);
cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
}
else {
if (args.split_k_slices > 1) {
return Status::kErrorInvalidProblem;
}
}
// Initialize the Params structure
params_ = typename DualGemmKernel::Params{
args.problem_size,
grid_shape,
args.ref_A0.non_const_ref(),
args.ref_B0.non_const_ref(),
args.ref_C0.non_const_ref(),
args.ref_D0,
args.ref_B1.non_const_ref(),
args.ref_C1.non_const_ref(),
args.ref_D1,
args.ref_D2,
args.epilogue0,
args.epilogue1,
args.epilogue2,
reinterpret_cast<int *>(workspace),
};
return Status::kSuccess;
}
/// Lightweight update given a subset of arguments
Status update(Arguments const &args, void *workspace = nullptr) {
if (kSplitKSerial && args.split_k_slices > 1) {
if (!workspace) {
return Status::kErrorWorkspaceNull;
}
}
params_.ref_A0.reset(args.ref_A0.non_const_ref().data());
params_.ref_B0.reset(args.ref_B0.non_const_ref().data());
params_.ref_C0.reset(args.ref_C0.non_const_ref().data());
params_.ref_D0.reset(args.ref_D0.data());
params_.ref_B1.reset(args.ref_B1.non_const_ref().data());
params_.ref_C1.reset(args.ref_C1.non_const_ref().data());
params_.ref_D1.reset(args.ref_D1.data());
params_.ref_D2.reset(args.ref_D2.data());
params_.output_op_0 = args.epilogue0;
params_.output_op_1 = args.epilogue1;
params_.output_op_2 = args.epilogue2;
params_.semaphore = reinterpret_cast<int *>(workspace);
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr) {
ThreadblockSwizzle threadblock_swizzle;
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
dim3 block(DualGemmKernel::kThreadCount, 1, 1);
cudaError_t result;
int smem_size = int(sizeof(typename DualGemmKernel::SharedStorage));
if (smem_size >= (48 << 10)) {
result = cudaFuncSetAttribute(Kernel<DualGemmKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
cutlass::Kernel<DualGemmKernel><<<grid, block, smem_size, stream>>>(params_);
result = cudaGetLastError();
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr) {
return run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess) {
status = run(stream);
}
return status;
}
};
} // namespace device
} // namespace gemm
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,262 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief CUTLASS Dual-GEMM Example.
Fused kernel that outputs `D0` and `D1`.
We assume that B0/B1 have the same shape/layout
```
D0 = epilogue0(X @ B0, C0)
D1 = epilogue1(X @ B1, C1)
D2 = element_wise(D0, D1)
```
D0 and D1 will be optionally stored in gmem (`kStoreD0` / `kStoreD1`)
*/
// #define IS_PROFILING
#include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/gemm.h"
#include "device/dual_gemm.h"
#include "thread/left_silu_and_mul.h"
#include "dual_gemm_run.h"
#include "test_run.h"
////////////////////////////////////////////////////////////////////////////////
cutlass::gemm::GemmCoord problem_size(4096, 4096, 8192);
constexpr int kStages = 3;
constexpr bool kSplitKSerial = false;
constexpr bool kUseBias = true;
#if 0
using ElementOperandA = cutlass::bfloat16_t;
using ElementOperandB = cutlass::bfloat16_t;
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using ElementCompute = float;
#else
using ElementOperandA = cutlass::half_t;
using ElementOperandB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using ElementCompute = cutlass::half_t;
#endif
constexpr auto kScaleType = kUseBias ? cutlass::epilogue::thread::ScaleType::NoBetaScaling : (
// No bias
kSplitKSerial ? cutlass::epilogue::thread::ScaleType::Default : cutlass::epilogue::thread::ScaleType::Nothing
);
using EpilogueOutputOp0 = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementCompute,
kScaleType
>;
using EpilogueOutputOp1 = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementCompute,
kScaleType
>;
using EpilogueOutputOp2 = cutlass::epilogue::thread::LeftSiLUAndMul<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementOutput,
ElementCompute
>;
const ElementCompute alpha0 = ElementCompute(1);
const ElementCompute beta0 = ElementCompute(kUseBias ? 1 : 0);
const ElementCompute alpha1 = ElementCompute(1);
const ElementCompute beta1 = ElementCompute(kUseBias ? 1 : 0);
bool run_nonfused_gemm_f16_sm80() {
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using Gemm0 = cutlass::gemm::device::Gemm<
ElementOperandA,
cutlass::layout::RowMajor,
ElementOperandB,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp0,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
kStages,
8,
8,
kSplitKSerial
>;
using Gemm1 = cutlass::gemm::device::Gemm<
ElementOperandA,
cutlass::layout::RowMajor,
ElementOperandB,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp1,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
kStages,
8,
8,
kSplitKSerial
>;
NonFusedDualGemmRun<Gemm0, Gemm1> nonFusedGemm;
std::cout << "Running Non-fused GEMMs FP16 TN GEMMs...\n";
bool pass = nonFusedGemm.run(problem_size, alpha0, beta0, alpha1, beta1);
if(pass)
std::cout << "Pass\n";
else
std::cout << "Fail\n";
return pass;
}
template <typename T>
struct LeftSiLUAndMul {
struct Params{};
CUTLASS_HOST_DEVICE LeftSiLUAndMul(Params p) {}
CUTLASS_HOST_DEVICE void set_k_partition(int, int) {}
CUTLASS_HOST_DEVICE T operator() (
T const &lhs,
T const &rhs) const {
cutlass::epilogue::thread::SiLu<T> silu;
cutlass::multiplies<T> mul;
auto silu_lhs = silu(lhs);
return mul(silu_lhs, rhs);
}
template <int kCount>
CUTLASS_HOST_DEVICE cutlass::Array<T, kCount> operator() (
cutlass::Array<T, kCount> const &lhs,
cutlass::Array<T, kCount> const &rhs) const {
cutlass::epilogue::thread::SiLu<T> silu;
cutlass::multiplies<T> mul;
auto silu_lhs = silu(lhs);
return mul(silu_lhs, rhs);
}
};
bool run_fused_gemm_f16_sm80_shmem() {
using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>;
using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
// Optionally, we might not need intermediate GEMM outputs
constexpr bool kStoreD0 = true;
constexpr bool kStoreD1 = true;
using DualGemm = cutlass::gemm::device::DualGemm<
ElementOperandA,
cutlass::layout::RowMajor,
ElementOperandB,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp0,
EpilogueOutputOp1,
EpilogueOutputOp2,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
kStages,
kStoreD0,
kStoreD1,
kSplitKSerial
>;
DualFusedGemmRun<DualGemm> fusedGemm;
std::cout << "Running Fused FP16 TN GEMMs + Epilogue2...\n";
bool passed = fusedGemm.run(problem_size, alpha0, beta0, alpha1, beta1);
if(passed)
std::cout << "Pass\n";
else
std::cout << "Fail\n";
return passed;
}
int main() {
std::vector<bool (*)()>funcs = {
&run_nonfused_gemm_f16_sm80,
&run_fused_gemm_f16_sm80_shmem
};
std::string test_name = "dual-gemm f16 bias=" + std::to_string(kUseBias) + " split_k_serial=" + std::to_string(kSplitKSerial);
return testRun(80, funcs, test_name);
}
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,829 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <iostream>
#include <fstream>
#include <sstream>
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_relu.h"
#include "helper.h"
#define CHECK_GT(val1, val2) \
if((val1) <= (val2)) \
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n";
#define CHECK_TRUE(val) \
if(!(val)) \
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n";
template <
typename OutputOp,
typename Element,
typename Layout>
struct TensorEpilogueForEachFunc {
/// View type
using TensorView = cutlass::TensorView<Element, Layout>;
/// Coordinate in tensor's index space
using TensorCoord = typename TensorView::TensorCoord;
/// Parameters structure
struct Params {
//
// Data members
//
TensorView view_x0;
TensorView view_x1;
TensorView view_y;
OutputOp output_op;
//
// Methods
//
Params(
TensorView view_x0_ = TensorView(),
TensorView view_x1_ = TensorView(),
TensorView view_y_ = TensorView(),
OutputOp output_op_ = OutputOp(typename OutputOp::Params{})
):
view_x0(view_x0_), view_x1(view_x1_), view_y(view_y_), output_op(output_op_) {
}
};
Params params;
CUTLASS_DEVICE
TensorEpilogueForEachFunc(Params const &params): params(params) {
}
CUTLASS_DEVICE
void operator()(TensorCoord const &coord) {
Element const & x0 = params.view_x0.at(coord);
Element const & x1 = params.view_x1.at(coord);
Element& y = params.view_y.at(coord);
y = params.output_op(x0, x1);
}
};
template <
typename OutputOp,
typename Element,
typename Layout>
void TensorEpilogueForEach(
cutlass::TensorView<Element, Layout> x0,
cutlass::TensorView<Element, Layout> x1,
cutlass::TensorView<Element, Layout> y) {
using Func = TensorEpilogueForEachFunc<OutputOp, Element, Layout>;
using Params = typename Func::Params;
cutlass::reference::device::TensorForEach<Func, Layout::kRank, Params>(
y.extent(),
Params(x0, x1, y)
);
}
////////////////////////////////////////////////////////////////////////////////
template <typename Gemm0_, typename Gemm1_>
struct NonFusedDualGemmRun
{
using Gemm0 = Gemm0_;
using Gemm1 = Gemm1_;
using ElementAccumulator = typename Gemm0::ElementAccumulator;
using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute;
/// Initialization
cutlass::Distribution::Kind init_A;
cutlass::Distribution::Kind init_B;
cutlass::Distribution::Kind init_C;
cutlass::Distribution::Kind init_Bias;
uint64_t seed;
//
// Methods
//
NonFusedDualGemmRun(
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080
):
init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { }
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
cutlass::reference::host::TensorFillRandomUniform(
view, seed, 2, -2, 0);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(
view.data(), view.capacity());
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view, Element(0));
}
else if (dist_kind == cutlass::Distribution::AllOnes) {
cutlass::reference::host::TensorFill(view, Element(1));
}
else {
std::cerr << "Not implemented\n";
return false;
}
return true;
}
/// Executes one test
bool run(
cutlass::gemm::GemmCoord problem_size,
ElementCompute alpha0 = ElementCompute(1),
ElementCompute beta0 = ElementCompute(0),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute beta1 = ElementCompute(0),
bool relu = false,
int warm_ups = 1,
int runs = 100) {
//
// Allocate the GEMM workspace
//
cutlass::HostTensor<
typename Gemm0::ElementA,
typename Gemm0::LayoutA> tensor_A0(problem_size.mk());
cutlass::HostTensor<
typename Gemm0::ElementB,
typename Gemm0::LayoutB> tensor_B0(problem_size.kn());
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_C0(problem_size.mn());
cutlass::HostTensor<
typename Gemm1::ElementC,
typename Gemm0::LayoutC> tensor_Bias0({1, problem_size.n()});
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_D0(problem_size.mn());
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::LayoutC> reference_D0(problem_size.mn());
cutlass::HostTensor<
typename Gemm1::ElementB,
typename Gemm1::LayoutB> tensor_B1(problem_size.kn());
cutlass::HostTensor<
typename Gemm1::ElementC,
typename Gemm1::LayoutC> tensor_C1(problem_size.mn());
cutlass::HostTensor<
typename Gemm1::ElementC,
typename Gemm1::LayoutC> tensor_Bias1({1, problem_size.n()});
cutlass::HostTensor<
typename Gemm1::ElementC,
typename Gemm1::LayoutC> tensor_D1(problem_size.mn());
cutlass::HostTensor<
typename Gemm1::ElementC,
typename Gemm1::LayoutC> reference_D1(problem_size.mn());
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014));
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013));
cutlass::reference::host::TensorFill(
tensor_D0.host_view());
cutlass::reference::host::TensorFill(
tensor_D1.host_view());
cutlass::reference::host::TensorFill(
reference_D0.host_view());
cutlass::reference::host::TensorFill(
reference_D1.host_view());
tensor_A0.sync_device();
tensor_B0.sync_device();
tensor_C0.sync_device();
tensor_Bias0.sync_device();
tensor_D0.sync_device();
reference_D0.sync_device();
tensor_B1.sync_device();
tensor_C1.sync_device();
tensor_Bias1.sync_device();
tensor_D1.sync_device();
reference_D1.sync_device();
//
// Initialize the GEMM operator
//
int split_k_slices = Gemm0::kSplitKSerial ? 2 : 1;
typename Gemm0::Arguments arguments_0{
problem_size,
tensor_A0.device_ref(),
tensor_B0.device_ref(),
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
tensor_D0.device_ref(),
{alpha0, beta0},
split_k_slices
};
split_k_slices = Gemm1::kSplitKSerial ? 2 : 1;
typename Gemm1::Arguments arguments_1{
problem_size,
tensor_A0.device_ref(),
tensor_B1.device_ref(),
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
tensor_D1.device_ref(),
{alpha1, beta1},
split_k_slices
};
Gemm0 gemm_op_0;
Gemm1 gemm_op_1;
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace0(gemm_op_0.get_workspace_size(arguments_0));
cutlass::device_memory::allocation<uint8_t> workspace1(gemm_op_1.get_workspace_size(arguments_1));
cutlass::Status status = gemm_op_0.initialize(arguments_0, workspace0.get());
CUTLASS_CHECK(status);
status = gemm_op_1.initialize(arguments_1, workspace1.get());
CUTLASS_CHECK(status);
for(int i = 0; i < warm_ups; i++) {
status = gemm_op_0();
CUTLASS_CHECK(status);
status = gemm_op_1();
CUTLASS_CHECK(status);
}
#ifdef IS_PROFILING
return true;
#endif
//
// Run the GEMM
//
cudaEvent_t start, stop1, stop2;
cudaEventCreate(&start);
cudaEventCreate(&stop1);
cudaEventCreate(&stop2);
cudaEventRecord(start);
for(int i = 0; i < runs; i++) {
status = gemm_op_0();
CUTLASS_CHECK(status);
}
cudaEventRecord(stop1);
for(int i = 0; i < runs; i++) {
status = gemm_op_1();
CUTLASS_CHECK(status);
}
cudaEventRecord(stop2);
cudaDeviceSynchronize();
float gemm0Time, gemm1Time, totalTime;
cudaEventElapsedTime(&gemm0Time, start, stop1);
cudaEventElapsedTime(&gemm1Time, stop1, stop2);
cudaEventElapsedTime(&totalTime, start, stop2);
std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n";
std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n";
std::cout << "Non-fusion GEMM only time " << totalTime / (float)runs << " ms\n";
tensor_D0.sync_host();
tensor_D1.sync_host();
//
// Verify
//
cutlass::reference::device::Gemm<
typename Gemm0::ElementA, typename Gemm0::LayoutA,
typename Gemm0::ElementB, typename Gemm0::LayoutB,
typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute,
ElementAccumulator, typename Gemm0::Operator>
reference_gemm_0;
cutlass::reference::device::Gemm<
typename Gemm1::ElementA, typename Gemm1::LayoutA,
typename Gemm1::ElementB, typename Gemm1::LayoutB,
typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute,
ElementAccumulator, typename Gemm1::Operator>
reference_gemm_1;
reference_gemm_0(
problem_size,
alpha0,
tensor_A0.device_ref(),
tensor_B0.device_ref(),
beta0,
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
reference_D0.device_ref()
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D0.device_view());
}
reference_gemm_1(
problem_size,
alpha1,
tensor_A0.device_ref(),
tensor_B1.device_ref(),
beta1,
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
reference_D1.device_ref()
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D1.device_view());
}
// Wait for kernels to finish
cudaDeviceSynchronize();
reference_D0.sync_host();
reference_D1.sync_host();
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
bool passed0 = cutlass::reference::host::TensorEquals(
reference_D1.host_view(),
tensor_D1.host_view());
CHECK_TRUE(passed0);
bool passed1 = cutlass::reference::host::TensorEquals(
reference_D1.host_view(),
tensor_D1.host_view());
CHECK_TRUE(passed1);
if (!passed0 || !passed1) {
std::stringstream fname;
fname << "error_DualGemm_device_nonfused.txt";
std::cerr << "Dumping results in " << fname.str() << "\n";
std::ofstream file(fname.str());
file
<< "A0 =\n" << tensor_A0.host_view()
<< "\nB0 =\n" << tensor_B0.host_view()
<< "\nC0 =\n" << tensor_C0.host_view()
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
<< "\nD0 =\n" << tensor_D0.host_view()
<< "\nB1 =\n" << tensor_B1.host_view()
<< "\nC1 =\n" << tensor_C1.host_view()
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
<< "\n\nReference =\n" << reference_D1.host_view()
<< "\nComputed =\n" << tensor_D1.host_view();
}
return passed0 && passed1;
}
};
template <typename DualGemm_>
struct DualFusedGemmRun
{
using DualGemm = DualGemm_;
using ElementAccumulator = typename DualGemm::ElementAccumulator;
using ElementCompute = typename DualGemm::DualGemmKernel::Epilogue0::OutputOp::ElementCompute;
using EpilogueOutputOp2 = typename DualGemm::EpilogueOutputOp2;
/// Initialization
cutlass::Distribution::Kind init_A;
cutlass::Distribution::Kind init_B;
cutlass::Distribution::Kind init_C;
cutlass::Distribution::Kind init_Scale;
cutlass::Distribution::Kind init_Bias;
uint64_t seed;
//
// Methods
//
DualFusedGemmRun(
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080
):
init_A(init_A_), init_B(init_B_), init_C(init_C_),
init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { }
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
cutlass::reference::host::TensorFillRandomUniform(
view, seed, 2, -2, 0);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(
view.data(), view.capacity());
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view, Element(0));
}
else if (dist_kind == cutlass::Distribution::AllOnes) {
cutlass::reference::host::TensorFill(view, Element(1));
}
else {
std::cerr << "Not implemented\n";
return false;
}
return true;
}
/// Executes one test
bool run(
cutlass::gemm::GemmCoord problem_size,
ElementCompute alpha0 = ElementCompute(1),
ElementCompute beta0 = ElementCompute(1),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute beta1 = ElementCompute(1),
bool relu = false,
int warm_ups = 1,
int runs = 100) {
//
// Allocate the GEMM workspace
//
cutlass::HostTensor<
typename DualGemm::ElementA,
typename DualGemm::LayoutA> tensor_A0(problem_size.mk());
cutlass::HostTensor<
typename DualGemm::ElementB,
typename DualGemm::LayoutB> tensor_B0(problem_size.kn());
cutlass::HostTensor<
typename DualGemm::ElementC,
typename DualGemm::LayoutC> tensor_C0(problem_size.mn());
cutlass::HostTensor<
typename DualGemm::ElementC,
typename DualGemm::LayoutScaleBias> tensor_Bias0({1, problem_size.n()});
cutlass::HostTensor<
typename DualGemm::ElementC,
typename DualGemm::LayoutC> tensor_D0(problem_size.mn());
cutlass::HostTensor<
typename DualGemm::ElementC,
typename DualGemm::LayoutC> reference_D0(problem_size.mn());
cutlass::HostTensor<
typename DualGemm::ElementB,
typename DualGemm::LayoutB> tensor_B1(problem_size.kn());
cutlass::HostTensor<
typename DualGemm::ElementC,
typename DualGemm::LayoutC> tensor_C1(problem_size.mn());
cutlass::HostTensor<
typename DualGemm::ElementC,
typename DualGemm::LayoutScaleBias> tensor_Bias1({1, problem_size.n()});
cutlass::HostTensor<
typename DualGemm::ElementC,
typename DualGemm::LayoutC> tensor_D1(problem_size.mn());
cutlass::HostTensor<
typename DualGemm::ElementC,
typename DualGemm::LayoutC> tensor_D2(problem_size.mn());
cutlass::HostTensor<
typename DualGemm::ElementC,
typename DualGemm::LayoutC> reference_D1(problem_size.mn());
cutlass::HostTensor<
typename DualGemm::ElementC,
typename DualGemm::LayoutC> reference_D2(problem_size.mn());
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2118));
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2011));
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2113));
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012));
cutlass::reference::host::TensorFill(
tensor_D0.host_view());
cutlass::reference::host::TensorFill(
tensor_D1.host_view());
cutlass::reference::host::TensorFill(
tensor_D2.host_view());
cutlass::reference::host::TensorFill(
reference_D0.host_view());
cutlass::reference::host::TensorFill(
reference_D1.host_view());
cutlass::reference::host::TensorFill(
reference_D2.host_view());
tensor_A0.sync_device();
tensor_B0.sync_device();
tensor_C0.sync_device();
tensor_Bias0.sync_device();
tensor_B1.sync_device();
tensor_C1.sync_device();
tensor_Bias1.sync_device();
tensor_D0.sync_device();
tensor_D1.sync_device();
tensor_D2.sync_device();
reference_D0.sync_device();
reference_D1.sync_device();
reference_D2.sync_device();
//
// Initialize the GEMM operator
//
int split_k_slices = DualGemm::kSplitKSerial ? 2 : 1;
typename cutlass::TensorRef<typename DualGemm::ElementC, typename DualGemm::LayoutC> nullptr_ref{};
decltype(nullptr_ref) ref_B0, ref_B1;
if (beta0 != ElementCompute(0)) {
ref_B0 = {tensor_Bias0.device_data(), typename DualGemm::LayoutC::Stride(0)};
}
if (beta1 != ElementCompute(0)) {
ref_B1 = {tensor_Bias1.device_data(), typename DualGemm::LayoutC::Stride(0)};
}
typename DualGemm::Arguments arguments{
problem_size,
tensor_A0.device_ref(),
tensor_B0.device_ref(),
ref_B0,
DualGemm::kStoreD0 ? tensor_D0.device_ref() : nullptr_ref,
tensor_B1.device_ref(),
ref_B1,
DualGemm::kStoreD1 ? tensor_D1.device_ref() : nullptr_ref,
tensor_D2.device_ref(),
{alpha0, beta0},
{alpha1, beta1},
{},
split_k_slices
};
DualGemm b2b_gemm_op;
cutlass::device_memory::allocation<uint8_t> workspace(b2b_gemm_op.get_workspace_size(arguments));
cutlass::Status status = b2b_gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
status = b2b_gemm_op.initialize(arguments, workspace.get());
CUTLASS_CHECK(status);
for(int i = 0; i < warm_ups; i++) {
status = b2b_gemm_op();
CUTLASS_CHECK(status);
}
#ifdef IS_PROFILING
return true;
#endif
//
// Run the GEMM
//
cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);
cudaEventRecord(start);
for(int i = 0; i < runs; i++) {
status = b2b_gemm_op();
CUTLASS_CHECK(status);
}
cudaEventRecord(stop);
cudaDeviceSynchronize();
float gemmTime;
cudaEventElapsedTime(&gemmTime, start, stop);
std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n";
tensor_D0.sync_host();
tensor_D1.sync_host();
tensor_D2.sync_host();
//
// Verify
//
cutlass::reference::device::Gemm<
typename DualGemm::ElementA, typename DualGemm::LayoutA,
typename DualGemm::ElementB, typename DualGemm::LayoutB,
typename DualGemm::ElementC, typename DualGemm::LayoutC,
ElementAccumulator, ElementAccumulator>
reference_gemm_0;
cutlass::reference::device::Gemm<
typename DualGemm::ElementA, typename DualGemm::LayoutA,
typename DualGemm::ElementB, typename DualGemm::LayoutB,
typename DualGemm::ElementC, typename DualGemm::LayoutC, ElementCompute,
ElementAccumulator, typename DualGemm::Operator>
reference_gemm_1;
reference_gemm_0(
problem_size,
alpha0,
tensor_A0.device_ref(),
tensor_B0.device_ref(),
beta0,
{tensor_Bias0.device_data(), typename DualGemm::LayoutC::Stride(0)},
reference_D0.device_ref()
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D0.device_view());
}
reference_gemm_1(
problem_size,
alpha1,
tensor_A0.device_ref(),
tensor_B1.device_ref(),
beta1,
{tensor_Bias1.device_data(), typename DualGemm::LayoutC::Stride(0)},
reference_D1.device_ref()
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D1.device_view());
}
TensorEpilogueForEach<EpilogueOutputOp2>(reference_D0.device_view(), reference_D1.device_view(), reference_D2.device_view());
cudaDeviceSynchronize();
reference_D0.sync_host();
reference_D1.sync_host();
reference_D2.sync_host();
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D2.host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D2.host_view()), 0);
bool passed_out0 = true;
if (DualGemm::kStoreD0) {
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0);
passed_out0 = cutlass::reference::host::TensorEquals(
reference_D0.host_view(),
tensor_D0.host_view());
}
CHECK_TRUE(passed_out0);
bool passed_out1 = true;
if (DualGemm::kStoreD1) {
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
passed_out1 = cutlass::reference::host::TensorEquals(
reference_D1.host_view(),
tensor_D1.host_view());
}
CHECK_TRUE(passed_out1);
bool passed_out2 = cutlass::reference::host::TensorEquals(
reference_D2.host_view(),
tensor_D2.host_view());
CHECK_TRUE(passed_out2);
bool passed = passed_out0 && passed_out1 && passed_out2;
if (!passed)
{
std::stringstream fname;
fname << "error_DualGemm_device_fused.txt";
std::cerr << "Dumping results in " << fname.str() << "\n";
std::ofstream file(fname.str());
file
<< "A0 =\n" << tensor_A0.host_view()
<< "\nB0 =\n" << tensor_B0.host_view()
<< "\nC0 =\n" << tensor_C0.host_view()
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
<< "\nB1 =\n" << tensor_B1.host_view()
<< "\nC1 =\n" << tensor_C1.host_view()
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
<< "\n\nReference0 =\n" << reference_D0.host_view()
<< "\nComputed0 =\n" << tensor_D0.host_view()
<< "\n\nReference1 =\n" << reference_D1.host_view()
<< "\nComputed1 =\n" << tensor_D1.host_view()
<< "\n\nReference2 =\n" << reference_D2.host_view()
<< "\nComputed2 =\n" << tensor_D2.host_view();
}
//std::cout << "A0 " << tensor_A0.host_view() << std::endl;
// std::cout << "reference_D0 " << reference_D0.host_view() << std::endl;
// std::cout << "reference_D1 " << reference_D1.host_view() << std::endl;
// std::cout << "reference_D2 " << reference_D2.host_view() << std::endl;
//std::cout << "reference_D0 " << reference_D0.host_view() << std::endl;
return passed;
}
};
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,489 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
#include "../threadblock/dual_mma_multistage.h"
#include "../threadblock/dual_epilogue.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename DualMma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue0_, ///! Epilogue
typename Epilogue1_, ///! Epilogue
typename OutputOp2_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
bool SplitKSerial, ///! If true, code supporting split-K via serial reduction is enabled.
bool StoreD0,
bool StoreD1
>
struct DualGemm {
using DualMma = DualMma_;
using Epilogue0 = Epilogue0_;
using Epilogue1 = Epilogue1_;
using OutputOp0 = typename Epilogue0::OutputOp;
using OutputOp1 = typename Epilogue1::OutputOp;
using OutputOp2 = OutputOp2_;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static constexpr bool kStoreD0 = StoreD0;
static constexpr bool kStoreD1 = StoreD1;
using DualEpilogue = cutlass::epilogue::threadblock::DualEpilogue<
typename Epilogue0::Shape,
typename Epilogue0::WarpMmaOperator,
Epilogue0::kPartitionsK,
typename Epilogue0::OutputTileIterator,
typename Epilogue0::AccumulatorFragmentIterator,
typename Epilogue0::WarpTileIterator,
typename Epilogue0::SharedLoadIterator,
OutputOp0,
OutputOp1,
OutputOp2,
typename Epilogue0::Padding,
kStoreD0,
kStoreD1,
Epilogue0::kFragmentsPerIteration,
true // IterationsUnroll
>;
static bool const kSplitKSerial = SplitKSerial;
static_assert(!kSplitKSerial || (kStoreD0 && kStoreD1),
"Split-K serial requires buffers for D0/D1 for reduction");
/// Warp count (concept: GemmShape)
using WarpCount0 = typename DualMma::WarpCount;
static int const kThreadCount = 32 * WarpCount0::kCount;
/// Parameters structure
struct Params {
cutlass::gemm::GemmCoord problem_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
int swizzle_log_tile;
// Mma0
typename DualMma::IteratorA::Params params_A0;
typename DualMma::IteratorA::TensorRef ref_A0;
typename DualMma::IteratorB::Params params_B0;
typename DualMma::IteratorB::TensorRef ref_B0;
typename Epilogue0::OutputTileIterator::Params params_C0;
typename Epilogue0::OutputTileIterator::TensorRef ref_C0;
typename Epilogue0::OutputTileIterator::Params params_D0;
typename Epilogue0::OutputTileIterator::TensorRef ref_D0;
typename OutputOp0::Params output_op_0;
// Mma1
typename DualMma::IteratorB::Params params_B1;
typename DualMma::IteratorB::TensorRef ref_B1;
typename Epilogue1::OutputTileIterator::Params params_C1;
typename Epilogue1::OutputTileIterator::TensorRef ref_C1;
typename Epilogue1::OutputTileIterator::Params params_D1;
typename Epilogue1::OutputTileIterator::TensorRef ref_D1;
typename OutputOp1::Params output_op_1;
typename Epilogue1::OutputTileIterator::Params params_D2;
typename Epilogue1::OutputTileIterator::TensorRef ref_D2;
typename OutputOp2::Params output_op_2;
int *semaphore;
int gemm_k_size;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params(): swizzle_log_tile(0), semaphore(0), gemm_k_size(0) { }
CUTLASS_HOST_DEVICE
Params(
cutlass::gemm::GemmCoord const & problem_size,
cutlass::gemm::GemmCoord const & grid_tiled_shape,
// Mma0: D0 = A @ B0 + C0
typename DualMma::IteratorA::TensorRef ref_A0,
typename DualMma::IteratorB::TensorRef ref_B0,
typename Epilogue0::OutputTileIterator::TensorRef ref_C0,
typename Epilogue0::OutputTileIterator::TensorRef ref_D0,
// Mma1: D1 = A @ B1 + C1
typename DualMma::IteratorB::TensorRef ref_B1,
typename Epilogue1::OutputTileIterator::TensorRef ref_C1,
typename Epilogue1::OutputTileIterator::TensorRef ref_D1,
typename Epilogue1::OutputTileIterator::TensorRef ref_D2,
typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(),
typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(),
typename OutputOp2::Params output_op_2 = typename OutputOp2::Params(),
int *workspace = nullptr
):
problem_size(problem_size),
grid_tiled_shape(grid_tiled_shape),
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
// Mma0
params_A0(ref_A0.layout()),
ref_A0(ref_A0),
params_B0(ref_B0.layout()),
ref_B0(ref_B0),
params_C0(ref_C0.layout()),
ref_C0(ref_C0),
params_D0(ref_D0.layout()),
ref_D0(ref_D0),
// Mma1
params_B1(ref_B1.layout()),
ref_B1(ref_B1),
params_C1(ref_C1.layout()),
ref_C1(ref_C1),
params_D1(ref_D1.layout()),
ref_D1(ref_D1),
params_D2(ref_D2.layout()),
ref_D2(ref_D2),
output_op_0(output_op_0),
output_op_1(output_op_1),
output_op_2(output_op_2) {
int total_gemm_k_iterations = (problem_size.k() + DualMma::Shape::kK - 1) / DualMma::Shape::kK;
int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
gemm_k_size = gemm_k_iterations * DualMma::Shape::kK;
semaphore = workspace;
}
};
/// Shared memory storage structure
union SharedStorage {
typename DualMma::SharedStorage main_loop;
typename DualEpilogue::SharedStorage epilogue;
};
//
// Methods
//
CUTLASS_HOST_DEVICE
DualGemm() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(
cutlass::gemm::GemmCoord const & problem_size,
typename DualMma::IteratorA::TensorRef ref_A0,
typename DualMma::IteratorB::TensorRef ref_B0,
typename Epilogue0::OutputTileIterator::TensorRef ref_C0,
typename Epilogue0::OutputTileIterator::TensorRef ref_D0,
typename DualMma::IteratorB::TensorRef ref_B1,
typename Epilogue1::OutputTileIterator::TensorRef ref_C1,
typename Epilogue1::OutputTileIterator::TensorRef ref_D1,
typename Epilogue1::OutputTileIterator::TensorRef ref_D2) {
static int const kAlignmentA = DualMma::IteratorA::AccessType::kElements;
static int const kAlignmentB = DualMma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue0::OutputTileIterator::kElementsPerAccess;
if (!TensorRef_aligned(ref_A0, kAlignmentA)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(ref_B0, kAlignmentB)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(ref_C0, kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(ref_D0, kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(ref_B1, kAlignmentB)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(ref_C1, kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(ref_D1, kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(ref_D2, kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
return Status::kSuccess;
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
return;
}
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A0{
threadblock_tile_offset.m() * DualMma::Shape::kM,
threadblock_tile_offset.k() * params.gemm_k_size,
};
cutlass::MatrixCoord tb_offset_B0{
threadblock_tile_offset.k() * params.gemm_k_size,
threadblock_tile_offset.n() * DualMma::Shape::kN
};
cutlass::MatrixCoord tb_offset_B1{
threadblock_tile_offset.k() * params.gemm_k_size,
threadblock_tile_offset.n() * DualMma::Shape::kN
};
// Problem size is a function of threadblock index in the K dimension
int problem_size_k =
(params.problem_size.k() < (threadblock_tile_offset.k() + 1) * params.gemm_k_size) ?
params.problem_size.k() :
(threadblock_tile_offset.k() + 1) * params.gemm_k_size;
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size_k - tb_offset_A0.column() + DualMma::Shape::kK - 1) / DualMma::Shape::kK;
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename DualMma::IteratorA iterator_A0(
params.params_A0,
params.ref_A0.data(),
{params.problem_size.m(), problem_size_k},
thread_idx,
tb_offset_A0);
typename DualMma::IteratorB iterator_B0(
params.params_B0,
params.ref_B0.data(),
{problem_size_k, params.problem_size.n()},
thread_idx,
tb_offset_B0);
typename DualMma::IteratorB iterator_B1(
params.params_B1,
params.ref_B1.data(),
{problem_size_k, params.problem_size.n()},
thread_idx,
tb_offset_B1);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Main loop
//
// Construct thread-scoped matrix multiply
typename DualMma::FragmentC accum0;
typename DualMma::FragmentC accum1;
accum0.clear();
accum1.clear();
DualMma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
if (!kSplitKSerial || gemm_k_iterations > 0) {
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations,
accum0, accum1,
iterator_A0, iterator_B0, iterator_B1,
accum0, accum1);
}
//
// Epilogue
//
OutputOp0 output_op_0(params.output_op_0);
OutputOp1 output_op_1(params.output_op_1);
OutputOp2 output_op_2(params.output_op_2);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
//assume identity swizzle
MatrixCoord threadblock_offset(
threadblock_tile_offset.m() * DualMma::Shape::kM,
threadblock_tile_offset.n() * DualMma::Shape::kN
);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
// Construct the semaphore.
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
// If performing a reduction via split-K, fetch the initial synchronization
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
// Fetch the synchronization lock initially but do not block.
semaphore.fetch();
// Indicate which position in a serial reduction the output operator is currently updating
output_op_0.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
}
// Tile iterator loading from source tensor.
typename Epilogue0::OutputTileIterator iterator_C0(
params.params_C0,
params.ref_C0.data(),
params.problem_size.mn(),
thread_idx,
threadblock_offset
);
typename Epilogue1::OutputTileIterator iterator_C1(
params.params_C1,
params.ref_C1.data(),
params.problem_size.mn(),
thread_idx,
threadblock_offset
);
// Tile iterator writing to destination tensor.
typename Epilogue0::OutputTileIterator iterator_D0(
params.params_D0,
params.ref_D0.data(),
params.problem_size.mn(),
thread_idx,
threadblock_offset
);
typename Epilogue1::OutputTileIterator iterator_D1(
params.params_D1,
params.ref_D1.data(),
params.problem_size.mn(),
thread_idx,
threadblock_offset
);
typename Epilogue1::OutputTileIterator iterator_D2(
params.params_D2,
params.ref_D2.data(),
params.problem_size.mn(),
thread_idx,
threadblock_offset
);
DualEpilogue epilogue(
shared_storage.epilogue,
thread_idx,
warp_idx,
lane_idx);
// Wait on the semaphore - this latency may have been covered by iterator construction
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
if (threadblock_tile_offset.k()) {
iterator_C0 = iterator_D0;
iterator_C1 = iterator_D1;
}
semaphore.wait(threadblock_tile_offset.k());
__threadfence();
}
// Execute the epilogue operator to update the destination tensor.
typename Epilogue0::OutputTileIterator source_iters[] = {
iterator_C0, iterator_C1
};
const bool writeToD2 = (!kSplitKSerial || params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1);
epilogue(
output_op_0, output_op_1, output_op_2,
iterator_D0, iterator_D1, iterator_D2,
accum0, accum1,
source_iters,
writeToD2
);
//
// Release the semaphore
//
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
int lock = 0;
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
// The final threadblock resets the semaphore for subsequent grids.
lock = 0;
}
else {
// Otherwise, the semaphore is incremented
lock = threadblock_tile_offset.k() + 1;
}
__threadfence();
semaphore.release(lock);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@ -0,0 +1,95 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <iostream>
// Run tests on GPUs
int testRun(int arch, std::vector<bool (*)()> & test_funcs, const std::string & test_name) {
bool supported = false;
int arch_major = arch / 10;
int arch_minor = arch - arch / 10 * 10;
if(arch_major >= 8) {
// Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0.
//
// CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples.
if (__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) {
supported = true;
}
}
else if(arch_major >= 7) {
// Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2.
//
// CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples.
if (__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) {
supported = true;
}
}
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (!(props.major == arch_major && props.minor == arch_minor)) {
supported = false;
}
if (!supported) {
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
std::cout << "This example isn't supported on current architecture" << std::endl;
return 0;
}
bool pass = true;
std::cout << "Device: " << props.name << std::endl;
std::cout << "Arch: SM" << arch << std::endl;
std::cout << "Test: " << test_name << std::endl;
for(auto func : test_funcs) {
pass &= func();
}
if(pass)
return 0;
else
return -1;
}

View File

@ -0,0 +1,150 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Functor performing linear combination operations used by epilogues.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/functional.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/epilogue/thread/scale_type.h"
#include "cutlass/epilogue/thread/linear_combination_params.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace thread {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Applies a linear combination operator to an array of elements.
///
/// D = alpha * accumulator + beta * source + uniform
///
template <
typename ElementOutput_, ///< Data type used to load and store tensors
int Count, ///< Number of elements computed per operation.
///< Usually it is 128/sizeof_bits<ElementOutput_>,
///< but we use 64 or 32 sometimes when there are not enough data to store
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
>
class LeftSiLUAndMul {
public:
using ElementOutput = ElementOutput_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
static int const kCount = Count;
using FragmentOutput = Array<ElementOutput, kCount>;
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
using ComputeFragment = Array<ElementCompute, kCount>;
static FloatRoundStyle const kRound = Round;
struct Params{};
private:
//
// Data members
//
ElementCompute alpha_;
ElementCompute beta_;
public:
/// Constructs the function object, possibly loading from pointers in host memory
CUTLASS_HOST_DEVICE
LeftSiLUAndMul(Params const &/*params*/) {}
/// Returns true if source is needed
CUTLASS_HOST_DEVICE
bool is_source_needed() const {
return true;
}
/// Functionally required for serial reduction in the epilogue
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition, int k_partition_count) {
assert(false);
}
/// Computes linear scaling: D = alpha * accumulator + beta * source
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &lhs,
FragmentAccumulator const &rhs) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_to_compute;
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> compute_to_output;
ComputeFragment converted_lhs = accumulator_to_compute(lhs);
ComputeFragment converted_rhs = accumulator_to_compute(rhs);
cutlass::epilogue::thread::SiLu<ComputeFragment> silu;
cutlass::multiplies<ComputeFragment> mul;
auto silu_lhs = silu(converted_lhs);
return compute_to_output(mul(silu_lhs, converted_rhs));
}
CUTLASS_HOST_DEVICE
ElementOutput operator()(
ElementAccumulator const& lhs,
ElementAccumulator const& rhs
) const {
ElementCompute convert_lhs(lhs);
ElementCompute convert_rhs(rhs);
cutlass::epilogue::thread::SiLu<ElementCompute> silu;
cutlass::multiplies<ElementCompute> mul;
auto silu_lhs = silu(convert_lhs);
return ElementOutput(mul(silu_lhs, convert_rhs));
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace thread
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,430 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
The epilogue rearranges the result of a matrix product through shared memory to match canonical
tensor layouts in global memory. Epilogues support conversion and reduction operations.
*/
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include <assert.h>
#endif
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/layout/vector.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/tensor_coord.h"
#include "cutlass/aligned_buffer.h"
#include "cutlass/functional.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
#include "cutlass/epilogue/threadblock/epilogue_base.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
#include "cutlass/numeric_types.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Epilogue operator
template <
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
int PartitionsK, ///< Number of partitions of the K dimension
typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
///< Output operator
typename OutputOp0_,
typename OutputOp1_,
typename OutputOp2_,
typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape)
bool StoreD0 = true,
bool StoreD1 = true,
int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity
int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large
(!IsEpilogueFunctorHeavy<OutputOp0_>::value)
>
class DualEpilogue {
public:
using Base = EpilogueBase<
Shape_,
typename WarpMmaOperator_::Shape,
PartitionsK,
AccumulatorFragmentIterator_,
WarpTileIterator_,
Padding_,
FragmentsPerPartition>;
using Shape = Shape_;
using WarpMmaOperator = WarpMmaOperator_;
static int const kPartitionsK = PartitionsK;
static bool constexpr kStoreD0 = StoreD0;
static bool constexpr kStoreD1 = StoreD1;
using OutputTileIterator = OutputTileIterator_;
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
using WarpTileIterator = WarpTileIterator_;
using SharedLoadIterator = SharedLoadIterator_;
using OutputOp0 = OutputOp0_;
using OutputOp1 = OutputOp1_;
using OutputOp2 = OutputOp2_;
using Padding = Padding_;
using Layout = layout::RowMajor;
using LongIndex = typename Layout::LongIndex;
/// The complete warp-level accumulator tile
using AccumulatorTile = typename Base::AccumulatorTile;
/// Accumulator element
using ElementAccumulator = typename WarpTileIterator::Element;
/// Output element
using ElementOutput = typename OutputTileIterator::Element;
/// Output access size
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
/// Tensor reference to destination tensor
using TensorRef = typename OutputTileIterator::TensorRef;
/// Tensor reference to sync tensor
using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
/// Const tensor reference to source tensor
using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
/// Array type used to output
using OutputAccessType = Array<
typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
/// Array type used by output functor
using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
/// Number of warps
using WarpCount = typename Base::WarpCount;
struct SharedStorage {
using Element = typename WarpTileIterator::Element;
/// Tensor reference to shared memory allocation
using TensorRef = typename WarpTileIterator::TensorRef;
/// Logical shape of the shared memory tile written to by all warps.
using Shape = typename Base::Shape;
/// Shape of the shared memory allocation for the epilogue
using StorageShape = typename Base::SharedStorage::StorageShape;
//
// Data members
//
AlignedBuffer<Element, StorageShape::kCount> storage[2];
//
// Methods
//
/// Returns a tensor reference to the shared memory buffer
CUTLASS_DEVICE
TensorRef reference(int i) {
return TensorRef(
storage[i].data(),
Layout::packed({StorageShape::kRow, StorageShape::kColumn}));
}
};
static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK;
static int constexpr kSmemPointerOffset = SharedStorage::StorageShape::kCount / kSmemTiles;
public:
static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
"Mismatch between shared load iterator and output tile iterator.");
static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero.");
static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
"Divisibility");
private:
/// Loads fragment from shared memory aligned with output tensor
SharedLoadIterator shared_load_iterator0_;
SharedLoadIterator shared_load_iterator1_;
/// Stores a warp's fragment of accumulators to SMEM
WarpTileIterator warp_tile_iterator0_;
WarpTileIterator warp_tile_iterator1_;
public:
/// Constructor
CUTLASS_DEVICE
DualEpilogue(
SharedStorage &shared_storage, ///< Shared storage object
int thread_idx, ///< ID of a thread within the threadblock
int warp_idx, ///< ID of warp within threadblock
int lane_idx ///< Id of thread within warp
):
shared_load_iterator0_(shared_storage.reference(0), thread_idx),
shared_load_iterator1_(shared_storage.reference(1), thread_idx),
warp_tile_iterator0_(shared_storage.reference(0), lane_idx),
warp_tile_iterator1_(shared_storage.reference(1), lane_idx)
{
int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN);
int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN);
int warp_m = warp_mn % WarpCount::kM;
int warp_n = warp_mn / WarpCount::kM;
MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n};
warp_tile_iterator0_.add_tile_offset(warp_offset);
warp_tile_iterator1_.add_tile_offset(warp_offset);
}
/// Streams the result to global memory
CUTLASS_DEVICE
void operator()(
OutputOp0 const &output_op0,
OutputOp1 const &output_op1,
OutputOp2 const &output_op2,
OutputTileIterator dest0,
OutputTileIterator dest1,
OutputTileIterator dest2,
AccumulatorTile const &accumulator0,
AccumulatorTile const &accumulator1,
OutputTileIterator source_iterator[2],
bool writeToD2 // true if it's the final split-k
) {
// TODO: Implement when no source is needed
typename OutputTileIterator::Fragment source_fragment[2];
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 2; ++i) {
source_fragment[i].clear();
}
//
// Iterator over warp-level accumulator fragment
//
AccumulatorFragmentIterator accum_fragment_iterator[2] = {accumulator0, accumulator1};
//
// Iterate over accumulator tile
//
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
//
// Load the source
//
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 2; ++i) {
source_iterator[i].load(source_fragment[i]);
++source_iterator[i];
}
//
// Convert and store fragment
//
__syncthreads();
acc2smem_source_needed<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
iter, accum_fragment_iterator[0], this->warp_tile_iterator0_);
acc2smem_source_needed<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
iter, accum_fragment_iterator[1], this->warp_tile_iterator1_);
__syncthreads();
//
// Load fragments from shared memory
//
typename SharedLoadIterator::Fragment aligned_accum_fragment0[kPartitionsK];
typename SharedLoadIterator::Fragment aligned_accum_fragment1[kPartitionsK];
shared_load_iterator0_.load(aligned_accum_fragment0[0]);
shared_load_iterator1_.load(aligned_accum_fragment1[0]);
// If the number of k-slices is > 1 - perform a reduction amongst the k-slices
if (kPartitionsK > 1) {
plus <typename SharedLoadIterator::Fragment> add_fragments;
CUTLASS_PRAGMA_UNROLL
for ( int i = 1; i < kPartitionsK; ++i) {
shared_load_iterator0_.add_pointer_offset(kSmemPointerOffset);
shared_load_iterator1_.add_pointer_offset(kSmemPointerOffset);
shared_load_iterator0_.load(aligned_accum_fragment0[i]);
shared_load_iterator1_.load(aligned_accum_fragment1[i]);
aligned_accum_fragment0[0] = add_fragments(aligned_accum_fragment0[0], aligned_accum_fragment0[i]);
aligned_accum_fragment1[0] = add_fragments(aligned_accum_fragment1[0], aligned_accum_fragment1[i]);
}
shared_load_iterator0_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
shared_load_iterator1_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
}
//
// Compute the output result
//
typename OutputTileIterator::Fragment output_fragment[3];
apply_output_operator_(output_fragment,
output_op0, output_op1, output_op2,
aligned_accum_fragment0[0], aligned_accum_fragment1[0],
source_fragment);
//
// Store the final result
//
if (kStoreD0) {
dest0.store(output_fragment[0]);
++dest0;
}
if (kStoreD1) {
dest1.store(output_fragment[1]);
++dest1;
}
if (writeToD2) {
dest2.store(output_fragment[2]);
++dest2;
}
}
}
private:
static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1.");
template<class Seq>
struct acc2smem_source_needed;
template <size_t... Seq>
struct acc2smem_source_needed<cutlass::index_sequence<Seq...>> {
template<int Advance>
CUTLASS_DEVICE
static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
WarpTileIterator &warp_tile_iterator) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Advance; i++) {
++accum_fragment_iterator;
}
typename AccumulatorFragmentIterator::Fragment accum_fragment;
accum_fragment_iterator.load(accum_fragment);
warp_tile_iterator.store(accum_fragment);
}
CUTLASS_DEVICE
static void push(size_t pos,
AccumulatorFragmentIterator const &iterator_begin,
WarpTileIterator &warp_tile_iterator) {
int dummy[] = {(pos == Seq) && (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
}
};
/// Helper to invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator_(
typename OutputTileIterator::Fragment (&output_fragment)[3],
OutputOp0 const &output_op0,
OutputOp1 const &output_op1,
OutputOp2 const &output_op2,
typename SharedLoadIterator::Fragment const& aligned_accum_fragment0,
typename SharedLoadIterator::Fragment const& aligned_accum_fragment1,
typename OutputTileIterator::Fragment const (&source_fragment)[2]) {
OutputAccessType* output_frag_ptr[3] = {
reinterpret_cast<OutputAccessType *>(&output_fragment[0]),
reinterpret_cast<OutputAccessType *>(&output_fragment[1]),
reinterpret_cast<OutputAccessType *>(&output_fragment[2])
};
AccumulatorAccessType const *compute_frag_ptr[2] = {
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment0),
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment1)
};
OutputAccessType const *source_frag_ptr[2] = {
reinterpret_cast<OutputAccessType const *>(&source_fragment[0]),
reinterpret_cast<OutputAccessType const *>(&source_fragment[1])
};
int const kOutputOpIterations =
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i) {
// Call the output operators
output_frag_ptr[0][i] = output_op0(compute_frag_ptr[0][i], source_frag_ptr[0][i]);
output_frag_ptr[1][i] = output_op1(compute_frag_ptr[1][i], source_frag_ptr[1][i]);
output_frag_ptr[2][i] = output_op2(output_frag_ptr[0][i], output_frag_ptr[1][i]);
}
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,218 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/threadblock/mma_base.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Used for partial specialization
typename Enable = bool>
class DualMmaBase {
public:
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Policy describing tuning details
using Policy = Policy_;
//
// Dependent types
//
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Shape describing the overall GEMM computed from shared memory
/// by each warp.
using WarpGemm = typename Policy::Operator::Shape;
/// Shape describing the number of warps filling the CTA
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM,
Shape::kN / WarpGemm::kN,
Shape::kK / WarpGemm::kK>;
/// Number of warp-level GEMM oeprations
static int const kWarpGemmIterations =
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
/// Number of stages
static int const kStages = Stages;
/// Tensor reference to the A operand
using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
/// Tensor reference to the B operand
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
static_assert(kWarpGemmIterations > 1,
"The pipelined structure requires at least two warp-level "
"GEMM operations.");
static_assert((kWarpGemmIterations % 2) == 0,
"Inner loop iteration must be an even number.");
//
// Nested structs
//
/// Shared storage object needed by threadblock-scoped GEMM
class SharedStorage {
public:
//
// Type definitions
//
/// Shape of the A matrix operand in shared memory
using ShapeA = MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,
Shape::kK * kStages +
Policy::SmemPaddingA::kColumn>;
/// Shape of the B matrix operand in shared memory
using ShapeB =
MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
Shape::kN + Policy::SmemPaddingB::kColumn>;
public:
//
// Data members
//
/// Buffer for A operand
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
/// Buffer for B operand
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B0;
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B1;
public:
//
// Methods
//
/// Returns a layout object for the A matrix
CUTLASS_DEVICE
static typename Operator::LayoutA LayoutA() {
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
}
/// Returns a layout object for the B matrix
CUTLASS_HOST_DEVICE
static typename Operator::LayoutB LayoutB() {
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
}
/// Returns a TensorRef to the A operand
CUTLASS_HOST_DEVICE
TensorRefA operand_A_ref() {
return TensorRefA{operand_A.data(), LayoutA()};
}
/// Returns a TensorRef to the B operand
CUTLASS_HOST_DEVICE
TensorRefB operand_B0_ref() {
return TensorRefB{operand_B0.data(), LayoutB()};
}
CUTLASS_HOST_DEVICE
TensorRefB operand_B1_ref() {
return TensorRefB{operand_B1.data(), LayoutB()};
}
};
protected:
//
// Data members
//
/// Iterator to load a warp-scoped tile of A operand from shared memory
typename Operator::IteratorA warp_tile_iterator_A_;
/// Iterator to load a warp-scoped tile of B operand from shared memory
typename Operator::IteratorB warp_tile_iterator_B0_;
typename Operator::IteratorB warp_tile_iterator_B1_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DualMmaBase(
///< Shared storage needed for internal use by threadblock-scoped GEMM
SharedStorage &shared_storage,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx
):
warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
warp_tile_iterator_B0_(shared_storage.operand_B0_ref(), lane_idx),
warp_tile_iterator_B1_(shared_storage.operand_B1_ref(), lane_idx) {
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,760 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/threadblock/mma_base.h"
#include "dual_mma_base.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Cache operation for operand A
cutlass::arch::CacheOperation::Kind CacheOpA,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Cache operation for operand B
cutlass::arch::CacheOperation::Kind CacheOpB,
/// Data type of accumulator matrix
typename ElementC_,
/// Data type of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
/// Used for partial specialization
typename Enable = bool>
class DualMmaMultistage :
public DualMmaBase<Shape_, Policy_, Stages> {
public:
///< Base class
using Base = DualMmaBase<Shape_, Policy_, Stages>;
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Iterates over tiles of A operand in global memory
using IteratorA = IteratorA_;
///< Iterates over tiles of B operand in global memory
using IteratorB = IteratorB_;
///< Data type of accumulator matrix
using ElementC = ElementC_;
///< Layout of accumulator matrix
using LayoutC = LayoutC_;
///< Policy describing tuning details
using Policy = Policy_;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
//
// Dependent types
//
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Minimum architecture is Sm80 to support cp.async
using ArchTag = arch::Sm80;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
/// Internal structure exposed for introspection.
struct Detail {
/// Number of cp.async instructions to load one stage of operand A
static int const AsyncCopyIterationsPerStageA =
IteratorA::ThreadMap::Iterations::kCount;
/// Number of cp.async instructions to load one stage of operand B
static int const AsyncCopyIterationsPerStageB =
IteratorB::ThreadMap::Iterations::kCount;
/// Number of stages
static int const kStages = Stages;
/// Number of cp.async instructions to load on group of operand A
static int const kAccessesPerGroupA =
(AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
/// Number of cp.async instructions to load on group of operand B
static int const kAccessesPerGroupB =
(AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
};
private:
using WarpLoadedFragmentA = typename Operator::FragmentA;
using WarpLoadedFragmentB = typename Operator::FragmentB;
using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
private:
//
// Data members
//
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B0_;
SmemIteratorB smem_iterator_B1_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DualMmaMultistage(
///< Shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorage &shared_storage,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx
):
Base(shared_storage, thread_idx, warp_idx, lane_idx),
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
smem_iterator_B0_(shared_storage.operand_B0_ref(), thread_idx),
smem_iterator_B1_(shared_storage.operand_B1_ref(), thread_idx)
{
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset(
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B0_.add_tile_offset(
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
this->warp_tile_iterator_B1_.add_tile_offset(
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
}
CUTLASS_DEVICE
void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B0, IteratorB &iterator_B1,
int group_start_A = 0, int group_start_B = 0) {
iterator_A.set_iteration_index(group_start_A *
IteratorA::kAccessesPerVector);
this->smem_iterator_A_.set_iteration_index(group_start_A);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) {
if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) {
typename IteratorA::AccessType *dst_ptr =
reinterpret_cast<typename IteratorA::AccessType *>(
this->smem_iterator_A_.get());
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
IteratorA::ThreadMap::kElementsPerAccess /
IteratorA::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_A.get();
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, gmem_ptr, iterator_A.valid());
} else {
cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(
dst_ptr + v, gmem_ptr, iterator_A.valid());
}
++iterator_A;
}
++this->smem_iterator_A_;
}
}
iterator_B0.set_iteration_index(group_start_B *
IteratorB::kAccessesPerVector);
iterator_B1.set_iteration_index(group_start_B *
IteratorB::kAccessesPerVector);
this->smem_iterator_B0_.set_iteration_index(group_start_B);
this->smem_iterator_B1_.set_iteration_index(group_start_B);
// Async Copy for operand B0
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
typename IteratorB::AccessType *dst_ptr =
reinterpret_cast<typename IteratorB::AccessType *>(
this->smem_iterator_B0_.get());
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
IteratorB::ThreadMap::kElementsPerAccess /
IteratorB::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_B0.get();
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, gmem_ptr, iterator_B0.valid());
} else {
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(
dst_ptr + v, gmem_ptr, iterator_B0.valid());
}
++iterator_B0;
}
++this->smem_iterator_B0_;
}
}
// Async Copy for operand B1
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
typename IteratorB::AccessType *dst_ptr =
reinterpret_cast<typename IteratorB::AccessType *>(
this->smem_iterator_B1_.get());
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
IteratorB::ThreadMap::kElementsPerAccess /
IteratorB::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_B1.get();
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, gmem_ptr, iterator_B1.valid());
} else {
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(
dst_ptr + v, gmem_ptr, iterator_B1.valid());
}
++iterator_B1;
}
++this->smem_iterator_B1_;
}
}
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(
///< problem size of GEMM
int gemm_k_iterations,
///< destination accumulator tile
FragmentC &accum0,
FragmentC &accum1,
///< iterator over A operand in global memory
IteratorA iterator_A,
///< iterator over B operand in global memory
IteratorB iterator_B0,
IteratorB iterator_B1,
///< initial value of accumulator
FragmentC const &src_accum0,
FragmentC const &src_accum1
) {
//
// Prologue
//
// Issue several complete stages
CUTLASS_PRAGMA_UNROLL
for (int stage = 0; stage < Base::kStages - 1;
++stage, --gemm_k_iterations) {
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B0.clear_mask(gemm_k_iterations == 0);
iterator_B1.clear_mask(gemm_k_iterations == 0);
iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
typename IteratorA::AccessType *dst_ptr =
reinterpret_cast<typename IteratorA::AccessType *>(
this->smem_iterator_A_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
int const kSrcBytes =
sizeof_bits<typename IteratorA::Element>::value *
IteratorA::ThreadMap::kElementsPerAccess /
IteratorA::kAccessesPerVector / 8;
int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, iterator_A.get(), iterator_A.valid());
++iterator_A;
}
++this->smem_iterator_A_;
}
iterator_B0.set_iteration_index(0);
iterator_B1.set_iteration_index(0);
this->smem_iterator_B0_.set_iteration_index(0);
this->smem_iterator_B1_.set_iteration_index(0);
// Async Copy for operand B0
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
typename IteratorB::AccessType *dst_ptr =
reinterpret_cast<typename IteratorB::AccessType *>(
this->smem_iterator_B0_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
int const kSrcBytes =
sizeof_bits<typename IteratorB::Element>::value *
IteratorB::ThreadMap::kElementsPerAccess /
IteratorB::kAccessesPerVector / 8;
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, iterator_B0.get(), iterator_B0.valid());
++iterator_B0;
}
++this->smem_iterator_B0_;
}
// Async Copy for operand B1
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
typename IteratorB::AccessType *dst_ptr =
reinterpret_cast<typename IteratorB::AccessType *>(
this->smem_iterator_B1_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
int const kSrcBytes =
sizeof_bits<typename IteratorB::Element>::value *
IteratorB::ThreadMap::kElementsPerAccess /
IteratorB::kAccessesPerVector / 8;
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, iterator_B1.get(), iterator_B1.valid());
++iterator_B1;
}
++this->smem_iterator_B1_;
}
// Move to the next stage
iterator_A.add_tile_offset({0, 1});
iterator_B0.add_tile_offset({1, 0});
iterator_B1.add_tile_offset({1, 0});
this->smem_iterator_A_.add_tile_offset({0, 1});
this->smem_iterator_B0_.add_tile_offset({1, 0});
this->smem_iterator_B1_.add_tile_offset({1, 0});
// Defines the boundary of a stage of cp.async.
cutlass::arch::cp_async_fence();
}
// Perform accumulation in the 'd' output operand
accum0 = src_accum0;
accum1 = src_accum1;
//
// Clear the remaining tiles of SMEM. This is a functional requirement for some kernels
// so that all accumulator elements outside the GEMM footprint are zero.
//
if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) {
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
typename IteratorA::AccessType zero_A;
zero_A.clear();
last_smem_iterator_A.set_iteration_index(0);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
typename IteratorA::AccessType *dst_ptr =
reinterpret_cast<typename IteratorA::AccessType *>(
last_smem_iterator_A.get());
*dst_ptr = zero_A;
++last_smem_iterator_A;
}
typename IteratorB::AccessType zero_B;
zero_B.clear();
/// Iterator to write threadblock-scoped tile of B0 operand to shared memory
SmemIteratorB last_smem_iterator_B0(this->smem_iterator_B0_);
last_smem_iterator_B0.set_iteration_index(0);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
typename IteratorB::AccessType *dst_ptr =
reinterpret_cast<typename IteratorB::AccessType *>(
last_smem_iterator_B0.get());
*dst_ptr = zero_B;
++last_smem_iterator_B0;
}
/// Iterator to write threadblock-scoped tile of B1 operand to shared memory
SmemIteratorB last_smem_iterator_B1(this->smem_iterator_B1_);
last_smem_iterator_B1.set_iteration_index(0);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
typename IteratorB::AccessType *dst_ptr =
reinterpret_cast<typename IteratorB::AccessType *>(
last_smem_iterator_B1.get());
*dst_ptr = zero_B;
++last_smem_iterator_B1;
}
}
// Waits until stages up to the previous (kStages-2)th stage have committed.
cutlass::arch::cp_async_wait<Base::kStages - 2>();
__syncthreads();
// Pair of fragments used to overlap shared memory loads and math
// instructions
WarpLoadedFragmentA warp_loaded_frag_A[2];
WarpLoadedFragmentB warp_loaded_frag_B0[2];
WarpLoadedFragmentB warp_loaded_frag_B1[2];
WarpTransformedFragmentA warp_transformed_frag_A[2];
WarpTransformedFragmentB warp_transformed_frag_B0[2];
WarpTransformedFragmentB warp_transformed_frag_B1[2];
Operator warp_mma;
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_B0_.set_kgroup_index(0);
this->warp_tile_iterator_B1_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]);
this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]);
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B0_;
++this->warp_tile_iterator_B1_;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B0.clear_mask(gemm_k_iterations == 0);
iterator_B1.clear_mask(gemm_k_iterations == 0);
int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B0[0],
warp_loaded_frag_A[0], warp_loaded_frag_B0[0]);
warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B1[0],
warp_loaded_frag_A[0], warp_loaded_frag_B1[0]);
// tf32x3 kernels use staging accumulation. warp_mma uses a temporary
// accumulator and this temporary accumulator is added to the final
// accumulator once in every mainloop iteration.
plus<FragmentC> plus_accum;
FragmentC tmp_accum0, tmp_accum1;
if (platform::is_same<typename Operator::MathOperator,
arch::OpMultiplyAddFastF32>::value
|| platform::is_same<typename Operator::MathOperator,
arch::OpMultiplyAddComplexFastF32>::value) {
tmp_accum0.clear();
tmp_accum1.clear();
}
//
// Mainloop
//
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > (-Base::kStages + 1);) {
//
// Loop over GEMM K dimension
//
// Computes a warp-level GEMM on data held in shared memory
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations;
++warp_mma_k) {
// Load warp-level tiles from shared memory, wrapping to k offset if
// this is the last group as the case may be.
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]);
this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]);
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B0_;
++this->warp_tile_iterator_B1_;
if (warp_mma_k > 0) {
warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2],
warp_transformed_frag_B0[warp_mma_k % 2],
warp_loaded_frag_A[warp_mma_k % 2],
warp_loaded_frag_B0[warp_mma_k % 2]);
warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2],
warp_transformed_frag_B1[warp_mma_k % 2],
warp_loaded_frag_A[warp_mma_k % 2],
warp_loaded_frag_B1[warp_mma_k % 2]);
}
if (platform::is_same<typename Operator::MathOperator,
arch::OpMultiplyAddFastF32>::value
|| platform::is_same<typename Operator::MathOperator,
arch::OpMultiplyAddComplexFastF32>::value) {
warp_mma(
tmp_accum0,
warp_transformed_frag_A[warp_mma_k % 2],
warp_transformed_frag_B0[warp_mma_k % 2],
tmp_accum0
);
warp_mma(
tmp_accum1,
warp_transformed_frag_A[warp_mma_k % 2],
warp_transformed_frag_B1[warp_mma_k % 2],
tmp_accum1
);
if (warp_mma_k == 0) {
accum0 = plus_accum(accum0, tmp_accum0);
accum1 = plus_accum(accum1, tmp_accum1);
tmp_accum0.clear();
tmp_accum1.clear();
}
} else {
warp_mma(
accum0,
warp_transformed_frag_A[warp_mma_k % 2],
warp_transformed_frag_B0[warp_mma_k % 2],
accum0
);
warp_mma(
accum1,
warp_transformed_frag_A[warp_mma_k % 2],
warp_transformed_frag_B1[warp_mma_k % 2],
accum1
);
}
// Issue global->shared copies for the this stage
if (warp_mma_k < Base::kWarpGemmIterations - 1) {
int group_start_iteration_A, group_start_iteration_B;
group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
copy_tiles_and_advance(iterator_A, iterator_B0, iterator_B1, group_start_iteration_A,
group_start_iteration_B);
}
if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
int group_start_iteration_A, group_start_iteration_B;
group_start_iteration_A =
(warp_mma_k + 1) * Detail::kAccessesPerGroupA;
group_start_iteration_B =
(warp_mma_k + 1) * Detail::kAccessesPerGroupB;
copy_tiles_and_advance(iterator_A, iterator_B0, iterator_B1, group_start_iteration_A,
group_start_iteration_B);
// Inserts a memory fence between stages of cp.async instructions.
cutlass::arch::cp_async_fence();
// Waits until stages up to the previous (kStages-2)th stage have committed.
arch::cp_async_wait<Base::kStages - 2>();
__syncthreads();
// Move to the next stage
iterator_A.add_tile_offset({0, 1});
iterator_B0.add_tile_offset({1, 0});
iterator_B1.add_tile_offset({1, 0});
this->smem_iterator_A_.add_tile_offset({0, 1});
this->smem_iterator_B0_.add_tile_offset({1, 0});
this->smem_iterator_B1_.add_tile_offset({1, 0});
// Add negative offsets to return iterators to the 'start' of the
// circular buffer in shared memory
if (smem_write_stage_idx == (Base::kStages - 1)) {
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0});
this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0});
smem_write_stage_idx = 0;
} else {
++smem_write_stage_idx;
}
if (smem_read_stage_idx == (Base::kStages - 1)) {
this->warp_tile_iterator_A_.add_tile_offset(
{0, -Base::kStages * Policy::kPartitionsK *
Base::kWarpGemmIterations});
this->warp_tile_iterator_B0_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK *
Base::kWarpGemmIterations,
0});
this->warp_tile_iterator_B1_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK *
Base::kWarpGemmIterations,
0});
smem_read_stage_idx = 0;
} else {
++smem_read_stage_idx;
}
--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B0.clear_mask(gemm_k_iterations == 0);
iterator_B1.clear_mask(gemm_k_iterations == 0);
}
// Do any conversions feeding the first stage at the end of the loop so
// we can start right away on mma instructions
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2],
warp_transformed_frag_B0[(warp_mma_k + 1) % 2],
warp_loaded_frag_A[(warp_mma_k + 1) % 2],
warp_loaded_frag_B0[(warp_mma_k + 1) % 2]);
warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2],
warp_transformed_frag_B1[(warp_mma_k + 1) % 2],
warp_loaded_frag_A[(warp_mma_k + 1) % 2],
warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
}
}
}
if (platform::is_same<typename Operator::MathOperator,
arch::OpMultiplyAddFastF32>::value
|| platform::is_same<typename Operator::MathOperator,
arch::OpMultiplyAddComplexFastF32>::value) {
accum0 = plus_accum(accum0, tmp_accum0);
accum1 = plus_accum(accum1, tmp_accum1);
}
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
// commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -121,6 +121,7 @@ foreach(EXAMPLE
39_gemm_permute
41_multi_head_attention
42_fused_multi_head_attention
43_dual_gemm
)
add_subdirectory(${EXAMPLE})

View File

@ -34,6 +34,7 @@
#pragma once
#include "cutlass/tensor_ref.h"
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"