Example 43 - DualGemm (#670)
* Ex50 wip * IS_PROFILING mode * MultiStage2 - but is slower * Add SwiGLU * Support SplitKSerial reduction Support not storing D0/D1 Cleanup code * Option to disable bias * Renumber example * Fix build * Remove references to pb_size_0 / pb_size_1 * Add support for bf16 inputs with float accum * small changes Co-authored-by: danthe3rd <danthe3rd> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
36
examples/43_dual_gemm/CMakeLists.txt
Normal file
36
examples/43_dual_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,36 @@
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
43_dual_gemm
|
||||
dual_gemm.cu
|
||||
)
|
||||
|
||||
457
examples/43_dual_gemm/device/dual_gemm.h
Normal file
457
examples/43_dual_gemm/device/dual_gemm.h
Normal file
@ -0,0 +1,457 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Performs a dual gemm in one fused kernel:
|
||||
```
|
||||
D0 = epilogue0(X @ B0, C0)
|
||||
D1 = epilogue1(X @ B1, C1)
|
||||
D2 = element_wise(D0, D1)
|
||||
```
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_relu.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
||||
|
||||
#include "../kernel/dual_gemm.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace device {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator_,
|
||||
/// Operator class tag
|
||||
typename OperatorClass_,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag_,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape_,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp0_,
|
||||
typename EpilogueOutputOp1_,
|
||||
typename EpilogueOutputOp2_,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kStages,
|
||||
bool StoreD0 = true,
|
||||
bool StoreD1 = true,
|
||||
/// If true, kernel supports split-K with serial reduction
|
||||
bool SplitKSerial = false,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kAlignmentA,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kAlignmentB,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::Operator>
|
||||
class DualGemm {
|
||||
public:
|
||||
|
||||
using ElementA = ElementA_;
|
||||
using LayoutA = LayoutA_;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
using ElementB = ElementB_;
|
||||
using LayoutB = LayoutB_;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
using ElementC = ElementC_;
|
||||
using LayoutC = LayoutC_;
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using OperatorClass = OperatorClass_;
|
||||
using ArchTag = ArchTag_;
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = InstructionShape_;
|
||||
using EpilogueOutputOp0 = EpilogueOutputOp0_;
|
||||
using EpilogueOutputOp1 = EpilogueOutputOp1_;
|
||||
using EpilogueOutputOp2 = EpilogueOutputOp2_;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
using Operator = Operator_;
|
||||
static int const kStages = Stages;
|
||||
static int const kAlignmentA = AlignmentA;
|
||||
static int const kAlignmentB = AlignmentB;
|
||||
static int const kAlignmentC = EpilogueOutputOp1::kCount;
|
||||
static bool const kSplitKSerial = SplitKSerial;
|
||||
static bool constexpr kStoreD0 = StoreD0;
|
||||
static bool constexpr kStoreD1 = StoreD1;
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
|
||||
using LayoutScaleBias = layout::RowMajor;
|
||||
/// Define the kernel
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
static_assert(ArchTag::kMinComputeCapability >= 80, "Only multistage is implemented");
|
||||
static_assert(kStages >= 3, "Only multistage is implemented");
|
||||
using Mma = typename cutlass::gemm::threadblock::DefaultMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag,
|
||||
ThreadblockShape, WarpShape,
|
||||
InstructionShape, Stages, Operator>::ThreadblockMma;
|
||||
using DualMma = threadblock::DualMmaMultistage<
|
||||
typename Mma::Shape,
|
||||
typename Mma::IteratorA,
|
||||
typename Mma::SmemIteratorA,
|
||||
Mma::kCacheOpA,
|
||||
typename Mma::IteratorB,
|
||||
typename Mma::SmemIteratorB,
|
||||
Mma::kCacheOpB,
|
||||
typename Mma::ElementC,
|
||||
typename Mma::LayoutC,
|
||||
typename Mma::Policy,
|
||||
Mma::kStages,
|
||||
SharedMemoryClearOption::kNone
|
||||
>;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue0 =
|
||||
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape, typename DualMma::Operator, kPartitionsK, EpilogueOutputOp0,
|
||||
EpilogueOutputOp0::kCount>::Epilogue;
|
||||
using Epilogue1 =
|
||||
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape, typename DualMma::Operator, kPartitionsK, EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using DualGemmKernel = kernel::DualGemm<
|
||||
DualMma,
|
||||
Epilogue0, Epilogue1, EpilogueOutputOp2,
|
||||
ThreadblockSwizzle, kSplitKSerial,
|
||||
kStoreD0, kStoreD1>;
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmCoord problem_size;
|
||||
TensorRef<ElementA const, LayoutA> ref_A0;
|
||||
TensorRef<ElementB const, LayoutB> ref_B0;
|
||||
TensorRef<ElementC const, LayoutC> ref_C0;
|
||||
TensorRef<ElementC, LayoutC> ref_D0;
|
||||
TensorRef<ElementB const, LayoutB> ref_B1;
|
||||
TensorRef<ElementC const, LayoutC> ref_C1;
|
||||
TensorRef<ElementC, LayoutC> ref_D1;
|
||||
TensorRef<ElementC, LayoutC> ref_D2;
|
||||
typename EpilogueOutputOp0::Params epilogue0;
|
||||
typename EpilogueOutputOp1::Params epilogue1;
|
||||
typename EpilogueOutputOp2::Params epilogue2;
|
||||
int split_k_slices;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): problem_size(0, 0, 0), split_k_slices(1) {
|
||||
|
||||
}
|
||||
|
||||
/// Constructs an Arguments structure
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
GemmCoord problem_size_,
|
||||
TensorRef<ElementA const, LayoutA> ref_A0_,
|
||||
TensorRef<ElementB const, LayoutB> ref_B0_,
|
||||
TensorRef<ElementC const, LayoutC> ref_C0_,
|
||||
TensorRef<ElementC, LayoutC> ref_D0_,
|
||||
TensorRef<ElementB const, LayoutB> ref_B1_,
|
||||
TensorRef<ElementC const, LayoutC> ref_C1_,
|
||||
TensorRef<ElementC, LayoutC> ref_D1_,
|
||||
TensorRef<ElementC, LayoutC> ref_D2_,
|
||||
typename EpilogueOutputOp0::Params epilogue0_ =
|
||||
typename EpilogueOutputOp0::Params(),
|
||||
typename EpilogueOutputOp1::Params epilogue1_ =
|
||||
typename EpilogueOutputOp1::Params(),
|
||||
typename EpilogueOutputOp2::Params epilogue2_ =
|
||||
typename EpilogueOutputOp2::Params(),
|
||||
int split_k_slices_ = 1
|
||||
):
|
||||
problem_size(problem_size_),
|
||||
ref_A0(ref_A0_),
|
||||
ref_B0(ref_B0_),
|
||||
ref_C0(ref_C0_),
|
||||
ref_D0(ref_D0_),
|
||||
ref_B1(ref_B1_),
|
||||
ref_C1(ref_C1_),
|
||||
ref_D1(ref_D1_),
|
||||
ref_D2(ref_D2_),
|
||||
epilogue0(epilogue0_),
|
||||
epilogue1(epilogue1_),
|
||||
epilogue2(epilogue2_),
|
||||
split_k_slices(split_k_slices_) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
/// Kernel parameters object
|
||||
typename DualGemmKernel::Params params_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the GEMM.
|
||||
DualGemm() = default;
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const &args) {
|
||||
|
||||
if (!kSplitKSerial && args.split_k_slices > 1) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
if (kStoreD0 != (args.ref_D0.data() != nullptr)) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
if (kStoreD1 != (args.ref_D1.data() != nullptr)) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
Status status = DualGemmKernel::can_implement(
|
||||
args.problem_size,
|
||||
args.ref_A0.non_const_ref(),
|
||||
args.ref_B0.non_const_ref(),
|
||||
args.ref_C0.non_const_ref(),
|
||||
args.ref_D0,
|
||||
args.ref_B1.non_const_ref(),
|
||||
args.ref_C1.non_const_ref(),
|
||||
args.ref_D1,
|
||||
args.ref_D2
|
||||
);
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
|
||||
size_t bytes = 0;
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
args.split_k_slices);
|
||||
|
||||
if (kSplitKSerial && args.split_k_slices > 1) {
|
||||
|
||||
|
||||
bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
|
||||
}
|
||||
|
||||
return bytes;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
args.split_k_slices);
|
||||
|
||||
if (kSplitKSerial) {
|
||||
if (args.split_k_slices > 1) {
|
||||
if (!workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
size_t bytes = get_workspace_size(args);
|
||||
|
||||
cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
if (args.split_k_slices > 1) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize the Params structure
|
||||
params_ = typename DualGemmKernel::Params{
|
||||
args.problem_size,
|
||||
grid_shape,
|
||||
args.ref_A0.non_const_ref(),
|
||||
args.ref_B0.non_const_ref(),
|
||||
args.ref_C0.non_const_ref(),
|
||||
args.ref_D0,
|
||||
args.ref_B1.non_const_ref(),
|
||||
args.ref_C1.non_const_ref(),
|
||||
args.ref_D1,
|
||||
args.ref_D2,
|
||||
args.epilogue0,
|
||||
args.epilogue1,
|
||||
args.epilogue2,
|
||||
reinterpret_cast<int *>(workspace),
|
||||
};
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const &args, void *workspace = nullptr) {
|
||||
|
||||
if (kSplitKSerial && args.split_k_slices > 1) {
|
||||
if (!workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
}
|
||||
|
||||
params_.ref_A0.reset(args.ref_A0.non_const_ref().data());
|
||||
params_.ref_B0.reset(args.ref_B0.non_const_ref().data());
|
||||
params_.ref_C0.reset(args.ref_C0.non_const_ref().data());
|
||||
params_.ref_D0.reset(args.ref_D0.data());
|
||||
params_.ref_B1.reset(args.ref_B1.non_const_ref().data());
|
||||
params_.ref_C1.reset(args.ref_C1.non_const_ref().data());
|
||||
params_.ref_D1.reset(args.ref_D1.data());
|
||||
params_.ref_D2.reset(args.ref_D2.data());
|
||||
params_.output_op_0 = args.epilogue0;
|
||||
params_.output_op_1 = args.epilogue1;
|
||||
params_.output_op_2 = args.epilogue2;
|
||||
params_.semaphore = reinterpret_cast<int *>(workspace);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
||||
dim3 block(DualGemmKernel::kThreadCount, 1, 1);
|
||||
|
||||
cudaError_t result;
|
||||
|
||||
int smem_size = int(sizeof(typename DualGemmKernel::SharedStorage));
|
||||
if (smem_size >= (48 << 10)) {
|
||||
result = cudaFuncSetAttribute(Kernel<DualGemmKernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
cutlass::Kernel<DualGemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
result = cudaGetLastError();
|
||||
|
||||
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr) {
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
262
examples/43_dual_gemm/dual_gemm.cu
Normal file
262
examples/43_dual_gemm/dual_gemm.cu
Normal file
@ -0,0 +1,262 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief CUTLASS Dual-GEMM Example.
|
||||
|
||||
Fused kernel that outputs `D0` and `D1`.
|
||||
We assume that B0/B1 have the same shape/layout
|
||||
|
||||
```
|
||||
D0 = epilogue0(X @ B0, C0)
|
||||
D1 = epilogue1(X @ B1, C1)
|
||||
D2 = element_wise(D0, D1)
|
||||
```
|
||||
D0 and D1 will be optionally stored in gmem (`kStoreD0` / `kStoreD1`)
|
||||
*/
|
||||
|
||||
// #define IS_PROFILING
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
|
||||
#include "device/dual_gemm.h"
|
||||
#include "thread/left_silu_and_mul.h"
|
||||
#include "dual_gemm_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size(4096, 4096, 8192);
|
||||
|
||||
constexpr int kStages = 3;
|
||||
constexpr bool kSplitKSerial = false;
|
||||
constexpr bool kUseBias = true;
|
||||
|
||||
|
||||
#if 0
|
||||
using ElementOperandA = cutlass::bfloat16_t;
|
||||
using ElementOperandB = cutlass::bfloat16_t;
|
||||
using ElementOutput = cutlass::bfloat16_t;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
#else
|
||||
using ElementOperandA = cutlass::half_t;
|
||||
using ElementOperandB = cutlass::half_t;
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
#endif
|
||||
|
||||
constexpr auto kScaleType = kUseBias ? cutlass::epilogue::thread::ScaleType::NoBetaScaling : (
|
||||
// No bias
|
||||
kSplitKSerial ? cutlass::epilogue::thread::ScaleType::Default : cutlass::epilogue::thread::ScaleType::Nothing
|
||||
);
|
||||
using EpilogueOutputOp0 = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
kScaleType
|
||||
>;
|
||||
using EpilogueOutputOp1 = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
kScaleType
|
||||
>;
|
||||
using EpilogueOutputOp2 = cutlass::epilogue::thread::LeftSiLUAndMul<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementOutput,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
const ElementCompute alpha0 = ElementCompute(1);
|
||||
const ElementCompute beta0 = ElementCompute(kUseBias ? 1 : 0);
|
||||
const ElementCompute alpha1 = ElementCompute(1);
|
||||
const ElementCompute beta1 = ElementCompute(kUseBias ? 1 : 0);
|
||||
|
||||
bool run_nonfused_gemm_f16_sm80() {
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using Gemm0 = cutlass::gemm::device::Gemm<
|
||||
ElementOperandA,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementOperandB,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
kStages,
|
||||
8,
|
||||
8,
|
||||
kSplitKSerial
|
||||
>;
|
||||
using Gemm1 = cutlass::gemm::device::Gemm<
|
||||
ElementOperandA,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementOperandB,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
kStages,
|
||||
8,
|
||||
8,
|
||||
kSplitKSerial
|
||||
>;
|
||||
|
||||
NonFusedDualGemmRun<Gemm0, Gemm1> nonFusedGemm;
|
||||
|
||||
std::cout << "Running Non-fused GEMMs FP16 TN GEMMs...\n";
|
||||
bool pass = nonFusedGemm.run(problem_size, alpha0, beta0, alpha1, beta1);
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct LeftSiLUAndMul {
|
||||
struct Params{};
|
||||
CUTLASS_HOST_DEVICE LeftSiLUAndMul(Params p) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE void set_k_partition(int, int) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE T operator() (
|
||||
T const &lhs,
|
||||
T const &rhs) const {
|
||||
cutlass::epilogue::thread::SiLu<T> silu;
|
||||
cutlass::multiplies<T> mul;
|
||||
auto silu_lhs = silu(lhs);
|
||||
return mul(silu_lhs, rhs);
|
||||
}
|
||||
|
||||
template <int kCount>
|
||||
CUTLASS_HOST_DEVICE cutlass::Array<T, kCount> operator() (
|
||||
cutlass::Array<T, kCount> const &lhs,
|
||||
cutlass::Array<T, kCount> const &rhs) const {
|
||||
cutlass::epilogue::thread::SiLu<T> silu;
|
||||
cutlass::multiplies<T> mul;
|
||||
auto silu_lhs = silu(lhs);
|
||||
return mul(silu_lhs, rhs);
|
||||
}
|
||||
};
|
||||
|
||||
bool run_fused_gemm_f16_sm80_shmem() {
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
// Optionally, we might not need intermediate GEMM outputs
|
||||
constexpr bool kStoreD0 = true;
|
||||
constexpr bool kStoreD1 = true;
|
||||
|
||||
using DualGemm = cutlass::gemm::device::DualGemm<
|
||||
ElementOperandA,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementOperandB,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp2,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
kStages,
|
||||
kStoreD0,
|
||||
kStoreD1,
|
||||
kSplitKSerial
|
||||
>;
|
||||
|
||||
DualFusedGemmRun<DualGemm> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused FP16 TN GEMMs + Epilogue2...\n";
|
||||
bool passed = fusedGemm.run(problem_size, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return passed;
|
||||
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_f16_sm80,
|
||||
&run_fused_gemm_f16_sm80_shmem
|
||||
};
|
||||
|
||||
std::string test_name = "dual-gemm f16 bias=" + std::to_string(kUseBias) + " split_k_serial=" + std::to_string(kSplitKSerial);
|
||||
return testRun(80, funcs, test_name);
|
||||
}
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
829
examples/43_dual_gemm/dual_gemm_run.h
Normal file
829
examples/43_dual_gemm/dual_gemm_run.h
Normal file
@ -0,0 +1,829 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_relu.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
#define CHECK_GT(val1, val2) \
|
||||
if((val1) <= (val2)) \
|
||||
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n";
|
||||
#define CHECK_TRUE(val) \
|
||||
if(!(val)) \
|
||||
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n";
|
||||
|
||||
template <
|
||||
typename OutputOp,
|
||||
typename Element,
|
||||
typename Layout>
|
||||
struct TensorEpilogueForEachFunc {
|
||||
/// View type
|
||||
using TensorView = cutlass::TensorView<Element, Layout>;
|
||||
|
||||
/// Coordinate in tensor's index space
|
||||
using TensorCoord = typename TensorView::TensorCoord;
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
TensorView view_x0;
|
||||
TensorView view_x1;
|
||||
TensorView view_y;
|
||||
OutputOp output_op;
|
||||
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Params(
|
||||
TensorView view_x0_ = TensorView(),
|
||||
TensorView view_x1_ = TensorView(),
|
||||
TensorView view_y_ = TensorView(),
|
||||
OutputOp output_op_ = OutputOp(typename OutputOp::Params{})
|
||||
):
|
||||
view_x0(view_x0_), view_x1(view_x1_), view_y(view_y_), output_op(output_op_) {
|
||||
}
|
||||
};
|
||||
|
||||
Params params;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
TensorEpilogueForEachFunc(Params const ¶ms): params(params) {
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void operator()(TensorCoord const &coord) {
|
||||
Element const & x0 = params.view_x0.at(coord);
|
||||
Element const & x1 = params.view_x1.at(coord);
|
||||
Element& y = params.view_y.at(coord);
|
||||
y = params.output_op(x0, x1);
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename OutputOp,
|
||||
typename Element,
|
||||
typename Layout>
|
||||
void TensorEpilogueForEach(
|
||||
cutlass::TensorView<Element, Layout> x0,
|
||||
cutlass::TensorView<Element, Layout> x1,
|
||||
cutlass::TensorView<Element, Layout> y) {
|
||||
|
||||
using Func = TensorEpilogueForEachFunc<OutputOp, Element, Layout>;
|
||||
using Params = typename Func::Params;
|
||||
|
||||
cutlass::reference::device::TensorForEach<Func, Layout::kRank, Params>(
|
||||
y.extent(),
|
||||
Params(x0, x1, y)
|
||||
);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Gemm0_, typename Gemm1_>
|
||||
struct NonFusedDualGemmRun
|
||||
{
|
||||
|
||||
using Gemm0 = Gemm0_;
|
||||
using Gemm1 = Gemm1_;
|
||||
using ElementAccumulator = typename Gemm0::ElementAccumulator;
|
||||
using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute;
|
||||
|
||||
/// Initialization
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
cutlass::Distribution::Kind init_C;
|
||||
cutlass::Distribution::Kind init_Bias;
|
||||
uint64_t seed;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
NonFusedDualGemmRun(
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { }
|
||||
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, 2, -2, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
|
||||
cutlass::reference::host::BlockFillSequential(
|
||||
view.data(), view.capacity());
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view, Element(0));
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
||||
cutlass::reference::host::TensorFill(view, Element(1));
|
||||
}
|
||||
else {
|
||||
std::cerr << "Not implemented\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
/// Executes one test
|
||||
bool run(
|
||||
cutlass::gemm::GemmCoord problem_size,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
bool relu = false,
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
//
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementA,
|
||||
typename Gemm0::LayoutA> tensor_A0(problem_size.mk());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementB,
|
||||
typename Gemm0::LayoutB> tensor_B0(problem_size.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_C0(problem_size.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_Bias0({1, problem_size.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_D0(problem_size.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> reference_D0(problem_size.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementB,
|
||||
typename Gemm1::LayoutB> tensor_B1(problem_size.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_C1(problem_size.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_Bias1({1, problem_size.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_D1(problem_size.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> reference_D1(problem_size.mn());
|
||||
|
||||
|
||||
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
|
||||
CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
|
||||
CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013));
|
||||
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_D0.host_view());
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_D1.host_view());
|
||||
cutlass::reference::host::TensorFill(
|
||||
reference_D0.host_view());
|
||||
cutlass::reference::host::TensorFill(
|
||||
reference_D1.host_view());
|
||||
|
||||
tensor_A0.sync_device();
|
||||
tensor_B0.sync_device();
|
||||
tensor_C0.sync_device();
|
||||
tensor_Bias0.sync_device();
|
||||
tensor_D0.sync_device();
|
||||
reference_D0.sync_device();
|
||||
tensor_B1.sync_device();
|
||||
tensor_C1.sync_device();
|
||||
tensor_Bias1.sync_device();
|
||||
tensor_D1.sync_device();
|
||||
reference_D1.sync_device();
|
||||
|
||||
//
|
||||
// Initialize the GEMM operator
|
||||
//
|
||||
|
||||
int split_k_slices = Gemm0::kSplitKSerial ? 2 : 1;
|
||||
typename Gemm0::Arguments arguments_0{
|
||||
problem_size,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
|
||||
tensor_D0.device_ref(),
|
||||
{alpha0, beta0},
|
||||
split_k_slices
|
||||
};
|
||||
|
||||
split_k_slices = Gemm1::kSplitKSerial ? 2 : 1;
|
||||
typename Gemm1::Arguments arguments_1{
|
||||
problem_size,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
|
||||
tensor_D1.device_ref(),
|
||||
{alpha1, beta1},
|
||||
split_k_slices
|
||||
};
|
||||
|
||||
|
||||
Gemm0 gemm_op_0;
|
||||
Gemm1 gemm_op_1;
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace0(gemm_op_0.get_workspace_size(arguments_0));
|
||||
cutlass::device_memory::allocation<uint8_t> workspace1(gemm_op_1.get_workspace_size(arguments_1));
|
||||
|
||||
cutlass::Status status = gemm_op_0.initialize(arguments_0, workspace0.get());
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
status = gemm_op_1.initialize(arguments_1, workspace1.get());
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
for(int i = 0; i < warm_ups; i++) {
|
||||
status = gemm_op_0();
|
||||
CUTLASS_CHECK(status);
|
||||
status = gemm_op_1();
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
#ifdef IS_PROFILING
|
||||
return true;
|
||||
#endif
|
||||
//
|
||||
// Run the GEMM
|
||||
//
|
||||
cudaEvent_t start, stop1, stop2;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop1);
|
||||
cudaEventCreate(&stop2);
|
||||
|
||||
cudaEventRecord(start);
|
||||
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = gemm_op_0();
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
cudaEventRecord(stop1);
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = gemm_op_1();
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
cudaEventRecord(stop2);
|
||||
cudaDeviceSynchronize();
|
||||
float gemm0Time, gemm1Time, totalTime;
|
||||
cudaEventElapsedTime(&gemm0Time, start, stop1);
|
||||
cudaEventElapsedTime(&gemm1Time, stop1, stop2);
|
||||
cudaEventElapsedTime(&totalTime, start, stop2);
|
||||
std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n";
|
||||
std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n";
|
||||
std::cout << "Non-fusion GEMM only time " << totalTime / (float)runs << " ms\n";
|
||||
|
||||
tensor_D0.sync_host();
|
||||
tensor_D1.sync_host();
|
||||
|
||||
//
|
||||
// Verify
|
||||
//
|
||||
cutlass::reference::device::Gemm<
|
||||
typename Gemm0::ElementA, typename Gemm0::LayoutA,
|
||||
typename Gemm0::ElementB, typename Gemm0::LayoutB,
|
||||
typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute,
|
||||
ElementAccumulator, typename Gemm0::Operator>
|
||||
reference_gemm_0;
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename Gemm1::ElementA, typename Gemm1::LayoutA,
|
||||
typename Gemm1::ElementB, typename Gemm1::LayoutB,
|
||||
typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute,
|
||||
ElementAccumulator, typename Gemm1::Operator>
|
||||
reference_gemm_1;
|
||||
|
||||
reference_gemm_0(
|
||||
problem_size,
|
||||
alpha0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
beta0,
|
||||
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
|
||||
reference_D0.device_ref()
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
}
|
||||
|
||||
reference_gemm_1(
|
||||
problem_size,
|
||||
alpha1,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
beta1,
|
||||
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
|
||||
reference_D1.device_ref()
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
}
|
||||
|
||||
// Wait for kernels to finish
|
||||
cudaDeviceSynchronize();
|
||||
reference_D0.sync_host();
|
||||
reference_D1.sync_host();
|
||||
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
||||
|
||||
bool passed0 = cutlass::reference::host::TensorEquals(
|
||||
reference_D1.host_view(),
|
||||
tensor_D1.host_view());
|
||||
CHECK_TRUE(passed0);
|
||||
|
||||
bool passed1 = cutlass::reference::host::TensorEquals(
|
||||
reference_D1.host_view(),
|
||||
tensor_D1.host_view());
|
||||
CHECK_TRUE(passed1);
|
||||
if (!passed0 || !passed1) {
|
||||
|
||||
std::stringstream fname;
|
||||
|
||||
fname << "error_DualGemm_device_nonfused.txt";
|
||||
std::cerr << "Dumping results in " << fname.str() << "\n";
|
||||
|
||||
std::ofstream file(fname.str());
|
||||
|
||||
file
|
||||
<< "A0 =\n" << tensor_A0.host_view()
|
||||
<< "\nB0 =\n" << tensor_B0.host_view()
|
||||
<< "\nC0 =\n" << tensor_C0.host_view()
|
||||
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
|
||||
<< "\nD0 =\n" << tensor_D0.host_view()
|
||||
<< "\nB1 =\n" << tensor_B1.host_view()
|
||||
<< "\nC1 =\n" << tensor_C1.host_view()
|
||||
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
|
||||
<< "\n\nReference =\n" << reference_D1.host_view()
|
||||
<< "\nComputed =\n" << tensor_D1.host_view();
|
||||
}
|
||||
return passed0 && passed1;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DualGemm_>
|
||||
struct DualFusedGemmRun
|
||||
{
|
||||
|
||||
using DualGemm = DualGemm_;
|
||||
using ElementAccumulator = typename DualGemm::ElementAccumulator;
|
||||
using ElementCompute = typename DualGemm::DualGemmKernel::Epilogue0::OutputOp::ElementCompute;
|
||||
using EpilogueOutputOp2 = typename DualGemm::EpilogueOutputOp2;
|
||||
|
||||
/// Initialization
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
cutlass::Distribution::Kind init_C;
|
||||
cutlass::Distribution::Kind init_Scale;
|
||||
cutlass::Distribution::Kind init_Bias;
|
||||
uint64_t seed;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
DualFusedGemmRun(
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_),
|
||||
init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { }
|
||||
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, 2, -2, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
|
||||
cutlass::reference::host::BlockFillSequential(
|
||||
view.data(), view.capacity());
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view, Element(0));
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
||||
cutlass::reference::host::TensorFill(view, Element(1));
|
||||
}
|
||||
else {
|
||||
std::cerr << "Not implemented\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
/// Executes one test
|
||||
bool run(
|
||||
cutlass::gemm::GemmCoord problem_size,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
ElementCompute beta0 = ElementCompute(1),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(1),
|
||||
bool relu = false,
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
//
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementA,
|
||||
typename DualGemm::LayoutA> tensor_A0(problem_size.mk());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementB,
|
||||
typename DualGemm::LayoutB> tensor_B0(problem_size.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementC,
|
||||
typename DualGemm::LayoutC> tensor_C0(problem_size.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementC,
|
||||
typename DualGemm::LayoutScaleBias> tensor_Bias0({1, problem_size.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementC,
|
||||
typename DualGemm::LayoutC> tensor_D0(problem_size.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementC,
|
||||
typename DualGemm::LayoutC> reference_D0(problem_size.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementB,
|
||||
typename DualGemm::LayoutB> tensor_B1(problem_size.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementC,
|
||||
typename DualGemm::LayoutC> tensor_C1(problem_size.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementC,
|
||||
typename DualGemm::LayoutScaleBias> tensor_Bias1({1, problem_size.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementC,
|
||||
typename DualGemm::LayoutC> tensor_D1(problem_size.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementC,
|
||||
typename DualGemm::LayoutC> tensor_D2(problem_size.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementC,
|
||||
typename DualGemm::LayoutC> reference_D1(problem_size.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementC,
|
||||
typename DualGemm::LayoutC> reference_D2(problem_size.mn());
|
||||
|
||||
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2118));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
|
||||
CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2011));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2113));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
|
||||
CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012));
|
||||
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_D0.host_view());
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_D1.host_view());
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_D2.host_view());
|
||||
cutlass::reference::host::TensorFill(
|
||||
reference_D0.host_view());
|
||||
cutlass::reference::host::TensorFill(
|
||||
reference_D1.host_view());
|
||||
cutlass::reference::host::TensorFill(
|
||||
reference_D2.host_view());
|
||||
|
||||
tensor_A0.sync_device();
|
||||
tensor_B0.sync_device();
|
||||
tensor_C0.sync_device();
|
||||
tensor_Bias0.sync_device();
|
||||
tensor_B1.sync_device();
|
||||
tensor_C1.sync_device();
|
||||
tensor_Bias1.sync_device();
|
||||
tensor_D0.sync_device();
|
||||
tensor_D1.sync_device();
|
||||
tensor_D2.sync_device();
|
||||
reference_D0.sync_device();
|
||||
reference_D1.sync_device();
|
||||
reference_D2.sync_device();
|
||||
|
||||
//
|
||||
// Initialize the GEMM operator
|
||||
//
|
||||
|
||||
int split_k_slices = DualGemm::kSplitKSerial ? 2 : 1;
|
||||
typename cutlass::TensorRef<typename DualGemm::ElementC, typename DualGemm::LayoutC> nullptr_ref{};
|
||||
decltype(nullptr_ref) ref_B0, ref_B1;
|
||||
if (beta0 != ElementCompute(0)) {
|
||||
ref_B0 = {tensor_Bias0.device_data(), typename DualGemm::LayoutC::Stride(0)};
|
||||
}
|
||||
if (beta1 != ElementCompute(0)) {
|
||||
ref_B1 = {tensor_Bias1.device_data(), typename DualGemm::LayoutC::Stride(0)};
|
||||
}
|
||||
typename DualGemm::Arguments arguments{
|
||||
problem_size,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
ref_B0,
|
||||
DualGemm::kStoreD0 ? tensor_D0.device_ref() : nullptr_ref,
|
||||
tensor_B1.device_ref(),
|
||||
ref_B1,
|
||||
DualGemm::kStoreD1 ? tensor_D1.device_ref() : nullptr_ref,
|
||||
tensor_D2.device_ref(),
|
||||
{alpha0, beta0},
|
||||
{alpha1, beta1},
|
||||
{},
|
||||
split_k_slices
|
||||
};
|
||||
|
||||
DualGemm b2b_gemm_op;
|
||||
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(b2b_gemm_op.get_workspace_size(arguments));
|
||||
|
||||
cutlass::Status status = b2b_gemm_op.can_implement(arguments);
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
status = b2b_gemm_op.initialize(arguments, workspace.get());
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
for(int i = 0; i < warm_ups; i++) {
|
||||
status = b2b_gemm_op();
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
#ifdef IS_PROFILING
|
||||
return true;
|
||||
#endif
|
||||
//
|
||||
// Run the GEMM
|
||||
//
|
||||
|
||||
cudaEvent_t start, stop;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop);
|
||||
|
||||
cudaEventRecord(start);
|
||||
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = b2b_gemm_op();
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
cudaEventRecord(stop);
|
||||
cudaDeviceSynchronize();
|
||||
float gemmTime;
|
||||
cudaEventElapsedTime(&gemmTime, start, stop);
|
||||
std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n";
|
||||
|
||||
tensor_D0.sync_host();
|
||||
tensor_D1.sync_host();
|
||||
tensor_D2.sync_host();
|
||||
|
||||
//
|
||||
// Verify
|
||||
//
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename DualGemm::ElementA, typename DualGemm::LayoutA,
|
||||
typename DualGemm::ElementB, typename DualGemm::LayoutB,
|
||||
typename DualGemm::ElementC, typename DualGemm::LayoutC,
|
||||
ElementAccumulator, ElementAccumulator>
|
||||
reference_gemm_0;
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename DualGemm::ElementA, typename DualGemm::LayoutA,
|
||||
typename DualGemm::ElementB, typename DualGemm::LayoutB,
|
||||
typename DualGemm::ElementC, typename DualGemm::LayoutC, ElementCompute,
|
||||
ElementAccumulator, typename DualGemm::Operator>
|
||||
reference_gemm_1;
|
||||
|
||||
reference_gemm_0(
|
||||
problem_size,
|
||||
alpha0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
beta0,
|
||||
{tensor_Bias0.device_data(), typename DualGemm::LayoutC::Stride(0)},
|
||||
reference_D0.device_ref()
|
||||
);
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
}
|
||||
|
||||
reference_gemm_1(
|
||||
problem_size,
|
||||
alpha1,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
beta1,
|
||||
{tensor_Bias1.device_data(), typename DualGemm::LayoutC::Stride(0)},
|
||||
reference_D1.device_ref()
|
||||
);
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
}
|
||||
TensorEpilogueForEach<EpilogueOutputOp2>(reference_D0.device_view(), reference_D1.device_view(), reference_D2.device_view());
|
||||
cudaDeviceSynchronize();
|
||||
reference_D0.sync_host();
|
||||
reference_D1.sync_host();
|
||||
reference_D2.sync_host();
|
||||
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D2.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D2.host_view()), 0);
|
||||
|
||||
bool passed_out0 = true;
|
||||
if (DualGemm::kStoreD0) {
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0);
|
||||
passed_out0 = cutlass::reference::host::TensorEquals(
|
||||
reference_D0.host_view(),
|
||||
tensor_D0.host_view());
|
||||
}
|
||||
CHECK_TRUE(passed_out0);
|
||||
|
||||
bool passed_out1 = true;
|
||||
if (DualGemm::kStoreD1) {
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
|
||||
passed_out1 = cutlass::reference::host::TensorEquals(
|
||||
reference_D1.host_view(),
|
||||
tensor_D1.host_view());
|
||||
}
|
||||
CHECK_TRUE(passed_out1);
|
||||
|
||||
bool passed_out2 = cutlass::reference::host::TensorEquals(
|
||||
reference_D2.host_view(),
|
||||
tensor_D2.host_view());
|
||||
CHECK_TRUE(passed_out2);
|
||||
|
||||
bool passed = passed_out0 && passed_out1 && passed_out2;
|
||||
if (!passed)
|
||||
{
|
||||
|
||||
std::stringstream fname;
|
||||
|
||||
fname << "error_DualGemm_device_fused.txt";
|
||||
std::cerr << "Dumping results in " << fname.str() << "\n";
|
||||
|
||||
std::ofstream file(fname.str());
|
||||
|
||||
file
|
||||
<< "A0 =\n" << tensor_A0.host_view()
|
||||
<< "\nB0 =\n" << tensor_B0.host_view()
|
||||
<< "\nC0 =\n" << tensor_C0.host_view()
|
||||
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
|
||||
<< "\nB1 =\n" << tensor_B1.host_view()
|
||||
<< "\nC1 =\n" << tensor_C1.host_view()
|
||||
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
|
||||
<< "\n\nReference0 =\n" << reference_D0.host_view()
|
||||
<< "\nComputed0 =\n" << tensor_D0.host_view()
|
||||
<< "\n\nReference1 =\n" << reference_D1.host_view()
|
||||
<< "\nComputed1 =\n" << tensor_D1.host_view()
|
||||
<< "\n\nReference2 =\n" << reference_D2.host_view()
|
||||
<< "\nComputed2 =\n" << tensor_D2.host_view();
|
||||
}
|
||||
//std::cout << "A0 " << tensor_A0.host_view() << std::endl;
|
||||
// std::cout << "reference_D0 " << reference_D0.host_view() << std::endl;
|
||||
// std::cout << "reference_D1 " << reference_D1.host_view() << std::endl;
|
||||
// std::cout << "reference_D2 " << reference_D2.host_view() << std::endl;
|
||||
//std::cout << "reference_D0 " << reference_D0.host_view() << std::endl;
|
||||
return passed;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
489
examples/43_dual_gemm/kernel/dual_gemm.h
Normal file
489
examples/43_dual_gemm/kernel/dual_gemm.h
Normal file
@ -0,0 +1,489 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
#include "../threadblock/dual_mma_multistage.h"
|
||||
#include "../threadblock/dual_epilogue.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename DualMma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue0_, ///! Epilogue
|
||||
typename Epilogue1_, ///! Epilogue
|
||||
typename OutputOp2_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
bool SplitKSerial, ///! If true, code supporting split-K via serial reduction is enabled.
|
||||
bool StoreD0,
|
||||
bool StoreD1
|
||||
>
|
||||
struct DualGemm {
|
||||
|
||||
using DualMma = DualMma_;
|
||||
|
||||
using Epilogue0 = Epilogue0_;
|
||||
using Epilogue1 = Epilogue1_;
|
||||
using OutputOp0 = typename Epilogue0::OutputOp;
|
||||
using OutputOp1 = typename Epilogue1::OutputOp;
|
||||
using OutputOp2 = OutputOp2_;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static constexpr bool kStoreD0 = StoreD0;
|
||||
static constexpr bool kStoreD1 = StoreD1;
|
||||
|
||||
using DualEpilogue = cutlass::epilogue::threadblock::DualEpilogue<
|
||||
typename Epilogue0::Shape,
|
||||
typename Epilogue0::WarpMmaOperator,
|
||||
Epilogue0::kPartitionsK,
|
||||
typename Epilogue0::OutputTileIterator,
|
||||
typename Epilogue0::AccumulatorFragmentIterator,
|
||||
typename Epilogue0::WarpTileIterator,
|
||||
typename Epilogue0::SharedLoadIterator,
|
||||
OutputOp0,
|
||||
OutputOp1,
|
||||
OutputOp2,
|
||||
typename Epilogue0::Padding,
|
||||
kStoreD0,
|
||||
kStoreD1,
|
||||
Epilogue0::kFragmentsPerIteration,
|
||||
true // IterationsUnroll
|
||||
>;
|
||||
|
||||
static bool const kSplitKSerial = SplitKSerial;
|
||||
static_assert(!kSplitKSerial || (kStoreD0 && kStoreD1),
|
||||
"Split-K serial requires buffers for D0/D1 for reduction");
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount0 = typename DualMma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount0::kCount;
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
|
||||
// Mma0
|
||||
typename DualMma::IteratorA::Params params_A0;
|
||||
typename DualMma::IteratorA::TensorRef ref_A0;
|
||||
typename DualMma::IteratorB::Params params_B0;
|
||||
typename DualMma::IteratorB::TensorRef ref_B0;
|
||||
typename Epilogue0::OutputTileIterator::Params params_C0;
|
||||
typename Epilogue0::OutputTileIterator::TensorRef ref_C0;
|
||||
typename Epilogue0::OutputTileIterator::Params params_D0;
|
||||
typename Epilogue0::OutputTileIterator::TensorRef ref_D0;
|
||||
typename OutputOp0::Params output_op_0;
|
||||
|
||||
// Mma1
|
||||
typename DualMma::IteratorB::Params params_B1;
|
||||
typename DualMma::IteratorB::TensorRef ref_B1;
|
||||
typename Epilogue1::OutputTileIterator::Params params_C1;
|
||||
typename Epilogue1::OutputTileIterator::TensorRef ref_C1;
|
||||
typename Epilogue1::OutputTileIterator::Params params_D1;
|
||||
typename Epilogue1::OutputTileIterator::TensorRef ref_D1;
|
||||
typename OutputOp1::Params output_op_1;
|
||||
|
||||
typename Epilogue1::OutputTileIterator::Params params_D2;
|
||||
typename Epilogue1::OutputTileIterator::TensorRef ref_D2;
|
||||
typename OutputOp2::Params output_op_2;
|
||||
|
||||
int *semaphore;
|
||||
int gemm_k_size;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): swizzle_log_tile(0), semaphore(0), gemm_k_size(0) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
cutlass::gemm::GemmCoord const & problem_size,
|
||||
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
||||
// Mma0: D0 = A @ B0 + C0
|
||||
typename DualMma::IteratorA::TensorRef ref_A0,
|
||||
typename DualMma::IteratorB::TensorRef ref_B0,
|
||||
typename Epilogue0::OutputTileIterator::TensorRef ref_C0,
|
||||
typename Epilogue0::OutputTileIterator::TensorRef ref_D0,
|
||||
// Mma1: D1 = A @ B1 + C1
|
||||
typename DualMma::IteratorB::TensorRef ref_B1,
|
||||
typename Epilogue1::OutputTileIterator::TensorRef ref_C1,
|
||||
typename Epilogue1::OutputTileIterator::TensorRef ref_D1,
|
||||
|
||||
typename Epilogue1::OutputTileIterator::TensorRef ref_D2,
|
||||
typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(),
|
||||
typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(),
|
||||
typename OutputOp2::Params output_op_2 = typename OutputOp2::Params(),
|
||||
int *workspace = nullptr
|
||||
):
|
||||
problem_size(problem_size),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
|
||||
// Mma0
|
||||
params_A0(ref_A0.layout()),
|
||||
ref_A0(ref_A0),
|
||||
params_B0(ref_B0.layout()),
|
||||
ref_B0(ref_B0),
|
||||
params_C0(ref_C0.layout()),
|
||||
ref_C0(ref_C0),
|
||||
params_D0(ref_D0.layout()),
|
||||
ref_D0(ref_D0),
|
||||
// Mma1
|
||||
params_B1(ref_B1.layout()),
|
||||
ref_B1(ref_B1),
|
||||
params_C1(ref_C1.layout()),
|
||||
ref_C1(ref_C1),
|
||||
params_D1(ref_D1.layout()),
|
||||
ref_D1(ref_D1),
|
||||
params_D2(ref_D2.layout()),
|
||||
ref_D2(ref_D2),
|
||||
output_op_0(output_op_0),
|
||||
output_op_1(output_op_1),
|
||||
output_op_2(output_op_2) {
|
||||
|
||||
int total_gemm_k_iterations = (problem_size.k() + DualMma::Shape::kK - 1) / DualMma::Shape::kK;
|
||||
int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
|
||||
gemm_k_size = gemm_k_iterations * DualMma::Shape::kK;
|
||||
|
||||
semaphore = workspace;
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage {
|
||||
typename DualMma::SharedStorage main_loop;
|
||||
typename DualEpilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
DualGemm() { }
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(
|
||||
cutlass::gemm::GemmCoord const & problem_size,
|
||||
typename DualMma::IteratorA::TensorRef ref_A0,
|
||||
typename DualMma::IteratorB::TensorRef ref_B0,
|
||||
typename Epilogue0::OutputTileIterator::TensorRef ref_C0,
|
||||
typename Epilogue0::OutputTileIterator::TensorRef ref_D0,
|
||||
typename DualMma::IteratorB::TensorRef ref_B1,
|
||||
typename Epilogue1::OutputTileIterator::TensorRef ref_C1,
|
||||
typename Epilogue1::OutputTileIterator::TensorRef ref_D1,
|
||||
typename Epilogue1::OutputTileIterator::TensorRef ref_D2) {
|
||||
|
||||
static int const kAlignmentA = DualMma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = DualMma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = Epilogue0::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
if (!TensorRef_aligned(ref_A0, kAlignmentA)) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(ref_B0, kAlignmentB)) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(ref_C0, kAlignmentC)) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(ref_D0, kAlignmentC)) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(ref_B1, kAlignmentB)) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(ref_C1, kAlignmentC)) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(ref_D1, kAlignmentC)) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(ref_D2, kAlignmentC)) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
||||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A0{
|
||||
threadblock_tile_offset.m() * DualMma::Shape::kM,
|
||||
threadblock_tile_offset.k() * params.gemm_k_size,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B0{
|
||||
threadblock_tile_offset.k() * params.gemm_k_size,
|
||||
threadblock_tile_offset.n() * DualMma::Shape::kN
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B1{
|
||||
threadblock_tile_offset.k() * params.gemm_k_size,
|
||||
threadblock_tile_offset.n() * DualMma::Shape::kN
|
||||
};
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
int problem_size_k =
|
||||
(params.problem_size.k() < (threadblock_tile_offset.k() + 1) * params.gemm_k_size) ?
|
||||
params.problem_size.k() :
|
||||
(threadblock_tile_offset.k() + 1) * params.gemm_k_size;
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - tb_offset_A0.column() + DualMma::Shape::kK - 1) / DualMma::Shape::kK;
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename DualMma::IteratorA iterator_A0(
|
||||
params.params_A0,
|
||||
params.ref_A0.data(),
|
||||
{params.problem_size.m(), problem_size_k},
|
||||
thread_idx,
|
||||
tb_offset_A0);
|
||||
|
||||
typename DualMma::IteratorB iterator_B0(
|
||||
params.params_B0,
|
||||
params.ref_B0.data(),
|
||||
{problem_size_k, params.problem_size.n()},
|
||||
thread_idx,
|
||||
tb_offset_B0);
|
||||
|
||||
typename DualMma::IteratorB iterator_B1(
|
||||
params.params_B1,
|
||||
params.ref_B1.data(),
|
||||
{problem_size_k, params.problem_size.n()},
|
||||
thread_idx,
|
||||
tb_offset_B1);
|
||||
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
typename DualMma::FragmentC accum0;
|
||||
typename DualMma::FragmentC accum1;
|
||||
accum0.clear();
|
||||
accum1.clear();
|
||||
|
||||
DualMma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
if (!kSplitKSerial || gemm_k_iterations > 0) {
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(gemm_k_iterations,
|
||||
accum0, accum1,
|
||||
iterator_A0, iterator_B0, iterator_B1,
|
||||
accum0, accum1);
|
||||
}
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
OutputOp0 output_op_0(params.output_op_0);
|
||||
OutputOp1 output_op_1(params.output_op_1);
|
||||
OutputOp2 output_op_2(params.output_op_2);
|
||||
|
||||
//
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
//assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_offset.m() * DualMma::Shape::kM,
|
||||
threadblock_tile_offset.n() * DualMma::Shape::kN
|
||||
);
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
// Construct the semaphore.
|
||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op_0.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue0::OutputTileIterator iterator_C0(
|
||||
params.params_C0,
|
||||
params.ref_C0.data(),
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
typename Epilogue1::OutputTileIterator iterator_C1(
|
||||
params.params_C1,
|
||||
params.ref_C1.data(),
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue0::OutputTileIterator iterator_D0(
|
||||
params.params_D0,
|
||||
params.ref_D0.data(),
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
typename Epilogue1::OutputTileIterator iterator_D1(
|
||||
params.params_D1,
|
||||
params.ref_D1.data(),
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
typename Epilogue1::OutputTileIterator iterator_D2(
|
||||
params.params_D2,
|
||||
params.ref_D2.data(),
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
DualEpilogue epilogue(
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
// Wait on the semaphore - this latency may have been covered by iterator construction
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
||||
if (threadblock_tile_offset.k()) {
|
||||
iterator_C0 = iterator_D0;
|
||||
iterator_C1 = iterator_D1;
|
||||
}
|
||||
|
||||
semaphore.wait(threadblock_tile_offset.k());
|
||||
|
||||
__threadfence();
|
||||
}
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
typename Epilogue0::OutputTileIterator source_iters[] = {
|
||||
iterator_C0, iterator_C1
|
||||
};
|
||||
const bool writeToD2 = (!kSplitKSerial || params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1);
|
||||
epilogue(
|
||||
output_op_0, output_op_1, output_op_2,
|
||||
iterator_D0, iterator_D1, iterator_D2,
|
||||
accum0, accum1,
|
||||
source_iters,
|
||||
writeToD2
|
||||
);
|
||||
|
||||
//
|
||||
// Release the semaphore
|
||||
//
|
||||
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
int lock = 0;
|
||||
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
|
||||
|
||||
// The final threadblock resets the semaphore for subsequent grids.
|
||||
lock = 0;
|
||||
}
|
||||
else {
|
||||
// Otherwise, the semaphore is incremented
|
||||
lock = threadblock_tile_offset.k() + 1;
|
||||
}
|
||||
|
||||
__threadfence();
|
||||
semaphore.release(lock);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
95
examples/43_dual_gemm/test_run.h
Normal file
95
examples/43_dual_gemm/test_run.h
Normal file
@ -0,0 +1,95 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
// Run tests on GPUs
|
||||
|
||||
int testRun(int arch, std::vector<bool (*)()> & test_funcs, const std::string & test_name) {
|
||||
|
||||
bool supported = false;
|
||||
|
||||
int arch_major = arch / 10;
|
||||
int arch_minor = arch - arch / 10 * 10;
|
||||
|
||||
if(arch_major >= 8) {
|
||||
// Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples.
|
||||
if (__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) {
|
||||
supported = true;
|
||||
}
|
||||
}
|
||||
else if(arch_major >= 7) {
|
||||
// Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples.
|
||||
if (__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) {
|
||||
supported = true;
|
||||
}
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!(props.major == arch_major && props.minor == arch_minor)) {
|
||||
supported = false;
|
||||
}
|
||||
|
||||
if (!supported) {
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
std::cout << "This example isn't supported on current architecture" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
|
||||
std::cout << "Device: " << props.name << std::endl;
|
||||
std::cout << "Arch: SM" << arch << std::endl;
|
||||
std::cout << "Test: " << test_name << std::endl;
|
||||
for(auto func : test_funcs) {
|
||||
pass &= func();
|
||||
}
|
||||
|
||||
|
||||
if(pass)
|
||||
return 0;
|
||||
else
|
||||
return -1;
|
||||
|
||||
}
|
||||
|
||||
150
examples/43_dual_gemm/thread/left_silu_and_mul.h
Normal file
150
examples/43_dual_gemm/thread/left_silu_and_mul.h
Normal file
@ -0,0 +1,150 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Functor performing linear combination operations used by epilogues.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/epilogue/thread/scale_type.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_params.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace thread {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Applies a linear combination operator to an array of elements.
|
||||
///
|
||||
/// D = alpha * accumulator + beta * source + uniform
|
||||
///
|
||||
template <
|
||||
typename ElementOutput_, ///< Data type used to load and store tensors
|
||||
int Count, ///< Number of elements computed per operation.
|
||||
///< Usually it is 128/sizeof_bits<ElementOutput_>,
|
||||
///< but we use 64 or 32 sometimes when there are not enough data to store
|
||||
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
|
||||
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
|
||||
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
class LeftSiLUAndMul {
|
||||
public:
|
||||
|
||||
using ElementOutput = ElementOutput_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
|
||||
static int const kCount = Count;
|
||||
using FragmentOutput = Array<ElementOutput, kCount>;
|
||||
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
||||
using ComputeFragment = Array<ElementCompute, kCount>;
|
||||
|
||||
static FloatRoundStyle const kRound = Round;
|
||||
|
||||
struct Params{};
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
ElementCompute alpha_;
|
||||
ElementCompute beta_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the function object, possibly loading from pointers in host memory
|
||||
CUTLASS_HOST_DEVICE
|
||||
LeftSiLUAndMul(Params const &/*params*/) {}
|
||||
|
||||
/// Returns true if source is needed
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool is_source_needed() const {
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Functionally required for serial reduction in the epilogue
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition, int k_partition_count) {
|
||||
assert(false);
|
||||
}
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(
|
||||
FragmentAccumulator const &lhs,
|
||||
FragmentAccumulator const &rhs) const {
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_to_compute;
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> compute_to_output;
|
||||
|
||||
ComputeFragment converted_lhs = accumulator_to_compute(lhs);
|
||||
ComputeFragment converted_rhs = accumulator_to_compute(rhs);
|
||||
|
||||
cutlass::epilogue::thread::SiLu<ComputeFragment> silu;
|
||||
cutlass::multiplies<ComputeFragment> mul;
|
||||
auto silu_lhs = silu(converted_lhs);
|
||||
return compute_to_output(mul(silu_lhs, converted_rhs));
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
ElementOutput operator()(
|
||||
ElementAccumulator const& lhs,
|
||||
ElementAccumulator const& rhs
|
||||
) const {
|
||||
ElementCompute convert_lhs(lhs);
|
||||
ElementCompute convert_rhs(rhs);
|
||||
cutlass::epilogue::thread::SiLu<ElementCompute> silu;
|
||||
cutlass::multiplies<ElementCompute> mul;
|
||||
auto silu_lhs = silu(convert_lhs);
|
||||
return ElementOutput(mul(silu_lhs, convert_rhs));
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace thread
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
430
examples/43_dual_gemm/threadblock/dual_epilogue.h
Normal file
430
examples/43_dual_gemm/threadblock/dual_epilogue.h
Normal file
@ -0,0 +1,430 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
||||
|
||||
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
||||
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#endif
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/layout/vector.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/tensor_coord.h"
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/functional.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/transform/pitch_linear_thread_map.h"
|
||||
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue_base.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue operator
|
||||
template <
|
||||
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
|
||||
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
|
||||
int PartitionsK, ///< Number of partitions of the K dimension
|
||||
typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors
|
||||
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
|
||||
typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
|
||||
typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
|
||||
///< Output operator
|
||||
typename OutputOp0_,
|
||||
typename OutputOp1_,
|
||||
typename OutputOp2_,
|
||||
typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape)
|
||||
bool StoreD0 = true,
|
||||
bool StoreD1 = true,
|
||||
int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity
|
||||
int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large
|
||||
(!IsEpilogueFunctorHeavy<OutputOp0_>::value)
|
||||
>
|
||||
class DualEpilogue {
|
||||
|
||||
public:
|
||||
|
||||
using Base = EpilogueBase<
|
||||
Shape_,
|
||||
typename WarpMmaOperator_::Shape,
|
||||
PartitionsK,
|
||||
AccumulatorFragmentIterator_,
|
||||
WarpTileIterator_,
|
||||
Padding_,
|
||||
FragmentsPerPartition>;
|
||||
|
||||
using Shape = Shape_;
|
||||
using WarpMmaOperator = WarpMmaOperator_;
|
||||
static int const kPartitionsK = PartitionsK;
|
||||
static bool constexpr kStoreD0 = StoreD0;
|
||||
static bool constexpr kStoreD1 = StoreD1;
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
|
||||
using WarpTileIterator = WarpTileIterator_;
|
||||
using SharedLoadIterator = SharedLoadIterator_;
|
||||
using OutputOp0 = OutputOp0_;
|
||||
using OutputOp1 = OutputOp1_;
|
||||
using OutputOp2 = OutputOp2_;
|
||||
using Padding = Padding_;
|
||||
|
||||
using Layout = layout::RowMajor;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
|
||||
/// The complete warp-level accumulator tile
|
||||
using AccumulatorTile = typename Base::AccumulatorTile;
|
||||
|
||||
/// Accumulator element
|
||||
using ElementAccumulator = typename WarpTileIterator::Element;
|
||||
|
||||
/// Output element
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
|
||||
/// Output access size
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
/// Tensor reference to destination tensor
|
||||
using TensorRef = typename OutputTileIterator::TensorRef;
|
||||
|
||||
/// Tensor reference to sync tensor
|
||||
using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
|
||||
|
||||
/// Const tensor reference to source tensor
|
||||
using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
|
||||
|
||||
/// Array type used to output
|
||||
using OutputAccessType = Array<
|
||||
typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
||||
|
||||
/// Array type used by output functor
|
||||
using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
||||
|
||||
/// Number of warps
|
||||
using WarpCount = typename Base::WarpCount;
|
||||
|
||||
struct SharedStorage {
|
||||
using Element = typename WarpTileIterator::Element;
|
||||
|
||||
/// Tensor reference to shared memory allocation
|
||||
using TensorRef = typename WarpTileIterator::TensorRef;
|
||||
|
||||
/// Logical shape of the shared memory tile written to by all warps.
|
||||
using Shape = typename Base::Shape;
|
||||
|
||||
/// Shape of the shared memory allocation for the epilogue
|
||||
using StorageShape = typename Base::SharedStorage::StorageShape;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
AlignedBuffer<Element, StorageShape::kCount> storage[2];
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Returns a tensor reference to the shared memory buffer
|
||||
CUTLASS_DEVICE
|
||||
TensorRef reference(int i) {
|
||||
return TensorRef(
|
||||
storage[i].data(),
|
||||
Layout::packed({StorageShape::kRow, StorageShape::kColumn}));
|
||||
}
|
||||
};
|
||||
|
||||
static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK;
|
||||
static int constexpr kSmemPointerOffset = SharedStorage::StorageShape::kCount / kSmemTiles;
|
||||
|
||||
public:
|
||||
|
||||
static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
|
||||
"Mismatch between shared load iterator and output tile iterator.");
|
||||
|
||||
static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero.");
|
||||
|
||||
static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
|
||||
"Divisibility");
|
||||
|
||||
private:
|
||||
|
||||
/// Loads fragment from shared memory aligned with output tensor
|
||||
SharedLoadIterator shared_load_iterator0_;
|
||||
SharedLoadIterator shared_load_iterator1_;
|
||||
|
||||
/// Stores a warp's fragment of accumulators to SMEM
|
||||
WarpTileIterator warp_tile_iterator0_;
|
||||
WarpTileIterator warp_tile_iterator1_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
DualEpilogue(
|
||||
SharedStorage &shared_storage, ///< Shared storage object
|
||||
int thread_idx, ///< ID of a thread within the threadblock
|
||||
int warp_idx, ///< ID of warp within threadblock
|
||||
int lane_idx ///< Id of thread within warp
|
||||
):
|
||||
shared_load_iterator0_(shared_storage.reference(0), thread_idx),
|
||||
shared_load_iterator1_(shared_storage.reference(1), thread_idx),
|
||||
warp_tile_iterator0_(shared_storage.reference(0), lane_idx),
|
||||
warp_tile_iterator1_(shared_storage.reference(1), lane_idx)
|
||||
{
|
||||
int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN);
|
||||
int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN);
|
||||
int warp_m = warp_mn % WarpCount::kM;
|
||||
int warp_n = warp_mn / WarpCount::kM;
|
||||
|
||||
MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n};
|
||||
|
||||
warp_tile_iterator0_.add_tile_offset(warp_offset);
|
||||
warp_tile_iterator1_.add_tile_offset(warp_offset);
|
||||
}
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
OutputOp0 const &output_op0,
|
||||
OutputOp1 const &output_op1,
|
||||
OutputOp2 const &output_op2,
|
||||
OutputTileIterator dest0,
|
||||
OutputTileIterator dest1,
|
||||
OutputTileIterator dest2,
|
||||
AccumulatorTile const &accumulator0,
|
||||
AccumulatorTile const &accumulator1,
|
||||
OutputTileIterator source_iterator[2],
|
||||
bool writeToD2 // true if it's the final split-k
|
||||
) {
|
||||
// TODO: Implement when no source is needed
|
||||
|
||||
typename OutputTileIterator::Fragment source_fragment[2];
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
source_fragment[i].clear();
|
||||
}
|
||||
|
||||
//
|
||||
// Iterator over warp-level accumulator fragment
|
||||
//
|
||||
|
||||
AccumulatorFragmentIterator accum_fragment_iterator[2] = {accumulator0, accumulator1};
|
||||
|
||||
//
|
||||
// Iterate over accumulator tile
|
||||
//
|
||||
|
||||
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
|
||||
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
|
||||
|
||||
//
|
||||
// Load the source
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
source_iterator[i].load(source_fragment[i]);
|
||||
++source_iterator[i];
|
||||
}
|
||||
|
||||
//
|
||||
// Convert and store fragment
|
||||
//
|
||||
|
||||
__syncthreads();
|
||||
|
||||
acc2smem_source_needed<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
|
||||
iter, accum_fragment_iterator[0], this->warp_tile_iterator0_);
|
||||
acc2smem_source_needed<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
|
||||
iter, accum_fragment_iterator[1], this->warp_tile_iterator1_);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
// Load fragments from shared memory
|
||||
//
|
||||
|
||||
typename SharedLoadIterator::Fragment aligned_accum_fragment0[kPartitionsK];
|
||||
typename SharedLoadIterator::Fragment aligned_accum_fragment1[kPartitionsK];
|
||||
|
||||
shared_load_iterator0_.load(aligned_accum_fragment0[0]);
|
||||
shared_load_iterator1_.load(aligned_accum_fragment1[0]);
|
||||
|
||||
// If the number of k-slices is > 1 - perform a reduction amongst the k-slices
|
||||
if (kPartitionsK > 1) {
|
||||
|
||||
plus <typename SharedLoadIterator::Fragment> add_fragments;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for ( int i = 1; i < kPartitionsK; ++i) {
|
||||
shared_load_iterator0_.add_pointer_offset(kSmemPointerOffset);
|
||||
shared_load_iterator1_.add_pointer_offset(kSmemPointerOffset);
|
||||
shared_load_iterator0_.load(aligned_accum_fragment0[i]);
|
||||
shared_load_iterator1_.load(aligned_accum_fragment1[i]);
|
||||
aligned_accum_fragment0[0] = add_fragments(aligned_accum_fragment0[0], aligned_accum_fragment0[i]);
|
||||
aligned_accum_fragment1[0] = add_fragments(aligned_accum_fragment1[0], aligned_accum_fragment1[i]);
|
||||
}
|
||||
|
||||
shared_load_iterator0_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
|
||||
shared_load_iterator1_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
|
||||
}
|
||||
|
||||
//
|
||||
// Compute the output result
|
||||
//
|
||||
|
||||
typename OutputTileIterator::Fragment output_fragment[3];
|
||||
|
||||
apply_output_operator_(output_fragment,
|
||||
output_op0, output_op1, output_op2,
|
||||
aligned_accum_fragment0[0], aligned_accum_fragment1[0],
|
||||
source_fragment);
|
||||
|
||||
|
||||
//
|
||||
// Store the final result
|
||||
//
|
||||
|
||||
if (kStoreD0) {
|
||||
dest0.store(output_fragment[0]);
|
||||
++dest0;
|
||||
}
|
||||
if (kStoreD1) {
|
||||
dest1.store(output_fragment[1]);
|
||||
++dest1;
|
||||
}
|
||||
if (writeToD2) {
|
||||
dest2.store(output_fragment[2]);
|
||||
++dest2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1.");
|
||||
|
||||
template<class Seq>
|
||||
struct acc2smem_source_needed;
|
||||
|
||||
template <size_t... Seq>
|
||||
struct acc2smem_source_needed<cutlass::index_sequence<Seq...>> {
|
||||
template<int Advance>
|
||||
CUTLASS_DEVICE
|
||||
static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
|
||||
WarpTileIterator &warp_tile_iterator) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Advance; i++) {
|
||||
++accum_fragment_iterator;
|
||||
}
|
||||
|
||||
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
||||
accum_fragment_iterator.load(accum_fragment);
|
||||
warp_tile_iterator.store(accum_fragment);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void push(size_t pos,
|
||||
AccumulatorFragmentIterator const &iterator_begin,
|
||||
WarpTileIterator &warp_tile_iterator) {
|
||||
int dummy[] = {(pos == Seq) && (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
|
||||
}
|
||||
};
|
||||
|
||||
/// Helper to invoke the output functor over each vector of output
|
||||
CUTLASS_DEVICE
|
||||
void apply_output_operator_(
|
||||
typename OutputTileIterator::Fragment (&output_fragment)[3],
|
||||
OutputOp0 const &output_op0,
|
||||
OutputOp1 const &output_op1,
|
||||
OutputOp2 const &output_op2,
|
||||
typename SharedLoadIterator::Fragment const& aligned_accum_fragment0,
|
||||
typename SharedLoadIterator::Fragment const& aligned_accum_fragment1,
|
||||
typename OutputTileIterator::Fragment const (&source_fragment)[2]) {
|
||||
|
||||
OutputAccessType* output_frag_ptr[3] = {
|
||||
reinterpret_cast<OutputAccessType *>(&output_fragment[0]),
|
||||
reinterpret_cast<OutputAccessType *>(&output_fragment[1]),
|
||||
reinterpret_cast<OutputAccessType *>(&output_fragment[2])
|
||||
};
|
||||
|
||||
AccumulatorAccessType const *compute_frag_ptr[2] = {
|
||||
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment0),
|
||||
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment1)
|
||||
};
|
||||
|
||||
OutputAccessType const *source_frag_ptr[2] = {
|
||||
reinterpret_cast<OutputAccessType const *>(&source_fragment[0]),
|
||||
reinterpret_cast<OutputAccessType const *>(&source_fragment[1])
|
||||
};
|
||||
|
||||
int const kOutputOpIterations =
|
||||
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kOutputOpIterations; ++i) {
|
||||
|
||||
// Call the output operators
|
||||
output_frag_ptr[0][i] = output_op0(compute_frag_ptr[0][i], source_frag_ptr[0][i]);
|
||||
output_frag_ptr[1][i] = output_op1(compute_frag_ptr[1][i], source_frag_ptr[1][i]);
|
||||
output_frag_ptr[2][i] = output_op2(output_frag_ptr[0][i], output_frag_ptr[1][i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
218
examples/43_dual_gemm/threadblock/dual_mma_base.h
Normal file
218
examples/43_dual_gemm/threadblock/dual_mma_base.h
Normal file
@ -0,0 +1,218 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/threadblock/mma_base.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class DualMmaBase {
|
||||
public:
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Shape describing the overall GEMM computed from shared memory
|
||||
/// by each warp.
|
||||
using WarpGemm = typename Policy::Operator::Shape;
|
||||
|
||||
/// Shape describing the number of warps filling the CTA
|
||||
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM,
|
||||
Shape::kN / WarpGemm::kN,
|
||||
Shape::kK / WarpGemm::kK>;
|
||||
|
||||
/// Number of warp-level GEMM oeprations
|
||||
static int const kWarpGemmIterations =
|
||||
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Tensor reference to the A operand
|
||||
using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
|
||||
|
||||
/// Tensor reference to the B operand
|
||||
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
|
||||
|
||||
static_assert(kWarpGemmIterations > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
|
||||
static_assert((kWarpGemmIterations % 2) == 0,
|
||||
"Inner loop iteration must be an even number.");
|
||||
|
||||
//
|
||||
// Nested structs
|
||||
//
|
||||
|
||||
/// Shared storage object needed by threadblock-scoped GEMM
|
||||
class SharedStorage {
|
||||
public:
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// Shape of the A matrix operand in shared memory
|
||||
using ShapeA = MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,
|
||||
Shape::kK * kStages +
|
||||
Policy::SmemPaddingA::kColumn>;
|
||||
|
||||
/// Shape of the B matrix operand in shared memory
|
||||
using ShapeB =
|
||||
MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
|
||||
Shape::kN + Policy::SmemPaddingB::kColumn>;
|
||||
|
||||
public:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Buffer for A operand
|
||||
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
|
||||
|
||||
/// Buffer for B operand
|
||||
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B0;
|
||||
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B1;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Returns a layout object for the A matrix
|
||||
CUTLASS_DEVICE
|
||||
static typename Operator::LayoutA LayoutA() {
|
||||
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a layout object for the B matrix
|
||||
CUTLASS_HOST_DEVICE
|
||||
static typename Operator::LayoutB LayoutB() {
|
||||
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the A operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefA operand_A_ref() {
|
||||
return TensorRefA{operand_A.data(), LayoutA()};
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the B operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefB operand_B0_ref() {
|
||||
return TensorRefB{operand_B0.data(), LayoutB()};
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefB operand_B1_ref() {
|
||||
return TensorRefB{operand_B1.data(), LayoutB()};
|
||||
}
|
||||
};
|
||||
|
||||
protected:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A operand from shared memory
|
||||
typename Operator::IteratorA warp_tile_iterator_A_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of B operand from shared memory
|
||||
typename Operator::IteratorB warp_tile_iterator_B0_;
|
||||
typename Operator::IteratorB warp_tile_iterator_B1_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
DualMmaBase(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
SharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx
|
||||
):
|
||||
warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
|
||||
warp_tile_iterator_B0_(shared_storage.operand_B0_ref(), lane_idx),
|
||||
warp_tile_iterator_B1_(shared_storage.operand_B1_ref(), lane_idx) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
760
examples/43_dual_gemm/threadblock/dual_mma_multistage.h
Normal file
760
examples/43_dual_gemm/threadblock/dual_mma_multistage.h
Normal file
@ -0,0 +1,760 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/threadblock/mma_base.h"
|
||||
#include "dual_mma_base.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class DualMmaMultistage :
|
||||
public DualMmaBase<Shape_, Policy_, Stages> {
|
||||
public:
|
||||
///< Base class
|
||||
using Base = DualMmaBase<Shape_, Policy_, Stages>;
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA = IteratorA_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB = IteratorB_;
|
||||
///< Data type of accumulator matrix
|
||||
using ElementC = ElementC_;
|
||||
///< Layout of accumulator matrix
|
||||
using LayoutC = LayoutC_;
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Minimum architecture is Sm80 to support cp.async
|
||||
using ArchTag = arch::Sm80;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail {
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand A
|
||||
static int const AsyncCopyIterationsPerStageA =
|
||||
IteratorA::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const AsyncCopyIterationsPerStageB =
|
||||
IteratorB::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand A
|
||||
static int const kAccessesPerGroupA =
|
||||
(AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB =
|
||||
(AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
using WarpLoadedFragmentA = typename Operator::FragmentA;
|
||||
using WarpLoadedFragmentB = typename Operator::FragmentB;
|
||||
using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B0_;
|
||||
SmemIteratorB smem_iterator_B1_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
DualMmaMultistage(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::SharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B0_(shared_storage.operand_B0_ref(), thread_idx),
|
||||
smem_iterator_B1_(shared_storage.operand_B1_ref(), thread_idx)
|
||||
{
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B0, IteratorB &iterator_B1,
|
||||
int group_start_A = 0, int group_start_B = 0) {
|
||||
iterator_A.set_iteration_index(group_start_A *
|
||||
IteratorA::kAccessesPerVector);
|
||||
this->smem_iterator_A_.set_iteration_index(group_start_A);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) {
|
||||
if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) {
|
||||
typename IteratorA::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA::AccessType *>(
|
||||
this->smem_iterator_A_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess /
|
||||
IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_A.get();
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, gmem_ptr, iterator_A.valid());
|
||||
} else {
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, gmem_ptr, iterator_A.valid());
|
||||
}
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
}
|
||||
|
||||
iterator_B0.set_iteration_index(group_start_B *
|
||||
IteratorB::kAccessesPerVector);
|
||||
iterator_B1.set_iteration_index(group_start_B *
|
||||
IteratorB::kAccessesPerVector);
|
||||
this->smem_iterator_B0_.set_iteration_index(group_start_B);
|
||||
this->smem_iterator_B1_.set_iteration_index(group_start_B);
|
||||
|
||||
// Async Copy for operand B0
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
|
||||
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
this->smem_iterator_B0_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess /
|
||||
IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_B0.get();
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B0.valid());
|
||||
} else {
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B0.valid());
|
||||
}
|
||||
|
||||
++iterator_B0;
|
||||
}
|
||||
++this->smem_iterator_B0_;
|
||||
}
|
||||
}
|
||||
// Async Copy for operand B1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
|
||||
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
this->smem_iterator_B1_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess /
|
||||
IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_B1.get();
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B1.valid());
|
||||
} else {
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B1.valid());
|
||||
}
|
||||
|
||||
++iterator_B1;
|
||||
}
|
||||
++this->smem_iterator_B1_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
///< problem size of GEMM
|
||||
int gemm_k_iterations,
|
||||
///< destination accumulator tile
|
||||
FragmentC &accum0,
|
||||
FragmentC &accum1,
|
||||
///< iterator over A operand in global memory
|
||||
IteratorA iterator_A,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB iterator_B0,
|
||||
IteratorB iterator_B1,
|
||||
///< initial value of accumulator
|
||||
FragmentC const &src_accum0,
|
||||
FragmentC const &src_accum1
|
||||
) {
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations) {
|
||||
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B0.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B1.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
iterator_A.set_iteration_index(0);
|
||||
this->smem_iterator_A_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
|
||||
typename IteratorA::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA::AccessType *>(
|
||||
this->smem_iterator_A_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess /
|
||||
IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
|
||||
iterator_B0.set_iteration_index(0);
|
||||
iterator_B1.set_iteration_index(0);
|
||||
this->smem_iterator_B0_.set_iteration_index(0);
|
||||
this->smem_iterator_B1_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B0
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
this->smem_iterator_B0_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess /
|
||||
IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, iterator_B0.get(), iterator_B0.valid());
|
||||
|
||||
++iterator_B0;
|
||||
}
|
||||
|
||||
++this->smem_iterator_B0_;
|
||||
}
|
||||
// Async Copy for operand B1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
this->smem_iterator_B1_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess /
|
||||
IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, iterator_B1.get(), iterator_B1.valid());
|
||||
|
||||
++iterator_B1;
|
||||
}
|
||||
|
||||
++this->smem_iterator_B1_;
|
||||
}
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A.add_tile_offset({0, 1});
|
||||
iterator_B0.add_tile_offset({1, 0});
|
||||
iterator_B1.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_A_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B0_.add_tile_offset({1, 0});
|
||||
this->smem_iterator_B1_.add_tile_offset({1, 0});
|
||||
|
||||
// Defines the boundary of a stage of cp.async.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
accum0 = src_accum0;
|
||||
accum1 = src_accum1;
|
||||
|
||||
//
|
||||
// Clear the remaining tiles of SMEM. This is a functional requirement for some kernels
|
||||
// so that all accumulator elements outside the GEMM footprint are zero.
|
||||
//
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) {
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
|
||||
|
||||
typename IteratorA::AccessType zero_A;
|
||||
zero_A.clear();
|
||||
|
||||
last_smem_iterator_A.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
|
||||
|
||||
typename IteratorA::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA::AccessType *>(
|
||||
last_smem_iterator_A.get());
|
||||
|
||||
*dst_ptr = zero_A;
|
||||
|
||||
++last_smem_iterator_A;
|
||||
}
|
||||
|
||||
typename IteratorB::AccessType zero_B;
|
||||
zero_B.clear();
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B0 operand to shared memory
|
||||
SmemIteratorB last_smem_iterator_B0(this->smem_iterator_B0_);
|
||||
last_smem_iterator_B0.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
|
||||
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
last_smem_iterator_B0.get());
|
||||
|
||||
*dst_ptr = zero_B;
|
||||
|
||||
++last_smem_iterator_B0;
|
||||
}
|
||||
/// Iterator to write threadblock-scoped tile of B1 operand to shared memory
|
||||
SmemIteratorB last_smem_iterator_B1(this->smem_iterator_B1_);
|
||||
last_smem_iterator_B1.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
|
||||
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
last_smem_iterator_B1.get());
|
||||
|
||||
*dst_ptr = zero_B;
|
||||
|
||||
++last_smem_iterator_B1;
|
||||
}
|
||||
}
|
||||
|
||||
// Waits until stages up to the previous (kStages-2)th stage have committed.
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpLoadedFragmentA warp_loaded_frag_A[2];
|
||||
WarpLoadedFragmentB warp_loaded_frag_B0[2];
|
||||
WarpLoadedFragmentB warp_loaded_frag_B1[2];
|
||||
WarpTransformedFragmentA warp_transformed_frag_A[2];
|
||||
WarpTransformedFragmentB warp_transformed_frag_B0[2];
|
||||
WarpTransformedFragmentB warp_transformed_frag_B1[2];
|
||||
|
||||
Operator warp_mma;
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]);
|
||||
this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]);
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B0.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B1.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
|
||||
warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B0[0],
|
||||
warp_loaded_frag_A[0], warp_loaded_frag_B0[0]);
|
||||
warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B1[0],
|
||||
warp_loaded_frag_A[0], warp_loaded_frag_B1[0]);
|
||||
|
||||
// tf32x3 kernels use staging accumulation. warp_mma uses a temporary
|
||||
// accumulator and this temporary accumulator is added to the final
|
||||
// accumulator once in every mainloop iteration.
|
||||
plus<FragmentC> plus_accum;
|
||||
|
||||
FragmentC tmp_accum0, tmp_accum1;
|
||||
|
||||
if (platform::is_same<typename Operator::MathOperator,
|
||||
arch::OpMultiplyAddFastF32>::value
|
||||
|| platform::is_same<typename Operator::MathOperator,
|
||||
arch::OpMultiplyAddComplexFastF32>::value) {
|
||||
|
||||
tmp_accum0.clear();
|
||||
tmp_accum1.clear();
|
||||
}
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > (-Base::kStages + 1);) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations;
|
||||
++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
if (warp_mma_k > 0) {
|
||||
warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2],
|
||||
warp_transformed_frag_B0[warp_mma_k % 2],
|
||||
warp_loaded_frag_A[warp_mma_k % 2],
|
||||
warp_loaded_frag_B0[warp_mma_k % 2]);
|
||||
warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2],
|
||||
warp_transformed_frag_B1[warp_mma_k % 2],
|
||||
warp_loaded_frag_A[warp_mma_k % 2],
|
||||
warp_loaded_frag_B1[warp_mma_k % 2]);
|
||||
}
|
||||
|
||||
if (platform::is_same<typename Operator::MathOperator,
|
||||
arch::OpMultiplyAddFastF32>::value
|
||||
|| platform::is_same<typename Operator::MathOperator,
|
||||
arch::OpMultiplyAddComplexFastF32>::value) {
|
||||
|
||||
warp_mma(
|
||||
tmp_accum0,
|
||||
warp_transformed_frag_A[warp_mma_k % 2],
|
||||
warp_transformed_frag_B0[warp_mma_k % 2],
|
||||
tmp_accum0
|
||||
);
|
||||
warp_mma(
|
||||
tmp_accum1,
|
||||
warp_transformed_frag_A[warp_mma_k % 2],
|
||||
warp_transformed_frag_B1[warp_mma_k % 2],
|
||||
tmp_accum1
|
||||
);
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
accum0 = plus_accum(accum0, tmp_accum0);
|
||||
accum1 = plus_accum(accum1, tmp_accum1);
|
||||
tmp_accum0.clear();
|
||||
tmp_accum1.clear();
|
||||
}
|
||||
} else {
|
||||
warp_mma(
|
||||
accum0,
|
||||
warp_transformed_frag_A[warp_mma_k % 2],
|
||||
warp_transformed_frag_B0[warp_mma_k % 2],
|
||||
accum0
|
||||
);
|
||||
warp_mma(
|
||||
accum1,
|
||||
warp_transformed_frag_A[warp_mma_k % 2],
|
||||
warp_transformed_frag_B1[warp_mma_k % 2],
|
||||
accum1
|
||||
);
|
||||
}
|
||||
|
||||
// Issue global->shared copies for the this stage
|
||||
if (warp_mma_k < Base::kWarpGemmIterations - 1) {
|
||||
int group_start_iteration_A, group_start_iteration_B;
|
||||
|
||||
group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
|
||||
group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
|
||||
|
||||
copy_tiles_and_advance(iterator_A, iterator_B0, iterator_B1, group_start_iteration_A,
|
||||
group_start_iteration_B);
|
||||
}
|
||||
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
|
||||
int group_start_iteration_A, group_start_iteration_B;
|
||||
group_start_iteration_A =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupA;
|
||||
group_start_iteration_B =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupB;
|
||||
|
||||
copy_tiles_and_advance(iterator_A, iterator_B0, iterator_B1, group_start_iteration_A,
|
||||
group_start_iteration_B);
|
||||
|
||||
// Inserts a memory fence between stages of cp.async instructions.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Waits until stages up to the previous (kStages-2)th stage have committed.
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A.add_tile_offset({0, 1});
|
||||
iterator_B0.add_tile_offset({1, 0});
|
||||
iterator_B1.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_A_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B0_.add_tile_offset({1, 0});
|
||||
this->smem_iterator_B1_.add_tile_offset({1, 0});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1)) {
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0});
|
||||
this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx = 0;
|
||||
} else {
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1)) {
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy::kPartitionsK *
|
||||
Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK *
|
||||
Base::kWarpGemmIterations,
|
||||
0});
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK *
|
||||
Base::kWarpGemmIterations,
|
||||
0});
|
||||
smem_read_stage_idx = 0;
|
||||
} else {
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
--gemm_k_iterations;
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B0.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B1.clear_mask(gemm_k_iterations == 0);
|
||||
}
|
||||
|
||||
// Do any conversions feeding the first stage at the end of the loop so
|
||||
// we can start right away on mma instructions
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
|
||||
warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2],
|
||||
warp_transformed_frag_B0[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_B0[(warp_mma_k + 1) % 2]);
|
||||
warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2],
|
||||
warp_transformed_frag_B1[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if (platform::is_same<typename Operator::MathOperator,
|
||||
arch::OpMultiplyAddFastF32>::value
|
||||
|| platform::is_same<typename Operator::MathOperator,
|
||||
arch::OpMultiplyAddComplexFastF32>::value) {
|
||||
accum0 = plus_accum(accum0, tmp_accum0);
|
||||
accum1 = plus_accum(accum1, tmp_accum1);
|
||||
}
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
// commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -121,6 +121,7 @@ foreach(EXAMPLE
|
||||
39_gemm_permute
|
||||
41_multi_head_attention
|
||||
42_fused_multi_head_attention
|
||||
43_dual_gemm
|
||||
)
|
||||
|
||||
add_subdirectory(${EXAMPLE})
|
||||
|
||||
@ -34,6 +34,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
|
||||
Reference in New Issue
Block a user