ex42: Fused MHA imported from xFormers (#662)

* ex42: Fused MHA imported from xFormers

* Remove std:: references

* Support K>128 in the example

* Support causal option

* Support different head size for V, and different seqlength for KV

* Update FLOPS counter

* Remove bit_cast

* fix build: Replace M_LOG2E

* Add doc

* Revert "Remove bit_cast"

This reverts commit 9662fa86bb.

* Explicit casts to int32_t for windows build

Co-authored-by: danthe3rd <danthe3rd>
This commit is contained in:
dan_the_3rd
2022-10-17 16:49:33 +02:00
committed by GitHub
parent 3bf95e90c2
commit 4db6a6140e
21 changed files with 12415 additions and 0 deletions

View File

@ -0,0 +1,36 @@
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cutlass_example_add_executable(
42_fused_multi_head_attention
fused_multihead_attention.cu
)

View File

@ -0,0 +1,482 @@
#pragma once
#include "cutlass/functional.h"
#include "cutlass/gemm/warp/mma_simt_tile_iterator.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h"
#include "cutlass/matrix_shape.h"
#include "gemm_kernel_utils.h"
namespace {
static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) {
// source: https://stackoverflow.com/a/51549250
return (value >= 0)
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
: __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
}
} // namespace
/* Iterates on the accumulator and corresponding position on result matrix
(1) Update `mi[r]` to the max value of the row `r`
(2) In a second iteration do the following:
(a) accum <- exp(accum - mi)
(b) m_prime <- exp(m_prime - mi)
(c) s_prime <- s_prime * m_prime + sum(accum)
All of this is done on registers, before we store all of this
on shared memory for the next matmul with Value.
We have multiple implementations, because each configuration has a different way
of iterating in the accumulators.
*/
template <typename BASE, typename T, typename accum_t, int kWarpSize>
struct RegisterOps {
template <
int kQueriesPerBlock,
bool kFullColumns,
bool kIsFirst,
bool kKeepOutputInRF>
CUTLASS_DEVICE static void update(
typename T::Fragment& frag_o, // output so far
typename T::Fragment& frag,
cutlass::Array<accum_t, kQueriesPerBlock>& mi,
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
int8_t lane_id,
int8_t thread_id,
int8_t warp_id,
int16_t max_col,
typename T::TensorCoord const& tile_offset,
float scaling) {
// Convert to `accum_t` (rather than double)
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
if (!kIsFirst) {
if (thread_id < kQueriesPerBlock) {
m_prime[thread_id] = mi[thread_id];
}
__syncthreads();
}
auto lane_offset = BASE::get_lane_offset(lane_id, warp_id, tile_offset);
// First update `mi` to the max per-row
{
accum_t max;
BASE::iterateRows(
lane_offset,
[&](int accum_m) {
max = -cutlass::platform::numeric_limits<accum_t>::infinity();
},
[&](int accum_m, int accum_n, int idx) {
if (kFullColumns || accum_n < max_col) {
max = cutlass::fast_max(max, frag[idx]);
}
},
[&](int accum_m) {
// Having 4x atomicMax seems faster than reduce within warp
// first...
atomicMaxFloat(&mi[accum_m], max * scaling);
});
}
frag = cutlass::multiplies<typename T::Fragment>()(scaling * kLog2e, frag);
// Make sure we all share the update values for `mi`
__syncthreads();
if (thread_id < kQueriesPerBlock) {
auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id]));
m_prime[thread_id] = m_prime_exp;
s_prime[thread_id] *= m_prime_exp;
}
__syncthreads(); // Update output fragments
if (kKeepOutputInRF && !kIsFirst) {
accum_t mp;
BASE::iterateRows(
lane_offset,
[&](int accum_m) { mp = m_prime[accum_m]; },
[&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; },
[&](int accum_m) {});
__syncthreads();
}
// Update accum_m, accum_n, ...
{
accum_t mi_row, total_row;
BASE::iterateRows(
lane_offset,
[&](int accum_m) { mi_row = kLog2e * mi[accum_m]; },
[&](int accum_m, int accum_n, int idx) {
frag[idx] = (kFullColumns || accum_n < max_col)
? exp2f(frag[idx] - mi_row)
: accum_t(0.0);
},
[&](int accum_m) {});
BASE::iterateRows(
lane_offset,
[&](int accum_m) { total_row = 0.0; },
[&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; },
[&](int accum_m) {
if (BASE::reduceSameRow(
lane_id, total_row, [](accum_t a, accum_t b) {
return a + b;
})) {
atomicAdd(&s_prime[accum_m], total_row);
}
});
}
}
};
template <typename T, typename accum_t, int kWarpSize>
struct AttentionScalingCoefsUpdaterSm80
: RegisterOps<
AttentionScalingCoefsUpdaterSm80<T, accum_t, kWarpSize>,
T,
accum_t,
kWarpSize> {
static_assert(
cutlass::platform::
is_same<typename T::Layout, cutlass::layout::RowMajor>::value,
"only RowMajor is supported");
using Policy = typename T::Policy;
using InstructionShape = typename T::InstructionShape;
using OpDelta = typename T::OpDelta;
using Shape = typename T::Shape;
static int const kElementsPerAccess = InstructionShape::kN / 4;
static int const kRowsPerTile = 8;
static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile;
static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset(
int8_t lane_id,
int8_t warp_id,
typename T::TensorCoord const& tile_offset) {
int quad = (lane_id >> 2);
int lane_in_quad = (lane_id & 3);
return cutlass::MatrixCoord(
quad + tile_offset.row() * Shape::kRow,
lane_in_quad * kElementsPerAccess +
tile_offset.column() * Shape::kColumn);
}
template <typename FA, typename FB, typename FC>
CUTLASS_DEVICE static void iterateRows(
cutlass::MatrixCoord& lane_offset,
FA beginRow,
FB op,
FC endRow) {
// See cutlass/gemm/warp/mma_tensor_op_tile_iterator.h
CUTLASS_PRAGMA_UNROLL
for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < kAccumulatorRows; ++row) {
int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow +
row * kRowsPerTile + lane_offset.row();
beginRow(accum_m);
CUTLASS_PRAGMA_UNROLL
for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) {
int mma_accum_start = kAccumulatorRows * kElementsPerAccess *
(mma_n * Policy::MmaIterations::kRow + mma_m);
CUTLASS_PRAGMA_UNROLL
for (int col = 0; col < kElementsPerAccess; ++col) {
int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn +
col + lane_offset.column();
int idx = mma_accum_start + row * kElementsPerAccess + col;
op(accum_m, accum_n, idx);
}
}
endRow(accum_m);
}
}
}
template <typename DT, typename F>
CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) {
// In each warp, 4 threads will work on the same row
// - the ones with the same `quad`
auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1);
myValue = fn(myValue, otherV);
otherV = __shfl_xor_sync(0xffffffff, myValue, 2);
myValue = fn(myValue, otherV);
int lane_in_quad = (lane_id & 3);
return lane_in_quad == 0;
}
};
template <typename T, typename accum_t, int kWarpSize>
struct AttentionScalingCoefsUpdaterVolta
: RegisterOps<
AttentionScalingCoefsUpdaterVolta<T, accum_t, kWarpSize>,
T,
accum_t,
kWarpSize> {
static_assert(
cutlass::platform::
is_same<typename T::Layout, cutlass::layout::RowMajor>::value,
"only RowMajor is supported");
using Policy = typename T::Policy;
using InstructionShape = typename T::InstructionShape;
using OpDelta = typename T::OpDelta;
using Shape = typename T::Shape;
using Element = accum_t;
static int const kElementsPerPartial = 4;
using EleShapePerPatial = typename cutlass::platform::conditional<
cutlass::platform::is_same<Element, float>::value,
cutlass::MatrixShape<2, 2>,
cutlass::MatrixShape<1, 4>>::type;
static int const kElementsPerMma = 8;
static int const kAccumulatorPatials = 2;
using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>;
static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset(
int8_t lane_id,
int8_t warp_id,
typename T::TensorCoord const& tile_offset) {
int quad = (lane_id >> 2);
int lane_in_quad = (lane_id & 3);
int accum_m, accum_n;
if (cutlass::platform::is_same<Element, float>::value) {
// (quad[2],quad[0])+lane_in_quad[0]
accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1);
// (quad[1])+lane_in_quad[1]
accum_n =
((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials +
(lane_in_quad & 2);
} else {
accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 +
lane_in_quad; // (quad[2],quad[0])
accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials;
}
return cutlass::MatrixCoord(
accum_m + tile_offset.row() * Shape::kRow,
accum_n + tile_offset.column() * Shape::kColumn);
}
template <typename DT, typename F>
CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) {
static_assert(
cutlass::platform::is_same<Element, float>::value,
"update to support non-float accum");
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
// T0 & T2 share same line within a quad
auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 1);
myValue = fn(myValue, otherV);
// quad 0 and quad 2 are on the same lines
otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 3);
myValue = fn(myValue, otherV);
return (lane_id & ((1 << 1) | (1 << 3))) == 0;
}
template <typename FA, typename FB, typename FC>
CUTLASS_DEVICE static void iterateRows(
cutlass::MatrixCoord& lane_offset,
FA beginRow,
FB op,
FC endRow) {
CUTLASS_PRAGMA_UNROLL
for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) {
CUTLASS_PRAGMA_UNROLL
for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < EleShapePerPatial::kRow; ++m) {
int accum_m = tile_m * Policy::InterleavedTile::kRow +
mma_m * QuadShapePerPatialMma::kRow + m * 2 + lane_offset.row();
beginRow(accum_m);
CUTLASS_PRAGMA_UNROLL
for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn;
++tile_n) {
CUTLASS_PRAGMA_UNROLL
for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn;
++mma_n) {
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < kAccumulatorPatials; ++p) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < EleShapePerPatial::kColumn; ++n) {
int mma_accum_start =
(((tile_n * Policy::TileIterations::kRow + tile_m) *
Policy::MmaIterations::kColumn +
mma_n) *
Policy::MmaIterations::kRow +
mma_m) *
kElementsPerMma;
int accum_n = tile_n * Policy::InterleavedTile::kColumn +
mma_n * QuadShapePerPatialMma::kColumn +
p * Policy::InterleavedTile::kColumn / 2 + n +
lane_offset.column();
int idx = mma_accum_start + p * kElementsPerPartial +
m * EleShapePerPatial::kColumn + n;
op(accum_m, accum_n, idx);
}
}
}
}
endRow(accum_m);
}
}
}
}
};
template <typename T, typename accum_t, int kWarpSize>
struct AttentionScalingCoefsUpdaterSimt
: RegisterOps<
AttentionScalingCoefsUpdaterSimt<T, accum_t, kWarpSize>,
T,
accum_t,
kWarpSize> {
using Policy = typename T::Policy;
using Iterations = typename T::Iterations;
using Element = typename T::Element;
using Delta = typename T::Delta;
using Shape = typename T::Shape;
static_assert(
cutlass::platform::
is_same<typename T::Layout, cutlass::layout::RowMajor>::value,
"only RowMajor is supported");
template <typename DT, typename F>
CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) {
CUTLASS_PRAGMA_UNROLL
for (int bit = 1; bit < Policy::WarpShape::kColumn; bit *= 2) {
auto otherV = __shfl_xor_sync(0xffffffff, myValue, bit);
myValue = fn(myValue, otherV);
}
return (lane_id & (Policy::WarpShape::kColumn - 1)) == 0;
}
template <typename FA, typename FB, typename FC>
CUTLASS_DEVICE static void iterateRows(
cutlass::MatrixCoord& lane_offset,
FA beginRow,
FB op,
FC endRow) {
CUTLASS_PRAGMA_UNROLL
for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) {
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) {
int accum_m = mma_m * Delta::kRow + m + lane_offset.row();
beginRow(accum_m);
CUTLASS_PRAGMA_UNROLL
for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) {
int accum_n =
mma_n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN +
lane_offset.column();
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) {
int idx = n +
Policy::LaneMmaShape::kN *
(mma_n +
Iterations::kColumn *
(m + mma_m * Policy::LaneMmaShape::kM));
op(accum_m, accum_n + n, idx);
}
}
endRow(accum_m);
}
}
}
static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset(
int8_t lane_id,
int8_t warp_id,
typename T::TensorCoord const& tile_offset) {
static_assert(
cutlass::platform::is_same<
typename Policy::LaneLayout,
cutlass::layout::RowMajorInterleaved<1>>::value,
"");
typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
cutlass::MatrixCoord lane_offset = lane_layout.inverse(lane_id) *
cutlass::MatrixCoord(Policy::LaneMmaShape::kM,
Policy::LaneMmaShape::kN);
return lane_offset +
tile_offset * cutlass::MatrixCoord(Shape::kRow, Shape::kColumn);
}
};
template <typename T, typename accum_t, int kWarpSize>
struct DefaultAttentionScalingCoefsUpdater;
// Simt
template <typename S, typename P, typename accum_t, int kWarpSize>
struct DefaultAttentionScalingCoefsUpdater<
cutlass::gemm::warp::MmaSimtTileIterator<
S,
cutlass::gemm::Operand::kC,
accum_t,
cutlass::layout::RowMajor,
P,
1,
1>,
accum_t,
kWarpSize> {
using Iterator = typename cutlass::gemm::warp::MmaSimtTileIterator<
S,
cutlass::gemm::Operand::kC,
accum_t,
cutlass::layout::RowMajor,
P,
1,
1>;
using Updater =
AttentionScalingCoefsUpdaterSimt<Iterator, accum_t, kWarpSize>;
};
// TensorOp - Volta
template <typename S1, typename S2, typename accum_t, int kWarpSize>
struct DefaultAttentionScalingCoefsUpdater<
cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator<
S1,
accum_t,
cutlass::layout::RowMajor,
S2,
cutlass::MatrixShape<1, 1>>,
accum_t,
kWarpSize> {
using Iterator =
typename cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator<
S1,
accum_t,
cutlass::layout::RowMajor,
S2,
cutlass::MatrixShape<1, 1>>;
using Updater =
AttentionScalingCoefsUpdaterVolta<Iterator, accum_t, kWarpSize>;
};
// TensorOp - Sm75+
template <
typename S1,
typename S2,
typename S3,
typename accum_t,
int kWarpSize>
struct DefaultAttentionScalingCoefsUpdater<
cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator<
S1,
accum_t,
cutlass::layout::RowMajor,
S2,
S3>,
accum_t,
kWarpSize> {
using Iterator =
typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator<
S1,
accum_t,
cutlass::layout::RowMajor,
S2,
S3>;
using Updater =
AttentionScalingCoefsUpdaterSm80<Iterator, accum_t, kWarpSize>;
};

View File

@ -0,0 +1,128 @@
#pragma once
#include <float.h>
#include <stdio.h>
#include <cmath>
////////////////////////////////////////////////////////////////////////////////
// Debugging functions
////////////////////////////////////////////////////////////////////////////////
// Nans & inf detection
#define NANCHECK(frag) \
{ \
for (int _i = 0; _i < frag.size(); ++_i) { \
assert(std::isfinite(float(frag[_i]))); \
assert(!std::isnan(float(frag[_i]))); \
} \
}
// Print on the first thread of the first block
#if 0
#define PRINT_WARP_ID 0
#define PRINT_LANE_ID 0
#define PRINT_T0_L0(msg, ...) \
if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && \
threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \
threadIdx.z == 0) { \
printf(msg "\n", __VA_ARGS__); \
}
struct __string_view {
char const* data;
std::size_t size;
};
template <class T>
constexpr __string_view __get_type_name() {
char const* p = __PRETTY_FUNCTION__;
while (*p++ != '=')
;
for (; *p == ' '; ++p)
;
char const* p2 = p;
int count = 1;
for (;; ++p2) {
switch (*p2) {
case '[':
++count;
break;
case ']':
--count;
if (!count)
return {p, std::size_t(p2 - p)};
}
}
return {};
}
#else
#define PRINT_T0_L0
#endif
// Print a given array
#define PRINT_ACCUM8_T0_L0_START(name, accum, start) \
PRINT_T0_L0( \
"%s[%d:%d] - {%f, %f, %f, %f, %f, %f, %f, %f}", \
name, \
int(start), \
int(start + 8), \
float(accum[start + 0]), \
float(accum[start + 1]), \
float(accum[start + 2]), \
float(accum[start + 3]), \
float(accum[start + 4]), \
float(accum[start + 5]), \
float(accum[start + 6]), \
float(accum[start + 7]));
#define PRINT_ACCUM8_T0_L0(name, accum) PRINT_ACCUM8_T0_L0_START(name, accum, 0)
#define PRINT_FRAG_T0_L0(name, frag) \
{ \
auto typeStr = __get_type_name<decltype(frag)>(); \
PRINT_T0_L0("printing %s (%s)", name, typeStr.data); \
for (int _start = 0; _start < frag.size(); _start += 8) { \
PRINT_ACCUM8_T0_L0_START(" ", frag, _start); \
} \
/*__syncthreads(); \
NANCHECK(frag); */ \
}
#define PRINT_ARRAY_T0_L0_INCR(name, array, length, incr) \
{ \
PRINT_T0_L0("printing %s (len=%d)", name, int(length)); \
for (int _start = 0; _start < length; _start += incr) { \
PRINT_ACCUM8_T0_L0_START(" ", array, _start); \
} \
}
#define PRINT_ARRAY_T0_L0(name, array, length) \
PRINT_ARRAY_T0_L0_INCR(name, array, length, 8)
// Print a 4x4 matrix
#define PRINT_TENSOR4x4_T0_L0_START(name, ref, start_x, start_y) \
PRINT_T0_L0( \
"%s[%d:%d, %d:%d]:\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f", \
name, \
int(start_x), \
int(start_x + 4), \
int(start_y), \
int(start_y + 4), \
float(ref.at({start_x + 0, start_y + 0})), \
float(ref.at({start_x + 0, start_y + 1})), \
float(ref.at({start_x + 0, start_y + 2})), \
float(ref.at({start_x + 0, start_y + 3})), \
float(ref.at({start_x + 1, start_y + 0})), \
float(ref.at({start_x + 1, start_y + 1})), \
float(ref.at({start_x + 1, start_y + 2})), \
float(ref.at({start_x + 1, start_y + 3})), \
float(ref.at({start_x + 2, start_y + 0})), \
float(ref.at({start_x + 2, start_y + 1})), \
float(ref.at({start_x + 2, start_y + 2})), \
float(ref.at({start_x + 2, start_y + 3})), \
float(ref.at({start_x + 3, start_y + 0})), \
float(ref.at({start_x + 3, start_y + 1})), \
float(ref.at({start_x + 3, start_y + 2})), \
float(ref.at({start_x + 3, start_y + 3})));
#define PRINT_TENSOR4x4_T0_L0(name, ref) \
PRINT_TENSOR4x4_T0_L0_START(name, ref, 0, 0)
#define PRINT_PROBLEM_SIZE(name, ps) \
PRINT_T0_L0( \
"%s.problem_size: {.m=%d, .n=%d, .k=%d}", \
name, \
int(ps.m()), \
int(ps.n()), \
int(ps.k()))

View File

@ -0,0 +1,632 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
File copied from "cutlass/epilogue/threadblock/epilogue.h"
then modified to:
(1) load 2 source fragments at the same time (pipelining)
(2) support reading from a different dtype
(3) pass the row id to the OutputOp if it takes it
(see MemoryEfficientAttentionNormalize)
Note that in general the fragment passed to the OutputOp could
span multiple rows but it does not happen with the configurations we have
*/
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include <assert.h>
#endif
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/functional.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/layout/vector.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_coord.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
#include "cutlass/epilogue/threadblock/epilogue_base.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
#include "cutlass/numeric_types.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
template <typename Op>
struct ApplyEpilogueOp {
static CUTLASS_DEVICE typename Op::FragmentOutput apply(
Op const& output_op,
int row_id,
typename Op::FragmentAccumulator const& accum,
typename Op::FragmentOutput const& source) {
return output_op(accum, source);
}
static CUTLASS_DEVICE typename Op::FragmentOutput apply(
Op const& output_op,
int row_id,
typename Op::FragmentAccumulator const& accum) {
return output_op(accum);
}
};
////////////////////////////////////////////////////////////////////////////////
/// Epilogue operator
template <
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept:
///< gemm::warp::MmaTensorOp)
int PartitionsK, ///< Number of partitions of the K dimension
typename OutputTileIterator_, ///< Tile iterator writing output tensors
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting
///< accumulators
typename WarpTileIterator_, ///< Warp-scoped tile iterator writing
///< accumulators to SMEM
typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading
///< from SMEM
typename OutputOp_, ///< Output operator
typename Padding_, ///< Padding added to SMEM allocation to avoid bank
///< conflicts (concept: MatrixShape)
int FragmentsPerPartition =
1, ///< Used to coarsten the epilogue granularity
int IterationsUnroll = ///< Used to reduce binary size when epilogue op is
///< large
(!IsEpilogueFunctorHeavy<OutputOp_>::value),
typename OutputTileSourceIterator_ =
OutputTileIterator_ ///< Tile iterator reading tensors
>
class EpiloguePipelined : public EpilogueBase<
Shape_,
typename WarpMmaOperator_::Shape,
PartitionsK,
AccumulatorFragmentIterator_,
WarpTileIterator_,
Padding_,
FragmentsPerPartition> {
public:
using Base = EpilogueBase<
Shape_,
typename WarpMmaOperator_::Shape,
PartitionsK,
AccumulatorFragmentIterator_,
WarpTileIterator_,
Padding_,
FragmentsPerPartition>;
using Shape = Shape_;
using WarpMmaOperator = WarpMmaOperator_;
static int const kPartitionsK = PartitionsK;
using OutputTileIterator = OutputTileIterator_;
using OutputTileSourceIterator = OutputTileSourceIterator_;
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
using WarpTileIterator = WarpTileIterator_;
using SharedLoadIterator = SharedLoadIterator_;
using OutputOp = OutputOp_;
using Padding = Padding_;
using Layout = layout::RowMajor;
using LongIndex = typename Layout::LongIndex;
/// The complete warp-level accumulator tile
using AccumulatorTile = typename Base::AccumulatorTile;
/// Accumulator element
using ElementAccumulator = typename WarpTileIterator::Element;
/// Output element
using ElementOutput = typename OutputTileIterator::Element;
using ElementSource = typename OutputTileSourceIterator::Element;
/// Output access size
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
/// Tensor reference to destination tensor
using TensorRef = typename OutputTileIterator::TensorRef;
/// Tensor reference to sync tensor
using SyncTensorRef =
typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
/// Const tensor reference to source tensor
using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
/// Array type used to output
using OutputAccessType = Array<
typename OutputTileIterator::Element,
OutputTileIterator::kElementsPerAccess>;
using SourceAccessType = Array<
typename OutputTileSourceIterator::Element,
OutputTileSourceIterator::kElementsPerAccess>;
/// Array type used by output functor
using AccumulatorAccessType = Array<
typename WarpTileIterator::Element,
OutputTileIterator::kElementsPerAccess>;
/// Number of warps
using WarpCount = typename Base::WarpCount;
static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1
? Base::kFragmentsPerIteration
: kPartitionsK;
static int constexpr kSmemPointerOffset =
Base::SharedStorage::StorageShape::kCount / kSmemTiles;
public:
static_assert(
OutputTileSourceIterator::Fragment::kElements ==
OutputTileIterator::Fragment::kElements,
"Mismatch between input tile and output tile iterator (kElements)");
static_assert(
OutputTileSourceIterator::kIterations == OutputTileIterator::kIterations,
"Mismatch between input tile and output tile iterator (kIterations)");
static_assert(
SharedLoadIterator::Fragment::kElements ==
OutputTileIterator::Fragment::kElements,
"Mismatch between shared load iterator and output tile iterator.");
static_assert(
OutputTileIterator::kElementsPerAccess,
"OutputTileIterator::kElementsPerAccess must not be zero.");
static_assert(
!(OutputTileIterator::Fragment::kElements %
OutputTileIterator::kElementsPerAccess),
"Divisibility");
private:
/// Loads fragment from shared memory aligned with output tensor
SharedLoadIterator shared_load_iterator_;
public:
/// Constructor
CUTLASS_DEVICE
EpiloguePipelined(
typename Base::SharedStorage& shared_storage, ///< Shared storage object
int thread_idx, ///< ID of a thread within the threadblock
int warp_idx, ///< ID of warp within threadblock
int lane_idx ///< Id of thread within warp
)
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
shared_load_iterator_(shared_storage.reference(), thread_idx) {}
/// Streams the result to global memory
CUTLASS_DEVICE
void operator()(
OutputOp const& output_op, ///< Output operator
OutputTileIterator
destination_iterator, ///< Tile iterator for destination
AccumulatorTile const&
accumulators, ///< Complete warp-level accumulator tile
OutputTileSourceIterator
source_iterator) { ///< Threadblock tile coordinate in GEMM (in units
///< of threadblock tiles)
if (!output_op.is_source_needed()) {
compute_source_not_needed_(output_op, destination_iterator, accumulators);
} else {
compute_source_needed_(
output_op, destination_iterator, accumulators, source_iterator);
}
}
CUTLASS_DEVICE
void operator()(
OutputOp const& output_op, ///< Output operator
OutputTileIterator
destination_iterator, ///< Tile iterator for destination
AccumulatorTile const&
accumulators) { ///< Complete warp-level accumulator tile
compute_source_not_needed_(output_op, destination_iterator, accumulators);
}
private:
template <class Seq>
struct acc2smem_source_not_needed;
template <size_t... Seq>
struct acc2smem_source_not_needed<cutlass::index_sequence<Seq...>> {
template <int Advance>
CUTLASS_DEVICE static void helper(
AccumulatorFragmentIterator accum_fragment_iterator,
WarpTileIterator& warp_tile_iterator) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Advance; i++) {
++accum_fragment_iterator;
}
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
typename AccumulatorFragmentIterator::Fragment accum_fragment;
accum_fragment_iterator.load(accum_fragment);
++accum_fragment_iterator;
warp_tile_iterator.store(accum_fragment);
if (p < Base::kFragmentsPerIteration - 1) {
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset);
}
}
if (Base::kFragmentsPerIteration > 1) {
warp_tile_iterator.add_pointer_offset(
kSmemPointerOffset * (1 - Base::kFragmentsPerIteration));
}
}
CUTLASS_DEVICE
static void push(
size_t pos,
AccumulatorFragmentIterator const& iterator_begin,
WarpTileIterator& warp_tile_iterator) {
int dummy[] = {
(pos == (Seq * Base::kFragmentsPerIteration)) &&
(helper<Seq * Base::kFragmentsPerIteration>(
iterator_begin, warp_tile_iterator),
0)...};
CUTLASS_UNUSED(dummy[0]);
}
};
static_assert(
kPartitionsK == 1 || Base::kFragmentsPerIteration == 1,
"One of these must be exactly 1.");
/// Streams the result to global memory
CUTLASS_DEVICE
void compute_source_not_needed_(
OutputOp const& output_op, ///< Output operator
OutputTileIterator
destination_iterator, ///< Tile iterator for destination
AccumulatorTile const&
accumulators ///< Complete warp-level accumulator tile
) {
//
// Iterator over warp-level accumulator fragment
//
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
//
// Iterate over accumulator tile
//
#pragma unroll( \
IterationsUnroll \
? OutputTileIterator::kIterations / Base::kFragmentsPerIteration \
: 1)
for (int iter = 0; iter < OutputTileIterator::kIterations;
iter += Base::kFragmentsPerIteration) {
//
// Convert and store fragment
//
__syncthreads();
acc2smem_source_not_needed<cutlass::make_index_sequence<
OutputTileIterator::kIterations / Base::kFragmentsPerIteration>>::
push(iter, accum_fragment_iterator, this->warp_tile_iterator_);
__syncthreads();
//
// Load fragments from shared memory
//
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
typename SharedLoadIterator::Fragment
aligned_accum_fragment[kPartitionsK];
shared_load_iterator_.load(aligned_accum_fragment[0]);
if (p < Base::kFragmentsPerIteration - 1) {
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
} else if (kPartitionsK > 1) {
plus<typename SharedLoadIterator::Fragment> add_fragments;
CUTLASS_PRAGMA_UNROLL
for (int i = 1; i < kPartitionsK; ++i) {
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
shared_load_iterator_.load(aligned_accum_fragment[i]);
aligned_accum_fragment[0] = add_fragments(
aligned_accum_fragment[0], aligned_accum_fragment[i]);
}
shared_load_iterator_.add_pointer_offset(
(1 - kPartitionsK) * kSmemPointerOffset);
}
//
// Compute the output result
//
typename OutputTileIterator::Fragment output_fragment;
apply_output_operator_source_not_needed_(
destination_iterator.thread_start_row(),
output_fragment,
output_op,
aligned_accum_fragment[0]);
//
// Store the final result
//
destination_iterator.store(output_fragment);
++destination_iterator;
}
if (Base::kFragmentsPerIteration > 1) {
shared_load_iterator_.add_pointer_offset(
kSmemPointerOffset * (1 - Base::kFragmentsPerIteration));
}
}
}
template <class Seq>
struct acc2smem_source_needed;
template <size_t... Seq>
struct acc2smem_source_needed<cutlass::index_sequence<Seq...>> {
template <int Advance>
CUTLASS_DEVICE static void helper(
AccumulatorFragmentIterator accum_fragment_iterator,
WarpTileIterator& warp_tile_iterator) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Advance; i++) {
++accum_fragment_iterator;
}
typename AccumulatorFragmentIterator::Fragment accum_fragment;
accum_fragment_iterator.load(accum_fragment);
warp_tile_iterator.store(accum_fragment);
}
CUTLASS_DEVICE
static void push(
size_t pos,
AccumulatorFragmentIterator const& iterator_begin,
WarpTileIterator& warp_tile_iterator) {
int dummy[] = {
(pos == Seq) &&
(helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
}
};
/// Streams the result to global memory
CUTLASS_DEVICE
void compute_source_needed_(
OutputOp const& output_op, ///< Output operator
OutputTileIterator
destination_iterator, ///< Tile iterator for destination
AccumulatorTile const&
accumulators, ///< Complete warp-level accumulator tile
OutputTileSourceIterator
source_iterator ///< Threadblock tile coordinate in GEMM (in units of
///< threadblock tiles)
) {
typename OutputTileSourceIterator::Fragment source_fragment[2];
source_fragment[0].clear();
source_iterator.load(source_fragment[0]);
++source_iterator;
source_fragment[1].clear();
//
// Iterator over warp-level accumulator fragment
//
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
//
// Iterate over accumulator tile
//
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
if (iter > 0) {
__syncthreads();
}
//
// Load the source for next iteration (pipelining)
//
if (iter + 1 < OutputTileIterator::kIterations) {
source_iterator.load(source_fragment[(iter + 1) % 2]);
}
++source_iterator;
acc2smem_source_needed<
cutlass::make_index_sequence<OutputTileIterator::kIterations>>::
push(iter, accum_fragment_iterator, this->warp_tile_iterator_);
__syncthreads();
//
// Load fragments from shared memory
//
typename SharedLoadIterator::Fragment
aligned_accum_fragment[kPartitionsK];
shared_load_iterator_.load(aligned_accum_fragment[0]);
// If the number of k-slices is > 1 - perform a reduction amongst the
// k-slices
if (kPartitionsK > 1) {
plus<typename SharedLoadIterator::Fragment> add_fragments;
CUTLASS_PRAGMA_UNROLL
for (int i = 1; i < kPartitionsK; ++i) {
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
shared_load_iterator_.load(aligned_accum_fragment[i]);
aligned_accum_fragment[0] = add_fragments(
aligned_accum_fragment[0], aligned_accum_fragment[i]);
}
shared_load_iterator_.add_pointer_offset(
(1 - kPartitionsK) * kSmemPointerOffset);
}
//
// Compute the output result
//
typename OutputTileIterator::Fragment output_fragment;
apply_output_operator_(
destination_iterator.thread_start_row(),
output_fragment,
output_op,
aligned_accum_fragment[0],
source_fragment[iter % 2]);
//
// Store the final result
//
destination_iterator.store(output_fragment);
++destination_iterator;
}
}
/// Helper to invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator_(
int begin_row,
typename OutputTileIterator::Fragment& output_fragment,
OutputOp const& output_op, ///< Output operator
typename SharedLoadIterator::Fragment const& aligned_accum_fragment,
typename OutputTileSourceIterator::Fragment const& source_fragment) {
OutputAccessType* output_frag_ptr =
reinterpret_cast<OutputAccessType*>(&output_fragment);
AccumulatorAccessType const* compute_frag_ptr =
reinterpret_cast<AccumulatorAccessType const*>(&aligned_accum_fragment);
SourceAccessType const* source_frag_ptr =
reinterpret_cast<SourceAccessType const*>(&source_fragment);
int const kOutputOpIterations = OutputTileIterator::Fragment::kElements /
OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i) {
// Call the output operator
output_frag_ptr[i] = ApplyEpilogueOp<OutputOp>::apply(
output_op,
begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess),
compute_frag_ptr[i],
source_frag_ptr[i]);
}
}
/// Helper to invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator_source_not_needed_(
int begin_row,
typename OutputTileIterator::Fragment& output_fragment,
OutputOp const& output_op, ///< Output operator
typename SharedLoadIterator::Fragment const& aligned_accum_fragment) {
OutputAccessType* output_frag_ptr =
reinterpret_cast<OutputAccessType*>(&output_fragment);
AccumulatorAccessType const* compute_frag_ptr =
reinterpret_cast<AccumulatorAccessType const*>(&aligned_accum_fragment);
int const kOutputOpIterations = OutputTileIterator::Fragment::kElements /
OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i) {
// Call the output operator
output_frag_ptr[i] = ApplyEpilogueOp<OutputOp>::apply(
output_op,
begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess),
compute_frag_ptr[i]);
}
}
// This should be constexpr, but it's only supported on c++14
static int CUTLASS_HOST_DEVICE getRowOffset(int i) {
using ThreadMap = typename OutputTileIterator::ThreadMap;
CUTLASS_PRAGMA_UNROLL
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster;
++cluster) {
CUTLASS_PRAGMA_UNROLL
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int row_offset = row * ThreadMap::Delta::kRow +
group * ThreadMap::Delta::kGroup +
cluster * ThreadMap::Delta::kCluster;
int frag_row_idx =
(row +
ThreadMap::Iterations::kRow *
(group + ThreadMap::Iterations::kGroup * cluster));
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn;
++column) {
int frag_idx = ThreadMap::kElementsPerAccess *
(frag_row_idx * ThreadMap::Iterations::kColumn + column);
if (i < frag_idx + ThreadMap::kElementsPerAccess) {
return row_offset;
}
}
}
}
}
return -1;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,231 @@
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
The epilogue rearranges the result of a matrix product through shared memory
to match canonical tensor layouts in global memory. Epilogues support
conversion and reduction operations.
This is a copy of cutlass/epilogue/threadblock/epilogue.h that can
handle "row_id" as a first argument, as uses it to get the corresponding
`m_prime` / `s_prime` to rescale the output.
*/
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include <assert.h>
#endif
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/functional.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/layout/vector.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_coord.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
#include "cutlass/epilogue/threadblock/epilogue_base.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/thread/scale_type.h"
#include "cutlass/functional.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
#include "epilogue_pipelined.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace thread {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Applies a linear combination operator to an array of elements.
// output <- alpha * accumulator + beta * source
// with:
// alpha = 1 / s_prime (to normalize when isLast=True, 1 otherwise)
// beta = alpha / m_prime (renormalize the output when the max changes)
// source is the current output
template <
typename ElementOutput_, ///< Data type used to store tensors
typename ElementSource_, //< Data type for source (usually matches
//`ElementOutput`)
int Count, ///< Number of elements computed per operation.
///< Usually it is 128/sizeof_bits<ElementOutput_>,
///< but we use 64 or 32 sometimes when there are not enough data
///< to store
typename ElementAccumulator_, ///< Accumulator data type
typename ElementCompute_, ///< Data type used to compute linear combination
bool isFirst,
bool isLast,
typename FragmentAlphaBeta_,
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest>
class MemoryEfficientAttentionNormalize {
public:
using ElementOutput = ElementOutput_;
using ElementSource = ElementSource_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
static int const kCount = Count;
using FragmentOutput = Array<ElementOutput, kCount>;
using FragmentSource = Array<ElementSource, kCount>;
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
using ComputeFragment = Array<ElementCompute, kCount>;
using FragmentAlphaBeta = FragmentAlphaBeta_;
static FloatRoundStyle const kRound = Round;
private:
//
// Data members
//
FragmentAlphaBeta const& s_prime_;
FragmentAlphaBeta const& m_prime_;
public:
/// Constructs the function object, possibly loading from pointers in host
/// memory
CUTLASS_HOST_DEVICE
MemoryEfficientAttentionNormalize(
FragmentAlphaBeta const& s_prime,
FragmentAlphaBeta const& m_prime)
: s_prime_(s_prime), m_prime_(m_prime) {}
/// Returns true if source is needed
CUTLASS_HOST_DEVICE
bool is_source_needed() const {
return !isFirst;
}
/// Functionally required for serial reduction in the epilogue
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition, int k_partition_count) {}
/// Computes linear scaling: D = alpha * accumulator + beta * source
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
int row,
FragmentAccumulator const& accumulator,
FragmentSource const& source) const {
assert(!isFirst);
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementSource, kCount, Round>
source_converter;
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
accumulator_converter;
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
destination_converter;
ComputeFragment converted_source = source_converter(source);
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_add_source;
multiply_add<ComputeFragment> mul_add_accumulator;
ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1;
ElementCompute beta = alpha * m_prime_[row];
intermediate = mul_add_source(beta, converted_source); // X = beta * C
intermediate = mul_add_accumulator(
alpha, converted_accumulator, intermediate); // D = alpha * Accum + X
return destination_converter(intermediate);
}
/// Computes linear scaling: D = alpha * accumulator
CUTLASS_HOST_DEVICE
FragmentOutput operator()(int row, FragmentAccumulator const& accumulator)
const {
assert(isFirst);
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
accumulator_converter;
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
destination_converter;
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_accumulator;
ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1;
intermediate = mul_accumulator(
alpha, converted_accumulator); // X = alpha * C + uniform
return destination_converter(intermediate);
}
};
} // namespace thread
namespace threadblock {
template <
typename EO,
typename ES,
int Count,
typename EA,
typename EC,
bool F,
bool L,
typename FAB,
FloatRoundStyle R>
struct ApplyEpilogueOp<thread::MemoryEfficientAttentionNormalize<
EO,
ES,
Count,
EA,
EC,
F,
L,
FAB,
R>> {
using Op = thread::
MemoryEfficientAttentionNormalize<EO, ES, Count, EA, EC, F, L, FAB, R>;
static CUTLASS_DEVICE typename Op::FragmentOutput apply(
Op const& output_op,
int row_id,
typename Op::FragmentAccumulator const& accum,
typename Op::FragmentSource const& source) {
return output_op(row_id, accum, source);
}
static CUTLASS_DEVICE typename Op::FragmentOutput apply(
Op const& output_op,
int row_id,
typename Op::FragmentAccumulator const& accum) {
return output_op(row_id, accum);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,175 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Functor performing linear combination operations used by epilogues.
*/
#pragma once
#include <cuda_fp16.h>
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/functional.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace thread {
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
template <typename Element, int ElementsPerAccess>
struct ArrayExponential {
CUTLASS_HOST_DEVICE
Array<Element, ElementsPerAccess> operator()(
Array<Element, ElementsPerAccess> const& input) const {
Array<Element, ElementsPerAccess> result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ElementsPerAccess; ++i) {
result[i] = expf(input[i]);
}
return result;
}
};
template <int ElementsPerAccess>
struct ArrayExponential<half_t, ElementsPerAccess> {
CUTLASS_DEVICE
Array<half_t, ElementsPerAccess> operator()(
Array<half_t, ElementsPerAccess> const& input) const {
Array<half_t, ElementsPerAccess> result;
int const kVectorCount = ElementsPerAccess / 2;
__half2 const* input_ptr =
reinterpret_cast<__half2 const*>(input.raw_data());
__half2* res_ptr = reinterpret_cast<__half2*>(result.raw_data());
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kVectorCount; ++i) {
res_ptr[i] = h2exp(input_ptr[i]);
}
return result;
}
};
} // namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Applies:
/// output <- (input - lse).exp()
template <
typename ElementOutput_, // output
typename ElementLSE_, // accumulator from LSE
typename ElementAccumulator_, // accumulator from matmul
typename ElementCompute_, // intermediate compute (and exp calculation)
int ElementsPerAccess>
class ApplyLogSumExp {
public:
using ElementOutput = ElementOutput_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
using ElementLSE = ElementLSE_;
static int const kElementsPerAccess = ElementsPerAccess;
static int const kCount = kElementsPerAccess;
static const ScaleType::Kind kScale =
cutlass::epilogue::thread::ScaleType::NoBetaScaling;
using FragmentOutput = Array<ElementOutput, kCount>;
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
using FragmentLSE = Array<ElementLSE, kElementsPerAccess>;
using FragmentScaleBias = FragmentLSE; // Used by epilogue_smem_accumulator.h
public:
//
// Methods
//
CUTLASS_HOST_DEVICE
ApplyLogSumExp() {}
/// Returns true if source is needed
CUTLASS_HOST_DEVICE
bool is_source_needed() const {
return true;
}
/// Functionally required for serial reduction in the epilogue
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition, int k_partition_count) {}
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const& AB,
FragmentLSE const& scale_unused,
// bias used as LSE
FragmentLSE const& bias) const {
FragmentCompute frag_AB = NumericArrayConverter<
ElementCompute,
ElementAccumulator,
kElementsPerAccess>()(AB);
FragmentCompute frag_lse_compute =
NumericArrayConverter<ElementCompute, ElementLSE, kElementsPerAccess>()(
bias);
FragmentCompute frag_compute;
minus<FragmentCompute> minus_lse;
detail::ArrayExponential<ElementCompute, kElementsPerAccess> apply_exp;
frag_compute = minus_lse(frag_AB, frag_lse_compute);
frag_compute = apply_exp(frag_compute);
return NumericArrayConverter<
ElementOutput,
ElementCompute,
kElementsPerAccess>()(frag_compute);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace thread
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,158 @@
/*! \file
\brief Cutlass provides helper template functions to figure out the right
datastructures to instanciate to run a GEMM with various parameters (see
`cutlass/gemm/threadblock/default_mma.h`). However, due to template
instanciation priority rules, it will only create an MmaMultiStage with
kStages=3 (otherwise creates an MmePipelined - which is not compatible with
FastF32). kStages=3 uses too much shared memory and we want to use kStages=2,
so we just copy-pasted some code from `default_mma.h` and
`default_mma_core.h` files and wrapped this template to allow our usecase.
This is really only for the FastF32 case - aka using TensorCores with fp32.
*/
#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Layout type for C and D matrix operand
typename LayoutC,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation perfomed by GEMM
typename Operator,
typename Enable_ = void>
struct FindDefaultMma {
static constexpr bool AccumulatorsInRowMajor = false;
static constexpr SharedMemoryClearOption SharedMemoryClear =
SharedMemoryClearOption::kNone;
using DefaultMma = cutlass::gemm::threadblock::DefaultMma<
ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementAccumulator,
LayoutC,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
Stages,
Operator,
AccumulatorsInRowMajor,
SharedMemoryClear>;
};
/// Specialization for sm80 / FastF32 / multistage with kStages=2
template <
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
int kStages,
typename Operator>
struct FindDefaultMma<
ElementA_,
LayoutA_,
kAlignmentA,
ElementB_,
LayoutB_,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
arch::Sm80,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
typename cutlass::platform::enable_if<(kAlignmentA > 1)>::type> {
using LayoutC = layout::RowMajor;
using OperatorClass = arch::OpClassTensorOp;
using ArchTag = arch::Sm80;
using DefaultMma_ = cutlass::gemm::threadblock::DefaultMma<
ElementA_,
LayoutA_,
kAlignmentA,
ElementB_,
LayoutB_,
kAlignmentB,
ElementAccumulator,
LayoutC,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
3,
Operator>;
struct DefaultMma : DefaultMma_ {
using MmaCore_ = typename DefaultMma_::MmaCore;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<
typename MmaCore_::Shape,
typename DefaultMma_::IteratorA,
typename MmaCore_::SmemIteratorA,
MmaCore_::kCacheOpA,
typename DefaultMma_::IteratorB,
typename MmaCore_::SmemIteratorB,
MmaCore_::kCacheOpB,
ElementAccumulator,
LayoutC,
typename MmaCore_::MmaPolicy,
kStages>;
};
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,93 @@
#pragma once
#include "custom_mma_multistage.h"
#include "custom_mma_pipelined.h"
#include "cutlass/gemm/threadblock/mma_multistage.h"
#include "cutlass/gemm/threadblock/mma_pipelined.h"
template <typename Mma, int kMaxK>
struct MakeCustomMma;
template <
typename Shape,
typename IteratorA,
typename SmemIteratorA,
cutlass::arch::CacheOperation::Kind CacheOpA,
typename IteratorB,
typename SmemIteratorB,
cutlass::arch::CacheOperation::Kind CacheOpB,
typename ElementC,
typename LayoutC,
typename Policy,
int Stages,
cutlass::gemm::SharedMemoryClearOption SharedMemoryClear,
int kMaxK>
struct MakeCustomMma<
cutlass::gemm::threadblock::MmaMultistage<
Shape,
IteratorA,
SmemIteratorA,
CacheOpA,
IteratorB,
SmemIteratorB,
CacheOpB,
ElementC,
LayoutC,
Policy,
Stages,
SharedMemoryClear>,
kMaxK> {
// Reduce the number of stages if we don't need that many
static int constexpr kStages =
kMaxK == cutlass::platform::numeric_limits<int>::max()
? Stages
: cutlass::const_min(
Stages,
(kMaxK + int(Shape::kK) - 1) / int(Shape::kK));
using Mma = cutlass::gemm::threadblock::CustomMmaMultistage<
Shape,
IteratorA,
SmemIteratorA,
CacheOpA,
IteratorB,
SmemIteratorB,
CacheOpB,
ElementC,
LayoutC,
Policy,
kStages,
SharedMemoryClear,
kMaxK>;
};
template <
typename Shape,
typename IteratorA,
typename SmemIteratorA,
typename IteratorB,
typename SmemIteratorB,
typename ElementC,
typename LayoutC,
typename Policy,
int kMaxK>
struct MakeCustomMma<
cutlass::gemm::threadblock::MmaPipelined<
Shape,
IteratorA,
SmemIteratorA,
IteratorB,
SmemIteratorB,
ElementC,
LayoutC,
Policy>,
kMaxK> {
using Mma = cutlass::gemm::threadblock::CustomMmaPipelined<
Shape,
IteratorA,
SmemIteratorA,
IteratorB,
SmemIteratorB,
ElementC,
LayoutC,
Policy>;
};

View File

@ -0,0 +1,183 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/threadblock/mma_base.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Used for partial specialization
typename Enable = bool>
class CustomMmaBase {
public:
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Policy describing tuning details
using Policy = Policy_;
//
// Dependent types
//
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Shape describing the overall GEMM computed from shared memory
/// by each warp.
using WarpGemm = typename Policy::Operator::Shape;
/// Shape describing the number of warps filling the CTA
using WarpCount = GemmShape<
Shape::kM / WarpGemm::kM,
Shape::kN / WarpGemm::kN,
Shape::kK / WarpGemm::kK>;
/// Number of warp-level GEMM oeprations
static int const kWarpGemmIterations =
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
/// Number of stages
static int const kStages = Stages;
//
// Nested structs
//
/// Shared storage object needed by threadblock-scoped GEMM
template <typename Element, typename OperandShape, typename OperandLayout>
struct OperandSharedStorage {
AlignedBuffer<Element, OperandShape::kCount> buffer;
using TensorRef = TensorRef<Element, OperandLayout>;
CUTLASS_DEVICE
static OperandLayout Layout() {
return OperandLayout::packed({OperandShape::kRow, OperandShape::kColumn});
}
/// Returns a TensorRef to the operand
CUTLASS_HOST_DEVICE
TensorRef ref() {
return TensorRef{buffer.data(), Layout()};
}
};
/// Shape of the A matrix operand in shared memory
using ShapeA = MatrixShape<
Shape::kM + Policy::SmemPaddingA::kRow,
Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
/// Shape of the B matrix operand in shared memory
using ShapeB = MatrixShape<
Shape::kK * kStages + Policy::SmemPaddingB::kRow,
Shape::kN + Policy::SmemPaddingB::kColumn>;
using SharedStorageA = OperandSharedStorage<
typename Operator::ElementA,
ShapeA,
typename Operator::LayoutA>;
using SharedStorageB = OperandSharedStorage<
typename Operator::ElementB,
ShapeB,
typename Operator::LayoutB>;
using TensorRefA = typename SharedStorageA::TensorRef;
using TensorRefB = typename SharedStorageB::TensorRef;
struct SharedStorage {
/// Buffer for A operand
SharedStorageA operand_A;
/// Buffer for B operand
SharedStorageB operand_B;
};
protected:
//
// Data members
//
/// Iterator to load a warp-scoped tile of A operand from shared memory
typename Operator::IteratorA warp_tile_iterator_A_;
/// Iterator to load a warp-scoped tile of B operand from shared memory
typename Operator::IteratorB warp_tile_iterator_B_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
CustomMmaBase(
///< Shared storage needed for internal use by threadblock-scoped GEMM
SharedStorageA& shared_storageA,
SharedStorageB& shared_storageB,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: warp_tile_iterator_A_(shared_storageA.ref(), lane_idx),
warp_tile_iterator_B_(shared_storageB.ref(), lane_idx) {}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,767 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/cache_operation.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "custom_mma_base.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Cache operation for operand A
cutlass::arch::CacheOperation::Kind CacheOpA,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Cache operation for operand B
cutlass::arch::CacheOperation::Kind CacheOpB,
/// Data type of accumulator matrix
typename ElementC_,
/// Data type of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
/// Upper boundon the K dimension
int kMaxK = cutlass::platform::numeric_limits<int>::max(),
/// Used for partial specialization
typename Enable = bool>
class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages> {
public:
///< Base class
using Base = CustomMmaBase<Shape_, Policy_, Stages>;
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Iterates over tiles of A operand in global memory
using IteratorA = IteratorA_;
///< Iterates over tiles of B operand in global memory
using IteratorB = IteratorB_;
///< Data type of accumulator matrix
using ElementC = ElementC_;
///< Layout of accumulator matrix
using LayoutC = LayoutC_;
///< Policy describing tuning details
using Policy = Policy_;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
//
// Dependent types
//
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Minimum architecture is Sm80 to support cp.async
using ArchTag = arch::Sm80;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
/// Internal structure exposed for introspection.
struct Detail {
static_assert(
Base::kWarpGemmIterations > 1,
"The pipelined structure requires at least two warp-level "
"GEMM operations.");
/// Number of cp.async instructions to load one stage of operand A
static int const AsyncCopyIterationsPerStageA =
IteratorA::ThreadMap::Iterations::kCount;
/// Number of cp.async instructions to load one stage of operand B
static int const AsyncCopyIterationsPerStageB =
IteratorB::ThreadMap::Iterations::kCount;
/// Number of stages
static int const kStages = Stages;
/// Number of cp.async instructions to load on group of operand A
static int const kAccessesPerGroupA =
(AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) /
Base::kWarpGemmIterations;
/// Number of cp.async instructions to load on group of operand B
static int const kAccessesPerGroupB =
(AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) /
Base::kWarpGemmIterations;
};
static bool const kSmemContainsEntireMat = kMaxK <= Shape::kK * Stages;
static constexpr int kNumStagesConcurrentLoad =
kSmemContainsEntireMat ? Stages : Stages - 1;
private:
using WarpLoadedFragmentA = typename Operator::FragmentA;
using WarpLoadedFragmentB = typename Operator::FragmentB;
using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
private:
//
// Data members
//
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
bool prologue_done_;
// Set to `True` to ensure the accumulator will be zero outside the GEMM
// footprint
bool zero_outside_bounds_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
CustomMmaMultistage(
///< Shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorageA& shared_storageA,
typename Base::SharedStorageB& shared_storageB,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx),
smem_iterator_A_(shared_storageA.ref(), thread_idx),
smem_iterator_B_(shared_storageB.ref(), thread_idx),
prologue_done_(false),
zero_outside_bounds_(false) {
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset(
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset(
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
}
CUTLASS_DEVICE
CustomMmaMultistage(
///< Shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorage& st,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: CustomMmaMultistage(
st.operand_A,
st.operand_B,
thread_idx,
warp_idx,
lane_idx) {}
CUTLASS_DEVICE
bool set_prologue_done(bool value) {
prologue_done_ = value;
}
CUTLASS_DEVICE
bool set_zero_outside_bounds(bool value) {
zero_outside_bounds_ = value;
}
template <bool kLoadA = true, bool kLoadB = true>
CUTLASS_DEVICE static void prologue(
typename Base::SharedStorage& shared_storage,
///< iterator over A operand in global memory
IteratorA iterator_A,
///< iterator over B operand in global memory
IteratorB iterator_B,
int thread_idx,
int problem_size_k) {
prologue<kLoadA, kLoadB>(
shared_storage.operand_A,
shared_storage.operand_B,
iterator_A,
iterator_B,
thread_idx,
problem_size_k);
}
template <bool kLoadA = true, bool kLoadB = true>
CUTLASS_DEVICE static void prologue(
typename Base::SharedStorageA& shared_storageA,
typename Base::SharedStorageB& shared_storageB,
///< iterator over A operand in global memory
IteratorA iterator_A,
///< iterator over B operand in global memory
IteratorB iterator_B,
int thread_idx,
int problem_size_k) {
SmemIteratorA smem_iterator_A(shared_storageA.ref(), thread_idx);
SmemIteratorB smem_iterator_B(shared_storageB.ref(), thread_idx);
int32_t iter = (problem_size_k + Base::Shape::kK - 1) / Base::Shape::kK;
_prologue<kLoadA, kLoadB>(
iterator_A, iterator_B, iter, smem_iterator_A, smem_iterator_B);
}
CUTLASS_DEVICE
void copy_tiles_and_advance(
IteratorA& iterator_A,
IteratorB& iterator_B,
int group_start_A = 0,
int group_start_B = 0) {
iterator_A.set_iteration_index(
group_start_A * IteratorA::kAccessesPerVector);
this->smem_iterator_A_.set_iteration_index(group_start_A);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) {
if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) {
typename IteratorA::AccessType* dst_ptr =
reinterpret_cast<typename IteratorA::AccessType*>(
this->smem_iterator_A_.get());
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
IteratorA::ThreadMap::kElementsPerAccess /
IteratorA::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_A.get();
if (zero_outside_bounds_ ||
SharedMemoryClear == SharedMemoryClearOption::kZfill) {
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, gmem_ptr, iterator_A.valid());
} else {
cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(
dst_ptr + v, gmem_ptr, iterator_A.valid());
}
++iterator_A;
}
++this->smem_iterator_A_;
}
}
iterator_B.set_iteration_index(
group_start_B * IteratorB::kAccessesPerVector);
this->smem_iterator_B_.set_iteration_index(group_start_B);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
typename IteratorB::AccessType* dst_ptr =
reinterpret_cast<typename IteratorB::AccessType*>(
this->smem_iterator_B_.get());
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
IteratorB::ThreadMap::kElementsPerAccess /
IteratorB::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_B.get();
if (zero_outside_bounds_ ||
SharedMemoryClear == SharedMemoryClearOption::kZfill) {
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, gmem_ptr, iterator_B.valid());
} else {
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(
dst_ptr + v, gmem_ptr, iterator_B.valid());
}
++iterator_B;
}
++this->smem_iterator_B_;
}
}
}
template <bool kLoadA = true, bool kLoadB = true>
CUTLASS_DEVICE static void _prologue(
IteratorA& iterator_A,
IteratorB& iterator_B,
int32_t& gemm_k_iterations,
SmemIteratorA& smem_iterator_A_,
SmemIteratorB& smem_iterator_B_) {
// Issue several complete stages
CUTLASS_PRAGMA_UNROLL
for (int stage = 0; stage < kNumStagesConcurrentLoad;
++stage, --gemm_k_iterations) {
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.set_iteration_index(0);
smem_iterator_A_.set_iteration_index(0);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
typename IteratorA::AccessType* dst_ptr =
reinterpret_cast<typename IteratorA::AccessType*>(
smem_iterator_A_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
int const kSrcBytes =
sizeof_bits<typename IteratorA::Element>::value *
IteratorA::ThreadMap::kElementsPerAccess /
IteratorA::kAccessesPerVector / 8;
int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
if (kLoadA) {
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, iterator_A.get(), iterator_A.valid());
}
++iterator_A;
}
++smem_iterator_A_;
}
iterator_B.set_iteration_index(0);
smem_iterator_B_.set_iteration_index(0);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
typename IteratorB::AccessType* dst_ptr =
reinterpret_cast<typename IteratorB::AccessType*>(
smem_iterator_B_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
int const kSrcBytes =
sizeof_bits<typename IteratorB::Element>::value *
IteratorB::ThreadMap::kElementsPerAccess /
IteratorB::kAccessesPerVector / 8;
if (kLoadB) {
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, iterator_B.get(), iterator_B.valid());
}
++iterator_B;
}
++smem_iterator_B_;
}
// Move to the next stage
iterator_A.add_tile_offset({0, 1});
iterator_B.add_tile_offset({1, 0});
smem_iterator_A_.add_tile_offset({0, 1});
smem_iterator_B_.add_tile_offset({1, 0});
// Defines the boundary of a stage of cp.async.
cutlass::arch::cp_async_fence();
}
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(
///< problem size of GEMM
int gemm_k_iterations,
///< destination accumulator tile
FragmentC& accum,
///< iterator over A operand in global memory
IteratorA iterator_A,
///< iterator over B operand in global memory
IteratorB iterator_B,
///< initial value of accumulator
FragmentC const& src_accum) {
//
// Prologue
//
if (!prologue_done_) {
_prologue<true, true>(
iterator_A,
iterator_B,
gemm_k_iterations,
smem_iterator_A_,
smem_iterator_B_);
} else if (!kSmemContainsEntireMat) {
_prologue<false, false>(
iterator_A,
iterator_B,
gemm_k_iterations,
smem_iterator_A_,
smem_iterator_B_);
} else {
gemm_k_iterations -= kNumStagesConcurrentLoad;
}
// Perform accumulation in the 'd' output operand
accum = src_accum;
//
// Clear the remaining tiles of SMEM. This is a functional requirement for
// some kernels so that all accumulator elements outside the GEMM footprint
// are zero.
//
if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) {
/// Iterator to write threadblock-scoped tile of A operand to shared
/// memory
SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
typename IteratorA::AccessType zero_A;
zero_A.clear();
last_smem_iterator_A.set_iteration_index(0);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
typename IteratorA::AccessType* dst_ptr =
reinterpret_cast<typename IteratorA::AccessType*>(
last_smem_iterator_A.get());
*dst_ptr = zero_A;
++last_smem_iterator_A;
}
/// Iterator to write threadblock-scoped tile of B operand to shared
/// memory
SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
typename IteratorB::AccessType zero_B;
zero_B.clear();
last_smem_iterator_B.set_iteration_index(0);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
typename IteratorB::AccessType* dst_ptr =
reinterpret_cast<typename IteratorB::AccessType*>(
last_smem_iterator_B.get());
*dst_ptr = zero_B;
++last_smem_iterator_B;
}
}
// Waits until kStages-2 stages have committed.
cutlass::arch::cp_async_wait<kNumStagesConcurrentLoad - 1>();
__syncthreads();
// Pair of fragments used to overlap shared memory loads and math
// instructions
WarpLoadedFragmentA warp_loaded_frag_A[2];
WarpLoadedFragmentB warp_loaded_frag_B[2];
WarpTransformedFragmentA warp_transformed_frag_A[2];
WarpTransformedFragmentB warp_transformed_frag_B[2];
Operator warp_mma;
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]);
this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
warp_mma.transform(
warp_transformed_frag_A[0],
warp_transformed_frag_B[0],
warp_loaded_frag_A[0],
warp_loaded_frag_B[0]);
// tf32x3 kernels use staging accumulation. warp_mma uses a temporary
// accumulator and this temporary accumulator is added to the final
// accumulator once in every mainloop iteration.
plus<FragmentC> plus_accum;
FragmentC tmp_accum;
if (platform::is_same<
typename Operator::MathOperator,
arch::OpMultiplyAddFastF32>::value ||
platform::is_same<
typename Operator::MathOperator,
arch::OpMultiplyAddComplexFastF32>::value) {
tmp_accum.clear();
}
//
// Mainloop
//
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > (-kNumStagesConcurrentLoad);) {
//
// Loop over GEMM K dimension
//
// Computes a warp-level GEMM on data held in shared memory
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations;
++warp_mma_k) {
// Load warp-level tiles from shared memory, wrapping to k offset if
// this is the last group as the case may be.
this->warp_tile_iterator_A_.set_kgroup_index(
(warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_B_.set_kgroup_index(
(warp_mma_k + 1) % Base::kWarpGemmIterations);
// In case of a non-circular buffer ("kSmemContainsEntireMat")
// make sure we don't load out of bounds data.
if (!kSmemContainsEntireMat ||
gemm_k_iterations > (-kNumStagesConcurrentLoad) ||
warp_mma_k < Base::kWarpGemmIterations - 1) {
this->warp_tile_iterator_A_.load(
warp_loaded_frag_A[(warp_mma_k + 1) % 2]);
this->warp_tile_iterator_B_.load(
warp_loaded_frag_B[(warp_mma_k + 1) % 2]);
}
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
if (warp_mma_k > 0)
warp_mma.transform(
warp_transformed_frag_A[warp_mma_k % 2],
warp_transformed_frag_B[warp_mma_k % 2],
warp_loaded_frag_A[warp_mma_k % 2],
warp_loaded_frag_B[warp_mma_k % 2]);
if (platform::is_same<
typename Operator::MathOperator,
arch::OpMultiplyAddFastF32>::value ||
platform::is_same<
typename Operator::MathOperator,
arch::OpMultiplyAddComplexFastF32>::value) {
warp_mma(
tmp_accum,
warp_transformed_frag_A[warp_mma_k % 2],
warp_transformed_frag_B[warp_mma_k % 2],
tmp_accum);
if (warp_mma_k == 0) {
accum = plus_accum(accum, tmp_accum);
tmp_accum.clear();
}
} else {
warp_mma(
accum,
warp_transformed_frag_A[warp_mma_k % 2],
warp_transformed_frag_B[warp_mma_k % 2],
accum);
}
// Issue global->shared copies for the this stage
if (!kSmemContainsEntireMat &&
warp_mma_k < Base::kWarpGemmIterations - 1) {
int group_start_iteration_A, group_start_iteration_B;
group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
copy_tiles_and_advance(
iterator_A,
iterator_B,
group_start_iteration_A,
group_start_iteration_B);
}
if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
if (!kSmemContainsEntireMat) {
int group_start_iteration_A, group_start_iteration_B;
group_start_iteration_A =
(warp_mma_k + 1) * Detail::kAccessesPerGroupA;
group_start_iteration_B =
(warp_mma_k + 1) * Detail::kAccessesPerGroupB;
copy_tiles_and_advance(
iterator_A,
iterator_B,
group_start_iteration_A,
group_start_iteration_B);
}
// Inserts a memory fence between stages of cp.async instructions.
cutlass::arch::cp_async_fence();
// Waits until kStages-2 stages have committed.
cutlass::arch::cp_async_wait<kNumStagesConcurrentLoad - 1>();
__syncthreads();
// Move to the next stage
iterator_A.add_tile_offset({0, 1});
iterator_B.add_tile_offset({1, 0});
this->smem_iterator_A_.add_tile_offset({0, 1});
this->smem_iterator_B_.add_tile_offset({1, 0});
// Add negative offsets to return iterators to the 'start' of the
// circular buffer in shared memory
if (smem_write_stage_idx == (Base::kStages - 1)) {
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
smem_write_stage_idx = 0;
} else {
++smem_write_stage_idx;
}
if (!kSmemContainsEntireMat &&
smem_read_stage_idx == (Base::kStages - 1)) {
this->warp_tile_iterator_A_.add_tile_offset(
{0,
-Base::kStages * Policy::kPartitionsK *
Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK *
Base::kWarpGemmIterations,
0});
smem_read_stage_idx = 0;
} else {
++smem_read_stage_idx;
}
--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
}
// Do any conversions feeding the first stage at the end of the loop so
// we can start right away on mma instructions
if (warp_mma_k + 1 == Base::kWarpGemmIterations)
warp_mma.transform(
warp_transformed_frag_A[(warp_mma_k + 1) % 2],
warp_transformed_frag_B[(warp_mma_k + 1) % 2],
warp_loaded_frag_A[(warp_mma_k + 1) % 2],
warp_loaded_frag_B[(warp_mma_k + 1) % 2]);
}
}
if (platform::is_same<
typename Operator::MathOperator,
arch::OpMultiplyAddFastF32>::value ||
platform::is_same<
typename Operator::MathOperator,
arch::OpMultiplyAddComplexFastF32>::value) {
accum = plus_accum(accum, tmp_accum);
}
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
// commit and drain all pending and predicated LDGSTS pnz from the GEMM
// mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,401 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "custom_mma_base.h"
#include "cutlass/gemm/gemm.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Data type of accumulator matrix
typename ElementC_,
/// Data type of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Transformation applied to A operand
typename TransformA_ = NumericArrayConverter<
typename SmemIteratorA_::Element,
typename IteratorA_::Element,
IteratorA_::Fragment::kElements>,
///
/// Transformation applied to B operand
typename TransformB_ = NumericArrayConverter<
typename SmemIteratorB_::Element,
typename IteratorB_::Element,
IteratorB_::Fragment::kElements>,
/// Used for partial specialization
typename Enable = bool>
class CustomMmaPipelined : public CustomMmaBase<Shape_, Policy_, 2> {
public:
///< Base class
using Base = CustomMmaBase<Shape_, Policy_, 2>;
using Shape =
Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using IteratorA =
IteratorA_; ///< Iterates over tiles of A operand in global memory
using IteratorB =
IteratorB_; ///< Iterates over tiles of B operand in global memory
using ElementC = ElementC_; ///< Data type of accumulator matrix
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
using Policy = Policy_; ///< Policy describing tuning details
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
using TransformA = TransformA_;
using TransformB = TransformB_;
//
// Dependent types
//
/// Fragment of operand A loaded from global memory
using FragmentA = typename IteratorA::Fragment;
/// Fragment of operand B loaded from global memory
using FragmentB = typename IteratorB::Fragment;
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Obtain the arch tag from the warp-level operator
using ArchTag = typename Policy::Operator::ArchTag;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
static_assert(
(Base::kStages == 2),
"MmaPipelined requires kStages set to value 2");
static bool const kSmemContainsEntireMat = false;
private:
using WarpFragmentA = typename Operator::FragmentA;
using WarpFragmentB = typename Operator::FragmentB;
protected:
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
CustomMmaPipelined(
typename Base::SharedStorageA& shared_storageA,
typename Base::SharedStorageB& shared_storageB,
int thread_idx, ///< ID within the threadblock
int warp_idx, ///< ID of warp
int lane_idx ///< ID of each thread within a warp
)
: Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx),
smem_iterator_A_(shared_storageA.ref(), thread_idx),
smem_iterator_B_(shared_storageB.ref(), thread_idx) {
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset(
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset(
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
}
CUTLASS_DEVICE
CustomMmaPipelined(
///< Shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorage& st,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: CustomMmaPipelined(
st.operand_A,
st.operand_B,
thread_idx,
warp_idx,
lane_idx) {}
CUTLASS_DEVICE
bool set_prologue_done(bool value) {
// NOT IMPLEMENTED FOR PIPELINED
}
CUTLASS_DEVICE
bool set_zero_outside_bounds(bool value) {
// NOT NEEDED FOR PIPELINED
// shared memory will always be zero-filled
}
template <bool kLoadA = true, bool kLoadB = true>
CUTLASS_DEVICE static void prologue(
typename Base::SharedStorage& shared_storage,
///< iterator over A operand in global memory
IteratorA iterator_A,
///< iterator over B operand in global memory
IteratorB iterator_B,
int thread_idx,
int problem_size_k) {
prologue<kLoadA, kLoadB>(
shared_storage.operand_A,
shared_storage.operand_B,
iterator_A,
iterator_B,
thread_idx,
problem_size_k);
}
template <bool kLoadA = true, bool kLoadB = true>
CUTLASS_DEVICE static void prologue(
typename Base::SharedStorageA& shared_storageA,
typename Base::SharedStorageB& shared_storageB,
///< iterator over A operand in global memory
IteratorA iterator_A,
///< iterator over B operand in global memory
IteratorB iterator_B,
int thread_idx,
int problem_size_k) {
// NOT IMPLEMENTED FOR PIPELINED
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(
int gemm_k_iterations, ///< number of iterations of the mainloop
FragmentC& accum, ///< destination accumulator tile
IteratorA iterator_A, ///< iterator over A operand in global memory
IteratorB iterator_B, ///< iterator over B operand in global memory
FragmentC const& src_accum, ///< source accumulator tile
TransformA transform_A =
TransformA(), ///< transformation applied to A fragment
TransformB transform_B =
TransformB()) { ///< transformation applied to B fragment
//
// Prologue
//
// Perform accumulation in the 'd' output operand
accum = src_accum;
FragmentA tb_frag_A;
FragmentB tb_frag_B;
tb_frag_A.clear();
tb_frag_B.clear();
// The last kblock is loaded in the prolog
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
++iterator_A;
++iterator_B;
this->smem_iterator_A_.store(transform_A(tb_frag_A));
this->smem_iterator_B_.store(transform_B(tb_frag_B));
++this->smem_iterator_A_;
++this->smem_iterator_B_;
__syncthreads();
// Pair of fragments used to overlap shared memory loads and math
// instructions
WarpFragmentA warp_frag_A[2];
WarpFragmentB warp_frag_B[2];
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
Operator warp_mma;
int smem_write_stage_idx = 1;
// Avoid reading out of bounds
iterator_A.clear_mask(gemm_k_iterations <= 1);
iterator_B.clear_mask(gemm_k_iterations <= 1);
// Issue loads during the first warp-level matrix multiply-add *AFTER*
// issuing shared memory loads (which have the tighest latency requirement).
//
// Mainloop
//
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > 0; --gemm_k_iterations) {
//
// Loop over GEMM K dimension
//
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations;
++warp_mma_k) {
// Load warp-level tiles from shared memory, wrapping to k offset if
// this is the last group as the case may be.
if (warp_mma_k == Base::kWarpGemmIterations - 1) {
// Write fragments to shared memory
this->smem_iterator_A_.store(transform_A(tb_frag_A));
this->smem_iterator_B_.store(transform_B(tb_frag_B));
__syncthreads();
++this->smem_iterator_A_;
++this->smem_iterator_B_;
// Add negative offsets to return iterators to the 'start' of the
// circular buffer in shared memory
if (smem_write_stage_idx == 1) {
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
} else {
this->warp_tile_iterator_A_.add_tile_offset(
{0,
-Base::kStages * Policy::kPartitionsK *
Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK *
Base::kWarpGemmIterations,
0});
}
smem_write_stage_idx ^= 1;
}
this->warp_tile_iterator_A_.set_kgroup_index(
(warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_B_.set_kgroup_index(
(warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
if (warp_mma_k == 0) {
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
++iterator_A;
++iterator_B;
// Avoid reading out of bounds if this was the last loop iteration
iterator_A.clear_mask(gemm_k_iterations <= 2);
iterator_B.clear_mask(gemm_k_iterations <= 2);
}
warp_mma(
accum,
warp_frag_A[warp_mma_k % 2],
warp_frag_B[warp_mma_k % 2],
accum);
}
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,264 @@
#pragma once
#include "cutlass/arch/mma.h"
////////////////////////////////////////////////////////////////////////////////
// Some helper functions
////////////////////////////////////////////////////////////////////////////////
#define DISPATCH_TYPES(tensor, func) \
{ \
if (query.scalar_type() == at::ScalarType::Float) { \
using scalar_t = float; \
func(); \
} else if (query.scalar_type() == at::ScalarType::Half) { \
using scalar_t = cutlass::half_t; \
func(); \
} else if (query.scalar_type() == at::ScalarType::BFloat16) { \
using scalar_t = cutlass::bfloat16_t; \
func(); \
} else { \
TORCH_CHECK(false, "Only fp32, half & bf16 supported at the moment"); \
} \
}
#define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \
{ \
if (BOOL_V) { \
constexpr bool BOOL_NAME = true; \
F(); \
} else { \
constexpr bool BOOL_NAME = false; \
F(); \
} \
}
#define DISPATCH_ARCHTAG(CC, func) \
{ \
if (CC >= 80) { \
using ArchTag = cutlass::arch::Sm80; \
func(); \
} else if (CC >= 75) { \
using ArchTag = cutlass::arch::Sm75; \
func(); \
} else if (CC >= 70) { \
using ArchTag = cutlass::arch::Sm70; \
func(); \
} else if (CC >= 50) { \
using ArchTag = cutlass::arch::Sm50; \
func(); \
} else { \
TORCH_CHECK( \
false, \
"Your device is too old. We require compute capability >= 50"); \
} \
}
#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \
TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
TORCH_CHECK(TENSOR.is_contiguous());
#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \
TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
TORCH_CHECK( \
TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous");
#ifdef HAS_PYTORCH
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
TORCH_CHECK(uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned")
#define XFORMERS_CHECK TORCH_CHECK
#elif defined(__CUDACC_RTC__)
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \
return false; \
}
#define XFORMERS_CHECK(COND, ERR) \
if (!(COND)) { \
return false; \
}
#else
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \
std::cerr << #PTR " is not correctly aligned\n"; \
return false; \
}
#define XFORMERS_CHECK(COND, ERR) \
if (!(COND)) { \
std::cerr << #COND " failed\n"; \
return false; \
}
#endif
#define ASSIGN_CHECK_OVERFLOW(A, B) \
{ \
A = B; \
TORCH_CHECK( \
B < cutlass::platform::numeric_limits<decltype(A)>::max(), \
#B " overflows"); \
}
namespace gemm_kernel_utils {
#ifdef HAS_PYTORCH
template <typename scalar_t>
struct TypeTraits;
template <>
struct TypeTraits<cutlass::half_t> {
using scalar_t = cutlass::half_t;
static constexpr __host__ at::ScalarType atScalarType() {
return at::ScalarType::Half;
}
template <int nDim>
static __host__ at::PackedTensorAccessor32<scalar_t, nDim> packed_accessor(
at::Tensor const& tensor) {
return at::PackedTensorAccessor32<scalar_t, nDim>(
(scalar_t*)(tensor.data_ptr()),
tensor.sizes().data(),
tensor.strides().data());
}
};
template <>
struct TypeTraits<cutlass::bfloat16_t> {
using scalar_t = cutlass::bfloat16_t;
static constexpr __host__ at::ScalarType atScalarType() {
return at::ScalarType::BFloat16;
}
template <int nDim>
static __host__ at::PackedTensorAccessor32<scalar_t, nDim> packed_accessor(
at::Tensor const& tensor) {
return at::PackedTensorAccessor32<scalar_t, nDim>(
(scalar_t*)(tensor.data_ptr()),
tensor.sizes().data(),
tensor.strides().data());
}
};
template <>
struct TypeTraits<float> {
using scalar_t = float;
static constexpr __host__ at::ScalarType atScalarType() {
return at::ScalarType::Float;
}
template <int nDim>
static __host__ at::PackedTensorAccessor32<scalar_t, nDim> packed_accessor(
at::Tensor const& tensor) {
return tensor.packed_accessor32<scalar_t, nDim>();
}
};
#endif
template <typename integer>
constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) {
return (n + m - 1) / m;
}
////////////////////////////////////////////////////////////////////////////////
// Determine the type of GEMM we do (TensorCores or not, Shapes ...)
// TODO: Maybe we could rely on Cutlass's DefaultGemm templates
////////////////////////////////////////////////////////////////////////////////
// Fallback to Simt (FMA on cuda cores) if not in a special case below
template <typename ArchTag, typename scalar_t_, typename Enable = void>
struct DefaultGemmType {
static constexpr int ThreadK = 8;
static constexpr int WarpK = 8;
static constexpr int kMinimumAlignment = 1;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
using OpClass = cutlass::arch::OpClassSimt;
using Operator = cutlass::arch::OpMultiplyAdd;
};
// Specialization for tensorcores with f32
template <typename ArchTag>
struct DefaultGemmType<
ArchTag,
float,
typename cutlass::platform::enable_if<
ArchTag::kMinComputeCapability >= 80>::type> {
static constexpr int ThreadK = 32;
static constexpr int WarpK = 32;
static constexpr int kMinimumAlignment = 4;
using OpClass = cutlass::arch::OpClassTensorOp;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using Operator = cutlass::arch::OpMultiplyAddFastF32;
};
// Specialization for tensorcores with f16/bf16 - Sm75+
template <typename ArchTag, typename scalar_t>
struct DefaultGemmType<
ArchTag,
scalar_t,
typename cutlass::platform::enable_if<
ArchTag::kMinComputeCapability >= 75 &&
cutlass::sizeof_bits<scalar_t>::value == 16>::type> {
static constexpr int ThreadK = 32;
static constexpr int WarpK = 32;
static constexpr int kMinimumAlignment = 4;
using OpClass = cutlass::arch::OpClassTensorOp;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using Operator = cutlass::arch::OpMultiplyAdd;
};
// Specialization for tensorcores with f16 - Volta
template <>
struct DefaultGemmType<cutlass::arch::Sm70, cutlass::half_t, void> {
static constexpr int ThreadK = 32;
static constexpr int WarpK = 32;
static constexpr int kMinimumAlignment = 2;
using OpClass = cutlass::arch::OpClassTensorOp;
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
using Operator = cutlass::arch::OpMultiplyAdd;
};
// Enables to do
// `auto x = kCondition ? fa(arg) : fb(arg)`
// when `fa` and `fb` have different types
template <bool kVal, typename TA, typename TB>
struct call_conditional;
template <typename TA, typename TB>
struct call_conditional<true, TA, TB> {
template <typename Arg>
static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg)
-> decltype(ta(arg)) {
return ta(arg);
}
};
template <typename TA, typename TB>
struct call_conditional<false, TA, TB> {
template <typename Arg>
static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg)
-> decltype(tb(arg)) {
return tb(arg);
}
};
////////////////////////////////////////////////////////////////////////////////
// Mark a variable as warp-uniform - enables some compiler optimizations
// The cheapest way to do it is just to broadcast it from lane 0
////////////////////////////////////////////////////////////////////////////////
CUTLASS_DEVICE int32_t warp_uniform(int32_t value) {
return (int32_t)__shfl_sync(0xffffffff, (unsigned)value, 0);
}
template <typename T>
CUTLASS_DEVICE T* warp_uniform(T* ptr) {
struct {
union {
T* ptr;
uint32_t asInt[2];
};
} p;
p.ptr = ptr;
p.asInt[0] = warp_uniform(p.asInt[0]);
p.asInt[1] = warp_uniform(p.asInt[1]);
return p.ptr;
}
} // namespace gemm_kernel_utils

View File

@ -0,0 +1,752 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue iterator that supports prefetching
Mostly copied from "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
*/
#pragma once
#include "cutlass/arch/arch.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/threadblock/output_tile_thread_map.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
////////////////////////////////////////////////////////////////////////////////
namespace epilogue {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Tile iterator used to load and store output tile from global memory in
/// epilogue.
///
/// Satisfies: ReadableTileIterator | PredicatedTileIterator |
/// ForwardTileIterator
///
template <
typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
typename Element_, ///< Element data type
bool ScatterD = false, ///< Scatter D operand or not
bool UseCUDAStore = false>
class PredicatedTileIteratorPrefetch {
public:
using ThreadMap = ThreadMap_;
using Shape = typename ThreadMap::Shape;
using Element = Element_;
using Layout = layout::RowMajor;
using TensorRef = TensorRef<Element, Layout>;
using ConstTensorRef = typename TensorRef::ConstTensorRef;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorCoord = MatrixCoord;
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
static int const kThreads = ThreadMap::kThreads;
static int const kIterations = ThreadMap::Count::kTile;
static_assert(
ThreadMap::Iterations::kRow > 0,
"ThreadMap::Iterations::kRow must be > 0");
static_assert(
ThreadMap::Iterations::kGroup > 0,
"ThreadMap::Iterations::kGroup must be > 0");
static_assert(
ThreadMap::Iterations::kCluster > 0,
"ThreadMap::Iterations::kCluster must be > 0");
static_assert(
ThreadMap::Iterations::kColumn > 0,
"ThreadMap::Iterations::kColumn must be > 0");
/// Fragment object
using Fragment = Array<
Element,
ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow *
ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster *
ThreadMap::kElementsPerAccess>;
/// Memory access size
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
//
// Parameters struct
//
/// Uses a non-template class
struct Params : PredicatedTileIteratorParams {
using Base = PredicatedTileIteratorParams;
CUTLASS_HOST_DEVICE
Params() {}
CUTLASS_HOST_DEVICE
Params(Layout const& layout)
: PredicatedTileIteratorParams(
layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess,
make_OutputTileThreadMapDesc<ThreadMap>()) {}
CUTLASS_HOST_DEVICE
Params(Base const& base) : Base(base) {}
};
/// Mask object
struct Mask {
static int const kCount = ThreadMap::Iterations::kColumn;
/// Predicate state
bool predicates[kCount];
//
// Mask
//
CUTLASS_HOST_DEVICE
Mask() {
enable();
}
///< Efficiently disables all accesses guarded by mask
CUTLASS_HOST_DEVICE void clear() {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kCount; ++i) {
predicates[i] = false;
}
}
///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask
CUTLASS_DEVICE void enable() {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kCount; ++i) {
predicates[i] = true;
}
}
};
private:
//
// Data members
//
/// Parameters structure containing reference and precomputed state.
PredicatedTileIteratorParams params_;
/// Byte-level pointer
uint8_t* byte_pointer_;
/// Array of boolean values to contain steady-state predicates
Mask mask_;
/// Extent of the matrix tile in rows
Index extent_row_;
/// Extent of the matrix tile in rows
Index extent_column_;
/// A thread's starting row position (assuming steady-state predicates have
/// been computed)
Index thread_start_row_;
/// A thread's starting column
Index thread_start_column_;
/// Internal state counter
int state_[3];
/// Scatter indices
int const* indices_;
//
// Static asserts about internal strides
//
static_assert(sizeof(extent_row_) == 4, "Expected 32b extents");
static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents");
static_assert(
sizeof(PredicatedTileIteratorParams::stride) == 8,
"Expected 64b strides");
private:
//
// Methods
//
public:
//
// Methods
//
/// Constructor
CUTLASS_DEVICE
PredicatedTileIteratorPrefetch(
PredicatedTileIteratorParams const& params,
Element* pointer,
TensorCoord extent,
int thread_idx,
TensorCoord threadblock_offset = TensorCoord(),
int const* indices = nullptr)
: params_(params), indices_(indices) {
TensorCoord thread_offset =
ThreadMap::initial_offset(thread_idx) + threadblock_offset;
extent_row_ = extent.row();
extent_column_ = extent.column();
thread_start_row_ = thread_offset.row();
thread_start_column_ = thread_offset.column();
// Initialize predicates
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) {
mask_.predicates[c] =
((thread_offset.column() + ThreadMap::Delta::kColumn * c) <
extent.column());
}
// Null pointer performs no accesses
if (!pointer) {
mask_.clear();
}
if (ScatterD && !indices) {
mask_.clear();
}
// Initialize pointer
byte_pointer_ = reinterpret_cast<uint8_t*>(pointer) +
LongIndex(thread_offset.row()) * LongIndex(params_.stride) +
LongIndex(thread_offset.column()) * sizeof(AccessType) /
kElementsPerAccess;
if (ScatterD) {
byte_pointer_ = reinterpret_cast<uint8_t*>(pointer) +
LongIndex(thread_offset.column()) * sizeof(AccessType) /
kElementsPerAccess;
}
// Initialize internal state counter
state_[0] = state_[1] = state_[2] = 0;
}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
byte_pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
}
CUTLASS_DEVICE
void prefetch_all() {
CUTLASS_PRAGMA_UNROLL
for (int iter = 0; iter < kIterations; ++iter) {
prefetch();
++(*this);
}
}
CUTLASS_DEVICE
void prefetch() {
uint8_t* byte_pointer = byte_pointer_;
CUTLASS_PRAGMA_UNROLL
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster;
++cluster) {
CUTLASS_PRAGMA_UNROLL
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int row_offset = row * ThreadMap::Delta::kRow +
group * ThreadMap::Delta::kGroup +
cluster * ThreadMap::Delta::kCluster;
AccessType* memory_pointer =
reinterpret_cast<AccessType*>(byte_pointer);
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn;
++column) {
// on windows using unsigned long here gives the error
// error: asm operand type size(4) does not match
// type/size implied by constraint 'l'
uint64_t addr = (uint64_t)(
(void*)&memory_pointer
[column * ThreadMap::Delta::kColumn / kElementsPerAccess]);
asm volatile("prefetch.global.L1 [ %1 ];" : "=l"(addr) : "l"(addr));
}
if (row + 1 < ThreadMap::Iterations::kRow) {
if (!ScatterD) {
byte_pointer += params_.increment_row;
}
}
}
if (group + 1 < ThreadMap::Iterations::kGroup) {
byte_pointer += params_.increment_group;
}
}
if (cluster + 1 < ThreadMap::Iterations::kCluster) {
byte_pointer += params_.increment_cluster;
}
}
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const {
uint8_t* byte_pointer = byte_pointer_;
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster;
++cluster) {
CUTLASS_PRAGMA_UNROLL
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int frag_row_idx =
(row +
ThreadMap::Iterations::kRow *
(group + ThreadMap::Iterations::kGroup * cluster));
int row_offset = row * ThreadMap::Delta::kRow +
group * ThreadMap::Delta::kGroup +
cluster * ThreadMap::Delta::kCluster;
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
AccessType* memory_pointer =
reinterpret_cast<AccessType*>(byte_pointer + byte_offset);
if (ScatterD && row_guard) {
assert(indices_);
memory_pointer = reinterpret_cast<AccessType*>(
byte_pointer + byte_offset +
LongIndex(indices_[row_offset + thread_start_row_]) *
LongIndex(params_.stride));
}
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn;
++column) {
bool guard = row_guard && mask_.predicates[column];
cutlass::arch::global_load<AccessType, sizeof(AccessType)>(
frag_ptr
[frag_row_idx * ThreadMap::Iterations::kColumn + column],
(void*)&memory_pointer
[column * ThreadMap::Delta::kColumn / kElementsPerAccess],
guard);
}
if (row + 1 < ThreadMap::Iterations::kRow) {
if (!ScatterD) {
byte_pointer += params_.increment_row;
}
}
}
if (group + 1 < ThreadMap::Iterations::kGroup) {
byte_pointer += params_.increment_group;
}
}
if (cluster + 1 < ThreadMap::Iterations::kCluster) {
byte_pointer += params_.increment_cluster;
}
}
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load(Fragment& frag) const {
load_with_byte_offset(frag, 0);
}
/// Stores a fragment to memory
CUTLASS_DEVICE
void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) const {
uint8_t* byte_pointer = byte_pointer_;
AccessType const* frag_ptr = reinterpret_cast<AccessType const*>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster;
++cluster) {
CUTLASS_PRAGMA_UNROLL
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int frag_row_idx =
(row +
ThreadMap::Iterations::kRow *
(group + ThreadMap::Iterations::kGroup * cluster));
int row_offset = row * ThreadMap::Delta::kRow +
group * ThreadMap::Delta::kGroup +
cluster * ThreadMap::Delta::kCluster;
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
AccessType* memory_pointer =
reinterpret_cast<AccessType*>(byte_pointer + byte_offset);
if (ScatterD && row_guard) {
assert(indices_);
memory_pointer = reinterpret_cast<AccessType*>(
byte_pointer + byte_offset +
LongIndex(indices_[row_offset + thread_start_row_]) *
LongIndex(params_.stride));
}
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn;
++column) {
bool guard = row_guard && mask_.predicates[column];
if (UseCUDAStore) {
if (guard) {
memory_pointer
[column * ThreadMap::Delta::kColumn / kElementsPerAccess] =
frag_ptr
[frag_row_idx * ThreadMap::Iterations::kColumn +
column];
}
} else {
cutlass::arch::global_store<AccessType, sizeof(AccessType)>(
frag_ptr
[frag_row_idx * ThreadMap::Iterations::kColumn + column],
(void*)&memory_pointer
[column * ThreadMap::Delta::kColumn / kElementsPerAccess],
guard);
}
}
if (row + 1 < ThreadMap::Iterations::kRow) {
if (!ScatterD) {
byte_pointer += params_.increment_row;
}
}
}
if (group + 1 < ThreadMap::Iterations::kGroup) {
byte_pointer += params_.increment_group;
}
}
if (cluster + 1 < ThreadMap::Iterations::kCluster) {
byte_pointer += params_.increment_cluster;
}
}
}
/// Stores a fragment to memory
CUTLASS_DEVICE
void store(Fragment const& frag) const {
store_with_byte_offset(frag, 0);
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void downsample_load_with_byte_offset(
Fragment& frag,
int64_t byte_offset,
int convolution_P,
int convolution_Q,
int add_P,
int add_Q,
int problem_N) const {
uint8_t* byte_pointer = byte_pointer_;
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster;
++cluster) {
CUTLASS_PRAGMA_UNROLL
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int frag_row_idx =
(row +
ThreadMap::Iterations::kRow *
(group + ThreadMap::Iterations::kGroup * cluster));
int row_offset = row * ThreadMap::Delta::kRow +
group * ThreadMap::Delta::kGroup +
cluster * ThreadMap::Delta::kCluster;
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
int output_row = row_offset + thread_start_row_;
int output_N = output_row / (convolution_P * convolution_Q);
int output_PQ = output_row % (convolution_P * convolution_Q);
int output_P = output_PQ / convolution_Q;
int output_Q = output_PQ % convolution_Q;
int input_row = output_N * 2 * convolution_P * 2 * convolution_Q +
(2 * output_P + add_P) * 2 * convolution_Q + 2 * output_Q + add_Q;
int64_t byte_offset =
(input_row - output_row) * problem_N * sizeof(float);
AccessType* memory_pointer =
reinterpret_cast<AccessType*>(byte_pointer + byte_offset);
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn;
++column) {
bool guard = row_guard && mask_.predicates[column];
cutlass::arch::global_load<AccessType, sizeof(AccessType)>(
frag_ptr
[frag_row_idx * ThreadMap::Iterations::kColumn + column],
(void*)&memory_pointer
[column * ThreadMap::Delta::kColumn / kElementsPerAccess],
guard);
}
if (row + 1 < ThreadMap::Iterations::kRow) {
byte_pointer += params_.increment_row;
}
}
if (group + 1 < ThreadMap::Iterations::kGroup) {
byte_pointer += params_.increment_group;
}
}
if (cluster + 1 < ThreadMap::Iterations::kCluster) {
byte_pointer += params_.increment_cluster;
}
}
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void upsample_load_with_byte_offset(
Fragment& frag,
int64_t byte_offset,
int convolution_P,
int convolution_Q,
int add_P,
int add_Q,
int problem_N) const {
uint8_t* byte_pointer = byte_pointer_;
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster;
++cluster) {
CUTLASS_PRAGMA_UNROLL
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int frag_row_idx =
(row +
ThreadMap::Iterations::kRow *
(group + ThreadMap::Iterations::kGroup * cluster));
int row_offset = row * ThreadMap::Delta::kRow +
group * ThreadMap::Delta::kGroup +
cluster * ThreadMap::Delta::kCluster;
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
int output_row = row_offset + thread_start_row_;
int output_N = output_row / (convolution_P * convolution_Q);
int output_PQ = output_row % (convolution_P * convolution_Q);
int output_P = output_PQ / convolution_Q;
int output_Q = output_PQ % convolution_Q;
int row_add_P = add_P;
int row_add_Q = add_Q;
if (output_P > convolution_P - 2)
row_add_P = 0;
if (output_Q > convolution_Q - 2)
row_add_Q = 0;
int input_row = output_N * (convolution_P / 2) * (convolution_Q / 2) +
((output_P + row_add_P) / 2) * (convolution_Q / 2) +
(output_Q + row_add_Q) / 2;
int64_t byte_offset =
(input_row - output_row) * problem_N * sizeof(float);
AccessType* memory_pointer =
reinterpret_cast<AccessType*>(byte_pointer + byte_offset);
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn;
++column) {
bool guard = row_guard && mask_.predicates[column];
cutlass::arch::global_load<AccessType, sizeof(AccessType)>(
frag_ptr
[frag_row_idx * ThreadMap::Iterations::kColumn + column],
(void*)&memory_pointer
[column * ThreadMap::Delta::kColumn / kElementsPerAccess],
guard);
}
if (row + 1 < ThreadMap::Iterations::kRow) {
byte_pointer += params_.increment_row;
}
}
if (group + 1 < ThreadMap::Iterations::kGroup) {
byte_pointer += params_.increment_group;
}
}
if (cluster + 1 < ThreadMap::Iterations::kCluster) {
byte_pointer += params_.increment_cluster;
}
}
}
CUTLASS_DEVICE
MatrixCoord thread_start() const {
return MatrixCoord(thread_start_row_, thread_start_column_);
}
/// Need to get the thread start row from the tile iterator
CUTLASS_DEVICE
int32_t thread_start_row() const {
return thread_start_row_;
}
/// Need to get the thread start row from the tile iterator
CUTLASS_DEVICE
int32_t thread_start_column() const {
return thread_start_column_;
}
/// Extent of the matrix in rows
CUTLASS_DEVICE
Index extent_row() const {
return extent_row_;
}
/// Extent of the matrix in columns
CUTLASS_DEVICE
Index extent_column() const {
return extent_column_;
}
/// Advances to the next position to load or store
CUTLASS_HOST_DEVICE
PredicatedTileIteratorPrefetch& operator++() {
++state_[0];
if (!ScatterD) {
byte_pointer_ += params_.advance_row;
}
thread_start_row_ += ThreadMap::Shape::kRow;
if (state_[0] == ThreadMap::Count::kRow) {
state_[0] = 0;
++state_[1];
byte_pointer_ += params_.advance_group;
thread_start_row_ += (ThreadMap::Shape::kGroup - 1) *
ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
if (state_[1] == ThreadMap::Count::kGroup) {
state_[1] = 0;
++state_[2];
byte_pointer_ += params_.advance_cluster;
thread_start_row_ += ThreadMap::Count::kGroup *
ThreadMap::Shape::kGroup * ThreadMap::Count::kRow *
ThreadMap::Shape::kRow;
if (state_[2] == ThreadMap::Count::kCluster) {
state_[2] = 0;
byte_pointer_ += params_.advance_tile;
}
}
}
return *this;
}
///< Efficiently disables all accesses guarded by mask
CUTLASS_DEVICE void clear_mask() {
mask_.clear();
}
///< Efficiently enables all accesses guarded by mask
CUTLASS_DEVICE void enable_mask() {
mask_.enable();
}
///< Sets the mask
CUTLASS_DEVICE void get_mask(Mask& mask) const {
mask = mask_;
}
///< Sets the mask
CUTLASS_DEVICE void set_mask(Mask const& mask) {
mask_ = mask;
}
};
template <typename IT>
struct MakePrefetchableIterator {
using Iterator = PredicatedTileIteratorPrefetch<
typename IT::ThreadMap,
typename IT::Element>;
};
///////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,66 @@
#pragma once
#include "predicated_tile_access_iterator_residual_last.h"
#include "predicated_tile_iterator_residual_last.h"
namespace cutlass {
namespace transform {
namespace threadblock {
template <typename BaseIterator>
struct MakeIteratorResidualLast;
template <
typename Shape,
typename Element,
typename Layout,
int AdvanceRank,
typename ThreadMap,
int AccessSize,
bool Gather>
struct MakeIteratorResidualLast<PredicatedTileIterator<
Shape,
Element,
Layout,
AdvanceRank,
ThreadMap,
AccessSize,
Gather>> {
using Iterator = PredicatedTileIteratorResidualLast<
Shape,
Element,
Layout,
AdvanceRank,
ThreadMap,
AccessSize,
Gather>;
};
template <
typename Shape,
typename Element,
typename Layout,
int AdvanceRank,
typename ThreadMap,
typename AccessType,
bool Gather>
struct MakeIteratorResidualLast<PredicatedTileAccessIterator<
Shape,
Element,
Layout,
AdvanceRank,
ThreadMap,
AccessType,
Gather>> {
using Iterator = PredicatedTileAccessIteratorResidualLast<
Shape,
Element,
Layout,
AdvanceRank,
ThreadMap,
AccessType,
Gather>;
};
} // namespace threadblock
} // namespace transform
} // namespace cutlass

View File

@ -0,0 +1,916 @@
#ifdef HAS_PYTORCH
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>
#endif
#include <cmath>
#include <vector>
#include "cutlass/bfloat16.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/vector.h"
#include "cutlass/numeric_types.h"
#include "attention_scaling_coefs_updater.h"
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/kernel/default_gemm.h"
#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/platform/platform.h"
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
#include "debug_utils.h"
#include "epilogue_pipelined.h"
#include "epilogue_rescale_output.h"
#include "find_default_mma.h"
#include "gemm_kernel_utils.h"
#include "mma_from_smem.h"
#include <inttypes.h>
using namespace gemm_kernel_utils;
namespace {
template <typename scalar_t, typename Arch>
constexpr int getWarpsPerSm() {
return (
Arch::kMinComputeCapability >= 80 &&
!cutlass::platform::is_same<scalar_t, float>::value
? 16
: 12);
}
} // namespace
template <
// The datatype of Q/K/V
typename scalar_t_,
// Architecture we are targeting (eg `cutlass::arch::Sm80`)
typename ArchTag,
// If Q/K/V are correctly aligned in memory and we can run a fast kernel
bool isAligned_,
int kQueriesPerBlock,
int kKeysPerBlock,
bool kSingleValueIteration // = `value.shape[-1] <= kKeysPerBlock`
>
struct AttentionKernel {
using scalar_t = scalar_t_;
using accum_t = float;
using lse_scalar_t = float;
using output_t = scalar_t;
// Accumulator between 2 iterations
// Using `accum_t` improves perf on f16 at the cost of
// numerical errors
using output_accum_t = accum_t;
static constexpr bool kIsAligned = isAligned_;
static constexpr int32_t kAlignLSE = 32; // block size of backward
static constexpr bool kPreloadV = ArchTag::kMinComputeCapability >= 80 &&
cutlass::sizeof_bits<scalar_t>::value == 16;
static constexpr bool kKeepOutputInRF = kSingleValueIteration;
static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF &&
!cutlass::platform::is_same<output_accum_t, output_t>::value;
static_assert(kQueriesPerBlock % 32 == 0, "");
static_assert(kKeysPerBlock % 32 == 0, "");
static constexpr int kNumWarpsPerBlock =
kQueriesPerBlock * kKeysPerBlock / (32 * 32);
static constexpr int kWarpSize = 32;
// Launch bounds
static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock;
static constexpr int kMinBlocksPerSm =
getWarpsPerSm<scalar_t, ArchTag>() / kNumWarpsPerBlock;
struct Params {
// Input tensors
scalar_t* query_ptr; // [num_queries, num_heads, head_dim]
scalar_t* key_ptr; // [num_keys, num_heads, head_dim]
scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value]
int32_t* cu_seqlens_q_ptr = nullptr;
int32_t* cu_seqlens_k_ptr = nullptr;
// Output tensors
output_t* output_ptr; // [num_queries, num_heads, head_dim_value]
output_accum_t*
output_accum_ptr; // [num_queries, num_heads, head_dim_value]
lse_scalar_t* logsumexp_ptr; // [num_heads, num_queries] - can be null
// Dimensions/strides
int32_t head_dim;
int32_t head_dim_value;
int32_t num_queries;
int32_t num_keys;
bool causal;
int32_t q_strideM;
int32_t k_strideM;
int32_t v_strideM;
// Everything below is only used in `advance_to_block`
// and shouldn't use registers
int32_t q_strideH;
int32_t k_strideH;
int32_t v_strideH;
int32_t o_strideH;
int64_t q_strideB;
int64_t k_strideB;
int64_t v_strideB;
int64_t o_strideB;
int32_t num_batches;
int32_t num_heads;
CUTLASS_HOST_DEVICE int32_t o_strideM() const {
return head_dim_value;
}
// Moves pointers to what we should process
// Returns "false" if there is no work to do
CUTLASS_DEVICE bool advance_to_block() {
auto batch_id = blockIdx.z;
auto head_id = blockIdx.y;
auto query_start = blockIdx.x * kQueriesPerBlock;
auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE;
int64_t q_start, k_start;
// Advance to current batch - in case of different sequence lengths
if (cu_seqlens_q_ptr != nullptr) {
assert(cu_seqlens_k_ptr != nullptr);
cu_seqlens_q_ptr += batch_id;
cu_seqlens_k_ptr += batch_id;
q_start = cu_seqlens_q_ptr[0];
k_start = cu_seqlens_k_ptr[0];
int64_t q_next_start = cu_seqlens_q_ptr[1];
int64_t k_next_start = cu_seqlens_k_ptr[1];
num_queries = q_next_start - q_start;
num_keys = k_next_start - k_start;
if (query_start >= num_queries) {
return false;
}
} else {
query_ptr += batch_id * q_strideB;
key_ptr += batch_id * k_strideB;
value_ptr += batch_id * v_strideB;
output_ptr += batch_id * o_strideB;
if (output_accum_ptr != nullptr) {
output_accum_ptr += batch_id * o_strideB;
}
q_start = 0;
k_start = 0;
}
// Advance to the current batch / head / query_start
query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH;
key_ptr += k_start * k_strideM + head_id * k_strideH;
value_ptr += k_start * v_strideM + head_id * v_strideH;
output_ptr += int64_t(q_start + query_start) * o_strideM() +
head_id * o_strideH;
if (output_accum_ptr != nullptr) {
output_accum_ptr += int64_t(q_start + query_start) * o_strideM() +
head_id * o_strideH;
} else {
// Accumulate directly in the destination buffer (eg for f32)
output_accum_ptr = (accum_t*)output_ptr;
}
if (logsumexp_ptr != nullptr) {
// lse[batch_id, head_id, query_start]
logsumexp_ptr +=
batch_id * lse_dim * num_heads + head_id * lse_dim + query_start;
}
num_queries -= query_start;
if (causal) {
num_keys = cutlass::fast_min(
int32_t(query_start + kQueriesPerBlock), num_keys);
}
num_batches = 0; // no longer used after
// Make sure the compiler knows these variables are the same on all
// the threads of the warp.
query_ptr = warp_uniform(query_ptr);
key_ptr = warp_uniform(key_ptr);
value_ptr = warp_uniform(value_ptr);
output_ptr = warp_uniform(output_ptr);
output_accum_ptr = warp_uniform(output_accum_ptr);
logsumexp_ptr = warp_uniform(logsumexp_ptr);
num_queries = warp_uniform(num_queries);
num_keys = warp_uniform(num_keys);
head_dim = warp_uniform(head_dim);
head_dim_value = warp_uniform(head_dim_value);
return true;
}
__host__ dim3 getBlocksGrid() const {
return dim3(
ceil_div(num_queries, (int32_t)kQueriesPerBlock),
num_heads,
num_batches);
}
__host__ dim3 getThreadsGrid() const {
return dim3(kWarpSize, kNumWarpsPerBlock, 1);
}
};
struct MM0 {
/*
In this first matmul, we compute a block of `Q @ K.T`.
While the calculation result is still hot in registers, we update
`mi`, `m_prime`, `s_prime` in shared-memory, and then store this value
into a shared-memory ("AccumulatorSharedStorage") that is used later as
operand A for the second matmul (see MM1)
*/
using GemmType = DefaultGemmType<ArchTag, scalar_t>;
using OpClass = typename GemmType::OpClass;
using DefaultConfig =
typename cutlass::gemm::device::DefaultGemmConfiguration<
OpClass,
ArchTag,
scalar_t,
scalar_t,
scalar_t, // ElementC
accum_t // ElementAccumulator
>;
static constexpr int kAlignmentA =
kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment;
static constexpr int kAlignmentB =
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment;
using ThreadblockShape = cutlass::gemm::
GemmShape<kQueriesPerBlock, kKeysPerBlock, GemmType::ThreadK>;
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma<
scalar_t, // ElementA,
cutlass::layout::RowMajor, // LayoutA,
kAlignmentA,
scalar_t, // ElementB,
cutlass::layout::ColumnMajor, // LayoutB,
kAlignmentB,
accum_t,
cutlass::layout::RowMajor, // LayoutC,
OpClass,
ArchTag, // ArchTag
ThreadblockShape, // ThreadblockShape
WarpShape, // WarpShape
typename GemmType::InstructionShape, // InstructionShape
DefaultConfig::kStages, // Should use `DefaultConfig::kStages`, but that
// uses too much smem
typename GemmType::Operator // Operator
>::DefaultMma;
using MmaCore = typename DefaultMma::MmaCore;
using IteratorA = typename DefaultMma::IteratorA;
using IteratorB = typename DefaultMma::IteratorB;
using Mma = typename DefaultMma::ThreadblockMma;
using ScalingCoefsUpdater = typename DefaultAttentionScalingCoefsUpdater<
typename Mma::Operator::IteratorC,
accum_t,
kWarpSize>::Updater;
static_assert(
MmaCore::WarpCount::kM * MmaCore::WarpCount::kN *
MmaCore::WarpCount::kK ==
kNumWarpsPerBlock,
"");
// Epilogue to store to shared-memory in a format that we can use later for
// the second matmul
using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm<
typename Mma::Operator::IteratorC,
typename Mma::Operator,
scalar_t,
WarpShape,
ThreadblockShape>;
using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage;
};
struct MM1 {
/**
Second matmul: perform `attn @ V` where `attn` is the attention (not
normalized) and stored in shared memory
*/
using GemmType = DefaultGemmType<ArchTag, scalar_t>;
using OpClass = typename GemmType::OpClass;
using DefaultConfig =
typename cutlass::gemm::device::DefaultGemmConfiguration<
OpClass,
ArchTag,
scalar_t,
scalar_t,
output_accum_t, // ElementC
accum_t // ElementAccumulator
>;
static constexpr int kAlignmentA = DefaultConfig::kAlignmentA; // from smem
static constexpr int kAlignmentB =
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment;
using ThreadblockShape = cutlass::gemm::
GemmShape<kQueriesPerBlock, kKeysPerBlock, GemmType::ThreadK>;
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
using InstructionShape = typename GemmType::InstructionShape;
using LayoutB = cutlass::layout::RowMajor;
using DefaultGemm = cutlass::gemm::kernel::DefaultGemm<
scalar_t, // ElementA,
cutlass::layout::RowMajor, // LayoutA,
kAlignmentA,
scalar_t, // ElementB,
LayoutB, // LayoutB,
kAlignmentB,
output_accum_t,
cutlass::layout::RowMajor, // LayoutC,
accum_t,
OpClass,
ArchTag,
ThreadblockShape,
WarpShape,
typename GemmType::InstructionShape,
typename DefaultConfig::EpilogueOutputOp,
void, // ThreadblockSwizzle - not used
DefaultConfig::kStages,
false, // SplitKSerial
typename GemmType::Operator>;
using DefaultMmaFromSmem =
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
typename DefaultGemm::Mma,
typename MM0::AccumulatorSharedStorage>;
using Mma = typename DefaultMmaFromSmem::Mma;
using IteratorB = typename Mma::IteratorB;
using WarpCount = typename Mma::WarpCount;
static_assert(
WarpCount::kM * WarpCount::kN * WarpCount::kK == kNumWarpsPerBlock,
"");
using DefaultEpilogue = typename DefaultGemm::Epilogue;
using OutputTileIterator =
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
output_t>;
using OutputTileIteratorAccum =
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
output_accum_t>;
struct SharedStorageMM1 {
typename Mma::SharedStorage mm;
};
};
static constexpr int64_t kAlignmentQ = MM0::kAlignmentA;
static constexpr int64_t kAlignmentK = MM0::kAlignmentB;
static constexpr int64_t kAlignmentV = 1;
// Shared storage - depends on kernel params
struct ScalingCoefs {
cutlass::Array<accum_t, kQueriesPerBlock> m_prime;
cutlass::Array<accum_t, kQueriesPerBlock> s_prime;
cutlass::Array<accum_t, kQueriesPerBlock> mi;
};
struct SharedStorageEpilogueAtEnd : ScalingCoefs {
struct SharedStorageAfterMM0 {
// Everything here might be overwritten during MM0
typename MM0::AccumulatorSharedStorage si;
typename MM1::SharedStorageMM1 mm1;
};
union {
typename MM0::Mma::SharedStorage mm0;
SharedStorageAfterMM0 after_mm0;
typename MM1::DefaultEpilogue::SharedStorage epilogue;
};
CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
epilogue_shared_storage() {
return epilogue;
}
};
struct SharedStorageEpilogueInLoop : ScalingCoefs {
struct SharedStorageAfterMM0 {
// Everything here might be overwritten during MM0
typename MM0::AccumulatorSharedStorage si;
typename MM1::SharedStorageMM1 mm1;
typename MM1::DefaultEpilogue::SharedStorage epilogue;
};
union {
typename MM0::Mma::SharedStorage mm0;
SharedStorageAfterMM0 after_mm0;
};
CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
epilogue_shared_storage() {
return after_mm0.epilogue;
}
};
using SharedStorage = typename cutlass::platform::conditional<
kSingleValueIteration || kKeepOutputInRF,
SharedStorageEpilogueAtEnd,
SharedStorageEpilogueInLoop>::type;
static bool __host__ check_supported(Params const& p) {
CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ);
CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK);
CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV);
XFORMERS_CHECK(
p.q_strideM % kAlignmentQ == 0, "query is not correctly aligned");
XFORMERS_CHECK(
p.k_strideM % kAlignmentK == 0, "key is not correctly aligned");
XFORMERS_CHECK(
p.v_strideM % kAlignmentV == 0, "value is not correctly aligned");
XFORMERS_CHECK(
p.q_strideH % kAlignmentQ == 0, "query is not correctly aligned");
XFORMERS_CHECK(
p.k_strideH % kAlignmentK == 0, "key is not correctly aligned");
XFORMERS_CHECK(
p.v_strideH % kAlignmentV == 0, "value is not correctly aligned");
return true;
}
static void CUTLASS_DEVICE attention_kernel(Params& p) {
// In this block, we will only ever:
// - read query[query_start:query_end, :]
// - write to output[query_start:query_end, :]
extern __shared__ char smem_buffer[];
SharedStorage& shared_storage = *((SharedStorage*)smem_buffer);
auto& m_prime = shared_storage.m_prime;
auto& s_prime = shared_storage.s_prime;
auto& si = shared_storage.after_mm0.si;
auto& mi = shared_storage.mi;
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
if (thread_id() < kQueriesPerBlock) {
s_prime[thread_id()] = accum_t(0);
m_prime[thread_id()] =
-cutlass::platform::numeric_limits<accum_t>::infinity();
mi[thread_id()] = -cutlass::platform::numeric_limits<accum_t>::infinity();
}
typename MM1::Mma::FragmentC accum_o;
accum_o.clear();
auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator {
using OutputTileIterator = typename MM1::OutputTileIterator;
return OutputTileIterator(
typename OutputTileIterator::Params{(int32_t)p.o_strideM()},
p.output_ptr,
typename OutputTileIterator::TensorCoord{
p.num_queries, p.head_dim_value},
thread_id(),
{0, col});
};
auto createOutputAccumIter = [&](int col) ->
typename MM1::OutputTileIteratorAccum {
using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum;
return OutputTileIteratorAccum(
typename OutputTileIteratorAccum::Params{(int32_t)p.o_strideM()},
p.output_accum_ptr,
typename OutputTileIteratorAccum::TensorCoord{
p.num_queries, p.head_dim_value},
thread_id(),
{0, col});
};
// Iterate through keys
for (int32_t iter_key_start = 0; iter_key_start < p.num_keys;
iter_key_start += kKeysPerBlock) {
int32_t problem_size_0_m =
cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries);
int32_t problem_size_0_n = cutlass::fast_min(
int32_t(kKeysPerBlock), p.num_keys - iter_key_start);
int32_t const& problem_size_0_k = p.head_dim;
int32_t const& problem_size_1_n = p.head_dim_value;
int32_t const& problem_size_1_k = problem_size_0_n;
auto prologueV = [&](int blockN) {
typename MM1::Mma::IteratorB iterator_V(
typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)},
p.value_ptr + iter_key_start * p.v_strideM,
{problem_size_1_k, problem_size_1_n},
thread_id(),
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
MM1::Mma::prologue(
shared_storage.after_mm0.mm1.mm,
iterator_V,
thread_id(),
problem_size_1_k);
};
__syncthreads(); // Need to have shared memory initialized, and `m_prime`
// updated from end of prev iter
//
// MATMUL: Q.K_t
//
// Computes the block-matrix product of:
// (a) query[query_start:query_end, :]
// with
// (b) key[iter_key_start:iter_key_start + kKeysPerBlock]
// and stores that into `shared_storage.si`
//
// Compute threadblock location
cutlass::gemm::GemmCoord tb_tile_offset = {0, 0, 0};
cutlass::MatrixCoord tb_offset_A{
tb_tile_offset.m() * MM0::Mma::Shape::kM, tb_tile_offset.k()};
cutlass::MatrixCoord tb_offset_B{
tb_tile_offset.k(), tb_tile_offset.n() * MM0::Mma::Shape::kN};
// Construct iterators to A and B operands
typename MM0::IteratorA iterator_A(
typename MM0::IteratorA::Params(
typename MM0::MmaCore::LayoutA(p.q_strideM)),
p.query_ptr,
{problem_size_0_m, problem_size_0_k},
thread_id(),
tb_offset_A);
typename MM0::IteratorB iterator_B(
typename MM0::IteratorB::Params(
typename MM0::MmaCore::LayoutB(p.k_strideM)),
p.key_ptr + iter_key_start * p.k_strideM,
{problem_size_0_k, problem_size_0_n},
thread_id(),
tb_offset_B);
auto my_warp_id = warp_id();
auto my_lane_id = lane_id();
// Construct thread-scoped matrix multiply
typename MM0::Mma mma(
shared_storage.mm0, thread_id(), my_warp_id, my_lane_id);
typename MM0::Mma::FragmentC accum;
accum.clear();
auto gemm_k_iterations =
(problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
__syncthreads();
if (kPreloadV) {
prologueV(0);
}
typename MM0::Mma::Operator::IteratorC::TensorCoord
iteratorC_tile_offset = {
(tb_tile_offset.m() * MM0::Mma::WarpCount::kM) +
(my_warp_id % MM0::Mma::WarpCount::kM),
(tb_tile_offset.n() * MM0::Mma::WarpCount::kN) +
(my_warp_id / MM0::Mma::WarpCount::kM)};
// Mask out last if causal
if (p.causal && p.num_keys - iter_key_start <= kKeysPerBlock) {
auto query_start = blockIdx.x * kQueriesPerBlock;
auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset(
lane_id(), warp_id(), iteratorC_tile_offset);
int32_t last_col;
MM0::ScalingCoefsUpdater::iterateRows(
lane_offset,
[&](int accum_m) {
last_col = query_start + accum_m - iter_key_start;
},
[&](int accum_m, int accum_n, int idx) {
if (accum_n > last_col) {
accum[idx] =
-cutlass::platform::numeric_limits<accum_t>::infinity();
}
},
[&](int accum_m) {});
}
DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
DISPATCH_BOOL(
p.num_keys - iter_key_start >= kKeysPerBlock,
kFullColumns,
([&] {
// Update `mi` from accum stored in registers
// Also updates `accum` with accum[i] <-
// exp(accum[i] * scale
// - mi)
MM0::ScalingCoefsUpdater::update<
kQueriesPerBlock,
kFullColumns,
kIsFirst,
kKeepOutputInRF>(
accum_o,
accum,
mi,
m_prime,
s_prime,
lane_id(),
thread_id(),
warp_id(),
p.num_keys - iter_key_start,
iteratorC_tile_offset,
1.0f / cutlass::fast_sqrt(float(p.head_dim)));
}));
}));
// Output results to shared-memory
int warp_idx_mn_0 = my_warp_id %
(MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN);
auto output_tile_coords = cutlass::MatrixCoord{
warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM,
warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM};
MM0::B2bGemm::accumToSmem(
shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords);
__syncthreads();
//
// MATMUL: Attn . V
// Run the matmul `attn @ V` for a block of attn and V.
// `attn` is read from shared memory (in `shared_storage_si`)
// `V` is read from global memory (with iterator_B)
//
const int64_t nBlockN = kSingleValueIteration
? 1
: ceil_div(
(int64_t)problem_size_1_n, int64_t(MM1::ThreadblockShape::kN));
for (int blockN = 0; blockN < nBlockN; ++blockN) {
int gemm_k_iterations =
(problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add and store it in accum
// (in registers)
if (!kPreloadV) {
__syncthreads(); // we share shmem between mma and epilogue
}
typename MM1::Mma::IteratorB iterator_V(
typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)},
p.value_ptr + iter_key_start * p.v_strideM,
{problem_size_1_k, problem_size_1_n},
thread_id(),
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
typename MM1::Mma mma_pv(
shared_storage.after_mm0.mm1.mm,
shared_storage.after_mm0.si,
(int)thread_id(),
(int)warp_id(),
(int)lane_id(),
(int)problem_size_1_k);
mma_pv.set_prologue_done(kPreloadV);
if (!kKeepOutputInRF) {
accum_o.clear();
}
mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o);
__syncthreads();
if (kPreloadV && !kSingleValueIteration && blockN + 1 < nBlockN) {
prologueV(blockN + 1);
}
if (!kKeepOutputInRF) {
DISPATCH_BOOL(
iter_key_start == 0, kIsFirst, ([&] {
DISPATCH_BOOL(
(iter_key_start + kKeysPerBlock) >= p.num_keys,
kIsLast,
([&] {
using DefaultEpilogue = typename MM1::DefaultEpilogue;
using DefaultOp =
typename MM1::DefaultConfig::EpilogueOutputOp;
using ElementCompute = typename DefaultOp::ElementCompute;
using EpilogueOutputOp = typename cutlass::epilogue::
thread::MemoryEfficientAttentionNormalize<
typename cutlass::platform::conditional<
kIsLast,
output_t,
output_accum_t>::type,
output_accum_t,
DefaultOp::kCount,
typename DefaultOp::ElementAccumulator,
ElementCompute,
kIsFirst,
kIsLast,
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
using Epilogue = typename cutlass::epilogue::threadblock::
EpiloguePipelined<
typename DefaultEpilogue::Shape,
typename MM1::Mma::Operator,
DefaultEpilogue::kPartitionsK,
typename cutlass::platform::conditional<
kIsLast,
typename MM1::OutputTileIterator,
typename MM1::OutputTileIteratorAccum>::type,
typename DefaultEpilogue::
AccumulatorFragmentIterator,
typename DefaultEpilogue::WarpTileIterator,
typename DefaultEpilogue::SharedLoadIterator,
EpilogueOutputOp,
typename DefaultEpilogue::Padding,
DefaultEpilogue::kFragmentsPerIteration,
true, // IterationsUnroll
typename MM1::OutputTileIteratorAccum // Read
// iterator
>;
int col = blockN * MM1::Mma::Shape::kN;
auto source_iter = createOutputAccumIter(col);
auto dest_iter = call_conditional<
kIsLast,
decltype(createOutputIter),
decltype(createOutputAccumIter)>::
apply(createOutputIter, createOutputAccumIter, col);
EpilogueOutputOp rescale(s_prime, m_prime);
Epilogue epilogue(
shared_storage.epilogue_shared_storage(),
thread_id(),
warp_id(),
lane_id());
epilogue(rescale, dest_iter, accum_o, source_iter);
}));
}));
if (!kSingleValueIteration) {
__syncthreads();
}
}
}
__syncthreads(); // we modify `m_prime` after
}
if (kKeepOutputInRF) {
constexpr bool kIsFirst = true;
constexpr bool kIsLast = true;
using DefaultEpilogue = typename MM1::DefaultEpilogue;
using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp;
using ElementCompute = typename DefaultOp::ElementCompute;
using EpilogueOutputOp =
typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize<
output_t, // output
output_accum_t, // source
DefaultOp::kCount,
typename DefaultOp::ElementAccumulator, // accum
output_accum_t, // compute
kIsFirst,
kIsLast,
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
using Epilogue =
typename cutlass::epilogue::threadblock::EpiloguePipelined<
typename DefaultEpilogue::Shape,
typename MM1::Mma::Operator,
DefaultEpilogue::kPartitionsK,
typename MM1::OutputTileIterator, // destination
typename DefaultEpilogue::AccumulatorFragmentIterator,
typename DefaultEpilogue::WarpTileIterator,
typename DefaultEpilogue::SharedLoadIterator,
EpilogueOutputOp,
typename DefaultEpilogue::Padding,
DefaultEpilogue::kFragmentsPerIteration,
true, // IterationsUnroll
typename MM1::OutputTileIteratorAccum // source tile
>;
auto dest_iter = createOutputIter(0);
EpilogueOutputOp rescale(s_prime, m_prime);
Epilogue epilogue(
shared_storage.epilogue_shared_storage(),
thread_id(),
warp_id(),
lane_id());
epilogue(rescale, dest_iter, accum_o);
}
// 7. Calculate logsumexp
// To make the backward easier, we pad logsumexp with `inf`
// this avoids a few bound checks, and is not more expensive during fwd
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) {
auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE;
if (thread_id() < p.num_queries) {
p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()]) +
cutlass::fast_log(accum_t(s_prime[thread_id()]));
} else if (thread_id() < lse_dim) {
p.logsumexp_ptr[thread_id()] =
cutlass::platform::numeric_limits<accum_t>::infinity();
}
}
}
static CUTLASS_DEVICE int8_t lane_id() {
return threadIdx.x;
}
static CUTLASS_DEVICE int8_t warp_id() {
return threadIdx.y;
}
static CUTLASS_DEVICE int16_t thread_id() {
return threadIdx.x + threadIdx.y * blockDim.x;
}
};
template <typename AK>
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
attention_kernel_batched_impl(typename AK::Params p) {
if (!p.advance_to_block()) {
return;
}
AK::attention_kernel(p);
}
template <typename AK>
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
attention_kernel_batched(typename AK::Params params);
#define _ATTENTION_KERNEL_FORWARD_BEGIN(...) \
template <> \
__global__ void __launch_bounds__( \
__VA_ARGS__::kNumThreads, __VA_ARGS__::kMinBlocksPerSm) \
attention_kernel_batched<__VA_ARGS__>(typename __VA_ARGS__::Params p) { \
using Kernel = __VA_ARGS__;
#define _ATTENTION_KERNEL_FORWARD_END() }
#ifdef __CUDA_ARCH__
#define __CUDA_ARCH_OR_ZERO__ __CUDA_ARCH__
#else
#define __CUDA_ARCH_OR_ZERO__ 0
#endif
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD( \
ARCH, \
SCALAR_T, \
IS_ALIGNED, \
QUERIES_PER_BLOCK, \
KEYS_PER_BLOCK, \
SINGLE_VALUE_ITER) \
_ATTENTION_KERNEL_FORWARD_BEGIN(AttentionKernel< \
SCALAR_T, \
cutlass::arch::Sm##ARCH, \
IS_ALIGNED, \
QUERIES_PER_BLOCK, \
KEYS_PER_BLOCK, \
SINGLE_VALUE_ITER>) \
if (!p.advance_to_block()) { \
return; \
} \
Kernel::attention_kernel(p); \
_ATTENTION_KERNEL_FORWARD_END();
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED( \
ARCH, \
SCALAR_T, \
IS_ALIGNED, \
QUERIES_PER_BLOCK, \
KEYS_PER_BLOCK, \
SINGLE_VALUE_ITER) \
_ATTENTION_KERNEL_FORWARD_BEGIN(AttentionKernel< \
SCALAR_T, \
cutlass::arch::Sm##ARCH, \
IS_ALIGNED, \
QUERIES_PER_BLOCK, \
KEYS_PER_BLOCK, \
SINGLE_VALUE_ITER>) \
printf( \
"FATAL: this function is for sm%d, but was built for sm%d\n", \
int(ARCH), \
int(__CUDA_ARCH_OR_ZERO__)); \
_ATTENTION_KERNEL_FORWARD_END();
// All kernels are disabled by default
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(50, __VA_ARGS__)
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(70, __VA_ARGS__)
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(75, __VA_ARGS__)
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(80, __VA_ARGS__)
// Enable the right one based on __CUDA_ARCH__
#ifndef __CUDA_ARCH__
#elif __CUDA_ARCH__ < 500
#error "Need cuda arch at least 5.0"
#elif __CUDA_ARCH__ < 700
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD(50, __VA_ARGS__)
#elif __CUDA_ARCH__ < 750
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD(70, __VA_ARGS__)
#elif __CUDA_ARCH__ < 800
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD(75, __VA_ARGS__)
#elif __CUDA_ARCH__ >= 800
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD(80, __VA_ARGS__)
#endif

File diff suppressed because it is too large Load Diff

View File

@ -120,6 +120,7 @@ foreach(EXAMPLE
38_syr2k_grouped
39_gemm_permute
41_multi_head_attention
42_fused_multi_head_attention
)
add_subdirectory(${EXAMPLE})

View File

@ -574,6 +574,21 @@ using std::is_trivially_copyable;
#endif
//-----------------------------------------------------------------------------
// bit_cast <bit>
//-----------------------------------------------------------------------------
template< class To, class From >
constexpr To CUTLASS_HOST_DEVICE bit_cast(const From& from ) noexcept;
template <class To, class From>
constexpr To CUTLASS_HOST_DEVICE bit_cast(const From& src) noexcept
{
static_assert(sizeof(To) == sizeof(From), "sizes must match");
return reinterpret_cast<To const &>(src);
}
//-----------------------------------------------------------------------------
// Alignment and layout utilities
//-----------------------------------------------------------------------------
@ -865,5 +880,13 @@ struct numeric_limits<uint8_t> {
static constexpr bool is_integer = true;
};
template <>
struct numeric_limits<float> {
CUTLASS_HOST_DEVICE
static constexpr float infinity() noexcept { return bit_cast<float, int32_t>(0x7f800000);}
static constexpr bool is_integer = false;
static constexpr bool has_infinity = true;
};
} // namespace platform
} // namespace cutlass