CUTLASS 2.10 updates (#622)
Co-authored-by: Aniket Shivam <ashivam@nvidia.com>
This commit is contained in:
@ -0,0 +1,225 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor with CTA row-wise broadcast
|
||||
|
||||
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
||||
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_complex.h"
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h"
|
||||
|
||||
#include "epilogue_visitor_op/visitor_op_linear_combination.h"
|
||||
#include "epilogue_visitor_op/visitor_op_tensor_input.h"
|
||||
#include "epilogue_visitor_op/visitor_op_accumulator.h"
|
||||
#include "epilogue_visitor_op/visitor_op_row_broadcast.h"
|
||||
#include "epilogue_visitor_op/visitor_op_tensor_output.h"
|
||||
#include "epilogue_visitor_op/visitor_op_column_reduction.h"
|
||||
#include "epilogue_visitor_op/visitor_op_row_reduction.h"
|
||||
#include "epilogue_visitor_op/visitor_op_column_broadcast.h"
|
||||
#include "epilogue_visitor_op/visitor_op_unary.h"
|
||||
#include "epilogue_visitor_op/visitor_op_binary.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Generic Epilogue Visitor.
|
||||
template <
|
||||
typename OutputOp_
|
||||
>
|
||||
class EpilogueVisitorGeneric {
|
||||
public:
|
||||
|
||||
using OutputOp = OutputOp_;
|
||||
using AccumulatorAccessType = typename OutputOp::AccumulatorAccessType;
|
||||
static int const kElementsPerAccess = OutputOp::kElementsPerAccess;
|
||||
using ElementOutput = typename OutputOp::ElementOutput;
|
||||
using OutputTileIterator = typename OutputOp::OutputTileIterator;
|
||||
|
||||
static int const kIterations = OutputTileIterator::kIterations;
|
||||
|
||||
///
|
||||
/// End Epilogue Tree
|
||||
///
|
||||
|
||||
/// Additional SMEM bufer is not required in the broadcast epilogue visitor
|
||||
struct SharedStorage {
|
||||
|
||||
typename OutputOp::SharedStorage output_smem;
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() { }
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
typename OutputOp::Arguments output_op_args;
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
Arguments() { }
|
||||
|
||||
Arguments(
|
||||
typename OutputOp::Arguments output_op_args
|
||||
):
|
||||
output_op_args(output_op_args)
|
||||
{
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
struct Params {
|
||||
typename OutputOp::Params output_op_params;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
output_op_params(args.output_op_args)
|
||||
{
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
private:
|
||||
|
||||
OutputOp output_op;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
EpilogueVisitorGeneric(
|
||||
Params const ¶ms, ///< Parameters routed to the epilogue
|
||||
SharedStorage &shared_storage, ///< Shared storage needed by the functors here
|
||||
MatrixCoord threadblock_offset,
|
||||
gemm::GemmCoord threadblock_tile_offset,
|
||||
int thread_idx,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
output_op(params.output_op_params, shared_storage.output_smem, thread_idx, threadblock_offset, problem_size)
|
||||
{ }
|
||||
|
||||
/// Helper to indicate split-K behavior
|
||||
CUTLASS_DEVICE
|
||||
void set_k_partition(
|
||||
int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
|
||||
int split_k_slices) { ///< Total number of split-K slices
|
||||
|
||||
}
|
||||
|
||||
/// Called to set the batch index
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
output_op.set_batch_index(batch_idx);
|
||||
}
|
||||
|
||||
/// Called at the start of the epilogue just before iterating over accumulator slices
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
output_op.begin_epilogue();
|
||||
}
|
||||
|
||||
/// Called at the start of one step before starting accumulator exchange
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
output_op.begin_step(step_idx);
|
||||
}
|
||||
|
||||
/// Called at the start of a row
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
output_op.begin_row(row_idx);
|
||||
}
|
||||
|
||||
/// Called after accumulators have been exchanged for each accumulator vector
|
||||
CUTLASS_DEVICE
|
||||
void visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum) {
|
||||
output_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
}
|
||||
|
||||
/// Called at the start of a row
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
output_op.end_row(row_idx);
|
||||
|
||||
}
|
||||
|
||||
/// Called after all accumulator elements have been visited
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
output_op.end_step(step_idx);
|
||||
}
|
||||
|
||||
/// Called after all steps have been completed
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
output_op.end_epilogue();
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,84 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the binary ops
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
/// Scalar multiplication
|
||||
template <typename T, int N>
|
||||
struct VectorAdd {
|
||||
|
||||
struct Arguments {
|
||||
int tmp;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():tmp(0){ }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(int tmp): tmp(tmp) { }
|
||||
};
|
||||
|
||||
struct Params {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args) { }
|
||||
};
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VectorAdd(
|
||||
Params const ¶ms
|
||||
) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
|
||||
cutlass::plus<Array<T, N>> add_op;
|
||||
return add_op(lhs, rhs);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,233 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the unary ops
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
/// Scalar multiplication
|
||||
template <typename T, int N>
|
||||
struct Mult {
|
||||
|
||||
struct Arguments {
|
||||
T alpha;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():alpha(T(1.0)){ }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(T alpha): alpha(alpha) { }
|
||||
};
|
||||
|
||||
struct Params {
|
||||
T alpha; ///< scales accumulators
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():alpha(T(1.0)){ }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args): alpha(args.alpha) { }
|
||||
};
|
||||
|
||||
T alpha_;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Mult(
|
||||
Params const ¶ms
|
||||
):
|
||||
alpha_(params.alpha)
|
||||
{ }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &source) const {
|
||||
cutlass::multiplies<Array<T, N>> multiply_op;
|
||||
return multiply_op(source, alpha_);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool guard() {
|
||||
return alpha_ != T(0);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
/// ReLU
|
||||
template <typename T, int N>
|
||||
struct ReLUVisitor {
|
||||
struct Arguments {
|
||||
T threshold;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():threshold(T(0.0)) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(T threshold): threshold(threshold) { }
|
||||
};
|
||||
|
||||
struct Params {
|
||||
T threshold;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():threshold(T(0.0)) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args): threshold(args.threshold) { }
|
||||
};
|
||||
|
||||
T threshold_;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
ReLUVisitor(Params const ¶ms):
|
||||
threshold_(params.threshold) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &frag) const {
|
||||
maximum<Array<T, N>> mx;
|
||||
return mx(frag, threshold_);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool guard() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
/// leakyReLU
|
||||
template <typename T, int N>
|
||||
struct LeakyReLUVisitor {
|
||||
struct Arguments {
|
||||
T leaky_alpha;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():leaky_alpha(T(0.0)) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(T leaky_alpha): leaky_alpha(leaky_alpha) { }
|
||||
};
|
||||
|
||||
struct Params {
|
||||
T leaky_alpha;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():leaky_alpha(T(0.0)) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args): leaky_alpha(args.leaky_alpha) { }
|
||||
};
|
||||
|
||||
T leaky_alpha_;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
LeakyReLUVisitor(Params const ¶ms):
|
||||
leaky_alpha_(params.leaky_alpha) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &frag) const {
|
||||
cutlass::epilogue::thread::LeakyReLU<Array<T, N>> leaky_op;
|
||||
return leaky_op(frag, leaky_alpha_);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool guard() {
|
||||
return true;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/// Tanh
|
||||
template <typename T, int N>
|
||||
struct TanhVisitor {
|
||||
/// Argument
|
||||
struct Arguments {
|
||||
// a placeholder argument to ensure correctness of ctypes
|
||||
int tmp;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): tmp(0) { };
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(int tmp): tmp(tmp) { };
|
||||
};
|
||||
|
||||
/// Param
|
||||
struct Params {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(){ };
|
||||
Params(Arguments const &args) { }
|
||||
};
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TanhVisitor(Params const ¶ms) { }
|
||||
|
||||
// scalar operator
|
||||
CUTLASS_HOST_DEVICE
|
||||
T tanh_op(T const &scalar) const {
|
||||
return fast_tanh(scalar);
|
||||
}
|
||||
|
||||
/// vector operator
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &frag) const {
|
||||
Array<T, N> y;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i=0; i < N; ++i) {
|
||||
y[i] = tanh_op(frag[i]);
|
||||
}
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool guard() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,148 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with accumulator
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue Visitor operator for the following Computation
|
||||
///
|
||||
/// ElementAccumulator accum;
|
||||
/// return accum;
|
||||
///
|
||||
/// It can only be the leaf node of the epilogue tree
|
||||
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
int kElementsPerAccess_ ///< Number of elements computed per operation
|
||||
>
|
||||
class VisitorOpAccumulator{
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
static int const kElementsPerAccess = kElementsPerAccess_;
|
||||
|
||||
/// Fragment type for Accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type returned by this visitor
|
||||
using VisitAccessType = AccumulatorAccessType;
|
||||
|
||||
/// SMEM buffer class required in the epilogue visitor
|
||||
struct SharedStorage {
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() {}
|
||||
};
|
||||
|
||||
/// Host-constructable Arguments structure
|
||||
struct Arguments {
|
||||
// Note: it is strange that ctypes will return issue with empty arguments
|
||||
int tmp;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(int tmp): tmp(tmp) { }
|
||||
};
|
||||
|
||||
/// Parameter structure
|
||||
struct Params {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args) { }
|
||||
};
|
||||
|
||||
public:
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpAccumulator(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
return accum;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() { }
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,246 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with Binary op
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "binary_ops.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementCompute alpha;
|
||||
/// ElementCompute beta;
|
||||
/// ElementCompute C = BinaryOp(alpha * ElementCompute(Visitor_A), beta * ElementCompute(Visitor_B)
|
||||
/// Return C;
|
||||
///
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename ElementCompute_, ///< Data type used to compute linear combination
|
||||
int kElementsPerAccess_, ///< Number of elements computed per operation
|
||||
typename VisitorA_, ///< Child node A
|
||||
typename VisitorB_, ///< Child node B
|
||||
template<typename T, int N> typename BinaryOp_
|
||||
>
|
||||
class VisitorOpBinary{
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
static int const kElementsPerAccess = kElementsPerAccess_;
|
||||
|
||||
using VisitorA = VisitorA_;
|
||||
using VisitorB = VisitorB_;
|
||||
|
||||
/// Fragment type returned from VisitorA.visit
|
||||
using VisitAccessTypeA = typename VisitorA::VisitAccessType;
|
||||
using ElementA = typename VisitAccessTypeA::Element;
|
||||
|
||||
/// Fragment type returned from VisitorB.visit
|
||||
using VisitAccessTypeB = typename VisitorB::VisitAccessType;
|
||||
using ElementB = typename VisitAccessTypeB::Element;
|
||||
|
||||
/// Fragment type returned by this visitor
|
||||
using VisitAccessType = Array<ElementCompute, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Combination Op TODO: generalize this
|
||||
using BinaryOp = BinaryOp_<ElementCompute, kElementsPerAccess>;
|
||||
|
||||
static_assert(kElementsPerAccess==VisitAccessTypeA::kElements, "kElementsPerAccess mismatches with Visitor A");
|
||||
static_assert(kElementsPerAccess==VisitAccessTypeB::kElements, "kElementsPerAccess misnatches with Visitor B");
|
||||
|
||||
/// SMEM buffer class required in the epilogue visitor
|
||||
struct SharedStorage {
|
||||
typename VisitorA::SharedStorage storage_a;
|
||||
typename VisitorB::SharedStorage storage_b;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() {}
|
||||
};
|
||||
|
||||
|
||||
/// Host-constructable Arguments structure
|
||||
struct Arguments {
|
||||
typename BinaryOp::Arguments binary_arg;
|
||||
typename VisitorA::Arguments visitor_a_arg; ///< Argument type for visitor_a
|
||||
typename VisitorB::Arguments visitor_b_arg; ///< Argument type for visitor_b
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():binary_arg() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
typename BinaryOp::Arguments binary_arg,
|
||||
typename VisitorA::Arguments visitor_a_arg,
|
||||
typename VisitorB::Arguments visitor_b_arg
|
||||
):
|
||||
binary_arg(binary_arg),
|
||||
visitor_a_arg(visitor_a_arg),
|
||||
visitor_b_arg(visitor_b_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
/// Parameter structure
|
||||
struct Params {
|
||||
typename BinaryOp::Params binary_param;
|
||||
typename VisitorA::Params visitor_a_param; ///< Argument type for visitor_a
|
||||
typename VisitorB::Params visitor_b_param; ///< Argument type for visitor_b
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
binary_param(args.binary_arg),
|
||||
visitor_a_param(args.visitor_a_arg),
|
||||
visitor_b_param(args.visitor_b_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
BinaryOp binary_op;
|
||||
|
||||
VisitorA visitor_a_op;
|
||||
VisitorB visitor_b_op;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpBinary(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
binary_op(params.binary_param),
|
||||
visitor_a_op(params.visitor_a_param, shared_storage.storage_a, thread_idx, threadblock_offset, problem_size),
|
||||
visitor_b_op(params.visitor_b_param, shared_storage.storage_b, thread_idx, threadblock_offset, problem_size)
|
||||
{ }
|
||||
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
visitor_a_op.begin_epilogue();
|
||||
visitor_b_op.begin_epilogue();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
visitor_a_op.set_batch_index(batch_idx);
|
||||
visitor_b_op.set_batch_index(batch_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
visitor_a_op.begin_step(step_idx);
|
||||
visitor_b_op.begin_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
visitor_a_op.begin_row(row_idx);
|
||||
visitor_b_op.begin_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
/// Get result from visitor A and visitor B
|
||||
VisitAccessTypeA result_A = visitor_a_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
VisitAccessTypeB result_B = visitor_b_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
|
||||
/// Type conversion
|
||||
NumericArrayConverter<ElementCompute, ElementA, kElementsPerAccess> source_converter_A;
|
||||
NumericArrayConverter<ElementCompute, ElementB, kElementsPerAccess> source_converter_B;
|
||||
|
||||
return binary_op(
|
||||
source_converter_A(result_A),
|
||||
source_converter_B(result_B)
|
||||
);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
visitor_a_op.end_row(row_idx);
|
||||
visitor_b_op.end_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
visitor_a_op.end_step(step_idx);
|
||||
visitor_b_op.end_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
visitor_a_op.end_epilogue();
|
||||
visitor_b_op.end_epilogue();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,250 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with broadcasting vector to all columns
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementVector T[i][j] <- device-memory Td[i]
|
||||
///
|
||||
/// It can only be a leaf node in the epilogue tree
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename ElementFragment_, ///< Data type used to cache vector in register
|
||||
typename InputTileIterator_ ///< Tile iterator type to read the broadcasted tensor
|
||||
>
|
||||
class VisitorOpColumnBroadcast {
|
||||
public:
|
||||
using InputTileIterator = InputTileIterator_;
|
||||
|
||||
static int const kElementsPerAccess = InputTileIterator::kElementsPerAccess;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementVector = typename InputTileIterator::Element;
|
||||
using ElementFragment = ElementFragment_;
|
||||
|
||||
using VisitAccessType = Array<ElementFragment, kElementsPerAccess>;
|
||||
|
||||
/// Thread map used by input tile iterators
|
||||
using ThreadMap = typename InputTileIterator::ThreadMap;
|
||||
|
||||
/// Fragment object used to store the broadcast values
|
||||
using BroadcastFragment = Array<
|
||||
ElementFragment, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Used for the broadcast
|
||||
struct BroadcastDetail {
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = 32;
|
||||
|
||||
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar column indices handled by each thread
|
||||
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar row indices handled by each thread
|
||||
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
|
||||
|
||||
/// Number of threads per threadblock
|
||||
static int const kThreadCount = ThreadMap::kThreads;
|
||||
|
||||
/// Number of distinct threads per row of output tile
|
||||
static int const kThreadsPerRow = (InputTileIterator::Shape::kN / kColumnsPerThread);
|
||||
|
||||
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock.
|
||||
static int const kThreadRows = kThreadCount / kThreadsPerRow;
|
||||
|
||||
// /// Number of iterations (accesses) the threadblock takes to reduce a row
|
||||
// static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount);
|
||||
};
|
||||
|
||||
// using ComputeFragmentType = Array<ElementVector, BroadcastDetail::kElementsPerAccess>;
|
||||
|
||||
struct SharedStorage {
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() { }
|
||||
};
|
||||
|
||||
/// Host-constructable Argument structure
|
||||
struct Arguments {
|
||||
ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand
|
||||
int64_t batch_stride;
|
||||
|
||||
/// Methods
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():
|
||||
broadcast_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ElementVector *broadcast_ptr,
|
||||
int64_t batch_stride
|
||||
):
|
||||
broadcast_ptr(broadcast_ptr),
|
||||
batch_stride(batch_stride) { }
|
||||
};
|
||||
|
||||
/// Param structure
|
||||
struct Params {
|
||||
ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand
|
||||
int64_t batch_stride;
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
broadcast_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
broadcast_ptr(args.broadcast_ptr),
|
||||
batch_stride(args.batch_stride) { }
|
||||
};
|
||||
|
||||
private:
|
||||
ElementVector *broadcast_ptr;
|
||||
BroadcastFragment broadcast_fragment; ///< Array holds the loaded broadcast fragment
|
||||
MatrixCoord threadblock_offset_;
|
||||
int thread_idx_;
|
||||
MatrixCoord problem_size;
|
||||
|
||||
int thread_start_row_;
|
||||
int state_[3];
|
||||
int thread_offset_row_;
|
||||
|
||||
int64_t batch_stride_;
|
||||
|
||||
public:
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpColumnBroadcast(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
broadcast_ptr(params.broadcast_ptr),
|
||||
threadblock_offset_(threadblock_offset),
|
||||
thread_idx_(thread_idx),
|
||||
problem_size(problem_size),
|
||||
thread_start_row_(ThreadMap::initial_offset(thread_idx).row() + threadblock_offset.row()),
|
||||
batch_stride_(params.batch_stride)
|
||||
{
|
||||
state_[0] = state_[1] = state_[2] = 0;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
broadcast_ptr += batch_idx * batch_stride_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
// get pointer
|
||||
thread_offset_row_ = thread_start_row_ + ThreadMap::iteration_offset(frag_idx).row();
|
||||
|
||||
ElementFragment broadcast_data = ElementFragment(*(broadcast_ptr + thread_offset_row_));
|
||||
|
||||
broadcast_fragment.fill(broadcast_data);
|
||||
|
||||
return broadcast_fragment;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
// run operator ++
|
||||
++state_[0];
|
||||
|
||||
thread_start_row_ += ThreadMap::Shape::kRow;
|
||||
if (state_[0] == ThreadMap::Count::kRow) {
|
||||
state_[0] = 0;
|
||||
++state_[1];
|
||||
thread_start_row_ += (ThreadMap::Shape::kGroup - 1) *
|
||||
ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
|
||||
|
||||
if (state_[1] == ThreadMap::Count::kGroup) {
|
||||
state_[1] = 0;
|
||||
++state_[2];
|
||||
thread_start_row_ += ThreadMap::Count::kGroup *
|
||||
ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;
|
||||
|
||||
if (state_[2] == ThreadMap::Count::kCluster) {
|
||||
state_[2] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() { }
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,342 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with reduction over columns in CTA
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementReductionAccumulator R[j] = \sum_i ElementReductionAccumulator(T[i][j])
|
||||
/// device memory <- ElementReduction(R[j])
|
||||
///
|
||||
template <
|
||||
typename ThreadblockShape_, /// Threadblock shape
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename ElementReduction_, ///< Data type of the output reduction in device memory
|
||||
typename ElementReductionAccumulator_ , ///< Data type to accumulate reduction in smem and register
|
||||
typename OutputTileIterator_, ///< Tile Iterator type
|
||||
typename Visitor_ ///< preceeding visitor op
|
||||
>
|
||||
class VisitorOpColumnReduction {
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementReductionAccumulator = ElementReductionAccumulator_;
|
||||
using ElementReduction = ElementReduction_;
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
using Visitor = Visitor_;
|
||||
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
// TODO: generalize the reduction op
|
||||
using ReductionOp = cutlass::plus<Array<ElementReductionAccumulator, kElementsPerAccess>>;
|
||||
using ReductionOpScalar = cutlass::plus<ElementReductionAccumulator>;
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
|
||||
|
||||
|
||||
/// Fragment type returned from Visitor
|
||||
using VisitAccessTypeVisitor = typename Visitor::VisitAccessType;
|
||||
using ElementVisitor = typename VisitAccessTypeVisitor::Element;
|
||||
|
||||
using VisitAccessType = VisitAccessTypeVisitor;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of redcution
|
||||
using ReductionAccumulatorAccessType = Array<ElementReductionAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Thread map used by output tile iterators
|
||||
using ThreadMap = typename OutputTileIterator::ThreadMap;
|
||||
/// Used for the reduction
|
||||
struct ReductionDetail {
|
||||
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = 32;
|
||||
|
||||
/// Number of distinct scalar column indices handled by each thread
|
||||
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar row indices handled by each thread
|
||||
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
|
||||
|
||||
/// Number of threads per threadblock
|
||||
static int const kThreadCount = ThreadMap::kThreads;
|
||||
|
||||
/// Number of distinct threads per row of output tile
|
||||
static int const kThreadsPerRow = ThreadblockShape::kN / kColumnsPerThread;
|
||||
|
||||
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock
|
||||
static int const kThreadRows = kThreadCount / kThreadsPerRow;
|
||||
|
||||
/// Number of iterations (accesses) the threadblock takes to reduce a row
|
||||
static int const kThreadAccessesPerRow = const_max(1, (ThreadblockShape::kN + kThreadCount - 1) / kThreadCount);
|
||||
|
||||
using StorageShape = MatrixShape<
|
||||
kThreadRows,
|
||||
ThreadblockShape::kN
|
||||
>;
|
||||
};
|
||||
|
||||
using ReductionFragment = Array<ElementReductionAccumulator, ReductionDetail::kColumnsPerThread>;
|
||||
|
||||
/// Shared storage
|
||||
struct SharedStorage {
|
||||
typename Visitor::SharedStorage storage_visitor;
|
||||
AlignedArray<ElementReductionAccumulator, ReductionDetail::StorageShape::kCount, 16> reduction;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() {}
|
||||
};
|
||||
|
||||
/// Host-constructable Argument structure
|
||||
struct Arguments {
|
||||
ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory
|
||||
int64_t batch_stride;
|
||||
typename Visitor::Arguments visitor_arg; ///< Argument type of visitor
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): reduction_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ElementReduction *reduction_ptr,
|
||||
int64_t batch_stride,
|
||||
typename Visitor::Arguments visitor_arg
|
||||
):
|
||||
reduction_ptr(reduction_ptr),
|
||||
batch_stride(batch_stride),
|
||||
visitor_arg(visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
/// Param structure
|
||||
struct Params {
|
||||
ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory
|
||||
int64_t batch_stride;
|
||||
typename Visitor::Params visitor_param; ///< Argument type of visitor
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): reduction_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
reduction_ptr(args.reduction_ptr),
|
||||
batch_stride(args.batch_stride),
|
||||
visitor_param(args.visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
private:
|
||||
ElementReduction *reduction_output_ptr_; ///< Pointer to the reduction tensor in device memory
|
||||
ElementReductionAccumulator *reduction_smem_ptr_; ///< Pointer to the partial reductions in shared memory
|
||||
ReductionFragment reduction_fragment; ///< register fragments that hold the partial reduction
|
||||
Visitor visitor_; ///< visitor
|
||||
int thread_idx_;
|
||||
MatrixCoord threadblock_offset;
|
||||
MatrixCoord problem_size_;
|
||||
int64_t batch_stride_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpColumnReduction(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
visitor_(params.visitor_param, shared_storage.storage_visitor,
|
||||
thread_idx, threadblock_offset, problem_size),
|
||||
reduction_smem_ptr_(shared_storage.reduction.data()),
|
||||
reduction_output_ptr_(params.reduction_ptr),
|
||||
thread_idx_(thread_idx),
|
||||
threadblock_offset(threadblock_offset),
|
||||
problem_size_(problem_size),
|
||||
batch_stride_(params.batch_stride)
|
||||
{ }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
reduction_output_ptr_ += batch_idx * batch_stride_;
|
||||
visitor_.set_batch_index(batch_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
visitor_.begin_epilogue();
|
||||
|
||||
// clear the reduction fragment
|
||||
reduction_fragment.clear();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
visitor_.begin_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
visitor_.begin_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
/// Get result from visitor
|
||||
VisitAccessTypeVisitor result = visitor_.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
|
||||
NumericArrayConverter<ElementReductionAccumulator, ElementVisitor, kElementsPerAccess> reduction_converter;
|
||||
ReductionOp reduction_op;
|
||||
ReductionAccumulatorAccessType* reduction_fragment_ = reinterpret_cast<ReductionAccumulatorAccessType*>(&reduction_fragment);
|
||||
reduction_fragment_[column_idx] = reduction_op(reduction_fragment_[column_idx], reduction_converter(result));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
visitor_.end_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
visitor_.end_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
visitor_.end_epilogue();
|
||||
//
|
||||
// Store the partially reduced value to SMEM
|
||||
//
|
||||
|
||||
// Guard against uses of the existing SMEM tile
|
||||
__syncthreads();
|
||||
|
||||
using AccessType = AlignedArray<ElementReductionAccumulator, ThreadMap::kElementsPerAccess>;
|
||||
|
||||
//
|
||||
// Determine a compact thread arrangement to store to SMEM
|
||||
//
|
||||
|
||||
MatrixCoord thread_offset(
|
||||
thread_idx_ / ReductionDetail::kThreadsPerRow,
|
||||
(thread_idx_ % ReductionDetail::kThreadsPerRow) * ThreadMap::kElementsPerAccess
|
||||
);
|
||||
|
||||
//
|
||||
// Each thread store its fragment to a SMEM
|
||||
//
|
||||
AccessType *aligned_reduction_ptr = reinterpret_cast<AccessType *>(
|
||||
&reduction_smem_ptr_[thread_offset.row() * ThreadblockShape::kN + thread_offset.column()]
|
||||
);
|
||||
|
||||
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(
|
||||
&reduction_fragment
|
||||
);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
|
||||
int col_idx = column * ThreadMap::Delta::kColumn / ThreadMap::kElementsPerAccess;
|
||||
|
||||
aligned_reduction_ptr[col_idx] = frag_ptr[column];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
// Now, threads are assigned several columns of the output. The fetch over all rows from
|
||||
// the compacted SMEM tile and perform a reduction.
|
||||
//
|
||||
|
||||
NumericConverter<ElementReduction, ElementReductionAccumulator> output_converter;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < ReductionDetail::kThreadAccessesPerRow; ++j) {
|
||||
int column_idx = thread_idx_ + j * ReductionDetail::kThreadCount;
|
||||
|
||||
ReductionOpScalar reduction_op;
|
||||
ElementReductionAccumulator reduction_element = ElementReductionAccumulator();
|
||||
|
||||
int output_column_idx = threadblock_offset.column() + column_idx;
|
||||
|
||||
if (column_idx < ThreadblockShape::kN && output_column_idx < problem_size_.column()) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int row = 0; row < ReductionDetail::kThreadRows; ++row) {
|
||||
if (row) {
|
||||
auto frag = reduction_smem_ptr_[row * ThreadblockShape::kN + column_idx];
|
||||
reduction_element = reduction_op(reduction_element, frag);
|
||||
}
|
||||
else {
|
||||
|
||||
reduction_element = reduction_smem_ptr_[column_idx];
|
||||
}
|
||||
}
|
||||
|
||||
// Store
|
||||
reduction_output_ptr_[column_idx + threadblock_offset.column() + threadblock_offset.row() / ThreadblockShape::kM * problem_size_.column()] = output_converter(reduction_element);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,266 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with Linear Combination
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementCompute alpha;
|
||||
/// ElementCompute beta;
|
||||
/// ElementCompute C = BinaryOp(alpha * ElementCompute(Visitor_A), beta * ElementCompute(Visitor_B)
|
||||
/// Return C;
|
||||
///
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename ElementCompute_, ///< Data type used to compute linear combination
|
||||
int kElementsPerAccess_, ///< Number of elements computed per operation
|
||||
typename VisitorA_, ///< Child node A
|
||||
typename VisitorB_ ///< Child node B
|
||||
>
|
||||
class VisitorOpLinearCombination{
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
static int const kElementsPerAccess = kElementsPerAccess_;
|
||||
|
||||
using VisitorA = VisitorA_;
|
||||
using VisitorB = VisitorB_;
|
||||
|
||||
/// Fragment type returned from VisitorA.visit
|
||||
using VisitAccessTypeA = typename VisitorA::VisitAccessType;
|
||||
using ElementA = typename VisitAccessTypeA::Element;
|
||||
|
||||
/// Fragment type returned from VisitorB.visit
|
||||
using VisitAccessTypeB = typename VisitorB::VisitAccessType;
|
||||
using ElementB = typename VisitAccessTypeB::Element;
|
||||
|
||||
/// Fragment type returned by this visitor
|
||||
using VisitAccessType = Array<ElementCompute, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Combination Op TODO: generalize this
|
||||
using CombinationOp = cutlass::plus<VisitAccessType>;
|
||||
|
||||
static_assert(kElementsPerAccess==VisitAccessTypeA::kElements, "kElementsPerAccess mismatches with Visitor A");
|
||||
static_assert(kElementsPerAccess==VisitAccessTypeB::kElements, "kElementsPerAccess misnatches with Visitor B");
|
||||
|
||||
/// SMEM buffer class required in the epilogue visitor
|
||||
struct SharedStorage {
|
||||
typename VisitorA::SharedStorage storage_a;
|
||||
typename VisitorB::SharedStorage storage_b;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() {}
|
||||
};
|
||||
|
||||
|
||||
/// Host-constructable Arguments structure
|
||||
struct Arguments {
|
||||
ElementCompute alpha; ///< scales accumulators
|
||||
ElementCompute beta; ///< scales source tensor
|
||||
typename VisitorA::Arguments visitor_a_arg; ///< Argument type for visitor_a
|
||||
typename VisitorB::Arguments visitor_b_arg; ///< Argument type for visitor_b
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():
|
||||
alpha(ElementCompute(1)),
|
||||
beta(ElementCompute(0))
|
||||
{ }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta,
|
||||
typename VisitorA::Arguments visitor_a_arg,
|
||||
typename VisitorB::Arguments visitor_b_arg
|
||||
):
|
||||
alpha(alpha),
|
||||
beta(beta),
|
||||
visitor_a_arg(visitor_a_arg),
|
||||
visitor_b_arg(visitor_b_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
/// Parameter structure
|
||||
struct Params {
|
||||
ElementCompute alpha; ///< scales accumulators
|
||||
ElementCompute beta; ///< scales source tensor
|
||||
typename VisitorA::Params visitor_a_param; ///< Argument type for visitor_a
|
||||
typename VisitorB::Params visitor_b_param; ///< Argument type for visitor_b
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
alpha(args.alpha),
|
||||
beta(args.beta),
|
||||
visitor_a_param(args.visitor_a_arg),
|
||||
visitor_b_param(args.visitor_b_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
ElementCompute alpha_;
|
||||
ElementCompute beta_;
|
||||
|
||||
VisitorA visitor_a_op;
|
||||
VisitorB visitor_b_op;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpLinearCombination(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
alpha_(params.alpha),
|
||||
beta_(params.beta),
|
||||
visitor_a_op(params.visitor_a_param, shared_storage.storage_a, thread_idx, threadblock_offset, problem_size),
|
||||
visitor_b_op(params.visitor_b_param, shared_storage.storage_b, thread_idx, threadblock_offset, problem_size)
|
||||
{ }
|
||||
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
if (alpha_ != ElementCompute(0)) visitor_a_op.begin_epilogue();
|
||||
if (beta_ != ElementCompute(0)) visitor_b_op.begin_epilogue();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
if (alpha_ != ElementCompute(0)) visitor_a_op.begin_step(step_idx);
|
||||
if (beta_ != ElementCompute(0)) visitor_b_op.begin_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
if (alpha_ != ElementCompute(0)) visitor_a_op.begin_row(row_idx);
|
||||
if (beta_ != ElementCompute(0)) visitor_b_op.begin_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
/// Get result from visitor A and visitor B
|
||||
VisitAccessTypeA result_A;
|
||||
VisitAccessTypeB result_B;
|
||||
|
||||
if (alpha_ != ElementCompute(0)) {
|
||||
result_A = visitor_a_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
} else {
|
||||
// Fill the result A with zeros
|
||||
result_A.clear();
|
||||
}
|
||||
|
||||
if (beta_ != ElementCompute(0)) {
|
||||
result_B = visitor_b_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
} else {
|
||||
// Fill the result B with zeros
|
||||
result_B.clear();
|
||||
}
|
||||
|
||||
/// Type conversion
|
||||
NumericArrayConverter<ElementCompute, ElementA, kElementsPerAccess> source_converter_A;
|
||||
NumericArrayConverter<ElementCompute, ElementB, kElementsPerAccess> source_converter_B;
|
||||
|
||||
CombinationOp combination_op;
|
||||
|
||||
cutlass::multiplies<VisitAccessType> multiply_op;
|
||||
|
||||
return combination_op(
|
||||
multiply_op(alpha_, source_converter_A(result_A)),
|
||||
multiply_op(beta_, source_converter_B(result_B))
|
||||
);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
if (alpha_ != ElementCompute(0)) visitor_a_op.end_row(row_idx);
|
||||
if (beta_ != ElementCompute(0)) visitor_b_op.end_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
if (alpha_ != ElementCompute(0)) visitor_a_op.end_step(step_idx);
|
||||
if (beta_ != ElementCompute(0)) visitor_b_op.end_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
if (alpha_ != ElementCompute(0)) visitor_a_op.end_epilogue();
|
||||
if (beta_ != ElementCompute(0)) visitor_b_op.end_epilogue();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,258 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with broadcasting vector to all rows
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementVector T[i][j] <- device-memory Td[j]
|
||||
///
|
||||
/// It can only be a leaf node in the epilogue tree
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename ElementFragment_, ///< Data type used to cache vector in register
|
||||
typename InputTileIterator_ ///< Tile iterator type to read the broadcasted tensor
|
||||
>
|
||||
class VisitorOpRowBroadcast {
|
||||
public:
|
||||
using InputTileIterator = InputTileIterator_;
|
||||
|
||||
static int const kElementsPerAccess = InputTileIterator::kElementsPerAccess;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementVector = typename InputTileIterator::Element;
|
||||
using ElementFragment = ElementFragment_;
|
||||
|
||||
using VisitAccessType = Array<ElementFragment, kElementsPerAccess>;
|
||||
|
||||
/// Thread map used by input tile iterators
|
||||
using ThreadMap = typename InputTileIterator::ThreadMap;
|
||||
|
||||
/// Fragment object used to store the broadcast values
|
||||
using BroadcastFragment = Array<
|
||||
ElementFragment,
|
||||
ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Used for the broadcast
|
||||
struct BroadcastDetail {
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = 32;
|
||||
|
||||
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar column indices handled by each thread
|
||||
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar row indices handled by each thread
|
||||
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
|
||||
|
||||
/// Number of threads per threadblock
|
||||
static int const kThreadCount = ThreadMap::kThreads;
|
||||
|
||||
/// Number of distinct threads per row of output tile
|
||||
static int const kThreadsPerRow = (InputTileIterator::Shape::kN / kColumnsPerThread);
|
||||
|
||||
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock.
|
||||
static int const kThreadRows = kThreadCount / kThreadsPerRow;
|
||||
|
||||
// /// Number of iterations (accesses) the threadblock takes to reduce a row
|
||||
// static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount);
|
||||
};
|
||||
|
||||
// using ComputeFragmentType = Array<ElementVector, BroadcastDetail::kElementsPerAccess>;
|
||||
|
||||
struct SharedStorage {
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() { }
|
||||
};
|
||||
|
||||
/// Host-constructable Argument structure
|
||||
struct Arguments {
|
||||
ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand
|
||||
int64_t batch_stride;
|
||||
|
||||
/// Methods
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():
|
||||
broadcast_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ElementVector *broadcast_ptr,
|
||||
int64_t batch_stride
|
||||
):
|
||||
broadcast_ptr(broadcast_ptr),
|
||||
batch_stride(batch_stride) { }
|
||||
};
|
||||
|
||||
/// Param structure
|
||||
struct Params {
|
||||
ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand
|
||||
int64_t batch_stride;
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
broadcast_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
broadcast_ptr(args.broadcast_ptr),
|
||||
batch_stride(args.batch_stride) { }
|
||||
};
|
||||
|
||||
private:
|
||||
ElementVector *broadcast_ptr;
|
||||
BroadcastFragment broadcast_fragment; ///< Array holds the loaded broadcast fragment
|
||||
MatrixCoord threadblock_offset_;
|
||||
int thread_idx_;
|
||||
MatrixCoord problem_size;
|
||||
int64_t batch_stride_;
|
||||
|
||||
public:
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpRowBroadcast(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
broadcast_ptr(params.broadcast_ptr + threadblock_offset.column()),
|
||||
threadblock_offset_(threadblock_offset),
|
||||
thread_idx_(thread_idx),
|
||||
problem_size(problem_size),
|
||||
batch_stride_(params.batch_stride) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
broadcast_ptr += batch_idx * batch_stride_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
// load broadcast fragment
|
||||
load_broadcast_fragment_();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
VisitAccessType* broadcast_fragment_ = reinterpret_cast<VisitAccessType*>(&broadcast_fragment);
|
||||
return broadcast_fragment_[column_idx];
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() { }
|
||||
|
||||
private:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void load_broadcast_fragment_() {
|
||||
|
||||
broadcast_fragment.clear();
|
||||
|
||||
// If no pointer is supplied, set with all zeros and avoid memory accesses
|
||||
if (!broadcast_ptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column();
|
||||
|
||||
int thread_column_idx = threadblock_offset_.column() + thread_initial_column;
|
||||
broadcast_ptr += thread_initial_column;
|
||||
|
||||
NumericArrayConverter<ElementFragment, ElementVector, BroadcastDetail::kElementsPerAccess> converter;
|
||||
using AccessType = AlignedArray<ElementVector, BroadcastDetail::kElementsPerAccess>;
|
||||
using AccessFragmentType = Array<ElementFragment, BroadcastDetail::kElementsPerAccess>;
|
||||
|
||||
AccessFragmentType *frag_ptr = reinterpret_cast<AccessFragmentType *>(&broadcast_fragment);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) {
|
||||
|
||||
AccessType loaded;
|
||||
|
||||
loaded.clear();
|
||||
|
||||
if (thread_column_idx < problem_size.column()) {
|
||||
loaded = *reinterpret_cast<AccessType const *>(broadcast_ptr);
|
||||
}
|
||||
|
||||
AccessFragmentType cvt = converter(loaded);
|
||||
frag_ptr[j] = cvt;
|
||||
|
||||
thread_column_idx += ThreadMap::Delta::kColumn;
|
||||
broadcast_ptr += ThreadMap::Delta::kColumn;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,320 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with reduction over rows in CTA
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "stdio.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementReductionAccumulator R[i] = \sum_i ElementReductionAccumulator(T[i][j])
|
||||
/// device memory <- ElementReduction(R[i])
|
||||
///
|
||||
template <
|
||||
typename ThreadblockShape_, /// Threadblock shape
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename ElementReduction_, ///< Data type of the output reduction in device memory
|
||||
typename ElementReductionAccumulator_ , ///< Data type to accumulate reduction in smem and register
|
||||
typename OutputTileIterator_, ///< Tile Iterator type
|
||||
typename Visitor_ ///< preceeding visitor op
|
||||
>
|
||||
class VisitorOpRowReduction {
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementReductionAccumulator = ElementReductionAccumulator_;
|
||||
using ElementReduction = ElementReduction_;
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
using Visitor = Visitor_;
|
||||
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
// TODO: generalize the reduction op
|
||||
using ReductionOp = cutlass::plus<Array<ElementReductionAccumulator, kElementsPerAccess>>;
|
||||
using ReductionOpScalar = cutlass::plus<ElementReductionAccumulator>;
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
|
||||
/// Fragment type returned from Visitor
|
||||
using VisitAccessTypeVisitor = typename Visitor::VisitAccessType;
|
||||
using ElementVisitor = typename VisitAccessTypeVisitor::Element;
|
||||
|
||||
using VisitAccessType = VisitAccessTypeVisitor;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of redcution
|
||||
using ReductionAccumulatorAccessType = Array<ElementReductionAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Thread map used by output tile iterators
|
||||
using ThreadMap = typename OutputTileIterator::ThreadMap;
|
||||
/// Used for the reduction
|
||||
struct ReductionDetail {
|
||||
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = 32;
|
||||
|
||||
/// Number of distinct scalar column indices handled by each thread
|
||||
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar row indices handled by each thread
|
||||
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
|
||||
|
||||
/// Number of threads per threadblock
|
||||
static int const kThreadCount = ThreadMap::kThreads;
|
||||
|
||||
/// Number of distinct threads per row of output tile
|
||||
static int const kThreadsPerRow = ThreadblockShape::kN / kColumnsPerThread;
|
||||
|
||||
/// Half number of threads per row used for cross-thread reduction
|
||||
static int const kHalfThreadsPerRow = (kThreadsPerRow >> 1);
|
||||
|
||||
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock
|
||||
static int const kThreadRows = kThreadCount / kThreadsPerRow;
|
||||
};
|
||||
|
||||
/// Shared storage
|
||||
struct SharedStorage {
|
||||
typename Visitor::SharedStorage storage_visitor;
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() { }
|
||||
};
|
||||
|
||||
/// Host-constructable Argument structure
|
||||
struct Arguments {
|
||||
ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory
|
||||
int64_t batch_stride;
|
||||
typename Visitor::Arguments visitor_arg; ///< Argument type of visitor
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): reduction_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ElementReduction *reduction_ptr,
|
||||
int64_t batch_stride,
|
||||
typename Visitor::Arguments visitor_arg
|
||||
):
|
||||
reduction_ptr(reduction_ptr),
|
||||
batch_stride(batch_stride),
|
||||
visitor_arg(visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
/// Param structure
|
||||
struct Params {
|
||||
ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory
|
||||
int64_t batch_stride;
|
||||
typename Visitor::Params visitor_param; ///< Argument type of visitor
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): reduction_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
reduction_ptr(args.reduction_ptr),
|
||||
batch_stride(args.batch_stride),
|
||||
visitor_param(args.visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
private:
|
||||
ElementReduction *reduction_output_ptr_; ///< Pointer to the reduction tensor in device memory
|
||||
ElementReductionAccumulator reduction_accum;
|
||||
Visitor visitor_; ///< visitor
|
||||
int thread_idx_;
|
||||
MatrixCoord threadblock_offset;
|
||||
MatrixCoord problem_size_;
|
||||
|
||||
int thread_start_row_; /// used to identify
|
||||
int state_[3]; /// used to track row iterator
|
||||
int thread_offset_row_;
|
||||
int64_t batch_stride_;
|
||||
public:
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpRowReduction(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
visitor_(params.visitor_param, shared_storage.storage_visitor,
|
||||
thread_idx, threadblock_offset, problem_size),
|
||||
reduction_output_ptr_(params.reduction_ptr),
|
||||
thread_idx_(thread_idx),
|
||||
threadblock_offset(threadblock_offset),
|
||||
problem_size_(problem_size),
|
||||
thread_start_row_(ThreadMap::initial_offset(thread_idx).row() + threadblock_offset.row()),
|
||||
batch_stride_(params.batch_stride)
|
||||
{
|
||||
state_[0] = state_[1] = state_[2] = 0;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
reduction_output_ptr_ += batch_idx * batch_stride_;
|
||||
visitor_.set_batch_index(batch_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
visitor_.begin_epilogue();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
visitor_.begin_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
visitor_.begin_row(row_idx);
|
||||
|
||||
reduction_accum = ElementReductionAccumulator(0);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
/// Get result from visitor
|
||||
VisitAccessTypeVisitor result = visitor_.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
|
||||
thread_offset_row_ = thread_start_row_ + ThreadMap::iteration_offset(frag_idx).row();
|
||||
|
||||
ReductionOpScalar reduction_op;
|
||||
|
||||
ElementReductionAccumulator reduction_accum_ = reduction(result);
|
||||
|
||||
// After performing the in-thread reduction, we then perform cross-thread / in-warp reduction
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = ReductionDetail::kHalfThreadsPerRow; i > 0; i >>= 1) {
|
||||
reduction_accum_ = reduction_op(reduction_accum_, __shfl_xor_sync(0xFFFFFFFF, reduction_accum_, i));
|
||||
}
|
||||
reduction_accum = reduction_op(reduction_accum, reduction_accum_);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
visitor_.end_row(row_idx);
|
||||
NumericConverter<ElementReduction, ElementReductionAccumulator> output_converter;
|
||||
|
||||
bool is_write_thread = (thread_offset_row_ < problem_size_.row() && (thread_idx_ % ReductionDetail::kThreadsPerRow) == 0);
|
||||
int row_offset = thread_offset_row_ + threadblock_offset.column() / ThreadblockShape::kN * problem_size_.row();
|
||||
|
||||
ElementReduction *curr_ptr_reduction = reduction_output_ptr_ + row_offset;
|
||||
|
||||
arch::global_store<ElementReduction, sizeof(ElementReduction)>(
|
||||
output_converter(reduction_accum),
|
||||
(void *)curr_ptr_reduction,
|
||||
is_write_thread);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
visitor_.end_step(step_idx);
|
||||
|
||||
// run operator ++
|
||||
++state_[0];
|
||||
|
||||
thread_start_row_ += ThreadMap::Shape::kRow;
|
||||
if (state_[0] == ThreadMap::Count::kRow) {
|
||||
state_[0] = 0;
|
||||
++state_[1];
|
||||
thread_start_row_ += (ThreadMap::Shape::kGroup - 1) *
|
||||
ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
|
||||
|
||||
if (state_[1] == ThreadMap::Count::kGroup) {
|
||||
state_[1] = 0;
|
||||
++state_[2];
|
||||
thread_start_row_ += ThreadMap::Count::kGroup *
|
||||
ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;
|
||||
|
||||
if (state_[2] == ThreadMap::Count::kCluster) {
|
||||
state_[2] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
visitor_.end_epilogue();
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementReductionAccumulator reduction(VisitAccessTypeVisitor const& result) {
|
||||
ElementReductionAccumulator sum_ = ElementReductionAccumulator(0);
|
||||
|
||||
ReductionOpScalar reduction_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < VisitAccessTypeVisitor::kElements; ++i) {
|
||||
sum_ = reduction_op(sum_, result[i]);
|
||||
}
|
||||
|
||||
return sum_;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,188 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with Tensor Output
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementInput C <- device memory
|
||||
///
|
||||
/// It can only be a leaf node in the epilogue tree
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename InputTileIterator_ ///< Tile iterator type to read the tensor
|
||||
>
|
||||
class VisitorOpTensorInput {
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using InputTileIterator = InputTileIterator_;
|
||||
|
||||
static int const kElementsPerAccess = InputTileIterator::kElementsPerAccess;
|
||||
using ElementInput = typename InputTileIterator::Element;
|
||||
|
||||
using VisitAccessType = Array<ElementInput, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
struct SharedStorage {
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() { }
|
||||
};
|
||||
|
||||
/// Host-constructable Argument structure
|
||||
struct Arguments {
|
||||
ElementInput *input_ptr; ///< Pointer to the input tensor in device memory
|
||||
int ldt; ///< Leading dimension of the input tensor operand
|
||||
int64_t batch_stride; ///< batch stride for batched GEMM
|
||||
|
||||
/// Methods
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): input_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ElementInput *input_ptr,
|
||||
int ldt, int64_t batch_stride
|
||||
):
|
||||
input_ptr(input_ptr),
|
||||
ldt(ldt),
|
||||
batch_stride(batch_stride)
|
||||
{ }
|
||||
};
|
||||
|
||||
/// Param structure
|
||||
struct Params {
|
||||
typename InputTileIterator::Params params_input;
|
||||
ElementInput *input_ptr;
|
||||
int64_t batch_stride;
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
input_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
params_input(args.ldt),
|
||||
input_ptr(args.input_ptr),
|
||||
batch_stride(args.batch_stride)
|
||||
{ }
|
||||
};
|
||||
|
||||
private:
|
||||
InputTileIterator iterator_T_;
|
||||
typename InputTileIterator::Fragment fragment_T_;
|
||||
MatrixCoord problem_size;
|
||||
int64_t batch_stride_;
|
||||
|
||||
public:
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpTensorInput(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
iterator_T_(
|
||||
InputTileIterator(
|
||||
params.params_input,
|
||||
params.input_ptr,
|
||||
problem_size,
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
)
|
||||
),
|
||||
problem_size(problem_size),
|
||||
batch_stride_(params.batch_stride) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
iterator_T_.add_pointer_offset(batch_idx * batch_stride_);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
fragment_T_.clear();
|
||||
iterator_T_.load(fragment_T_);
|
||||
++iterator_T_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
VisitAccessType source = reinterpret_cast<VisitAccessType *>(&fragment_T_)[frag_idx];
|
||||
return source;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() { }
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,240 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with Tensor Output
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "stdio.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementOutput T = ElementOutput(Visitor)
|
||||
/// T-> device memory
|
||||
///
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename OutputTileIterator_, ///< Tile iterator type to write the tensor
|
||||
typename Visitor_ ///< Child visitor that produces the output tensor
|
||||
>
|
||||
class VisitorOpTensorOutput {
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
|
||||
using Visitor = Visitor_;
|
||||
|
||||
/// Fragment type returned from Visitor
|
||||
using VisitAccessTypeVisitor = typename Visitor::VisitAccessType;
|
||||
using ElementVisitor = typename VisitAccessTypeVisitor::Element;
|
||||
|
||||
using VisitAccessType = VisitAccessTypeVisitor;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of output
|
||||
using OutputAccessType = Array<ElementOutput, kElementsPerAccess>;
|
||||
|
||||
static_assert(kElementsPerAccess==VisitAccessTypeVisitor::kElements, "kElementsPerAccess mismatches with Visitor");
|
||||
|
||||
struct SharedStorage {
|
||||
typename Visitor::SharedStorage storage_visitor;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() { }
|
||||
};
|
||||
|
||||
/// Host-constructable Argument structure
|
||||
struct Arguments {
|
||||
ElementOutput *output_ptr; ///< Pointer to the output tensor in device memory
|
||||
int ldt; ///< Leading dimension of the output tensor operand
|
||||
int64_t batch_stride; ///< batch stride
|
||||
typename Visitor::Arguments visitor_arg; ///< Argument type of visitor
|
||||
|
||||
/// Methods
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): output_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ElementOutput *output_ptr,
|
||||
int ldt,
|
||||
int64_t batch_stride,
|
||||
typename Visitor::Arguments visitor_arg
|
||||
):
|
||||
output_ptr(output_ptr),
|
||||
ldt(ldt),
|
||||
batch_stride(batch_stride),
|
||||
visitor_arg(visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
/// Param structure
|
||||
struct Params {
|
||||
typename OutputTileIterator::Params params_output;
|
||||
ElementOutput *output_ptr;
|
||||
int64_t batch_stride;
|
||||
typename Visitor::Params visitor_param;
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
output_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
params_output(args.ldt),
|
||||
output_ptr(args.output_ptr),
|
||||
batch_stride(args.batch_stride),
|
||||
visitor_param(args.visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
private:
|
||||
OutputTileIterator iterator_T_;
|
||||
typename OutputTileIterator::Fragment fragment_T_;
|
||||
MatrixCoord problem_size;
|
||||
Visitor visitor_;
|
||||
int64_t batch_stride_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpTensorOutput(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
visitor_(params.visitor_param, shared_storage.storage_visitor, thread_idx, threadblock_offset, problem_size),
|
||||
iterator_T_(
|
||||
OutputTileIterator(
|
||||
params.params_output,
|
||||
params.output_ptr,
|
||||
problem_size,
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
)
|
||||
),
|
||||
problem_size(problem_size),
|
||||
batch_stride_(params.batch_stride) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
iterator_T_.add_pointer_offset(batch_idx * batch_stride_);
|
||||
visitor_.set_batch_index(batch_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
visitor_.begin_epilogue();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
fragment_T_.clear();
|
||||
visitor_.begin_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
visitor_.begin_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
/// Get result from visitor
|
||||
VisitAccessTypeVisitor result = visitor_.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
|
||||
// Column guard
|
||||
MatrixCoord thread_offset_ = iterator_T_.thread_start() + OutputTileIterator::ThreadMap::iteration_offset(frag_idx);
|
||||
bool column_guard = (thread_offset_.column() < problem_size.column());
|
||||
|
||||
if (column_guard) {
|
||||
NumericArrayConverter<ElementOutput, ElementVisitor, kElementsPerAccess> output_converter;
|
||||
OutputAccessType &output = reinterpret_cast<OutputAccessType *>(&fragment_T_)[frag_idx];
|
||||
output = output_converter(result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
visitor_.end_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
visitor_.end_step(step_idx);
|
||||
iterator_T_.store(fragment_T_);
|
||||
++iterator_T_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
visitor_.end_epilogue();
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,226 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with Unary operation
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "unary_ops.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementCompute alpha;
|
||||
/// ElementCompute beta;
|
||||
/// ElementCompute C = UnaryOp(ElementCompute(Visitor))
|
||||
/// Return C;
|
||||
///
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename ElementCompute_, ///< Data type used to compute linear combination
|
||||
int kElementsPerAccess_, ///< Number of elements computed per operation
|
||||
typename Visitor_, ///< Child node
|
||||
template<typename T, int N> typename UnaryOp_
|
||||
>
|
||||
class VisitorOpUnary{
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
static int const kElementsPerAccess = kElementsPerAccess_;
|
||||
|
||||
using Visitor = Visitor_;
|
||||
|
||||
/// Fragment type returned from Visitor.visit
|
||||
using VisitAccessTypeVisitor = typename Visitor::VisitAccessType;
|
||||
using ElementVisit = typename VisitAccessTypeVisitor::Element;
|
||||
|
||||
/// Fragment type returned by this visitor
|
||||
using VisitAccessType = Array<ElementCompute, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Combination Op TODO: generalize this
|
||||
using UnaryOp = UnaryOp_<ElementCompute, kElementsPerAccess>;
|
||||
|
||||
static_assert(kElementsPerAccess==VisitAccessTypeVisitor::kElements, "kElementsPerAccess mismatches with Visitor");
|
||||
|
||||
/// SMEM buffer class required in the epilogue visitor
|
||||
struct SharedStorage {
|
||||
typename Visitor::SharedStorage storage_visitor;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() {}
|
||||
};
|
||||
|
||||
|
||||
/// Host-constructable Arguments structure
|
||||
struct Arguments {
|
||||
typename UnaryOp::Arguments unary_arg;
|
||||
typename Visitor::Arguments visitor_arg; ///< Argument type for visitor
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():unary_arg() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
typename UnaryOp::Arguments unary_arg,
|
||||
typename Visitor::Arguments visitor_arg
|
||||
):
|
||||
unary_arg(unary_arg),
|
||||
visitor_arg(visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
/// Parameter structure
|
||||
struct Params {
|
||||
typename UnaryOp::Params unary_param;
|
||||
typename Visitor::Params visitor_param; ///< Argument type for visitor
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():unary_param() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
unary_param(args.unary_arg),
|
||||
visitor_param(args.visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
UnaryOp unary_op;
|
||||
|
||||
Visitor visitor_op;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpUnary(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
unary_op(params.unary_param),
|
||||
visitor_op(params.visitor_param, shared_storage.storage_visitor, thread_idx, threadblock_offset, problem_size)
|
||||
{ }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
visitor_op.set_batch_index(batch_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
if (unary_op.guard()) visitor_op.begin_epilogue();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
if (unary_op.guard()) visitor_op.begin_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
if (unary_op.guard()) visitor_op.begin_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
/// Get result from visitor A and visitor B
|
||||
VisitAccessTypeVisitor result;
|
||||
|
||||
if (unary_op.guard()){
|
||||
result = visitor_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
} else {
|
||||
result.clear();
|
||||
}
|
||||
|
||||
/// Type conversion
|
||||
NumericArrayConverter<ElementCompute, ElementVisit, kElementsPerAccess> source_converter;
|
||||
|
||||
cutlass::multiplies<VisitAccessType> multiply_op;
|
||||
|
||||
return unary_op(source_converter(result));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
if (unary_op.guard()) visitor_op.end_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
if (unary_op.guard()) visitor_op.end_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
if (unary_op.guard()) visitor_op.end_epilogue();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,481 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief A file contains all functioning classes needed by GemmLayernorm.
|
||||
|
||||
GemmLayernorm example = GEMM0 with partial reduction fused in epilogue (EpilogueVisitorLayerNorm)
|
||||
+ lightweight full reduction kernel (ApplyFinalReduction)
|
||||
+ GEMM1 with elemenwise operations fused in mainloop (GemmLayernormMainloopFusion)
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_complex.h"
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename ThreadblockShape_,
|
||||
int ThreadCount,
|
||||
typename OutputTileIterator_,
|
||||
typename AccumulatorTile_,
|
||||
typename ElementAccumulator_,
|
||||
typename ElementVariance_,
|
||||
typename ElementMean_,
|
||||
typename ElementLayernormCompute_,
|
||||
typename ElementwiseFunctor_,
|
||||
bool IsShiftedVariance_ = false
|
||||
>
|
||||
class EpilogueVisitorLayerNorm {
|
||||
public:
|
||||
|
||||
using ElementVariance = ElementVariance_;
|
||||
using ElementMean = ElementMean_;
|
||||
using ElementLayernormCompute = ElementLayernormCompute_;
|
||||
|
||||
using AccumulatorTile = AccumulatorTile_;
|
||||
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
static int const kThreadCount = ThreadCount;
|
||||
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using ElementwiseFunctor = ElementwiseFunctor_;
|
||||
|
||||
static int const kIterations = OutputTileIterator::kIterations;
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
static int const kRowIterations = OutputTileIterator::ThreadMap::Iterations::kRow;
|
||||
|
||||
static int const kThreads = OutputTileIterator::ThreadMap::kThreads;
|
||||
|
||||
static bool const kIsShiftedVariance = IsShiftedVariance_;
|
||||
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
|
||||
static int const kDeltaRow = OutputTileIterator::ThreadMap::Delta::kRow;
|
||||
|
||||
/// Array type used in Shift-K Layernorm
|
||||
static int const kRowAccessCount = kIterations * kRowIterations;
|
||||
|
||||
using ConvertedShiftFragment = Array<ElementLayernormCompute, kRowAccessCount>;
|
||||
|
||||
// Conducts manual transpose externally (already supported) for column major
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
|
||||
using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
using LayernormFragment = Array<ElementLayernormCompute, kElementsPerAccess>;
|
||||
using OutputVector = Array<ElementOutput, kElementsPerAccess>;
|
||||
using TensorRefD = TensorRef<ElementOutput, LayoutOutput>;
|
||||
|
||||
static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth;
|
||||
static int const kThreadsInColumn = kThreads / kThreadsPerRow;
|
||||
static int const kHalfThreadsPerRow = (kThreadsPerRow >> 1);
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
typename ElementwiseFunctor::Params elementwise;
|
||||
ElementVariance *ptr_Variance;
|
||||
ElementMean *ptr_Mean;
|
||||
ElementOutput *ptr_Shifted_K;
|
||||
MatrixCoord extent;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
Arguments():
|
||||
ptr_Variance(nullptr),
|
||||
ptr_Mean(nullptr),
|
||||
ptr_Shifted_K(nullptr)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
Arguments(
|
||||
typename ElementwiseFunctor::Params elementwise_,
|
||||
ElementVariance *ptr_Variance,
|
||||
ElementMean *ptr_Mean_,
|
||||
ElementOutput *ptr_Shifted_K_ = nullptr,
|
||||
MatrixCoord extent = MatrixCoord(0, 0)
|
||||
):
|
||||
elementwise(elementwise_),
|
||||
ptr_Variance(ptr_Variance),
|
||||
ptr_Mean(ptr_Mean_),
|
||||
ptr_Shifted_K(ptr_Shifted_K_),
|
||||
extent(extent)
|
||||
{
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
struct Params {
|
||||
|
||||
typename ElementwiseFunctor::Params elementwise;
|
||||
ElementVariance *ptr_Variance;
|
||||
ElementMean *ptr_Mean;
|
||||
ElementOutput *ptr_Shifted_K;
|
||||
MatrixCoord extent;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
ptr_Variance(nullptr),
|
||||
ptr_Mean(nullptr)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
elementwise(args.elementwise),
|
||||
ptr_Variance(args.ptr_Variance),
|
||||
ptr_Mean(args.ptr_Mean),
|
||||
ptr_Shifted_K(args.ptr_Shifted_K),
|
||||
extent(args.extent)
|
||||
{
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared storage
|
||||
struct SharedStorage {
|
||||
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
Params const & params_;
|
||||
SharedStorage & shared_storage_;
|
||||
MatrixCoord extent_;
|
||||
ElementwiseFunctor elementwise_;
|
||||
|
||||
OutputTileIterator iterator_C_;
|
||||
OutputTileIterator iterator_D_;
|
||||
typename OutputTileIterator::Fragment fragment_C_;
|
||||
typename OutputTileIterator::Fragment fragment_D_;
|
||||
|
||||
ElementAccumulator alpha_;
|
||||
ElementAccumulator beta_;
|
||||
ConvertedShiftFragment shift_k_frag_;
|
||||
|
||||
ElementLayernormCompute accum_sum_square_;
|
||||
ElementLayernormCompute accum_sum_element_;
|
||||
int thread_idx_;
|
||||
|
||||
MatrixCoord thread_offset_;
|
||||
|
||||
gemm::GemmCoord threadblock_tile_offset_;
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
EpilogueVisitorLayerNorm(
|
||||
Params const ¶ms, ///< Parameters routed to the epilogue
|
||||
SharedStorage &shared_storage, ///< Shared storage needed by the functors here
|
||||
MatrixCoord threadblock_offset,
|
||||
gemm::GemmCoord threadblock_tile_offset,
|
||||
int thread_idx,
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
OutputTileIterator source_iterator ///< Threadblock tile coordinate in GEMMM
|
||||
):
|
||||
params_(params),
|
||||
shared_storage_(shared_storage),
|
||||
elementwise_(params.elementwise),
|
||||
extent_(params.extent),
|
||||
iterator_C_(source_iterator),
|
||||
iterator_D_(destination_iterator),
|
||||
threadblock_tile_offset_(threadblock_tile_offset),
|
||||
thread_idx_(thread_idx)
|
||||
{
|
||||
alpha_ = (params.elementwise.alpha_ptr ? *params.elementwise.alpha_ptr : params.elementwise.alpha);
|
||||
beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta);
|
||||
|
||||
if (beta_ == ElementAccumulator()) {
|
||||
iterator_C_.clear_mask();
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to indicate split-K behavior
|
||||
CUTLASS_DEVICE
|
||||
void set_k_partition(
|
||||
int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
|
||||
int split_k_slices) { ///< Total number of split-K slices
|
||||
|
||||
}
|
||||
|
||||
/// Called to set the batch index
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
|
||||
}
|
||||
|
||||
/// Called at the start of the epilogue just before iterating over accumulator slices
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
|
||||
// If shift-K feature is enabled, we load shift-k fragment
|
||||
// at the very beginning of an epilogue
|
||||
if (kIsShiftedVariance && params_.ptr_Shifted_K != nullptr) {
|
||||
shift_k_frag_.clear();
|
||||
int thread_offset_row_base = iterator_D_.thread_start_row();
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int iter_idx = 0; iter_idx < kIterations; ++iter_idx) {
|
||||
int step_offset = iter_idx * OutputTileIterator::Shape::kRow;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int rid = 0; rid < kRowIterations; ++rid) {
|
||||
int row_step_offset = rid * kDeltaRow;
|
||||
int row_offset = thread_offset_row_base + step_offset + row_step_offset;
|
||||
bool is_load = (row_offset < extent_.row());
|
||||
shift_k_frag_[iter_idx * kRowIterations + rid] = load_shift_k_(row_offset, is_load);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// Called at the start of one step before starting accumulator exchange
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
fragment_D_.clear();
|
||||
|
||||
if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
|
||||
fragment_C_.clear();
|
||||
iterator_C_.load(fragment_C_);
|
||||
++iterator_C_;
|
||||
}
|
||||
}
|
||||
|
||||
/// Called at the start of a row
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
/// set the accumulator to 0
|
||||
accum_sum_element_ = ElementLayernormCompute(0);
|
||||
accum_sum_square_ = ElementLayernormCompute(0);
|
||||
}
|
||||
|
||||
/// Called after accumulators have been exchanged for each accumulator vector
|
||||
CUTLASS_DEVICE
|
||||
void visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorFragment const &accum) {
|
||||
|
||||
using Mul = cutlass::multiplies<ElementLayernormCompute>;
|
||||
using Minus = cutlass::minus<ElementLayernormCompute>;
|
||||
using Exp = cutlass::fast_exp_op<ElementLayernormCompute>;
|
||||
|
||||
Minus minus;
|
||||
Mul mul;
|
||||
Exp exponential;
|
||||
|
||||
LayernormFragment result;
|
||||
|
||||
thread_offset_ =
|
||||
iterator_D_.thread_start() +
|
||||
OutputTileIterator::ThreadMap::iteration_offset(frag_idx);
|
||||
|
||||
NumericArrayConverter<ElementLayernormCompute, ElementOutput, kElementsPerAccess> source_converter;
|
||||
OutputVector &source_vector = reinterpret_cast<OutputVector *>(&fragment_C_)[frag_idx];
|
||||
|
||||
bool column_guard = (thread_offset_.column() < extent_.column());
|
||||
|
||||
if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
|
||||
result = source_converter(elementwise_(accum));
|
||||
}else{
|
||||
result = source_converter(elementwise_(accum, source_vector));
|
||||
}
|
||||
|
||||
|
||||
ElementLayernormCompute inv_scalar = cutlass::constants::one<ElementLayernormCompute>() / ElementLayernormCompute(extent_.column());
|
||||
|
||||
// Fragment is cleared for non-reachable columns so no need to check against column guard
|
||||
ElementLayernormCompute accum_sum_element_tmp = element_sum_accumulator_(result);
|
||||
|
||||
// Square sum is different. Non-reachable columns should've been computed for shift-k
|
||||
// Otherwise we will incorrectly have some extra k^2 added into square sum.
|
||||
ElementLayernormCompute accum_sum_square_tmp = ElementLayernormCompute(0);
|
||||
|
||||
if (column_guard) {
|
||||
accum_sum_square_tmp = (kIsShiftedVariance) ? \
|
||||
square_sum_accumulator_(result, shift_k_frag_[iter_idx * kRowIterations + row_idx]) : \
|
||||
square_sum_accumulator_(result);
|
||||
}
|
||||
|
||||
accum_sum_element_tmp *= inv_scalar;
|
||||
accum_sum_square_tmp *= inv_scalar;
|
||||
|
||||
// After performing the in-thread reduction, we then perform cross-thread / in-warp reduction
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = kHalfThreadsPerRow; i > 0; i >>= 1) {
|
||||
accum_sum_element_tmp += __shfl_xor_sync(0xFFFFFFFF, accum_sum_element_tmp, i);
|
||||
accum_sum_square_tmp += __shfl_xor_sync(0xFFFFFFFF, accum_sum_square_tmp, i);
|
||||
}
|
||||
accum_sum_element_ += accum_sum_element_tmp;
|
||||
accum_sum_square_ += accum_sum_square_tmp;
|
||||
|
||||
// Convert to the output
|
||||
NumericArrayConverter<ElementOutput, ElementLayernormCompute, kElementsPerAccess> output_converter;
|
||||
OutputVector &output = reinterpret_cast<OutputVector *>(&fragment_D_)[frag_idx];
|
||||
output = output_converter(result);
|
||||
}
|
||||
|
||||
/// Called at the start of a row
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
|
||||
using ConvertVarianceOutput = cutlass::NumericConverter<ElementVariance, ElementLayernormCompute>;
|
||||
using ConvertMeanOutput = cutlass::NumericConverter<ElementMean, ElementLayernormCompute>;
|
||||
|
||||
ConvertVarianceOutput convert_variance_output;
|
||||
ConvertMeanOutput convert_mean_output;
|
||||
|
||||
bool is_write_thread = (thread_offset_.row() < extent_.row() && (threadIdx.x % kThreadsPerRow) == 0);
|
||||
int row_offset = thread_offset_.row() + threadblock_tile_offset_.n() * extent_.row();
|
||||
|
||||
ElementVariance *curr_ptr_sum_square = params_.ptr_Variance + row_offset;
|
||||
ElementMean *curr_ptr_element_sum = params_.ptr_Mean + row_offset;
|
||||
|
||||
arch::global_store<ElementVariance, sizeof(ElementVariance)>(
|
||||
convert_variance_output(accum_sum_square_),
|
||||
(void *)curr_ptr_sum_square,
|
||||
is_write_thread);
|
||||
|
||||
arch::global_store<ElementMean, sizeof(ElementMean)>(
|
||||
convert_mean_output(accum_sum_element_),
|
||||
(void *)curr_ptr_element_sum,
|
||||
is_write_thread);
|
||||
}
|
||||
|
||||
/// Called after all accumulator elements have been visited
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
|
||||
iterator_D_.store(fragment_D_);
|
||||
++iterator_D_;
|
||||
}
|
||||
|
||||
/// Called after all steps have been completed
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementLayernormCompute load_shift_k_(int row_offset, bool is_load) {
|
||||
using ConvertShiftK = cutlass::NumericConverter<ElementLayernormCompute, ElementOutput>;
|
||||
ConvertShiftK convert_shift_k;
|
||||
ElementOutput shift_k_val;
|
||||
|
||||
// Computes the address to load shift_k element
|
||||
ElementOutput *curr_ptr_shift_k = params_.ptr_Shifted_K + row_offset;
|
||||
// Conditionally loads from global memory
|
||||
arch::global_load<ElementOutput, sizeof(ElementOutput)>(shift_k_val, (void *)curr_ptr_shift_k, is_load);
|
||||
// Converts data type to return
|
||||
ElementLayernormCompute converted_shift_k_val = convert_shift_k(shift_k_val);
|
||||
|
||||
return converted_shift_k_val;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum) {
|
||||
ElementLayernormCompute sum_ = ElementLayernormCompute(0);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < LayernormFragment::kElements; ++i) {
|
||||
auto accum_ = accum[i];
|
||||
sum_ += accum_ * accum_;
|
||||
}
|
||||
|
||||
return sum_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum, ElementLayernormCompute shift_k_val) {
|
||||
ElementLayernormCompute sum_ = ElementLayernormCompute(0);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < LayernormFragment::kElements; ++i) {
|
||||
auto accum_ = accum[i] - shift_k_val;
|
||||
sum_ += accum_ * accum_;
|
||||
}
|
||||
|
||||
return sum_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementLayernormCompute element_sum_accumulator_(LayernormFragment const &accum) {
|
||||
ElementLayernormCompute sum_ = ElementLayernormCompute(0);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < LayernormFragment::kElements; ++i) {
|
||||
sum_ += accum[i];
|
||||
}
|
||||
|
||||
return sum_;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,692 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
|
||||
>
|
||||
struct GemmUniversalwithEpilogueVisitor {
|
||||
public:
|
||||
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueVisitor = typename Epilogue::Visitor;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename EpilogueVisitor::ElementOutput;
|
||||
using LayoutC = typename EpilogueVisitor::OutputTileIterator::Layout;
|
||||
|
||||
static ComplexTransform const kTransformA = Mma::kTransformA;
|
||||
static ComplexTransform const kTransformB = Mma::kTransformB;
|
||||
using Operator = typename Mma::Operator;
|
||||
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
/// Split-K preserves splits that are 128b aligned
|
||||
static int const kSplitKAlignment = const_max(
|
||||
128 / sizeof_bits<ElementA>::value,
|
||||
128 / sizeof_bits<ElementB>::value
|
||||
);
|
||||
|
||||
//
|
||||
// Structures
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmUniversalMode mode;
|
||||
GemmCoord problem_size;
|
||||
int batch_count;
|
||||
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor;
|
||||
|
||||
void const * ptr_A;
|
||||
void const * ptr_B;
|
||||
void const * ptr_C;
|
||||
void * ptr_D;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
typename LayoutA::Stride stride_a;
|
||||
typename LayoutB::Stride stride_b;
|
||||
typename LayoutC::Stride stride_c;
|
||||
typename LayoutC::Stride stride_d;
|
||||
|
||||
typename LayoutA::Stride::LongIndex lda;
|
||||
typename LayoutB::Stride::LongIndex ldb;
|
||||
typename LayoutC::Stride::LongIndex ldc;
|
||||
typename LayoutC::Stride::LongIndex ldd;
|
||||
|
||||
int const * ptr_gather_A_indices;
|
||||
int const * ptr_gather_B_indices;
|
||||
int const * ptr_scatter_D_indices;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Arguments():
|
||||
mode(GemmUniversalMode::kGemm),
|
||||
batch_count(1),
|
||||
ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr),
|
||||
ptr_gather_A_indices(nullptr),
|
||||
ptr_gather_B_indices(nullptr),
|
||||
ptr_scatter_D_indices(nullptr) {}
|
||||
|
||||
/// constructs an arguments structure
|
||||
Arguments(
|
||||
GemmUniversalMode mode,
|
||||
GemmCoord problem_size,
|
||||
int batch_count,
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor,
|
||||
void const * ptr_A,
|
||||
void const * ptr_B,
|
||||
void const * ptr_C,
|
||||
void * ptr_D,
|
||||
int64_t batch_stride_A,
|
||||
int64_t batch_stride_B,
|
||||
int64_t batch_stride_C,
|
||||
int64_t batch_stride_D,
|
||||
typename LayoutA::Stride stride_a,
|
||||
typename LayoutB::Stride stride_b,
|
||||
typename LayoutC::Stride stride_c,
|
||||
typename LayoutC::Stride stride_d,
|
||||
int const *ptr_gather_A_indices = nullptr,
|
||||
int const *ptr_gather_B_indices = nullptr,
|
||||
int const *ptr_scatter_D_indices = nullptr
|
||||
):
|
||||
mode(mode),
|
||||
problem_size(problem_size),
|
||||
batch_count(batch_count),
|
||||
epilogue_visitor(epilogue_visitor),
|
||||
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
|
||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
|
||||
stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d),
|
||||
ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),
|
||||
ptr_scatter_D_indices(ptr_scatter_D_indices) {
|
||||
lda = 0;
|
||||
ldb = 0;
|
||||
ldc = 0;
|
||||
ldd = 0;
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
|
||||
}
|
||||
|
||||
/// constructs an arguments structure
|
||||
Arguments(
|
||||
GemmUniversalMode mode,
|
||||
GemmCoord problem_size,
|
||||
int batch_count,
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor,
|
||||
void const * ptr_A,
|
||||
void const * ptr_B,
|
||||
void const * ptr_C,
|
||||
void * ptr_D,
|
||||
int64_t batch_stride_A,
|
||||
int64_t batch_stride_B,
|
||||
int64_t batch_stride_C,
|
||||
int64_t batch_stride_D,
|
||||
typename LayoutA::Stride::LongIndex lda,
|
||||
typename LayoutB::Stride::LongIndex ldb,
|
||||
typename LayoutC::Stride::LongIndex ldc,
|
||||
typename LayoutC::Stride::LongIndex ldd,
|
||||
int const *ptr_gather_A_indices = nullptr,
|
||||
int const *ptr_gather_B_indices = nullptr,
|
||||
int const *ptr_scatter_D_indices = nullptr
|
||||
):
|
||||
mode(mode),
|
||||
problem_size(problem_size),
|
||||
batch_count(batch_count),
|
||||
epilogue_visitor(epilogue_visitor),
|
||||
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
|
||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
|
||||
lda(lda), ldb(ldb), ldc(ldc), ldd(ldd),
|
||||
ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),
|
||||
ptr_scatter_D_indices(ptr_scatter_D_indices) {
|
||||
stride_a = make_Coord(lda);
|
||||
stride_b = make_Coord(ldb);
|
||||
stride_c = make_Coord(ldc);
|
||||
stride_d = make_Coord(ldd);
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
|
||||
}
|
||||
|
||||
/// Returns arguments for the transposed problem
|
||||
Arguments transposed_problem() const {
|
||||
Arguments args(*this);
|
||||
|
||||
std::swap(args.problem_size.m(), args.problem_size.n());
|
||||
std::swap(args.ptr_A, args.ptr_B);
|
||||
std::swap(args.lda, args.ldb);
|
||||
std::swap(args.stride_a, args.stride_b);
|
||||
std::swap(args.batch_stride_A, args.batch_stride_B);
|
||||
std::swap(args.ptr_gather_A_indices, args.ptr_gather_B_indices);
|
||||
|
||||
return args;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
|
||||
typename Mma::IteratorA::Params params_A;
|
||||
typename Mma::IteratorB::Params params_B;
|
||||
typename EpilogueVisitor::OutputTileIterator::Params params_C;
|
||||
typename EpilogueVisitor::OutputTileIterator::Params params_D;
|
||||
|
||||
typename EpilogueVisitor::Params epilogue_visitor;
|
||||
|
||||
GemmUniversalMode mode;
|
||||
int batch_count;
|
||||
int gemm_k_size;
|
||||
|
||||
void * ptr_A;
|
||||
void * ptr_B;
|
||||
void * ptr_C;
|
||||
void * ptr_D;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
int * ptr_gather_A_indices;
|
||||
int * ptr_gather_B_indices;
|
||||
int * ptr_scatter_D_indices;
|
||||
|
||||
int *semaphore;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
swizzle_log_tile(0),
|
||||
params_A(0),
|
||||
params_B(0),
|
||||
params_C(0),
|
||||
params_D(0),
|
||||
batch_count(0),
|
||||
gemm_k_size(0),
|
||||
mode(cutlass::gemm::GemmUniversalMode::kGemm),
|
||||
ptr_A(nullptr),
|
||||
ptr_B(nullptr),
|
||||
ptr_C(nullptr),
|
||||
ptr_D(nullptr),
|
||||
batch_stride_A(0),
|
||||
batch_stride_B(0),
|
||||
batch_stride_C(0),
|
||||
batch_stride_D(0),
|
||||
ptr_gather_A_indices(nullptr),
|
||||
ptr_gather_B_indices(nullptr),
|
||||
ptr_scatter_D_indices(nullptr),
|
||||
semaphore(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Arguments const &args,
|
||||
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
||||
int gemm_k_size,
|
||||
void *workspace = nullptr
|
||||
):
|
||||
problem_size(args.problem_size),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
|
||||
params_A(args.lda ? make_Coord_with_padding<LayoutA::kStrideRank>(args.lda) : args.stride_a),
|
||||
params_B(args.ldb ? make_Coord_with_padding<LayoutB::kStrideRank>(args.ldb) : args.stride_b),
|
||||
params_C(args.ldc ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldc) : args.stride_c),
|
||||
params_D(args.ldd ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldd) : args.stride_d),
|
||||
epilogue_visitor(args.epilogue_visitor),
|
||||
mode(args.mode),
|
||||
batch_count(args.batch_count),
|
||||
gemm_k_size(gemm_k_size),
|
||||
ptr_A(const_cast<void *>(args.ptr_A)),
|
||||
ptr_B(const_cast<void *>(args.ptr_B)),
|
||||
ptr_C(const_cast<void *>(args.ptr_C)),
|
||||
ptr_D(args.ptr_D),
|
||||
batch_stride_A(args.batch_stride_A),
|
||||
batch_stride_B(args.batch_stride_B),
|
||||
batch_stride_C(args.batch_stride_C),
|
||||
batch_stride_D(args.batch_stride_D),
|
||||
ptr_gather_A_indices(const_cast<int *>(args.ptr_gather_A_indices)),
|
||||
ptr_gather_B_indices(const_cast<int *>(args.ptr_gather_B_indices)),
|
||||
ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices)),
|
||||
semaphore(static_cast<int *>(workspace)) {
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void update(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr) {
|
||||
|
||||
ptr_A = const_cast<void *>(args.ptr_A);
|
||||
ptr_B = const_cast<void *>(args.ptr_B);
|
||||
ptr_C = const_cast<void *>(args.ptr_C);
|
||||
ptr_D = args.ptr_D;
|
||||
|
||||
ptr_gather_A_indices = const_cast<int *>(args.ptr_gather_A_indices);
|
||||
ptr_gather_B_indices = const_cast<int *>(args.ptr_gather_B_indices);
|
||||
ptr_scatter_D_indices = const_cast<int *>(args.ptr_scatter_D_indices);
|
||||
|
||||
batch_stride_A = args.batch_stride_A;
|
||||
batch_stride_B = args.batch_stride_B;
|
||||
batch_stride_C = args.batch_stride_C;
|
||||
batch_stride_D = args.batch_stride_D;
|
||||
|
||||
epilogue_visitor = args.epilogue_visitor;
|
||||
|
||||
semaphore = static_cast<int *>(workspace);
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::Params::update()");
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage {
|
||||
typename Mma::SharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
typename EpilogueVisitor::SharedStorage visitor;
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GemmUniversalwithEpilogueVisitor() { }
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(
|
||||
cutlass::gemm::GemmCoord const & problem_size) {
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalwithEpilogueVisitor::can_implement()");
|
||||
|
||||
static int const kAlignmentA = (platform::is_same<LayoutA,
|
||||
layout::ColumnMajorInterleaved<32>>::value)
|
||||
? 32
|
||||
: (platform::is_same<LayoutA,
|
||||
layout::ColumnMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = (platform::is_same<LayoutB,
|
||||
layout::RowMajorInterleaved<32>>::value)
|
||||
? 32
|
||||
: (platform::is_same<LayoutB,
|
||||
layout::RowMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = (platform::is_same<LayoutC,
|
||||
layout::ColumnMajorInterleaved<32>>::value)
|
||||
? 32
|
||||
: (platform::is_same<LayoutC,
|
||||
layout::ColumnMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
bool isAMisaligned = false;
|
||||
bool isBMisaligned = false;
|
||||
bool isCMisaligned = false;
|
||||
|
||||
if (platform::is_same<LayoutA, layout::RowMajor>::value) {
|
||||
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||
} else if (platform::is_same<LayoutA, layout::ColumnMajor>::value) {
|
||||
isAMisaligned = problem_size.m() % kAlignmentA;
|
||||
} else if (platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {
|
||||
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||
}
|
||||
|
||||
if (platform::is_same<LayoutB, layout::RowMajor>::value) {
|
||||
isBMisaligned = problem_size.n() % kAlignmentB;
|
||||
} else if (platform::is_same<LayoutB, layout::ColumnMajor>::value) {
|
||||
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||
} else if (platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {
|
||||
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||
}
|
||||
|
||||
if (platform::is_same<LayoutC, layout::RowMajor>::value) {
|
||||
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||
} else if (platform::is_same<LayoutC, layout::ColumnMajor>::value) {
|
||||
isCMisaligned = problem_size.m() % kAlignmentC;
|
||||
} else if (platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {
|
||||
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||
}
|
||||
|
||||
if (isAMisaligned) {
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (isBMisaligned) {
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (isCMisaligned) {
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" returning kSuccess");
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const &args) {
|
||||
return can_implement(args.problem_size);
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(Arguments const &args,
|
||||
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
||||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
int offset_k = 0;
|
||||
int problem_size_k = params.problem_size.k();
|
||||
|
||||
ElementA *ptr_A = static_cast<ElementA *>(params.ptr_A);
|
||||
ElementB *ptr_B = static_cast<ElementB *>(params.ptr_B);
|
||||
|
||||
//
|
||||
// Fetch pointers based on mode.
|
||||
//
|
||||
if (params.mode == GemmUniversalMode::kGemm ||
|
||||
params.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
|
||||
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) {
|
||||
|
||||
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
|
||||
}
|
||||
|
||||
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kBatched) {
|
||||
ptr_A += threadblock_tile_offset.k() * params.batch_stride_A;
|
||||
ptr_B += threadblock_tile_offset.k() * params.batch_stride_B;
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kArray) {
|
||||
ptr_A = static_cast<ElementA * const *>(params.ptr_A)[threadblock_tile_offset.k()];
|
||||
ptr_B = static_cast<ElementB * const *>(params.ptr_B)[threadblock_tile_offset.k()];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
offset_k,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{
|
||||
offset_k,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN
|
||||
};
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
params.params_A,
|
||||
ptr_A,
|
||||
{params.problem_size.m(), problem_size_k},
|
||||
thread_idx,
|
||||
tb_offset_A,
|
||||
params.ptr_gather_A_indices);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
params.params_B,
|
||||
ptr_B,
|
||||
{problem_size_k, params.problem_size.n()},
|
||||
thread_idx,
|
||||
tb_offset_B,
|
||||
params.ptr_gather_B_indices);
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(
|
||||
gemm_k_iterations,
|
||||
accumulators,
|
||||
iterator_A,
|
||||
iterator_B,
|
||||
accumulators);
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
// EpilogueOutputOp output_op(params.output_op);
|
||||
|
||||
//
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
//assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN
|
||||
);
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
ElementC *ptr_C = static_cast<ElementC *>(params.ptr_C);
|
||||
ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);
|
||||
|
||||
//
|
||||
// Fetch pointers based on mode.
|
||||
//
|
||||
|
||||
// Construct the semaphore.
|
||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||
|
||||
// if (params.mode == GemmUniversalMode::kGemm) {
|
||||
|
||||
// // TODO: fix this order
|
||||
// // If performing a reduction via split-K, fetch the initial synchronization
|
||||
// if (params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// // Fetch the synchronization lock initially but do not block.
|
||||
// semaphore.fetch();
|
||||
|
||||
// // Indicate which position in a serial reduction the output operator is currently updating
|
||||
// output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
// }
|
||||
// }
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
|
||||
EpilogueVisitor epilogue_visitor(
|
||||
params.epilogue_visitor,
|
||||
shared_storage.visitor,
|
||||
threadblock_offset,
|
||||
threadblock_tile_offset,
|
||||
thread_idx,
|
||||
params.problem_size.mn()
|
||||
);
|
||||
|
||||
// if (params.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
// ptr_D += threadblock_tile_offset.k() * params.batch_stride_D;
|
||||
// }
|
||||
if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) {
|
||||
epilogue_visitor.set_batch_index(threadblock_tile_offset.k());
|
||||
}
|
||||
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
// Wait on the semaphore - this latency may have been covered by iterator construction
|
||||
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
||||
// TODO: ???
|
||||
// if (threadblock_tile_offset.k()) {
|
||||
// iterator_C = iterator_D;
|
||||
// }
|
||||
|
||||
semaphore.wait(threadblock_tile_offset.k());
|
||||
}
|
||||
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(epilogue_visitor, accumulators);
|
||||
|
||||
//
|
||||
// Release the semaphore
|
||||
//
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
int lock = 0;
|
||||
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
|
||||
|
||||
// The final threadblock resets the semaphore for subsequent grids.
|
||||
lock = 0;
|
||||
}
|
||||
else {
|
||||
// Otherwise, the semaphore is incremented
|
||||
lock = threadblock_tile_offset.k() + 1;
|
||||
}
|
||||
|
||||
semaphore.release(lock);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -50,7 +50,13 @@ void bind_tensor_coord(py::module &m) {
|
||||
R"pbdoc(Defines a canonical 4D coordinate used by tensor operations)pbdoc")
|
||||
.def(py::init<int, int, int, int>(),
|
||||
py::arg("n"), py::arg("h"), py::arg("w"), py::arg("c"),
|
||||
R"pbdoc(Helper to construct from N, H, W, and C)pbdoc");
|
||||
R"pbdoc(Helper to construct from N, H, W, and C)pbdoc")
|
||||
.def("at", py::overload_cast<int>(&cutlass::Tensor4DCoord::at),
|
||||
py::arg("dim"),
|
||||
R"pbdoc(Gets the index of a given Coord element)pbdoc")
|
||||
.def("size", [](const cutlass::Tensor4DCoord & coord) {
|
||||
return coord.at(0) * coord.at(1) * coord.at(2) * coord.at(3);},
|
||||
R"pbdoc(The size of the tensor coord)pbdoc");
|
||||
|
||||
py::class_<cutlass::Coord<3>>(m, "Tensor3DCoord",
|
||||
R"pbdoc(Defines a canonical 3D coordinate used by tensor operations)pbdoc")
|
||||
|
||||
Reference in New Issue
Block a user