ex42: Fused MHA imported from xFormers (#662)
* ex42: Fused MHA imported from xFormers
* Remove std:: references
* Support K>128 in the example
* Support causal option
* Support different head size for V, and different seqlength for KV
* Update FLOPS counter
* Remove bit_cast
* fix build: Replace M_LOG2E
* Add doc
* Revert "Remove bit_cast"
This reverts commit 9662fa86bb.
* Explicit casts to int32_t for windows build
Co-authored-by: danthe3rd <danthe3rd>
This commit is contained in:
36
examples/42_fused_multi_head_attention/CMakeLists.txt
Normal file
36
examples/42_fused_multi_head_attention/CMakeLists.txt
Normal file
@ -0,0 +1,36 @@
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
42_fused_multi_head_attention
|
||||
fused_multihead_attention.cu
|
||||
)
|
||||
|
||||
@ -0,0 +1,482 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/gemm/warp/mma_simt_tile_iterator.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "gemm_kernel_utils.h"
|
||||
|
||||
namespace {
|
||||
|
||||
static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) {
|
||||
// source: https://stackoverflow.com/a/51549250
|
||||
return (value >= 0)
|
||||
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
|
||||
: __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/* Iterates on the accumulator and corresponding position on result matrix
|
||||
|
||||
(1) Update `mi[r]` to the max value of the row `r`
|
||||
(2) In a second iteration do the following:
|
||||
(a) accum <- exp(accum - mi)
|
||||
(b) m_prime <- exp(m_prime - mi)
|
||||
(c) s_prime <- s_prime * m_prime + sum(accum)
|
||||
|
||||
All of this is done on registers, before we store all of this
|
||||
on shared memory for the next matmul with Value.
|
||||
|
||||
We have multiple implementations, because each configuration has a different way
|
||||
of iterating in the accumulators.
|
||||
*/
|
||||
|
||||
template <typename BASE, typename T, typename accum_t, int kWarpSize>
|
||||
struct RegisterOps {
|
||||
template <
|
||||
int kQueriesPerBlock,
|
||||
bool kFullColumns,
|
||||
bool kIsFirst,
|
||||
bool kKeepOutputInRF>
|
||||
CUTLASS_DEVICE static void update(
|
||||
typename T::Fragment& frag_o, // output so far
|
||||
typename T::Fragment& frag,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock>& mi,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
|
||||
int8_t lane_id,
|
||||
int8_t thread_id,
|
||||
int8_t warp_id,
|
||||
int16_t max_col,
|
||||
typename T::TensorCoord const& tile_offset,
|
||||
float scaling) {
|
||||
// Convert to `accum_t` (rather than double)
|
||||
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
|
||||
if (!kIsFirst) {
|
||||
if (thread_id < kQueriesPerBlock) {
|
||||
m_prime[thread_id] = mi[thread_id];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
auto lane_offset = BASE::get_lane_offset(lane_id, warp_id, tile_offset);
|
||||
|
||||
// First update `mi` to the max per-row
|
||||
{
|
||||
accum_t max;
|
||||
BASE::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) {
|
||||
max = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
},
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
if (kFullColumns || accum_n < max_col) {
|
||||
max = cutlass::fast_max(max, frag[idx]);
|
||||
}
|
||||
},
|
||||
[&](int accum_m) {
|
||||
// Having 4x atomicMax seems faster than reduce within warp
|
||||
// first...
|
||||
atomicMaxFloat(&mi[accum_m], max * scaling);
|
||||
});
|
||||
}
|
||||
frag = cutlass::multiplies<typename T::Fragment>()(scaling * kLog2e, frag);
|
||||
|
||||
// Make sure we all share the update values for `mi`
|
||||
__syncthreads();
|
||||
|
||||
if (thread_id < kQueriesPerBlock) {
|
||||
auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id]));
|
||||
m_prime[thread_id] = m_prime_exp;
|
||||
s_prime[thread_id] *= m_prime_exp;
|
||||
}
|
||||
__syncthreads(); // Update output fragments
|
||||
if (kKeepOutputInRF && !kIsFirst) {
|
||||
accum_t mp;
|
||||
BASE::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) { mp = m_prime[accum_m]; },
|
||||
[&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; },
|
||||
[&](int accum_m) {});
|
||||
__syncthreads();
|
||||
}
|
||||
// Update accum_m, accum_n, ...
|
||||
{
|
||||
accum_t mi_row, total_row;
|
||||
BASE::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) { mi_row = kLog2e * mi[accum_m]; },
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
frag[idx] = (kFullColumns || accum_n < max_col)
|
||||
? exp2f(frag[idx] - mi_row)
|
||||
: accum_t(0.0);
|
||||
},
|
||||
[&](int accum_m) {});
|
||||
BASE::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) { total_row = 0.0; },
|
||||
[&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; },
|
||||
[&](int accum_m) {
|
||||
if (BASE::reduceSameRow(
|
||||
lane_id, total_row, [](accum_t a, accum_t b) {
|
||||
return a + b;
|
||||
})) {
|
||||
atomicAdd(&s_prime[accum_m], total_row);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename accum_t, int kWarpSize>
|
||||
struct AttentionScalingCoefsUpdaterSm80
|
||||
: RegisterOps<
|
||||
AttentionScalingCoefsUpdaterSm80<T, accum_t, kWarpSize>,
|
||||
T,
|
||||
accum_t,
|
||||
kWarpSize> {
|
||||
static_assert(
|
||||
cutlass::platform::
|
||||
is_same<typename T::Layout, cutlass::layout::RowMajor>::value,
|
||||
"only RowMajor is supported");
|
||||
|
||||
using Policy = typename T::Policy;
|
||||
using InstructionShape = typename T::InstructionShape;
|
||||
using OpDelta = typename T::OpDelta;
|
||||
using Shape = typename T::Shape;
|
||||
static int const kElementsPerAccess = InstructionShape::kN / 4;
|
||||
static int const kRowsPerTile = 8;
|
||||
static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile;
|
||||
|
||||
static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset(
|
||||
int8_t lane_id,
|
||||
int8_t warp_id,
|
||||
typename T::TensorCoord const& tile_offset) {
|
||||
int quad = (lane_id >> 2);
|
||||
int lane_in_quad = (lane_id & 3);
|
||||
return cutlass::MatrixCoord(
|
||||
quad + tile_offset.row() * Shape::kRow,
|
||||
lane_in_quad * kElementsPerAccess +
|
||||
tile_offset.column() * Shape::kColumn);
|
||||
}
|
||||
|
||||
template <typename FA, typename FB, typename FC>
|
||||
CUTLASS_DEVICE static void iterateRows(
|
||||
cutlass::MatrixCoord& lane_offset,
|
||||
FA beginRow,
|
||||
FB op,
|
||||
FC endRow) {
|
||||
// See cutlass/gemm/warp/mma_tensor_op_tile_iterator.h
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int row = 0; row < kAccumulatorRows; ++row) {
|
||||
int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow +
|
||||
row * kRowsPerTile + lane_offset.row();
|
||||
beginRow(accum_m);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) {
|
||||
int mma_accum_start = kAccumulatorRows * kElementsPerAccess *
|
||||
(mma_n * Policy::MmaIterations::kRow + mma_m);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int col = 0; col < kElementsPerAccess; ++col) {
|
||||
int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn +
|
||||
col + lane_offset.column();
|
||||
int idx = mma_accum_start + row * kElementsPerAccess + col;
|
||||
op(accum_m, accum_n, idx);
|
||||
}
|
||||
}
|
||||
|
||||
endRow(accum_m);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DT, typename F>
|
||||
CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) {
|
||||
// In each warp, 4 threads will work on the same row
|
||||
// - the ones with the same `quad`
|
||||
auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1);
|
||||
myValue = fn(myValue, otherV);
|
||||
otherV = __shfl_xor_sync(0xffffffff, myValue, 2);
|
||||
myValue = fn(myValue, otherV);
|
||||
int lane_in_quad = (lane_id & 3);
|
||||
return lane_in_quad == 0;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename accum_t, int kWarpSize>
|
||||
struct AttentionScalingCoefsUpdaterVolta
|
||||
: RegisterOps<
|
||||
AttentionScalingCoefsUpdaterVolta<T, accum_t, kWarpSize>,
|
||||
T,
|
||||
accum_t,
|
||||
kWarpSize> {
|
||||
static_assert(
|
||||
cutlass::platform::
|
||||
is_same<typename T::Layout, cutlass::layout::RowMajor>::value,
|
||||
"only RowMajor is supported");
|
||||
|
||||
using Policy = typename T::Policy;
|
||||
using InstructionShape = typename T::InstructionShape;
|
||||
using OpDelta = typename T::OpDelta;
|
||||
using Shape = typename T::Shape;
|
||||
using Element = accum_t;
|
||||
|
||||
static int const kElementsPerPartial = 4;
|
||||
using EleShapePerPatial = typename cutlass::platform::conditional<
|
||||
cutlass::platform::is_same<Element, float>::value,
|
||||
cutlass::MatrixShape<2, 2>,
|
||||
cutlass::MatrixShape<1, 4>>::type;
|
||||
static int const kElementsPerMma = 8;
|
||||
static int const kAccumulatorPatials = 2;
|
||||
using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>;
|
||||
|
||||
static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset(
|
||||
int8_t lane_id,
|
||||
int8_t warp_id,
|
||||
typename T::TensorCoord const& tile_offset) {
|
||||
int quad = (lane_id >> 2);
|
||||
int lane_in_quad = (lane_id & 3);
|
||||
int accum_m, accum_n;
|
||||
|
||||
if (cutlass::platform::is_same<Element, float>::value) {
|
||||
// (quad[2],quad[0])+lane_in_quad[0]
|
||||
accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1);
|
||||
// (quad[1])+lane_in_quad[1]
|
||||
accum_n =
|
||||
((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials +
|
||||
(lane_in_quad & 2);
|
||||
} else {
|
||||
accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 +
|
||||
lane_in_quad; // (quad[2],quad[0])
|
||||
accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials;
|
||||
}
|
||||
return cutlass::MatrixCoord(
|
||||
accum_m + tile_offset.row() * Shape::kRow,
|
||||
accum_n + tile_offset.column() * Shape::kColumn);
|
||||
}
|
||||
|
||||
template <typename DT, typename F>
|
||||
CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) {
|
||||
static_assert(
|
||||
cutlass::platform::is_same<Element, float>::value,
|
||||
"update to support non-float accum");
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
|
||||
// T0 & T2 share same line within a quad
|
||||
auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 1);
|
||||
myValue = fn(myValue, otherV);
|
||||
// quad 0 and quad 2 are on the same lines
|
||||
otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 3);
|
||||
myValue = fn(myValue, otherV);
|
||||
return (lane_id & ((1 << 1) | (1 << 3))) == 0;
|
||||
}
|
||||
|
||||
template <typename FA, typename FB, typename FC>
|
||||
CUTLASS_DEVICE static void iterateRows(
|
||||
cutlass::MatrixCoord& lane_offset,
|
||||
FA beginRow,
|
||||
FB op,
|
||||
FC endRow) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < EleShapePerPatial::kRow; ++m) {
|
||||
int accum_m = tile_m * Policy::InterleavedTile::kRow +
|
||||
mma_m * QuadShapePerPatialMma::kRow + m * 2 + lane_offset.row();
|
||||
beginRow(accum_m);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn;
|
||||
++tile_n) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn;
|
||||
++mma_n) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int p = 0; p < kAccumulatorPatials; ++p) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < EleShapePerPatial::kColumn; ++n) {
|
||||
int mma_accum_start =
|
||||
(((tile_n * Policy::TileIterations::kRow + tile_m) *
|
||||
Policy::MmaIterations::kColumn +
|
||||
mma_n) *
|
||||
Policy::MmaIterations::kRow +
|
||||
mma_m) *
|
||||
kElementsPerMma;
|
||||
int accum_n = tile_n * Policy::InterleavedTile::kColumn +
|
||||
mma_n * QuadShapePerPatialMma::kColumn +
|
||||
p * Policy::InterleavedTile::kColumn / 2 + n +
|
||||
lane_offset.column();
|
||||
int idx = mma_accum_start + p * kElementsPerPartial +
|
||||
m * EleShapePerPatial::kColumn + n;
|
||||
op(accum_m, accum_n, idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
endRow(accum_m);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename accum_t, int kWarpSize>
|
||||
struct AttentionScalingCoefsUpdaterSimt
|
||||
: RegisterOps<
|
||||
AttentionScalingCoefsUpdaterSimt<T, accum_t, kWarpSize>,
|
||||
T,
|
||||
accum_t,
|
||||
kWarpSize> {
|
||||
using Policy = typename T::Policy;
|
||||
using Iterations = typename T::Iterations;
|
||||
using Element = typename T::Element;
|
||||
using Delta = typename T::Delta;
|
||||
using Shape = typename T::Shape;
|
||||
static_assert(
|
||||
cutlass::platform::
|
||||
is_same<typename T::Layout, cutlass::layout::RowMajor>::value,
|
||||
"only RowMajor is supported");
|
||||
|
||||
template <typename DT, typename F>
|
||||
CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int bit = 1; bit < Policy::WarpShape::kColumn; bit *= 2) {
|
||||
auto otherV = __shfl_xor_sync(0xffffffff, myValue, bit);
|
||||
myValue = fn(myValue, otherV);
|
||||
}
|
||||
return (lane_id & (Policy::WarpShape::kColumn - 1)) == 0;
|
||||
}
|
||||
|
||||
template <typename FA, typename FB, typename FC>
|
||||
CUTLASS_DEVICE static void iterateRows(
|
||||
cutlass::MatrixCoord& lane_offset,
|
||||
FA beginRow,
|
||||
FB op,
|
||||
FC endRow) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) {
|
||||
int accum_m = mma_m * Delta::kRow + m + lane_offset.row();
|
||||
beginRow(accum_m);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) {
|
||||
int accum_n =
|
||||
mma_n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN +
|
||||
lane_offset.column();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) {
|
||||
int idx = n +
|
||||
Policy::LaneMmaShape::kN *
|
||||
(mma_n +
|
||||
Iterations::kColumn *
|
||||
(m + mma_m * Policy::LaneMmaShape::kM));
|
||||
op(accum_m, accum_n + n, idx);
|
||||
}
|
||||
}
|
||||
endRow(accum_m);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset(
|
||||
int8_t lane_id,
|
||||
int8_t warp_id,
|
||||
typename T::TensorCoord const& tile_offset) {
|
||||
static_assert(
|
||||
cutlass::platform::is_same<
|
||||
typename Policy::LaneLayout,
|
||||
cutlass::layout::RowMajorInterleaved<1>>::value,
|
||||
"");
|
||||
typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
|
||||
|
||||
cutlass::MatrixCoord lane_offset = lane_layout.inverse(lane_id) *
|
||||
cutlass::MatrixCoord(Policy::LaneMmaShape::kM,
|
||||
Policy::LaneMmaShape::kN);
|
||||
return lane_offset +
|
||||
tile_offset * cutlass::MatrixCoord(Shape::kRow, Shape::kColumn);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename accum_t, int kWarpSize>
|
||||
struct DefaultAttentionScalingCoefsUpdater;
|
||||
|
||||
// Simt
|
||||
template <typename S, typename P, typename accum_t, int kWarpSize>
|
||||
struct DefaultAttentionScalingCoefsUpdater<
|
||||
cutlass::gemm::warp::MmaSimtTileIterator<
|
||||
S,
|
||||
cutlass::gemm::Operand::kC,
|
||||
accum_t,
|
||||
cutlass::layout::RowMajor,
|
||||
P,
|
||||
1,
|
||||
1>,
|
||||
accum_t,
|
||||
kWarpSize> {
|
||||
using Iterator = typename cutlass::gemm::warp::MmaSimtTileIterator<
|
||||
S,
|
||||
cutlass::gemm::Operand::kC,
|
||||
accum_t,
|
||||
cutlass::layout::RowMajor,
|
||||
P,
|
||||
1,
|
||||
1>;
|
||||
using Updater =
|
||||
AttentionScalingCoefsUpdaterSimt<Iterator, accum_t, kWarpSize>;
|
||||
};
|
||||
|
||||
// TensorOp - Volta
|
||||
template <typename S1, typename S2, typename accum_t, int kWarpSize>
|
||||
struct DefaultAttentionScalingCoefsUpdater<
|
||||
cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator<
|
||||
S1,
|
||||
accum_t,
|
||||
cutlass::layout::RowMajor,
|
||||
S2,
|
||||
cutlass::MatrixShape<1, 1>>,
|
||||
accum_t,
|
||||
kWarpSize> {
|
||||
using Iterator =
|
||||
typename cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator<
|
||||
S1,
|
||||
accum_t,
|
||||
cutlass::layout::RowMajor,
|
||||
S2,
|
||||
cutlass::MatrixShape<1, 1>>;
|
||||
using Updater =
|
||||
AttentionScalingCoefsUpdaterVolta<Iterator, accum_t, kWarpSize>;
|
||||
};
|
||||
|
||||
// TensorOp - Sm75+
|
||||
template <
|
||||
typename S1,
|
||||
typename S2,
|
||||
typename S3,
|
||||
typename accum_t,
|
||||
int kWarpSize>
|
||||
struct DefaultAttentionScalingCoefsUpdater<
|
||||
cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator<
|
||||
S1,
|
||||
accum_t,
|
||||
cutlass::layout::RowMajor,
|
||||
S2,
|
||||
S3>,
|
||||
accum_t,
|
||||
kWarpSize> {
|
||||
using Iterator =
|
||||
typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator<
|
||||
S1,
|
||||
accum_t,
|
||||
cutlass::layout::RowMajor,
|
||||
S2,
|
||||
S3>;
|
||||
using Updater =
|
||||
AttentionScalingCoefsUpdaterSm80<Iterator, accum_t, kWarpSize>;
|
||||
};
|
||||
128
examples/42_fused_multi_head_attention/debug_utils.h
Normal file
128
examples/42_fused_multi_head_attention/debug_utils.h
Normal file
@ -0,0 +1,128 @@
|
||||
#pragma once
|
||||
#include <float.h>
|
||||
#include <stdio.h>
|
||||
#include <cmath>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Debugging functions
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Nans & inf detection
|
||||
#define NANCHECK(frag) \
|
||||
{ \
|
||||
for (int _i = 0; _i < frag.size(); ++_i) { \
|
||||
assert(std::isfinite(float(frag[_i]))); \
|
||||
assert(!std::isnan(float(frag[_i]))); \
|
||||
} \
|
||||
}
|
||||
|
||||
// Print on the first thread of the first block
|
||||
#if 0
|
||||
#define PRINT_WARP_ID 0
|
||||
#define PRINT_LANE_ID 0
|
||||
#define PRINT_T0_L0(msg, ...) \
|
||||
if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && \
|
||||
threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \
|
||||
threadIdx.z == 0) { \
|
||||
printf(msg "\n", __VA_ARGS__); \
|
||||
}
|
||||
struct __string_view {
|
||||
char const* data;
|
||||
std::size_t size;
|
||||
};
|
||||
template <class T>
|
||||
constexpr __string_view __get_type_name() {
|
||||
char const* p = __PRETTY_FUNCTION__;
|
||||
while (*p++ != '=')
|
||||
;
|
||||
for (; *p == ' '; ++p)
|
||||
;
|
||||
char const* p2 = p;
|
||||
int count = 1;
|
||||
for (;; ++p2) {
|
||||
switch (*p2) {
|
||||
case '[':
|
||||
++count;
|
||||
break;
|
||||
case ']':
|
||||
--count;
|
||||
if (!count)
|
||||
return {p, std::size_t(p2 - p)};
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
#else
|
||||
#define PRINT_T0_L0
|
||||
#endif
|
||||
|
||||
// Print a given array
|
||||
#define PRINT_ACCUM8_T0_L0_START(name, accum, start) \
|
||||
PRINT_T0_L0( \
|
||||
"%s[%d:%d] - {%f, %f, %f, %f, %f, %f, %f, %f}", \
|
||||
name, \
|
||||
int(start), \
|
||||
int(start + 8), \
|
||||
float(accum[start + 0]), \
|
||||
float(accum[start + 1]), \
|
||||
float(accum[start + 2]), \
|
||||
float(accum[start + 3]), \
|
||||
float(accum[start + 4]), \
|
||||
float(accum[start + 5]), \
|
||||
float(accum[start + 6]), \
|
||||
float(accum[start + 7]));
|
||||
#define PRINT_ACCUM8_T0_L0(name, accum) PRINT_ACCUM8_T0_L0_START(name, accum, 0)
|
||||
#define PRINT_FRAG_T0_L0(name, frag) \
|
||||
{ \
|
||||
auto typeStr = __get_type_name<decltype(frag)>(); \
|
||||
PRINT_T0_L0("printing %s (%s)", name, typeStr.data); \
|
||||
for (int _start = 0; _start < frag.size(); _start += 8) { \
|
||||
PRINT_ACCUM8_T0_L0_START(" ", frag, _start); \
|
||||
} \
|
||||
/*__syncthreads(); \
|
||||
NANCHECK(frag); */ \
|
||||
}
|
||||
#define PRINT_ARRAY_T0_L0_INCR(name, array, length, incr) \
|
||||
{ \
|
||||
PRINT_T0_L0("printing %s (len=%d)", name, int(length)); \
|
||||
for (int _start = 0; _start < length; _start += incr) { \
|
||||
PRINT_ACCUM8_T0_L0_START(" ", array, _start); \
|
||||
} \
|
||||
}
|
||||
#define PRINT_ARRAY_T0_L0(name, array, length) \
|
||||
PRINT_ARRAY_T0_L0_INCR(name, array, length, 8)
|
||||
|
||||
// Print a 4x4 matrix
|
||||
#define PRINT_TENSOR4x4_T0_L0_START(name, ref, start_x, start_y) \
|
||||
PRINT_T0_L0( \
|
||||
"%s[%d:%d, %d:%d]:\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f", \
|
||||
name, \
|
||||
int(start_x), \
|
||||
int(start_x + 4), \
|
||||
int(start_y), \
|
||||
int(start_y + 4), \
|
||||
float(ref.at({start_x + 0, start_y + 0})), \
|
||||
float(ref.at({start_x + 0, start_y + 1})), \
|
||||
float(ref.at({start_x + 0, start_y + 2})), \
|
||||
float(ref.at({start_x + 0, start_y + 3})), \
|
||||
float(ref.at({start_x + 1, start_y + 0})), \
|
||||
float(ref.at({start_x + 1, start_y + 1})), \
|
||||
float(ref.at({start_x + 1, start_y + 2})), \
|
||||
float(ref.at({start_x + 1, start_y + 3})), \
|
||||
float(ref.at({start_x + 2, start_y + 0})), \
|
||||
float(ref.at({start_x + 2, start_y + 1})), \
|
||||
float(ref.at({start_x + 2, start_y + 2})), \
|
||||
float(ref.at({start_x + 2, start_y + 3})), \
|
||||
float(ref.at({start_x + 3, start_y + 0})), \
|
||||
float(ref.at({start_x + 3, start_y + 1})), \
|
||||
float(ref.at({start_x + 3, start_y + 2})), \
|
||||
float(ref.at({start_x + 3, start_y + 3})));
|
||||
#define PRINT_TENSOR4x4_T0_L0(name, ref) \
|
||||
PRINT_TENSOR4x4_T0_L0_START(name, ref, 0, 0)
|
||||
|
||||
#define PRINT_PROBLEM_SIZE(name, ps) \
|
||||
PRINT_T0_L0( \
|
||||
"%s.problem_size: {.m=%d, .n=%d, .k=%d}", \
|
||||
name, \
|
||||
int(ps.m()), \
|
||||
int(ps.n()), \
|
||||
int(ps.k()))
|
||||
632
examples/42_fused_multi_head_attention/epilogue_pipelined.h
Normal file
632
examples/42_fused_multi_head_attention/epilogue_pipelined.h
Normal file
@ -0,0 +1,632 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
||||
|
||||
File copied from "cutlass/epilogue/threadblock/epilogue.h"
|
||||
then modified to:
|
||||
(1) load 2 source fragments at the same time (pipelining)
|
||||
(2) support reading from a different dtype
|
||||
(3) pass the row id to the OutputOp if it takes it
|
||||
(see MemoryEfficientAttentionNormalize)
|
||||
Note that in general the fragment passed to the OutputOp could
|
||||
span multiple rows but it does not happen with the configurations we have
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#endif
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/vector.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_coord.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/transform/pitch_linear_thread_map.h"
|
||||
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue_base.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
template <typename Op>
|
||||
struct ApplyEpilogueOp {
|
||||
static CUTLASS_DEVICE typename Op::FragmentOutput apply(
|
||||
Op const& output_op,
|
||||
int row_id,
|
||||
typename Op::FragmentAccumulator const& accum,
|
||||
typename Op::FragmentOutput const& source) {
|
||||
return output_op(accum, source);
|
||||
}
|
||||
static CUTLASS_DEVICE typename Op::FragmentOutput apply(
|
||||
Op const& output_op,
|
||||
int row_id,
|
||||
typename Op::FragmentAccumulator const& accum) {
|
||||
return output_op(accum);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue operator
|
||||
template <
|
||||
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
|
||||
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept:
|
||||
///< gemm::warp::MmaTensorOp)
|
||||
int PartitionsK, ///< Number of partitions of the K dimension
|
||||
typename OutputTileIterator_, ///< Tile iterator writing output tensors
|
||||
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting
|
||||
///< accumulators
|
||||
typename WarpTileIterator_, ///< Warp-scoped tile iterator writing
|
||||
///< accumulators to SMEM
|
||||
typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading
|
||||
///< from SMEM
|
||||
typename OutputOp_, ///< Output operator
|
||||
typename Padding_, ///< Padding added to SMEM allocation to avoid bank
|
||||
///< conflicts (concept: MatrixShape)
|
||||
int FragmentsPerPartition =
|
||||
1, ///< Used to coarsten the epilogue granularity
|
||||
int IterationsUnroll = ///< Used to reduce binary size when epilogue op is
|
||||
///< large
|
||||
(!IsEpilogueFunctorHeavy<OutputOp_>::value),
|
||||
typename OutputTileSourceIterator_ =
|
||||
OutputTileIterator_ ///< Tile iterator reading tensors
|
||||
>
|
||||
class EpiloguePipelined : public EpilogueBase<
|
||||
Shape_,
|
||||
typename WarpMmaOperator_::Shape,
|
||||
PartitionsK,
|
||||
AccumulatorFragmentIterator_,
|
||||
WarpTileIterator_,
|
||||
Padding_,
|
||||
FragmentsPerPartition> {
|
||||
public:
|
||||
using Base = EpilogueBase<
|
||||
Shape_,
|
||||
typename WarpMmaOperator_::Shape,
|
||||
PartitionsK,
|
||||
AccumulatorFragmentIterator_,
|
||||
WarpTileIterator_,
|
||||
Padding_,
|
||||
FragmentsPerPartition>;
|
||||
|
||||
using Shape = Shape_;
|
||||
using WarpMmaOperator = WarpMmaOperator_;
|
||||
static int const kPartitionsK = PartitionsK;
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using OutputTileSourceIterator = OutputTileSourceIterator_;
|
||||
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
|
||||
using WarpTileIterator = WarpTileIterator_;
|
||||
using SharedLoadIterator = SharedLoadIterator_;
|
||||
using OutputOp = OutputOp_;
|
||||
using Padding = Padding_;
|
||||
|
||||
using Layout = layout::RowMajor;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
|
||||
/// The complete warp-level accumulator tile
|
||||
using AccumulatorTile = typename Base::AccumulatorTile;
|
||||
|
||||
/// Accumulator element
|
||||
using ElementAccumulator = typename WarpTileIterator::Element;
|
||||
|
||||
/// Output element
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
using ElementSource = typename OutputTileSourceIterator::Element;
|
||||
|
||||
/// Output access size
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
/// Tensor reference to destination tensor
|
||||
using TensorRef = typename OutputTileIterator::TensorRef;
|
||||
|
||||
/// Tensor reference to sync tensor
|
||||
using SyncTensorRef =
|
||||
typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
|
||||
|
||||
/// Const tensor reference to source tensor
|
||||
using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
|
||||
|
||||
/// Array type used to output
|
||||
using OutputAccessType = Array<
|
||||
typename OutputTileIterator::Element,
|
||||
OutputTileIterator::kElementsPerAccess>;
|
||||
using SourceAccessType = Array<
|
||||
typename OutputTileSourceIterator::Element,
|
||||
OutputTileSourceIterator::kElementsPerAccess>;
|
||||
|
||||
/// Array type used by output functor
|
||||
using AccumulatorAccessType = Array<
|
||||
typename WarpTileIterator::Element,
|
||||
OutputTileIterator::kElementsPerAccess>;
|
||||
|
||||
/// Number of warps
|
||||
using WarpCount = typename Base::WarpCount;
|
||||
|
||||
static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1
|
||||
? Base::kFragmentsPerIteration
|
||||
: kPartitionsK;
|
||||
static int constexpr kSmemPointerOffset =
|
||||
Base::SharedStorage::StorageShape::kCount / kSmemTiles;
|
||||
|
||||
public:
|
||||
static_assert(
|
||||
OutputTileSourceIterator::Fragment::kElements ==
|
||||
OutputTileIterator::Fragment::kElements,
|
||||
"Mismatch between input tile and output tile iterator (kElements)");
|
||||
static_assert(
|
||||
OutputTileSourceIterator::kIterations == OutputTileIterator::kIterations,
|
||||
"Mismatch between input tile and output tile iterator (kIterations)");
|
||||
static_assert(
|
||||
SharedLoadIterator::Fragment::kElements ==
|
||||
OutputTileIterator::Fragment::kElements,
|
||||
"Mismatch between shared load iterator and output tile iterator.");
|
||||
|
||||
static_assert(
|
||||
OutputTileIterator::kElementsPerAccess,
|
||||
"OutputTileIterator::kElementsPerAccess must not be zero.");
|
||||
|
||||
static_assert(
|
||||
!(OutputTileIterator::Fragment::kElements %
|
||||
OutputTileIterator::kElementsPerAccess),
|
||||
"Divisibility");
|
||||
|
||||
private:
|
||||
/// Loads fragment from shared memory aligned with output tensor
|
||||
SharedLoadIterator shared_load_iterator_;
|
||||
|
||||
public:
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
EpiloguePipelined(
|
||||
typename Base::SharedStorage& shared_storage, ///< Shared storage object
|
||||
int thread_idx, ///< ID of a thread within the threadblock
|
||||
int warp_idx, ///< ID of warp within threadblock
|
||||
int lane_idx ///< Id of thread within warp
|
||||
)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
shared_load_iterator_(shared_storage.reference(), thread_idx) {}
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
OutputOp const& output_op, ///< Output operator
|
||||
OutputTileIterator
|
||||
destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const&
|
||||
accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileSourceIterator
|
||||
source_iterator) { ///< Threadblock tile coordinate in GEMM (in units
|
||||
///< of threadblock tiles)
|
||||
|
||||
if (!output_op.is_source_needed()) {
|
||||
compute_source_not_needed_(output_op, destination_iterator, accumulators);
|
||||
} else {
|
||||
compute_source_needed_(
|
||||
output_op, destination_iterator, accumulators, source_iterator);
|
||||
}
|
||||
}
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
OutputOp const& output_op, ///< Output operator
|
||||
OutputTileIterator
|
||||
destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const&
|
||||
accumulators) { ///< Complete warp-level accumulator tile
|
||||
compute_source_not_needed_(output_op, destination_iterator, accumulators);
|
||||
}
|
||||
|
||||
private:
|
||||
template <class Seq>
|
||||
struct acc2smem_source_not_needed;
|
||||
|
||||
template <size_t... Seq>
|
||||
struct acc2smem_source_not_needed<cutlass::index_sequence<Seq...>> {
|
||||
template <int Advance>
|
||||
CUTLASS_DEVICE static void helper(
|
||||
AccumulatorFragmentIterator accum_fragment_iterator,
|
||||
WarpTileIterator& warp_tile_iterator) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Advance; i++) {
|
||||
++accum_fragment_iterator;
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
|
||||
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
||||
|
||||
accum_fragment_iterator.load(accum_fragment);
|
||||
++accum_fragment_iterator;
|
||||
|
||||
warp_tile_iterator.store(accum_fragment);
|
||||
if (p < Base::kFragmentsPerIteration - 1) {
|
||||
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset);
|
||||
}
|
||||
}
|
||||
|
||||
if (Base::kFragmentsPerIteration > 1) {
|
||||
warp_tile_iterator.add_pointer_offset(
|
||||
kSmemPointerOffset * (1 - Base::kFragmentsPerIteration));
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void push(
|
||||
size_t pos,
|
||||
AccumulatorFragmentIterator const& iterator_begin,
|
||||
WarpTileIterator& warp_tile_iterator) {
|
||||
int dummy[] = {
|
||||
(pos == (Seq * Base::kFragmentsPerIteration)) &&
|
||||
(helper<Seq * Base::kFragmentsPerIteration>(
|
||||
iterator_begin, warp_tile_iterator),
|
||||
0)...};
|
||||
|
||||
CUTLASS_UNUSED(dummy[0]);
|
||||
}
|
||||
};
|
||||
|
||||
static_assert(
|
||||
kPartitionsK == 1 || Base::kFragmentsPerIteration == 1,
|
||||
"One of these must be exactly 1.");
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void compute_source_not_needed_(
|
||||
OutputOp const& output_op, ///< Output operator
|
||||
OutputTileIterator
|
||||
destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const&
|
||||
accumulators ///< Complete warp-level accumulator tile
|
||||
) {
|
||||
//
|
||||
// Iterator over warp-level accumulator fragment
|
||||
//
|
||||
|
||||
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
||||
|
||||
//
|
||||
// Iterate over accumulator tile
|
||||
//
|
||||
|
||||
#pragma unroll( \
|
||||
IterationsUnroll \
|
||||
? OutputTileIterator::kIterations / Base::kFragmentsPerIteration \
|
||||
: 1)
|
||||
for (int iter = 0; iter < OutputTileIterator::kIterations;
|
||||
iter += Base::kFragmentsPerIteration) {
|
||||
//
|
||||
// Convert and store fragment
|
||||
//
|
||||
|
||||
__syncthreads();
|
||||
|
||||
acc2smem_source_not_needed<cutlass::make_index_sequence<
|
||||
OutputTileIterator::kIterations / Base::kFragmentsPerIteration>>::
|
||||
push(iter, accum_fragment_iterator, this->warp_tile_iterator_);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
// Load fragments from shared memory
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
|
||||
typename SharedLoadIterator::Fragment
|
||||
aligned_accum_fragment[kPartitionsK];
|
||||
|
||||
shared_load_iterator_.load(aligned_accum_fragment[0]);
|
||||
|
||||
if (p < Base::kFragmentsPerIteration - 1) {
|
||||
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
|
||||
} else if (kPartitionsK > 1) {
|
||||
plus<typename SharedLoadIterator::Fragment> add_fragments;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 1; i < kPartitionsK; ++i) {
|
||||
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
|
||||
shared_load_iterator_.load(aligned_accum_fragment[i]);
|
||||
aligned_accum_fragment[0] = add_fragments(
|
||||
aligned_accum_fragment[0], aligned_accum_fragment[i]);
|
||||
}
|
||||
|
||||
shared_load_iterator_.add_pointer_offset(
|
||||
(1 - kPartitionsK) * kSmemPointerOffset);
|
||||
}
|
||||
|
||||
//
|
||||
// Compute the output result
|
||||
//
|
||||
|
||||
typename OutputTileIterator::Fragment output_fragment;
|
||||
|
||||
apply_output_operator_source_not_needed_(
|
||||
destination_iterator.thread_start_row(),
|
||||
output_fragment,
|
||||
output_op,
|
||||
aligned_accum_fragment[0]);
|
||||
|
||||
//
|
||||
// Store the final result
|
||||
//
|
||||
|
||||
destination_iterator.store(output_fragment);
|
||||
++destination_iterator;
|
||||
}
|
||||
|
||||
if (Base::kFragmentsPerIteration > 1) {
|
||||
shared_load_iterator_.add_pointer_offset(
|
||||
kSmemPointerOffset * (1 - Base::kFragmentsPerIteration));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class Seq>
|
||||
struct acc2smem_source_needed;
|
||||
|
||||
template <size_t... Seq>
|
||||
struct acc2smem_source_needed<cutlass::index_sequence<Seq...>> {
|
||||
template <int Advance>
|
||||
CUTLASS_DEVICE static void helper(
|
||||
AccumulatorFragmentIterator accum_fragment_iterator,
|
||||
WarpTileIterator& warp_tile_iterator) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Advance; i++) {
|
||||
++accum_fragment_iterator;
|
||||
}
|
||||
|
||||
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
||||
accum_fragment_iterator.load(accum_fragment);
|
||||
warp_tile_iterator.store(accum_fragment);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void push(
|
||||
size_t pos,
|
||||
AccumulatorFragmentIterator const& iterator_begin,
|
||||
WarpTileIterator& warp_tile_iterator) {
|
||||
int dummy[] = {
|
||||
(pos == Seq) &&
|
||||
(helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
|
||||
}
|
||||
};
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void compute_source_needed_(
|
||||
OutputOp const& output_op, ///< Output operator
|
||||
OutputTileIterator
|
||||
destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const&
|
||||
accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileSourceIterator
|
||||
source_iterator ///< Threadblock tile coordinate in GEMM (in units of
|
||||
///< threadblock tiles)
|
||||
) {
|
||||
typename OutputTileSourceIterator::Fragment source_fragment[2];
|
||||
|
||||
source_fragment[0].clear();
|
||||
source_iterator.load(source_fragment[0]);
|
||||
++source_iterator;
|
||||
source_fragment[1].clear();
|
||||
|
||||
//
|
||||
// Iterator over warp-level accumulator fragment
|
||||
//
|
||||
|
||||
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
||||
|
||||
//
|
||||
// Iterate over accumulator tile
|
||||
//
|
||||
|
||||
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
|
||||
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
|
||||
if (iter > 0) {
|
||||
__syncthreads();
|
||||
}
|
||||
//
|
||||
// Load the source for next iteration (pipelining)
|
||||
//
|
||||
|
||||
if (iter + 1 < OutputTileIterator::kIterations) {
|
||||
source_iterator.load(source_fragment[(iter + 1) % 2]);
|
||||
}
|
||||
++source_iterator;
|
||||
acc2smem_source_needed<
|
||||
cutlass::make_index_sequence<OutputTileIterator::kIterations>>::
|
||||
push(iter, accum_fragment_iterator, this->warp_tile_iterator_);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
// Load fragments from shared memory
|
||||
//
|
||||
|
||||
typename SharedLoadIterator::Fragment
|
||||
aligned_accum_fragment[kPartitionsK];
|
||||
|
||||
shared_load_iterator_.load(aligned_accum_fragment[0]);
|
||||
|
||||
// If the number of k-slices is > 1 - perform a reduction amongst the
|
||||
// k-slices
|
||||
if (kPartitionsK > 1) {
|
||||
plus<typename SharedLoadIterator::Fragment> add_fragments;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 1; i < kPartitionsK; ++i) {
|
||||
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
|
||||
shared_load_iterator_.load(aligned_accum_fragment[i]);
|
||||
aligned_accum_fragment[0] = add_fragments(
|
||||
aligned_accum_fragment[0], aligned_accum_fragment[i]);
|
||||
}
|
||||
|
||||
shared_load_iterator_.add_pointer_offset(
|
||||
(1 - kPartitionsK) * kSmemPointerOffset);
|
||||
}
|
||||
|
||||
//
|
||||
// Compute the output result
|
||||
//
|
||||
|
||||
typename OutputTileIterator::Fragment output_fragment;
|
||||
|
||||
apply_output_operator_(
|
||||
destination_iterator.thread_start_row(),
|
||||
output_fragment,
|
||||
output_op,
|
||||
aligned_accum_fragment[0],
|
||||
source_fragment[iter % 2]);
|
||||
|
||||
//
|
||||
// Store the final result
|
||||
//
|
||||
|
||||
destination_iterator.store(output_fragment);
|
||||
++destination_iterator;
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to invoke the output functor over each vector of output
|
||||
CUTLASS_DEVICE
|
||||
void apply_output_operator_(
|
||||
int begin_row,
|
||||
typename OutputTileIterator::Fragment& output_fragment,
|
||||
OutputOp const& output_op, ///< Output operator
|
||||
typename SharedLoadIterator::Fragment const& aligned_accum_fragment,
|
||||
typename OutputTileSourceIterator::Fragment const& source_fragment) {
|
||||
OutputAccessType* output_frag_ptr =
|
||||
reinterpret_cast<OutputAccessType*>(&output_fragment);
|
||||
|
||||
AccumulatorAccessType const* compute_frag_ptr =
|
||||
reinterpret_cast<AccumulatorAccessType const*>(&aligned_accum_fragment);
|
||||
|
||||
SourceAccessType const* source_frag_ptr =
|
||||
reinterpret_cast<SourceAccessType const*>(&source_fragment);
|
||||
|
||||
int const kOutputOpIterations = OutputTileIterator::Fragment::kElements /
|
||||
OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kOutputOpIterations; ++i) {
|
||||
// Call the output operator
|
||||
output_frag_ptr[i] = ApplyEpilogueOp<OutputOp>::apply(
|
||||
output_op,
|
||||
begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess),
|
||||
compute_frag_ptr[i],
|
||||
source_frag_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to invoke the output functor over each vector of output
|
||||
CUTLASS_DEVICE
|
||||
void apply_output_operator_source_not_needed_(
|
||||
int begin_row,
|
||||
typename OutputTileIterator::Fragment& output_fragment,
|
||||
OutputOp const& output_op, ///< Output operator
|
||||
typename SharedLoadIterator::Fragment const& aligned_accum_fragment) {
|
||||
OutputAccessType* output_frag_ptr =
|
||||
reinterpret_cast<OutputAccessType*>(&output_fragment);
|
||||
|
||||
AccumulatorAccessType const* compute_frag_ptr =
|
||||
reinterpret_cast<AccumulatorAccessType const*>(&aligned_accum_fragment);
|
||||
|
||||
int const kOutputOpIterations = OutputTileIterator::Fragment::kElements /
|
||||
OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kOutputOpIterations; ++i) {
|
||||
// Call the output operator
|
||||
output_frag_ptr[i] = ApplyEpilogueOp<OutputOp>::apply(
|
||||
output_op,
|
||||
begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess),
|
||||
compute_frag_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// This should be constexpr, but it's only supported on c++14
|
||||
static int CUTLASS_HOST_DEVICE getRowOffset(int i) {
|
||||
using ThreadMap = typename OutputTileIterator::ThreadMap;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster;
|
||||
++cluster) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
|
||||
int row_offset = row * ThreadMap::Delta::kRow +
|
||||
group * ThreadMap::Delta::kGroup +
|
||||
cluster * ThreadMap::Delta::kCluster;
|
||||
int frag_row_idx =
|
||||
(row +
|
||||
ThreadMap::Iterations::kRow *
|
||||
(group + ThreadMap::Iterations::kGroup * cluster));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int column = 0; column < ThreadMap::Iterations::kColumn;
|
||||
++column) {
|
||||
int frag_idx = ThreadMap::kElementsPerAccess *
|
||||
(frag_row_idx * ThreadMap::Iterations::kColumn + column);
|
||||
if (i < frag_idx + ThreadMap::kElementsPerAccess) {
|
||||
return row_offset;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
231
examples/42_fused_multi_head_attention/epilogue_rescale_output.h
Normal file
231
examples/42_fused_multi_head_attention/epilogue_rescale_output.h
Normal file
@ -0,0 +1,231 @@
|
||||
/*! \file
|
||||
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
||||
|
||||
The epilogue rearranges the result of a matrix product through shared memory
|
||||
to match canonical tensor layouts in global memory. Epilogues support
|
||||
conversion and reduction operations.
|
||||
|
||||
This is a copy of cutlass/epilogue/threadblock/epilogue.h that can
|
||||
handle "row_id" as a first argument, as uses it to get the corresponding
|
||||
`m_prime` / `s_prime` to rescale the output.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#endif
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/vector.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_coord.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/transform/pitch_linear_thread_map.h"
|
||||
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue_base.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/thread/scale_type.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "epilogue_pipelined.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace thread {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Applies a linear combination operator to an array of elements.
|
||||
// output <- alpha * accumulator + beta * source
|
||||
// with:
|
||||
// alpha = 1 / s_prime (to normalize when isLast=True, 1 otherwise)
|
||||
// beta = alpha / m_prime (renormalize the output when the max changes)
|
||||
// source is the current output
|
||||
template <
|
||||
typename ElementOutput_, ///< Data type used to store tensors
|
||||
typename ElementSource_, //< Data type for source (usually matches
|
||||
//`ElementOutput`)
|
||||
int Count, ///< Number of elements computed per operation.
|
||||
///< Usually it is 128/sizeof_bits<ElementOutput_>,
|
||||
///< but we use 64 or 32 sometimes when there are not enough data
|
||||
///< to store
|
||||
typename ElementAccumulator_, ///< Accumulator data type
|
||||
typename ElementCompute_, ///< Data type used to compute linear combination
|
||||
bool isFirst,
|
||||
bool isLast,
|
||||
typename FragmentAlphaBeta_,
|
||||
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest>
|
||||
class MemoryEfficientAttentionNormalize {
|
||||
public:
|
||||
using ElementOutput = ElementOutput_;
|
||||
using ElementSource = ElementSource_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
|
||||
static int const kCount = Count;
|
||||
|
||||
using FragmentOutput = Array<ElementOutput, kCount>;
|
||||
using FragmentSource = Array<ElementSource, kCount>;
|
||||
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
||||
using ComputeFragment = Array<ElementCompute, kCount>;
|
||||
using FragmentAlphaBeta = FragmentAlphaBeta_;
|
||||
|
||||
static FloatRoundStyle const kRound = Round;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
FragmentAlphaBeta const& s_prime_;
|
||||
FragmentAlphaBeta const& m_prime_;
|
||||
|
||||
public:
|
||||
/// Constructs the function object, possibly loading from pointers in host
|
||||
/// memory
|
||||
CUTLASS_HOST_DEVICE
|
||||
MemoryEfficientAttentionNormalize(
|
||||
FragmentAlphaBeta const& s_prime,
|
||||
FragmentAlphaBeta const& m_prime)
|
||||
: s_prime_(s_prime), m_prime_(m_prime) {}
|
||||
|
||||
/// Returns true if source is needed
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool is_source_needed() const {
|
||||
return !isFirst;
|
||||
}
|
||||
|
||||
/// Functionally required for serial reduction in the epilogue
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition, int k_partition_count) {}
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(
|
||||
int row,
|
||||
FragmentAccumulator const& accumulator,
|
||||
FragmentSource const& source) const {
|
||||
assert(!isFirst);
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementSource, kCount, Round>
|
||||
source_converter;
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
|
||||
accumulator_converter;
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
|
||||
destination_converter;
|
||||
|
||||
ComputeFragment converted_source = source_converter(source);
|
||||
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Perform binary operations
|
||||
ComputeFragment intermediate;
|
||||
|
||||
multiplies<ComputeFragment> mul_add_source;
|
||||
multiply_add<ComputeFragment> mul_add_accumulator;
|
||||
|
||||
ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1;
|
||||
ElementCompute beta = alpha * m_prime_[row];
|
||||
|
||||
intermediate = mul_add_source(beta, converted_source); // X = beta * C
|
||||
|
||||
intermediate = mul_add_accumulator(
|
||||
alpha, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
|
||||
return destination_converter(intermediate);
|
||||
}
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(int row, FragmentAccumulator const& accumulator)
|
||||
const {
|
||||
assert(isFirst);
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
|
||||
accumulator_converter;
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
|
||||
destination_converter;
|
||||
|
||||
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
ComputeFragment intermediate;
|
||||
multiplies<ComputeFragment> mul_accumulator;
|
||||
|
||||
ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1;
|
||||
|
||||
intermediate = mul_accumulator(
|
||||
alpha, converted_accumulator); // X = alpha * C + uniform
|
||||
|
||||
return destination_converter(intermediate);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace thread
|
||||
|
||||
namespace threadblock {
|
||||
template <
|
||||
typename EO,
|
||||
typename ES,
|
||||
int Count,
|
||||
typename EA,
|
||||
typename EC,
|
||||
bool F,
|
||||
bool L,
|
||||
typename FAB,
|
||||
FloatRoundStyle R>
|
||||
struct ApplyEpilogueOp<thread::MemoryEfficientAttentionNormalize<
|
||||
EO,
|
||||
ES,
|
||||
Count,
|
||||
EA,
|
||||
EC,
|
||||
F,
|
||||
L,
|
||||
FAB,
|
||||
R>> {
|
||||
using Op = thread::
|
||||
MemoryEfficientAttentionNormalize<EO, ES, Count, EA, EC, F, L, FAB, R>;
|
||||
static CUTLASS_DEVICE typename Op::FragmentOutput apply(
|
||||
Op const& output_op,
|
||||
int row_id,
|
||||
typename Op::FragmentAccumulator const& accum,
|
||||
typename Op::FragmentSource const& source) {
|
||||
return output_op(row_id, accum, source);
|
||||
}
|
||||
static CUTLASS_DEVICE typename Op::FragmentOutput apply(
|
||||
Op const& output_op,
|
||||
int row_id,
|
||||
typename Op::FragmentAccumulator const& accum) {
|
||||
return output_op(row_id, accum);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,175 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Functor performing linear combination operations used by epilogues.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace thread {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename Element, int ElementsPerAccess>
|
||||
struct ArrayExponential {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<Element, ElementsPerAccess> operator()(
|
||||
Array<Element, ElementsPerAccess> const& input) const {
|
||||
Array<Element, ElementsPerAccess> result;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ElementsPerAccess; ++i) {
|
||||
result[i] = expf(input[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
template <int ElementsPerAccess>
|
||||
struct ArrayExponential<half_t, ElementsPerAccess> {
|
||||
CUTLASS_DEVICE
|
||||
Array<half_t, ElementsPerAccess> operator()(
|
||||
Array<half_t, ElementsPerAccess> const& input) const {
|
||||
Array<half_t, ElementsPerAccess> result;
|
||||
|
||||
int const kVectorCount = ElementsPerAccess / 2;
|
||||
|
||||
__half2 const* input_ptr =
|
||||
reinterpret_cast<__half2 const*>(input.raw_data());
|
||||
__half2* res_ptr = reinterpret_cast<__half2*>(result.raw_data());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kVectorCount; ++i) {
|
||||
res_ptr[i] = h2exp(input_ptr[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Applies:
|
||||
/// output <- (input - lse).exp()
|
||||
template <
|
||||
typename ElementOutput_, // output
|
||||
typename ElementLSE_, // accumulator from LSE
|
||||
typename ElementAccumulator_, // accumulator from matmul
|
||||
typename ElementCompute_, // intermediate compute (and exp calculation)
|
||||
int ElementsPerAccess>
|
||||
class ApplyLogSumExp {
|
||||
public:
|
||||
using ElementOutput = ElementOutput_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
using ElementLSE = ElementLSE_;
|
||||
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
static int const kCount = kElementsPerAccess;
|
||||
static const ScaleType::Kind kScale =
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling;
|
||||
|
||||
using FragmentOutput = Array<ElementOutput, kCount>;
|
||||
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
|
||||
using FragmentLSE = Array<ElementLSE, kElementsPerAccess>;
|
||||
using FragmentScaleBias = FragmentLSE; // Used by epilogue_smem_accumulator.h
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
ApplyLogSumExp() {}
|
||||
|
||||
/// Returns true if source is needed
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool is_source_needed() const {
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Functionally required for serial reduction in the epilogue
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition, int k_partition_count) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(
|
||||
FragmentAccumulator const& AB,
|
||||
FragmentLSE const& scale_unused,
|
||||
// bias used as LSE
|
||||
FragmentLSE const& bias) const {
|
||||
FragmentCompute frag_AB = NumericArrayConverter<
|
||||
ElementCompute,
|
||||
ElementAccumulator,
|
||||
kElementsPerAccess>()(AB);
|
||||
FragmentCompute frag_lse_compute =
|
||||
NumericArrayConverter<ElementCompute, ElementLSE, kElementsPerAccess>()(
|
||||
bias);
|
||||
FragmentCompute frag_compute;
|
||||
|
||||
minus<FragmentCompute> minus_lse;
|
||||
detail::ArrayExponential<ElementCompute, kElementsPerAccess> apply_exp;
|
||||
frag_compute = minus_lse(frag_AB, frag_lse_compute);
|
||||
frag_compute = apply_exp(frag_compute);
|
||||
|
||||
return NumericArrayConverter<
|
||||
ElementOutput,
|
||||
ElementCompute,
|
||||
kElementsPerAccess>()(frag_compute);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace thread
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
158
examples/42_fused_multi_head_attention/find_default_mma.h
Normal file
158
examples/42_fused_multi_head_attention/find_default_mma.h
Normal file
@ -0,0 +1,158 @@
|
||||
/*! \file
|
||||
\brief Cutlass provides helper template functions to figure out the right
|
||||
datastructures to instanciate to run a GEMM with various parameters (see
|
||||
`cutlass/gemm/threadblock/default_mma.h`). However, due to template
|
||||
instanciation priority rules, it will only create an MmaMultiStage with
|
||||
kStages=3 (otherwise creates an MmePipelined - which is not compatible with
|
||||
FastF32). kStages=3 uses too much shared memory and we want to use kStages=2,
|
||||
so we just copy-pasted some code from `default_mma.h` and
|
||||
`default_mma_core.h` files and wrapped this template to allow our usecase.
|
||||
|
||||
This is really only for the FastF32 case - aka using TensorCores with fp32.
|
||||
*/
|
||||
|
||||
#include "cutlass/gemm/threadblock/default_mma.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Layout type for C and D matrix operand
|
||||
typename LayoutC,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation perfomed by GEMM
|
||||
typename Operator,
|
||||
typename Enable_ = void>
|
||||
struct FindDefaultMma {
|
||||
static constexpr bool AccumulatorsInRowMajor = false;
|
||||
static constexpr SharedMemoryClearOption SharedMemoryClear =
|
||||
SharedMemoryClearOption::kNone;
|
||||
using DefaultMma = cutlass::gemm::threadblock::DefaultMma<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
LayoutC,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
Stages,
|
||||
Operator,
|
||||
AccumulatorsInRowMajor,
|
||||
SharedMemoryClear>;
|
||||
};
|
||||
|
||||
/// Specialization for sm80 / FastF32 / multistage with kStages=2
|
||||
template <
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
int kStages,
|
||||
typename Operator>
|
||||
struct FindDefaultMma<
|
||||
ElementA_,
|
||||
LayoutA_,
|
||||
kAlignmentA,
|
||||
ElementB_,
|
||||
LayoutB_,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
arch::Sm80,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
kStages,
|
||||
Operator,
|
||||
typename cutlass::platform::enable_if<(kAlignmentA > 1)>::type> {
|
||||
using LayoutC = layout::RowMajor;
|
||||
using OperatorClass = arch::OpClassTensorOp;
|
||||
using ArchTag = arch::Sm80;
|
||||
|
||||
using DefaultMma_ = cutlass::gemm::threadblock::DefaultMma<
|
||||
ElementA_,
|
||||
LayoutA_,
|
||||
kAlignmentA,
|
||||
ElementB_,
|
||||
LayoutB_,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
LayoutC,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
3,
|
||||
Operator>;
|
||||
struct DefaultMma : DefaultMma_ {
|
||||
using MmaCore_ = typename DefaultMma_::MmaCore;
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<
|
||||
typename MmaCore_::Shape,
|
||||
typename DefaultMma_::IteratorA,
|
||||
typename MmaCore_::SmemIteratorA,
|
||||
MmaCore_::kCacheOpA,
|
||||
typename DefaultMma_::IteratorB,
|
||||
typename MmaCore_::SmemIteratorB,
|
||||
MmaCore_::kCacheOpB,
|
||||
ElementAccumulator,
|
||||
LayoutC,
|
||||
typename MmaCore_::MmaPolicy,
|
||||
kStages>;
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
1092
examples/42_fused_multi_head_attention/fused_multihead_attention.cu
Normal file
1092
examples/42_fused_multi_head_attention/fused_multihead_attention.cu
Normal file
File diff suppressed because it is too large
Load Diff
93
examples/42_fused_multi_head_attention/gemm/custom_mma.h
Normal file
93
examples/42_fused_multi_head_attention/gemm/custom_mma.h
Normal file
@ -0,0 +1,93 @@
|
||||
#pragma once
|
||||
|
||||
#include "custom_mma_multistage.h"
|
||||
#include "custom_mma_pipelined.h"
|
||||
#include "cutlass/gemm/threadblock/mma_multistage.h"
|
||||
#include "cutlass/gemm/threadblock/mma_pipelined.h"
|
||||
|
||||
template <typename Mma, int kMaxK>
|
||||
struct MakeCustomMma;
|
||||
|
||||
template <
|
||||
typename Shape,
|
||||
typename IteratorA,
|
||||
typename SmemIteratorA,
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
typename IteratorB,
|
||||
typename SmemIteratorB,
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename Policy,
|
||||
int Stages,
|
||||
cutlass::gemm::SharedMemoryClearOption SharedMemoryClear,
|
||||
int kMaxK>
|
||||
struct MakeCustomMma<
|
||||
cutlass::gemm::threadblock::MmaMultistage<
|
||||
Shape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
CacheOpA,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
CacheOpB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
Policy,
|
||||
Stages,
|
||||
SharedMemoryClear>,
|
||||
kMaxK> {
|
||||
// Reduce the number of stages if we don't need that many
|
||||
static int constexpr kStages =
|
||||
kMaxK == cutlass::platform::numeric_limits<int>::max()
|
||||
? Stages
|
||||
: cutlass::const_min(
|
||||
Stages,
|
||||
(kMaxK + int(Shape::kK) - 1) / int(Shape::kK));
|
||||
using Mma = cutlass::gemm::threadblock::CustomMmaMultistage<
|
||||
Shape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
CacheOpA,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
CacheOpB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
Policy,
|
||||
kStages,
|
||||
SharedMemoryClear,
|
||||
kMaxK>;
|
||||
};
|
||||
|
||||
template <
|
||||
typename Shape,
|
||||
typename IteratorA,
|
||||
typename SmemIteratorA,
|
||||
typename IteratorB,
|
||||
typename SmemIteratorB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename Policy,
|
||||
int kMaxK>
|
||||
struct MakeCustomMma<
|
||||
cutlass::gemm::threadblock::MmaPipelined<
|
||||
Shape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
Policy>,
|
||||
kMaxK> {
|
||||
using Mma = cutlass::gemm::threadblock::CustomMmaPipelined<
|
||||
Shape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
Policy>;
|
||||
};
|
||||
183
examples/42_fused_multi_head_attention/gemm/custom_mma_base.h
Normal file
183
examples/42_fused_multi_head_attention/gemm/custom_mma_base.h
Normal file
@ -0,0 +1,183 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/threadblock/mma_base.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class CustomMmaBase {
|
||||
public:
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Shape describing the overall GEMM computed from shared memory
|
||||
/// by each warp.
|
||||
using WarpGemm = typename Policy::Operator::Shape;
|
||||
|
||||
/// Shape describing the number of warps filling the CTA
|
||||
using WarpCount = GemmShape<
|
||||
Shape::kM / WarpGemm::kM,
|
||||
Shape::kN / WarpGemm::kN,
|
||||
Shape::kK / WarpGemm::kK>;
|
||||
|
||||
/// Number of warp-level GEMM oeprations
|
||||
static int const kWarpGemmIterations =
|
||||
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
//
|
||||
// Nested structs
|
||||
//
|
||||
|
||||
/// Shared storage object needed by threadblock-scoped GEMM
|
||||
template <typename Element, typename OperandShape, typename OperandLayout>
|
||||
struct OperandSharedStorage {
|
||||
AlignedBuffer<Element, OperandShape::kCount> buffer;
|
||||
using TensorRef = TensorRef<Element, OperandLayout>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static OperandLayout Layout() {
|
||||
return OperandLayout::packed({OperandShape::kRow, OperandShape::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef ref() {
|
||||
return TensorRef{buffer.data(), Layout()};
|
||||
}
|
||||
};
|
||||
|
||||
/// Shape of the A matrix operand in shared memory
|
||||
using ShapeA = MatrixShape<
|
||||
Shape::kM + Policy::SmemPaddingA::kRow,
|
||||
Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
|
||||
|
||||
/// Shape of the B matrix operand in shared memory
|
||||
using ShapeB = MatrixShape<
|
||||
Shape::kK * kStages + Policy::SmemPaddingB::kRow,
|
||||
Shape::kN + Policy::SmemPaddingB::kColumn>;
|
||||
|
||||
using SharedStorageA = OperandSharedStorage<
|
||||
typename Operator::ElementA,
|
||||
ShapeA,
|
||||
typename Operator::LayoutA>;
|
||||
using SharedStorageB = OperandSharedStorage<
|
||||
typename Operator::ElementB,
|
||||
ShapeB,
|
||||
typename Operator::LayoutB>;
|
||||
using TensorRefA = typename SharedStorageA::TensorRef;
|
||||
using TensorRefB = typename SharedStorageB::TensorRef;
|
||||
|
||||
struct SharedStorage {
|
||||
/// Buffer for A operand
|
||||
SharedStorageA operand_A;
|
||||
|
||||
/// Buffer for B operand
|
||||
SharedStorageB operand_B;
|
||||
};
|
||||
|
||||
protected:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A operand from shared memory
|
||||
typename Operator::IteratorA warp_tile_iterator_A_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of B operand from shared memory
|
||||
typename Operator::IteratorB warp_tile_iterator_B_;
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
CustomMmaBase(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
SharedStorageA& shared_storageA,
|
||||
SharedStorageB& shared_storageB,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx)
|
||||
: warp_tile_iterator_A_(shared_storageA.ref(), lane_idx),
|
||||
warp_tile_iterator_B_(shared_storageB.ref(), lane_idx) {}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,767 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/cache_operation.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "custom_mma_base.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
|
||||
/// Upper boundon the K dimension
|
||||
int kMaxK = cutlass::platform::numeric_limits<int>::max(),
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages> {
|
||||
public:
|
||||
///< Base class
|
||||
using Base = CustomMmaBase<Shape_, Policy_, Stages>;
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA = IteratorA_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB = IteratorB_;
|
||||
///< Data type of accumulator matrix
|
||||
using ElementC = ElementC_;
|
||||
///< Layout of accumulator matrix
|
||||
using LayoutC = LayoutC_;
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Minimum architecture is Sm80 to support cp.async
|
||||
using ArchTag = arch::Sm80;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail {
|
||||
static_assert(
|
||||
Base::kWarpGemmIterations > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand A
|
||||
static int const AsyncCopyIterationsPerStageA =
|
||||
IteratorA::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const AsyncCopyIterationsPerStageB =
|
||||
IteratorB::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand A
|
||||
static int const kAccessesPerGroupA =
|
||||
(AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) /
|
||||
Base::kWarpGemmIterations;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB =
|
||||
(AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) /
|
||||
Base::kWarpGemmIterations;
|
||||
};
|
||||
|
||||
static bool const kSmemContainsEntireMat = kMaxK <= Shape::kK * Stages;
|
||||
static constexpr int kNumStagesConcurrentLoad =
|
||||
kSmemContainsEntireMat ? Stages : Stages - 1;
|
||||
|
||||
private:
|
||||
using WarpLoadedFragmentA = typename Operator::FragmentA;
|
||||
using WarpLoadedFragmentB = typename Operator::FragmentB;
|
||||
using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
bool prologue_done_;
|
||||
|
||||
// Set to `True` to ensure the accumulator will be zero outside the GEMM
|
||||
// footprint
|
||||
bool zero_outside_bounds_;
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
CustomMmaMultistage(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::SharedStorageA& shared_storageA,
|
||||
typename Base::SharedStorageB& shared_storageB,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx)
|
||||
: Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A_(shared_storageA.ref(), thread_idx),
|
||||
smem_iterator_B_(shared_storageB.ref(), thread_idx),
|
||||
prologue_done_(false),
|
||||
zero_outside_bounds_(false) {
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
CUTLASS_DEVICE
|
||||
CustomMmaMultistage(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::SharedStorage& st,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx)
|
||||
: CustomMmaMultistage(
|
||||
st.operand_A,
|
||||
st.operand_B,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool set_prologue_done(bool value) {
|
||||
prologue_done_ = value;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool set_zero_outside_bounds(bool value) {
|
||||
zero_outside_bounds_ = value;
|
||||
}
|
||||
|
||||
template <bool kLoadA = true, bool kLoadB = true>
|
||||
CUTLASS_DEVICE static void prologue(
|
||||
typename Base::SharedStorage& shared_storage,
|
||||
///< iterator over A operand in global memory
|
||||
IteratorA iterator_A,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB iterator_B,
|
||||
int thread_idx,
|
||||
int problem_size_k) {
|
||||
prologue<kLoadA, kLoadB>(
|
||||
shared_storage.operand_A,
|
||||
shared_storage.operand_B,
|
||||
iterator_A,
|
||||
iterator_B,
|
||||
thread_idx,
|
||||
problem_size_k);
|
||||
}
|
||||
|
||||
template <bool kLoadA = true, bool kLoadB = true>
|
||||
CUTLASS_DEVICE static void prologue(
|
||||
typename Base::SharedStorageA& shared_storageA,
|
||||
typename Base::SharedStorageB& shared_storageB,
|
||||
///< iterator over A operand in global memory
|
||||
IteratorA iterator_A,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB iterator_B,
|
||||
int thread_idx,
|
||||
int problem_size_k) {
|
||||
SmemIteratorA smem_iterator_A(shared_storageA.ref(), thread_idx);
|
||||
SmemIteratorB smem_iterator_B(shared_storageB.ref(), thread_idx);
|
||||
int32_t iter = (problem_size_k + Base::Shape::kK - 1) / Base::Shape::kK;
|
||||
_prologue<kLoadA, kLoadB>(
|
||||
iterator_A, iterator_B, iter, smem_iterator_A, smem_iterator_B);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance(
|
||||
IteratorA& iterator_A,
|
||||
IteratorB& iterator_B,
|
||||
int group_start_A = 0,
|
||||
int group_start_B = 0) {
|
||||
iterator_A.set_iteration_index(
|
||||
group_start_A * IteratorA::kAccessesPerVector);
|
||||
this->smem_iterator_A_.set_iteration_index(group_start_A);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) {
|
||||
if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) {
|
||||
typename IteratorA::AccessType* dst_ptr =
|
||||
reinterpret_cast<typename IteratorA::AccessType*>(
|
||||
this->smem_iterator_A_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess /
|
||||
IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_A.get();
|
||||
|
||||
if (zero_outside_bounds_ ||
|
||||
SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, gmem_ptr, iterator_A.valid());
|
||||
} else {
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, gmem_ptr, iterator_A.valid());
|
||||
}
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
}
|
||||
|
||||
iterator_B.set_iteration_index(
|
||||
group_start_B * IteratorB::kAccessesPerVector);
|
||||
this->smem_iterator_B_.set_iteration_index(group_start_B);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
|
||||
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
|
||||
typename IteratorB::AccessType* dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType*>(
|
||||
this->smem_iterator_B_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess /
|
||||
IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_B.get();
|
||||
|
||||
if (zero_outside_bounds_ ||
|
||||
SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
} else {
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
}
|
||||
|
||||
++iterator_B;
|
||||
}
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool kLoadA = true, bool kLoadB = true>
|
||||
CUTLASS_DEVICE static void _prologue(
|
||||
IteratorA& iterator_A,
|
||||
IteratorB& iterator_B,
|
||||
int32_t& gemm_k_iterations,
|
||||
SmemIteratorA& smem_iterator_A_,
|
||||
SmemIteratorB& smem_iterator_B_) {
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < kNumStagesConcurrentLoad;
|
||||
++stage, --gemm_k_iterations) {
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
iterator_A.set_iteration_index(0);
|
||||
smem_iterator_A_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
|
||||
typename IteratorA::AccessType* dst_ptr =
|
||||
reinterpret_cast<typename IteratorA::AccessType*>(
|
||||
smem_iterator_A_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess /
|
||||
IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
|
||||
|
||||
if (kLoadA) {
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
||||
}
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
++smem_iterator_A_;
|
||||
}
|
||||
|
||||
iterator_B.set_iteration_index(0);
|
||||
smem_iterator_B_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
|
||||
typename IteratorB::AccessType* dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType*>(
|
||||
smem_iterator_B_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess /
|
||||
IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
if (kLoadB) {
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, iterator_B.get(), iterator_B.valid());
|
||||
}
|
||||
|
||||
++iterator_B;
|
||||
}
|
||||
|
||||
++smem_iterator_B_;
|
||||
}
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A.add_tile_offset({0, 1});
|
||||
iterator_B.add_tile_offset({1, 0});
|
||||
|
||||
smem_iterator_A_.add_tile_offset({0, 1});
|
||||
smem_iterator_B_.add_tile_offset({1, 0});
|
||||
|
||||
// Defines the boundary of a stage of cp.async.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
///< problem size of GEMM
|
||||
int gemm_k_iterations,
|
||||
///< destination accumulator tile
|
||||
FragmentC& accum,
|
||||
///< iterator over A operand in global memory
|
||||
IteratorA iterator_A,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB iterator_B,
|
||||
///< initial value of accumulator
|
||||
FragmentC const& src_accum) {
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
if (!prologue_done_) {
|
||||
_prologue<true, true>(
|
||||
iterator_A,
|
||||
iterator_B,
|
||||
gemm_k_iterations,
|
||||
smem_iterator_A_,
|
||||
smem_iterator_B_);
|
||||
} else if (!kSmemContainsEntireMat) {
|
||||
_prologue<false, false>(
|
||||
iterator_A,
|
||||
iterator_B,
|
||||
gemm_k_iterations,
|
||||
smem_iterator_A_,
|
||||
smem_iterator_B_);
|
||||
} else {
|
||||
gemm_k_iterations -= kNumStagesConcurrentLoad;
|
||||
}
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
accum = src_accum;
|
||||
|
||||
//
|
||||
// Clear the remaining tiles of SMEM. This is a functional requirement for
|
||||
// some kernels so that all accumulator elements outside the GEMM footprint
|
||||
// are zero.
|
||||
//
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) {
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared
|
||||
/// memory
|
||||
SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
|
||||
|
||||
typename IteratorA::AccessType zero_A;
|
||||
zero_A.clear();
|
||||
|
||||
last_smem_iterator_A.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
|
||||
typename IteratorA::AccessType* dst_ptr =
|
||||
reinterpret_cast<typename IteratorA::AccessType*>(
|
||||
last_smem_iterator_A.get());
|
||||
|
||||
*dst_ptr = zero_A;
|
||||
|
||||
++last_smem_iterator_A;
|
||||
}
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared
|
||||
/// memory
|
||||
SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
|
||||
typename IteratorB::AccessType zero_B;
|
||||
|
||||
zero_B.clear();
|
||||
last_smem_iterator_B.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
|
||||
typename IteratorB::AccessType* dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType*>(
|
||||
last_smem_iterator_B.get());
|
||||
|
||||
*dst_ptr = zero_B;
|
||||
|
||||
++last_smem_iterator_B;
|
||||
}
|
||||
}
|
||||
|
||||
// Waits until kStages-2 stages have committed.
|
||||
cutlass::arch::cp_async_wait<kNumStagesConcurrentLoad - 1>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpLoadedFragmentA warp_loaded_frag_A[2];
|
||||
WarpLoadedFragmentB warp_loaded_frag_B[2];
|
||||
WarpTransformedFragmentA warp_transformed_frag_A[2];
|
||||
WarpTransformedFragmentB warp_transformed_frag_B[2];
|
||||
|
||||
Operator warp_mma;
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]);
|
||||
this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
|
||||
warp_mma.transform(
|
||||
warp_transformed_frag_A[0],
|
||||
warp_transformed_frag_B[0],
|
||||
warp_loaded_frag_A[0],
|
||||
warp_loaded_frag_B[0]);
|
||||
|
||||
// tf32x3 kernels use staging accumulation. warp_mma uses a temporary
|
||||
// accumulator and this temporary accumulator is added to the final
|
||||
// accumulator once in every mainloop iteration.
|
||||
plus<FragmentC> plus_accum;
|
||||
|
||||
FragmentC tmp_accum;
|
||||
|
||||
if (platform::is_same<
|
||||
typename Operator::MathOperator,
|
||||
arch::OpMultiplyAddFastF32>::value ||
|
||||
platform::is_same<
|
||||
typename Operator::MathOperator,
|
||||
arch::OpMultiplyAddComplexFastF32>::value) {
|
||||
tmp_accum.clear();
|
||||
}
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > (-kNumStagesConcurrentLoad);) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations;
|
||||
++warp_mma_k) {
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(
|
||||
(warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(
|
||||
(warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
|
||||
// In case of a non-circular buffer ("kSmemContainsEntireMat")
|
||||
// make sure we don't load out of bounds data.
|
||||
if (!kSmemContainsEntireMat ||
|
||||
gemm_k_iterations > (-kNumStagesConcurrentLoad) ||
|
||||
warp_mma_k < Base::kWarpGemmIterations - 1) {
|
||||
this->warp_tile_iterator_A_.load(
|
||||
warp_loaded_frag_A[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B_.load(
|
||||
warp_loaded_frag_B[(warp_mma_k + 1) % 2]);
|
||||
}
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
if (warp_mma_k > 0)
|
||||
warp_mma.transform(
|
||||
warp_transformed_frag_A[warp_mma_k % 2],
|
||||
warp_transformed_frag_B[warp_mma_k % 2],
|
||||
warp_loaded_frag_A[warp_mma_k % 2],
|
||||
warp_loaded_frag_B[warp_mma_k % 2]);
|
||||
|
||||
if (platform::is_same<
|
||||
typename Operator::MathOperator,
|
||||
arch::OpMultiplyAddFastF32>::value ||
|
||||
platform::is_same<
|
||||
typename Operator::MathOperator,
|
||||
arch::OpMultiplyAddComplexFastF32>::value) {
|
||||
warp_mma(
|
||||
tmp_accum,
|
||||
warp_transformed_frag_A[warp_mma_k % 2],
|
||||
warp_transformed_frag_B[warp_mma_k % 2],
|
||||
tmp_accum);
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
accum = plus_accum(accum, tmp_accum);
|
||||
tmp_accum.clear();
|
||||
}
|
||||
} else {
|
||||
warp_mma(
|
||||
accum,
|
||||
warp_transformed_frag_A[warp_mma_k % 2],
|
||||
warp_transformed_frag_B[warp_mma_k % 2],
|
||||
accum);
|
||||
}
|
||||
|
||||
// Issue global->shared copies for the this stage
|
||||
if (!kSmemContainsEntireMat &&
|
||||
warp_mma_k < Base::kWarpGemmIterations - 1) {
|
||||
int group_start_iteration_A, group_start_iteration_B;
|
||||
|
||||
group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
|
||||
group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
|
||||
|
||||
copy_tiles_and_advance(
|
||||
iterator_A,
|
||||
iterator_B,
|
||||
group_start_iteration_A,
|
||||
group_start_iteration_B);
|
||||
}
|
||||
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
|
||||
if (!kSmemContainsEntireMat) {
|
||||
int group_start_iteration_A, group_start_iteration_B;
|
||||
group_start_iteration_A =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupA;
|
||||
group_start_iteration_B =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupB;
|
||||
|
||||
copy_tiles_and_advance(
|
||||
iterator_A,
|
||||
iterator_B,
|
||||
group_start_iteration_A,
|
||||
group_start_iteration_B);
|
||||
}
|
||||
|
||||
// Inserts a memory fence between stages of cp.async instructions.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Waits until kStages-2 stages have committed.
|
||||
cutlass::arch::cp_async_wait<kNumStagesConcurrentLoad - 1>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A.add_tile_offset({0, 1});
|
||||
iterator_B.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_A_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B_.add_tile_offset({1, 0});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1)) {
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx = 0;
|
||||
} else {
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (!kSmemContainsEntireMat &&
|
||||
smem_read_stage_idx == (Base::kStages - 1)) {
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0,
|
||||
-Base::kStages * Policy::kPartitionsK *
|
||||
Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK *
|
||||
Base::kWarpGemmIterations,
|
||||
0});
|
||||
smem_read_stage_idx = 0;
|
||||
} else {
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
--gemm_k_iterations;
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
}
|
||||
|
||||
// Do any conversions feeding the first stage at the end of the loop so
|
||||
// we can start right away on mma instructions
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations)
|
||||
warp_mma.transform(
|
||||
warp_transformed_frag_A[(warp_mma_k + 1) % 2],
|
||||
warp_transformed_frag_B[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_B[(warp_mma_k + 1) % 2]);
|
||||
}
|
||||
}
|
||||
|
||||
if (platform::is_same<
|
||||
typename Operator::MathOperator,
|
||||
arch::OpMultiplyAddFastF32>::value ||
|
||||
platform::is_same<
|
||||
typename Operator::MathOperator,
|
||||
arch::OpMultiplyAddComplexFastF32>::value) {
|
||||
accum = plus_accum(accum, tmp_accum);
|
||||
}
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
// commit and drain all pending and predicated LDGSTS pnz from the GEMM
|
||||
// mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,401 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "custom_mma_base.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Transformation applied to A operand
|
||||
typename TransformA_ = NumericArrayConverter<
|
||||
typename SmemIteratorA_::Element,
|
||||
typename IteratorA_::Element,
|
||||
IteratorA_::Fragment::kElements>,
|
||||
///
|
||||
/// Transformation applied to B operand
|
||||
typename TransformB_ = NumericArrayConverter<
|
||||
typename SmemIteratorB_::Element,
|
||||
typename IteratorB_::Element,
|
||||
IteratorB_::Fragment::kElements>,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class CustomMmaPipelined : public CustomMmaBase<Shape_, Policy_, 2> {
|
||||
public:
|
||||
///< Base class
|
||||
using Base = CustomMmaBase<Shape_, Policy_, 2>;
|
||||
|
||||
using Shape =
|
||||
Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorA =
|
||||
IteratorA_; ///< Iterates over tiles of A operand in global memory
|
||||
using IteratorB =
|
||||
IteratorB_; ///< Iterates over tiles of B operand in global memory
|
||||
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
||||
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
|
||||
using Policy = Policy_; ///< Policy describing tuning details
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
|
||||
using TransformA = TransformA_;
|
||||
using TransformB = TransformB_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of operand A loaded from global memory
|
||||
using FragmentA = typename IteratorA::Fragment;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB = typename IteratorB::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Obtain the arch tag from the warp-level operator
|
||||
using ArchTag = typename Policy::Operator::ArchTag;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
|
||||
// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
|
||||
static_assert(
|
||||
(Base::kStages == 2),
|
||||
"MmaPipelined requires kStages set to value 2");
|
||||
|
||||
static bool const kSmemContainsEntireMat = false;
|
||||
|
||||
private:
|
||||
using WarpFragmentA = typename Operator::FragmentA;
|
||||
using WarpFragmentB = typename Operator::FragmentB;
|
||||
|
||||
protected:
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
CustomMmaPipelined(
|
||||
typename Base::SharedStorageA& shared_storageA,
|
||||
typename Base::SharedStorageB& shared_storageB,
|
||||
int thread_idx, ///< ID within the threadblock
|
||||
int warp_idx, ///< ID of warp
|
||||
int lane_idx ///< ID of each thread within a warp
|
||||
)
|
||||
: Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A_(shared_storageA.ref(), thread_idx),
|
||||
smem_iterator_B_(shared_storageB.ref(), thread_idx) {
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
CUTLASS_DEVICE
|
||||
CustomMmaPipelined(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::SharedStorage& st,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx)
|
||||
: CustomMmaPipelined(
|
||||
st.operand_A,
|
||||
st.operand_B,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool set_prologue_done(bool value) {
|
||||
// NOT IMPLEMENTED FOR PIPELINED
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool set_zero_outside_bounds(bool value) {
|
||||
// NOT NEEDED FOR PIPELINED
|
||||
// shared memory will always be zero-filled
|
||||
}
|
||||
|
||||
template <bool kLoadA = true, bool kLoadB = true>
|
||||
CUTLASS_DEVICE static void prologue(
|
||||
typename Base::SharedStorage& shared_storage,
|
||||
///< iterator over A operand in global memory
|
||||
IteratorA iterator_A,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB iterator_B,
|
||||
int thread_idx,
|
||||
int problem_size_k) {
|
||||
prologue<kLoadA, kLoadB>(
|
||||
shared_storage.operand_A,
|
||||
shared_storage.operand_B,
|
||||
iterator_A,
|
||||
iterator_B,
|
||||
thread_idx,
|
||||
problem_size_k);
|
||||
}
|
||||
|
||||
template <bool kLoadA = true, bool kLoadB = true>
|
||||
CUTLASS_DEVICE static void prologue(
|
||||
typename Base::SharedStorageA& shared_storageA,
|
||||
typename Base::SharedStorageB& shared_storageB,
|
||||
///< iterator over A operand in global memory
|
||||
IteratorA iterator_A,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB iterator_B,
|
||||
int thread_idx,
|
||||
int problem_size_k) {
|
||||
// NOT IMPLEMENTED FOR PIPELINED
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
int gemm_k_iterations, ///< number of iterations of the mainloop
|
||||
FragmentC& accum, ///< destination accumulator tile
|
||||
IteratorA iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB iterator_B, ///< iterator over B operand in global memory
|
||||
FragmentC const& src_accum, ///< source accumulator tile
|
||||
TransformA transform_A =
|
||||
TransformA(), ///< transformation applied to A fragment
|
||||
TransformB transform_B =
|
||||
TransformB()) { ///< transformation applied to B fragment
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
accum = src_accum;
|
||||
|
||||
FragmentA tb_frag_A;
|
||||
FragmentB tb_frag_B;
|
||||
|
||||
tb_frag_A.clear();
|
||||
tb_frag_B.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
|
||||
this->smem_iterator_A_.store(transform_A(tb_frag_A));
|
||||
this->smem_iterator_B_.store(transform_B(tb_frag_B));
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpFragmentA warp_frag_A[2];
|
||||
WarpFragmentB warp_frag_B[2];
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
Operator warp_mma;
|
||||
|
||||
int smem_write_stage_idx = 1;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 1);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 1);
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER*
|
||||
// issuing shared memory loads (which have the tighest latency requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > 0; --gemm_k_iterations) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations;
|
||||
++warp_mma_k) {
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations - 1) {
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_A_.store(transform_A(tb_frag_A));
|
||||
|
||||
this->smem_iterator_B_.store(transform_B(tb_frag_B));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1) {
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
} else {
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0,
|
||||
-Base::kStages * Policy::kPartitionsK *
|
||||
Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK *
|
||||
Base::kWarpGemmIterations,
|
||||
0});
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(
|
||||
(warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(
|
||||
(warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 2);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 2);
|
||||
}
|
||||
|
||||
warp_mma(
|
||||
accum,
|
||||
warp_frag_A[warp_mma_k % 2],
|
||||
warp_frag_B[warp_mma_k % 2],
|
||||
accum);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
264
examples/42_fused_multi_head_attention/gemm_kernel_utils.h
Normal file
264
examples/42_fused_multi_head_attention/gemm_kernel_utils.h
Normal file
@ -0,0 +1,264 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/arch/mma.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Some helper functions
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
#define DISPATCH_TYPES(tensor, func) \
|
||||
{ \
|
||||
if (query.scalar_type() == at::ScalarType::Float) { \
|
||||
using scalar_t = float; \
|
||||
func(); \
|
||||
} else if (query.scalar_type() == at::ScalarType::Half) { \
|
||||
using scalar_t = cutlass::half_t; \
|
||||
func(); \
|
||||
} else if (query.scalar_type() == at::ScalarType::BFloat16) { \
|
||||
using scalar_t = cutlass::bfloat16_t; \
|
||||
func(); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Only fp32, half & bf16 supported at the moment"); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \
|
||||
{ \
|
||||
if (BOOL_V) { \
|
||||
constexpr bool BOOL_NAME = true; \
|
||||
F(); \
|
||||
} else { \
|
||||
constexpr bool BOOL_NAME = false; \
|
||||
F(); \
|
||||
} \
|
||||
}
|
||||
#define DISPATCH_ARCHTAG(CC, func) \
|
||||
{ \
|
||||
if (CC >= 80) { \
|
||||
using ArchTag = cutlass::arch::Sm80; \
|
||||
func(); \
|
||||
} else if (CC >= 75) { \
|
||||
using ArchTag = cutlass::arch::Sm75; \
|
||||
func(); \
|
||||
} else if (CC >= 70) { \
|
||||
using ArchTag = cutlass::arch::Sm70; \
|
||||
func(); \
|
||||
} else if (CC >= 50) { \
|
||||
using ArchTag = cutlass::arch::Sm50; \
|
||||
func(); \
|
||||
} else { \
|
||||
TORCH_CHECK( \
|
||||
false, \
|
||||
"Your device is too old. We require compute capability >= 50"); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \
|
||||
TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
|
||||
TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
|
||||
TORCH_CHECK(TENSOR.is_contiguous());
|
||||
|
||||
#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \
|
||||
TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
|
||||
TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
|
||||
TORCH_CHECK( \
|
||||
TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous");
|
||||
|
||||
#ifdef HAS_PYTORCH
|
||||
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
|
||||
TORCH_CHECK(uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned")
|
||||
#define XFORMERS_CHECK TORCH_CHECK
|
||||
#elif defined(__CUDACC_RTC__)
|
||||
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
|
||||
if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \
|
||||
return false; \
|
||||
}
|
||||
#define XFORMERS_CHECK(COND, ERR) \
|
||||
if (!(COND)) { \
|
||||
return false; \
|
||||
}
|
||||
#else
|
||||
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
|
||||
if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \
|
||||
std::cerr << #PTR " is not correctly aligned\n"; \
|
||||
return false; \
|
||||
}
|
||||
#define XFORMERS_CHECK(COND, ERR) \
|
||||
if (!(COND)) { \
|
||||
std::cerr << #COND " failed\n"; \
|
||||
return false; \
|
||||
}
|
||||
#endif
|
||||
|
||||
#define ASSIGN_CHECK_OVERFLOW(A, B) \
|
||||
{ \
|
||||
A = B; \
|
||||
TORCH_CHECK( \
|
||||
B < cutlass::platform::numeric_limits<decltype(A)>::max(), \
|
||||
#B " overflows"); \
|
||||
}
|
||||
|
||||
namespace gemm_kernel_utils {
|
||||
|
||||
#ifdef HAS_PYTORCH
|
||||
template <typename scalar_t>
|
||||
struct TypeTraits;
|
||||
|
||||
template <>
|
||||
struct TypeTraits<cutlass::half_t> {
|
||||
using scalar_t = cutlass::half_t;
|
||||
|
||||
static constexpr __host__ at::ScalarType atScalarType() {
|
||||
return at::ScalarType::Half;
|
||||
}
|
||||
template <int nDim>
|
||||
static __host__ at::PackedTensorAccessor32<scalar_t, nDim> packed_accessor(
|
||||
at::Tensor const& tensor) {
|
||||
return at::PackedTensorAccessor32<scalar_t, nDim>(
|
||||
(scalar_t*)(tensor.data_ptr()),
|
||||
tensor.sizes().data(),
|
||||
tensor.strides().data());
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeTraits<cutlass::bfloat16_t> {
|
||||
using scalar_t = cutlass::bfloat16_t;
|
||||
|
||||
static constexpr __host__ at::ScalarType atScalarType() {
|
||||
return at::ScalarType::BFloat16;
|
||||
}
|
||||
template <int nDim>
|
||||
static __host__ at::PackedTensorAccessor32<scalar_t, nDim> packed_accessor(
|
||||
at::Tensor const& tensor) {
|
||||
return at::PackedTensorAccessor32<scalar_t, nDim>(
|
||||
(scalar_t*)(tensor.data_ptr()),
|
||||
tensor.sizes().data(),
|
||||
tensor.strides().data());
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeTraits<float> {
|
||||
using scalar_t = float;
|
||||
|
||||
static constexpr __host__ at::ScalarType atScalarType() {
|
||||
return at::ScalarType::Float;
|
||||
}
|
||||
template <int nDim>
|
||||
static __host__ at::PackedTensorAccessor32<scalar_t, nDim> packed_accessor(
|
||||
at::Tensor const& tensor) {
|
||||
return tensor.packed_accessor32<scalar_t, nDim>();
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename integer>
|
||||
constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) {
|
||||
return (n + m - 1) / m;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Determine the type of GEMM we do (TensorCores or not, Shapes ...)
|
||||
// TODO: Maybe we could rely on Cutlass's DefaultGemm templates
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Fallback to Simt (FMA on cuda cores) if not in a special case below
|
||||
template <typename ArchTag, typename scalar_t_, typename Enable = void>
|
||||
struct DefaultGemmType {
|
||||
static constexpr int ThreadK = 8;
|
||||
static constexpr int WarpK = 8;
|
||||
static constexpr int kMinimumAlignment = 1;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
|
||||
using OpClass = cutlass::arch::OpClassSimt;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
// Specialization for tensorcores with f32
|
||||
template <typename ArchTag>
|
||||
struct DefaultGemmType<
|
||||
ArchTag,
|
||||
float,
|
||||
typename cutlass::platform::enable_if<
|
||||
ArchTag::kMinComputeCapability >= 80>::type> {
|
||||
static constexpr int ThreadK = 32;
|
||||
static constexpr int WarpK = 32;
|
||||
static constexpr int kMinimumAlignment = 4;
|
||||
using OpClass = cutlass::arch::OpClassTensorOp;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
using Operator = cutlass::arch::OpMultiplyAddFastF32;
|
||||
};
|
||||
|
||||
// Specialization for tensorcores with f16/bf16 - Sm75+
|
||||
template <typename ArchTag, typename scalar_t>
|
||||
struct DefaultGemmType<
|
||||
ArchTag,
|
||||
scalar_t,
|
||||
typename cutlass::platform::enable_if<
|
||||
ArchTag::kMinComputeCapability >= 75 &&
|
||||
cutlass::sizeof_bits<scalar_t>::value == 16>::type> {
|
||||
static constexpr int ThreadK = 32;
|
||||
static constexpr int WarpK = 32;
|
||||
static constexpr int kMinimumAlignment = 4;
|
||||
using OpClass = cutlass::arch::OpClassTensorOp;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
// Specialization for tensorcores with f16 - Volta
|
||||
template <>
|
||||
struct DefaultGemmType<cutlass::arch::Sm70, cutlass::half_t, void> {
|
||||
static constexpr int ThreadK = 32;
|
||||
static constexpr int WarpK = 32;
|
||||
static constexpr int kMinimumAlignment = 2;
|
||||
using OpClass = cutlass::arch::OpClassTensorOp;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
// Enables to do
|
||||
// `auto x = kCondition ? fa(arg) : fb(arg)`
|
||||
// when `fa` and `fb` have different types
|
||||
template <bool kVal, typename TA, typename TB>
|
||||
struct call_conditional;
|
||||
|
||||
template <typename TA, typename TB>
|
||||
struct call_conditional<true, TA, TB> {
|
||||
template <typename Arg>
|
||||
static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg)
|
||||
-> decltype(ta(arg)) {
|
||||
return ta(arg);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TA, typename TB>
|
||||
struct call_conditional<false, TA, TB> {
|
||||
template <typename Arg>
|
||||
static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg)
|
||||
-> decltype(tb(arg)) {
|
||||
return tb(arg);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Mark a variable as warp-uniform - enables some compiler optimizations
|
||||
// The cheapest way to do it is just to broadcast it from lane 0
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
CUTLASS_DEVICE int32_t warp_uniform(int32_t value) {
|
||||
return (int32_t)__shfl_sync(0xffffffff, (unsigned)value, 0);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CUTLASS_DEVICE T* warp_uniform(T* ptr) {
|
||||
struct {
|
||||
union {
|
||||
T* ptr;
|
||||
uint32_t asInt[2];
|
||||
};
|
||||
} p;
|
||||
p.ptr = ptr;
|
||||
p.asInt[0] = warp_uniform(p.asInt[0]);
|
||||
p.asInt[1] = warp_uniform(p.asInt[1]);
|
||||
return p.ptr;
|
||||
}
|
||||
} // namespace gemm_kernel_utils
|
||||
@ -0,0 +1,752 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Epilogue iterator that supports prefetching
|
||||
|
||||
Mostly copied from "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/threadblock/output_tile_thread_map.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/transform/pitch_linear_thread_map.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tile iterator used to load and store output tile from global memory in
|
||||
/// epilogue.
|
||||
///
|
||||
/// Satisfies: ReadableTileIterator | PredicatedTileIterator |
|
||||
/// ForwardTileIterator
|
||||
///
|
||||
template <
|
||||
typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
|
||||
typename Element_, ///< Element data type
|
||||
bool ScatterD = false, ///< Scatter D operand or not
|
||||
bool UseCUDAStore = false>
|
||||
class PredicatedTileIteratorPrefetch {
|
||||
public:
|
||||
using ThreadMap = ThreadMap_;
|
||||
using Shape = typename ThreadMap::Shape;
|
||||
|
||||
using Element = Element_;
|
||||
|
||||
using Layout = layout::RowMajor;
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
using ConstTensorRef = typename TensorRef::ConstTensorRef;
|
||||
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
using TensorCoord = MatrixCoord;
|
||||
|
||||
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
|
||||
static int const kThreads = ThreadMap::kThreads;
|
||||
static int const kIterations = ThreadMap::Count::kTile;
|
||||
|
||||
static_assert(
|
||||
ThreadMap::Iterations::kRow > 0,
|
||||
"ThreadMap::Iterations::kRow must be > 0");
|
||||
static_assert(
|
||||
ThreadMap::Iterations::kGroup > 0,
|
||||
"ThreadMap::Iterations::kGroup must be > 0");
|
||||
static_assert(
|
||||
ThreadMap::Iterations::kCluster > 0,
|
||||
"ThreadMap::Iterations::kCluster must be > 0");
|
||||
static_assert(
|
||||
ThreadMap::Iterations::kColumn > 0,
|
||||
"ThreadMap::Iterations::kColumn must be > 0");
|
||||
|
||||
/// Fragment object
|
||||
using Fragment = Array<
|
||||
Element,
|
||||
ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow *
|
||||
ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster *
|
||||
ThreadMap::kElementsPerAccess>;
|
||||
|
||||
/// Memory access size
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
|
||||
//
|
||||
// Parameters struct
|
||||
//
|
||||
|
||||
/// Uses a non-template class
|
||||
struct Params : PredicatedTileIteratorParams {
|
||||
using Base = PredicatedTileIteratorParams;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Layout const& layout)
|
||||
: PredicatedTileIteratorParams(
|
||||
layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess,
|
||||
make_OutputTileThreadMapDesc<ThreadMap>()) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Base const& base) : Base(base) {}
|
||||
};
|
||||
|
||||
/// Mask object
|
||||
struct Mask {
|
||||
static int const kCount = ThreadMap::Iterations::kColumn;
|
||||
|
||||
/// Predicate state
|
||||
bool predicates[kCount];
|
||||
|
||||
//
|
||||
// Mask
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Mask() {
|
||||
enable();
|
||||
}
|
||||
|
||||
///< Efficiently disables all accesses guarded by mask
|
||||
CUTLASS_HOST_DEVICE void clear() {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kCount; ++i) {
|
||||
predicates[i] = false;
|
||||
}
|
||||
}
|
||||
|
||||
///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask
|
||||
CUTLASS_DEVICE void enable() {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kCount; ++i) {
|
||||
predicates[i] = true;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters structure containing reference and precomputed state.
|
||||
PredicatedTileIteratorParams params_;
|
||||
|
||||
/// Byte-level pointer
|
||||
uint8_t* byte_pointer_;
|
||||
|
||||
/// Array of boolean values to contain steady-state predicates
|
||||
Mask mask_;
|
||||
|
||||
/// Extent of the matrix tile in rows
|
||||
Index extent_row_;
|
||||
|
||||
/// Extent of the matrix tile in rows
|
||||
Index extent_column_;
|
||||
|
||||
/// A thread's starting row position (assuming steady-state predicates have
|
||||
/// been computed)
|
||||
Index thread_start_row_;
|
||||
|
||||
/// A thread's starting column
|
||||
Index thread_start_column_;
|
||||
|
||||
/// Internal state counter
|
||||
int state_[3];
|
||||
|
||||
/// Scatter indices
|
||||
int const* indices_;
|
||||
|
||||
//
|
||||
// Static asserts about internal strides
|
||||
//
|
||||
|
||||
static_assert(sizeof(extent_row_) == 4, "Expected 32b extents");
|
||||
static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents");
|
||||
static_assert(
|
||||
sizeof(PredicatedTileIteratorParams::stride) == 8,
|
||||
"Expected 64b strides");
|
||||
|
||||
private:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
PredicatedTileIteratorPrefetch(
|
||||
PredicatedTileIteratorParams const& params,
|
||||
Element* pointer,
|
||||
TensorCoord extent,
|
||||
int thread_idx,
|
||||
TensorCoord threadblock_offset = TensorCoord(),
|
||||
int const* indices = nullptr)
|
||||
: params_(params), indices_(indices) {
|
||||
TensorCoord thread_offset =
|
||||
ThreadMap::initial_offset(thread_idx) + threadblock_offset;
|
||||
|
||||
extent_row_ = extent.row();
|
||||
extent_column_ = extent.column();
|
||||
|
||||
thread_start_row_ = thread_offset.row();
|
||||
thread_start_column_ = thread_offset.column();
|
||||
|
||||
// Initialize predicates
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) {
|
||||
mask_.predicates[c] =
|
||||
((thread_offset.column() + ThreadMap::Delta::kColumn * c) <
|
||||
extent.column());
|
||||
}
|
||||
|
||||
// Null pointer performs no accesses
|
||||
if (!pointer) {
|
||||
mask_.clear();
|
||||
}
|
||||
|
||||
if (ScatterD && !indices) {
|
||||
mask_.clear();
|
||||
}
|
||||
|
||||
// Initialize pointer
|
||||
byte_pointer_ = reinterpret_cast<uint8_t*>(pointer) +
|
||||
LongIndex(thread_offset.row()) * LongIndex(params_.stride) +
|
||||
LongIndex(thread_offset.column()) * sizeof(AccessType) /
|
||||
kElementsPerAccess;
|
||||
|
||||
if (ScatterD) {
|
||||
byte_pointer_ = reinterpret_cast<uint8_t*>(pointer) +
|
||||
LongIndex(thread_offset.column()) * sizeof(AccessType) /
|
||||
kElementsPerAccess;
|
||||
}
|
||||
|
||||
// Initialize internal state counter
|
||||
state_[0] = state_[1] = state_[2] = 0;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
byte_pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void prefetch_all() {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int iter = 0; iter < kIterations; ++iter) {
|
||||
prefetch();
|
||||
++(*this);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void prefetch() {
|
||||
uint8_t* byte_pointer = byte_pointer_;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster;
|
||||
++cluster) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
|
||||
int row_offset = row * ThreadMap::Delta::kRow +
|
||||
group * ThreadMap::Delta::kGroup +
|
||||
cluster * ThreadMap::Delta::kCluster;
|
||||
|
||||
AccessType* memory_pointer =
|
||||
reinterpret_cast<AccessType*>(byte_pointer);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int column = 0; column < ThreadMap::Iterations::kColumn;
|
||||
++column) {
|
||||
// on windows using unsigned long here gives the error
|
||||
// error: asm operand type size(4) does not match
|
||||
// type/size implied by constraint 'l'
|
||||
uint64_t addr = (uint64_t)(
|
||||
(void*)&memory_pointer
|
||||
[column * ThreadMap::Delta::kColumn / kElementsPerAccess]);
|
||||
asm volatile("prefetch.global.L1 [ %1 ];" : "=l"(addr) : "l"(addr));
|
||||
}
|
||||
|
||||
if (row + 1 < ThreadMap::Iterations::kRow) {
|
||||
if (!ScatterD) {
|
||||
byte_pointer += params_.increment_row;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (group + 1 < ThreadMap::Iterations::kGroup) {
|
||||
byte_pointer += params_.increment_group;
|
||||
}
|
||||
}
|
||||
|
||||
if (cluster + 1 < ThreadMap::Iterations::kCluster) {
|
||||
byte_pointer += params_.increment_cluster;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const {
|
||||
uint8_t* byte_pointer = byte_pointer_;
|
||||
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster;
|
||||
++cluster) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
|
||||
int frag_row_idx =
|
||||
(row +
|
||||
ThreadMap::Iterations::kRow *
|
||||
(group + ThreadMap::Iterations::kGroup * cluster));
|
||||
|
||||
int row_offset = row * ThreadMap::Delta::kRow +
|
||||
group * ThreadMap::Delta::kGroup +
|
||||
cluster * ThreadMap::Delta::kCluster;
|
||||
|
||||
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
|
||||
|
||||
AccessType* memory_pointer =
|
||||
reinterpret_cast<AccessType*>(byte_pointer + byte_offset);
|
||||
|
||||
if (ScatterD && row_guard) {
|
||||
assert(indices_);
|
||||
|
||||
memory_pointer = reinterpret_cast<AccessType*>(
|
||||
byte_pointer + byte_offset +
|
||||
LongIndex(indices_[row_offset + thread_start_row_]) *
|
||||
LongIndex(params_.stride));
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int column = 0; column < ThreadMap::Iterations::kColumn;
|
||||
++column) {
|
||||
bool guard = row_guard && mask_.predicates[column];
|
||||
|
||||
cutlass::arch::global_load<AccessType, sizeof(AccessType)>(
|
||||
frag_ptr
|
||||
[frag_row_idx * ThreadMap::Iterations::kColumn + column],
|
||||
(void*)&memory_pointer
|
||||
[column * ThreadMap::Delta::kColumn / kElementsPerAccess],
|
||||
guard);
|
||||
}
|
||||
|
||||
if (row + 1 < ThreadMap::Iterations::kRow) {
|
||||
if (!ScatterD) {
|
||||
byte_pointer += params_.increment_row;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (group + 1 < ThreadMap::Iterations::kGroup) {
|
||||
byte_pointer += params_.increment_group;
|
||||
}
|
||||
}
|
||||
|
||||
if (cluster + 1 < ThreadMap::Iterations::kCluster) {
|
||||
byte_pointer += params_.increment_cluster;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment& frag) const {
|
||||
load_with_byte_offset(frag, 0);
|
||||
}
|
||||
|
||||
/// Stores a fragment to memory
|
||||
CUTLASS_DEVICE
|
||||
void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) const {
|
||||
uint8_t* byte_pointer = byte_pointer_;
|
||||
AccessType const* frag_ptr = reinterpret_cast<AccessType const*>(&frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster;
|
||||
++cluster) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
|
||||
int frag_row_idx =
|
||||
(row +
|
||||
ThreadMap::Iterations::kRow *
|
||||
(group + ThreadMap::Iterations::kGroup * cluster));
|
||||
|
||||
int row_offset = row * ThreadMap::Delta::kRow +
|
||||
group * ThreadMap::Delta::kGroup +
|
||||
cluster * ThreadMap::Delta::kCluster;
|
||||
|
||||
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
|
||||
|
||||
AccessType* memory_pointer =
|
||||
reinterpret_cast<AccessType*>(byte_pointer + byte_offset);
|
||||
|
||||
if (ScatterD && row_guard) {
|
||||
assert(indices_);
|
||||
|
||||
memory_pointer = reinterpret_cast<AccessType*>(
|
||||
byte_pointer + byte_offset +
|
||||
LongIndex(indices_[row_offset + thread_start_row_]) *
|
||||
LongIndex(params_.stride));
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int column = 0; column < ThreadMap::Iterations::kColumn;
|
||||
++column) {
|
||||
bool guard = row_guard && mask_.predicates[column];
|
||||
|
||||
if (UseCUDAStore) {
|
||||
if (guard) {
|
||||
memory_pointer
|
||||
[column * ThreadMap::Delta::kColumn / kElementsPerAccess] =
|
||||
frag_ptr
|
||||
[frag_row_idx * ThreadMap::Iterations::kColumn +
|
||||
column];
|
||||
}
|
||||
} else {
|
||||
cutlass::arch::global_store<AccessType, sizeof(AccessType)>(
|
||||
frag_ptr
|
||||
[frag_row_idx * ThreadMap::Iterations::kColumn + column],
|
||||
(void*)&memory_pointer
|
||||
[column * ThreadMap::Delta::kColumn / kElementsPerAccess],
|
||||
guard);
|
||||
}
|
||||
}
|
||||
|
||||
if (row + 1 < ThreadMap::Iterations::kRow) {
|
||||
if (!ScatterD) {
|
||||
byte_pointer += params_.increment_row;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (group + 1 < ThreadMap::Iterations::kGroup) {
|
||||
byte_pointer += params_.increment_group;
|
||||
}
|
||||
}
|
||||
|
||||
if (cluster + 1 < ThreadMap::Iterations::kCluster) {
|
||||
byte_pointer += params_.increment_cluster;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Stores a fragment to memory
|
||||
CUTLASS_DEVICE
|
||||
void store(Fragment const& frag) const {
|
||||
store_with_byte_offset(frag, 0);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void downsample_load_with_byte_offset(
|
||||
Fragment& frag,
|
||||
int64_t byte_offset,
|
||||
int convolution_P,
|
||||
int convolution_Q,
|
||||
int add_P,
|
||||
int add_Q,
|
||||
int problem_N) const {
|
||||
uint8_t* byte_pointer = byte_pointer_;
|
||||
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster;
|
||||
++cluster) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
|
||||
int frag_row_idx =
|
||||
(row +
|
||||
ThreadMap::Iterations::kRow *
|
||||
(group + ThreadMap::Iterations::kGroup * cluster));
|
||||
|
||||
int row_offset = row * ThreadMap::Delta::kRow +
|
||||
group * ThreadMap::Delta::kGroup +
|
||||
cluster * ThreadMap::Delta::kCluster;
|
||||
|
||||
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
|
||||
|
||||
int output_row = row_offset + thread_start_row_;
|
||||
int output_N = output_row / (convolution_P * convolution_Q);
|
||||
int output_PQ = output_row % (convolution_P * convolution_Q);
|
||||
int output_P = output_PQ / convolution_Q;
|
||||
int output_Q = output_PQ % convolution_Q;
|
||||
|
||||
int input_row = output_N * 2 * convolution_P * 2 * convolution_Q +
|
||||
(2 * output_P + add_P) * 2 * convolution_Q + 2 * output_Q + add_Q;
|
||||
|
||||
int64_t byte_offset =
|
||||
(input_row - output_row) * problem_N * sizeof(float);
|
||||
|
||||
AccessType* memory_pointer =
|
||||
reinterpret_cast<AccessType*>(byte_pointer + byte_offset);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int column = 0; column < ThreadMap::Iterations::kColumn;
|
||||
++column) {
|
||||
bool guard = row_guard && mask_.predicates[column];
|
||||
|
||||
cutlass::arch::global_load<AccessType, sizeof(AccessType)>(
|
||||
frag_ptr
|
||||
[frag_row_idx * ThreadMap::Iterations::kColumn + column],
|
||||
(void*)&memory_pointer
|
||||
[column * ThreadMap::Delta::kColumn / kElementsPerAccess],
|
||||
guard);
|
||||
}
|
||||
|
||||
if (row + 1 < ThreadMap::Iterations::kRow) {
|
||||
byte_pointer += params_.increment_row;
|
||||
}
|
||||
}
|
||||
|
||||
if (group + 1 < ThreadMap::Iterations::kGroup) {
|
||||
byte_pointer += params_.increment_group;
|
||||
}
|
||||
}
|
||||
|
||||
if (cluster + 1 < ThreadMap::Iterations::kCluster) {
|
||||
byte_pointer += params_.increment_cluster;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void upsample_load_with_byte_offset(
|
||||
Fragment& frag,
|
||||
int64_t byte_offset,
|
||||
int convolution_P,
|
||||
int convolution_Q,
|
||||
int add_P,
|
||||
int add_Q,
|
||||
int problem_N) const {
|
||||
uint8_t* byte_pointer = byte_pointer_;
|
||||
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster;
|
||||
++cluster) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
|
||||
int frag_row_idx =
|
||||
(row +
|
||||
ThreadMap::Iterations::kRow *
|
||||
(group + ThreadMap::Iterations::kGroup * cluster));
|
||||
|
||||
int row_offset = row * ThreadMap::Delta::kRow +
|
||||
group * ThreadMap::Delta::kGroup +
|
||||
cluster * ThreadMap::Delta::kCluster;
|
||||
|
||||
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
|
||||
|
||||
int output_row = row_offset + thread_start_row_;
|
||||
int output_N = output_row / (convolution_P * convolution_Q);
|
||||
int output_PQ = output_row % (convolution_P * convolution_Q);
|
||||
int output_P = output_PQ / convolution_Q;
|
||||
int output_Q = output_PQ % convolution_Q;
|
||||
int row_add_P = add_P;
|
||||
int row_add_Q = add_Q;
|
||||
if (output_P > convolution_P - 2)
|
||||
row_add_P = 0;
|
||||
if (output_Q > convolution_Q - 2)
|
||||
row_add_Q = 0;
|
||||
|
||||
int input_row = output_N * (convolution_P / 2) * (convolution_Q / 2) +
|
||||
((output_P + row_add_P) / 2) * (convolution_Q / 2) +
|
||||
(output_Q + row_add_Q) / 2;
|
||||
|
||||
int64_t byte_offset =
|
||||
(input_row - output_row) * problem_N * sizeof(float);
|
||||
|
||||
AccessType* memory_pointer =
|
||||
reinterpret_cast<AccessType*>(byte_pointer + byte_offset);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int column = 0; column < ThreadMap::Iterations::kColumn;
|
||||
++column) {
|
||||
bool guard = row_guard && mask_.predicates[column];
|
||||
|
||||
cutlass::arch::global_load<AccessType, sizeof(AccessType)>(
|
||||
frag_ptr
|
||||
[frag_row_idx * ThreadMap::Iterations::kColumn + column],
|
||||
(void*)&memory_pointer
|
||||
[column * ThreadMap::Delta::kColumn / kElementsPerAccess],
|
||||
guard);
|
||||
}
|
||||
|
||||
if (row + 1 < ThreadMap::Iterations::kRow) {
|
||||
byte_pointer += params_.increment_row;
|
||||
}
|
||||
}
|
||||
|
||||
if (group + 1 < ThreadMap::Iterations::kGroup) {
|
||||
byte_pointer += params_.increment_group;
|
||||
}
|
||||
}
|
||||
|
||||
if (cluster + 1 < ThreadMap::Iterations::kCluster) {
|
||||
byte_pointer += params_.increment_cluster;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MatrixCoord thread_start() const {
|
||||
return MatrixCoord(thread_start_row_, thread_start_column_);
|
||||
}
|
||||
|
||||
/// Need to get the thread start row from the tile iterator
|
||||
CUTLASS_DEVICE
|
||||
int32_t thread_start_row() const {
|
||||
return thread_start_row_;
|
||||
}
|
||||
|
||||
/// Need to get the thread start row from the tile iterator
|
||||
CUTLASS_DEVICE
|
||||
int32_t thread_start_column() const {
|
||||
return thread_start_column_;
|
||||
}
|
||||
|
||||
/// Extent of the matrix in rows
|
||||
CUTLASS_DEVICE
|
||||
Index extent_row() const {
|
||||
return extent_row_;
|
||||
}
|
||||
|
||||
/// Extent of the matrix in columns
|
||||
CUTLASS_DEVICE
|
||||
Index extent_column() const {
|
||||
return extent_column_;
|
||||
}
|
||||
|
||||
/// Advances to the next position to load or store
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedTileIteratorPrefetch& operator++() {
|
||||
++state_[0];
|
||||
|
||||
if (!ScatterD) {
|
||||
byte_pointer_ += params_.advance_row;
|
||||
}
|
||||
|
||||
thread_start_row_ += ThreadMap::Shape::kRow;
|
||||
|
||||
if (state_[0] == ThreadMap::Count::kRow) {
|
||||
state_[0] = 0;
|
||||
++state_[1];
|
||||
byte_pointer_ += params_.advance_group;
|
||||
|
||||
thread_start_row_ += (ThreadMap::Shape::kGroup - 1) *
|
||||
ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
|
||||
|
||||
if (state_[1] == ThreadMap::Count::kGroup) {
|
||||
state_[1] = 0;
|
||||
++state_[2];
|
||||
byte_pointer_ += params_.advance_cluster;
|
||||
|
||||
thread_start_row_ += ThreadMap::Count::kGroup *
|
||||
ThreadMap::Shape::kGroup * ThreadMap::Count::kRow *
|
||||
ThreadMap::Shape::kRow;
|
||||
|
||||
if (state_[2] == ThreadMap::Count::kCluster) {
|
||||
state_[2] = 0;
|
||||
byte_pointer_ += params_.advance_tile;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
///< Efficiently disables all accesses guarded by mask
|
||||
CUTLASS_DEVICE void clear_mask() {
|
||||
mask_.clear();
|
||||
}
|
||||
|
||||
///< Efficiently enables all accesses guarded by mask
|
||||
CUTLASS_DEVICE void enable_mask() {
|
||||
mask_.enable();
|
||||
}
|
||||
|
||||
///< Sets the mask
|
||||
CUTLASS_DEVICE void get_mask(Mask& mask) const {
|
||||
mask = mask_;
|
||||
}
|
||||
|
||||
///< Sets the mask
|
||||
CUTLASS_DEVICE void set_mask(Mask const& mask) {
|
||||
mask_ = mask;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename IT>
|
||||
struct MakePrefetchableIterator {
|
||||
using Iterator = PredicatedTileIteratorPrefetch<
|
||||
typename IT::ThreadMap,
|
||||
typename IT::Element>;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,66 @@
|
||||
#pragma once
|
||||
|
||||
#include "predicated_tile_access_iterator_residual_last.h"
|
||||
#include "predicated_tile_iterator_residual_last.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace transform {
|
||||
namespace threadblock {
|
||||
|
||||
template <typename BaseIterator>
|
||||
struct MakeIteratorResidualLast;
|
||||
|
||||
template <
|
||||
typename Shape,
|
||||
typename Element,
|
||||
typename Layout,
|
||||
int AdvanceRank,
|
||||
typename ThreadMap,
|
||||
int AccessSize,
|
||||
bool Gather>
|
||||
struct MakeIteratorResidualLast<PredicatedTileIterator<
|
||||
Shape,
|
||||
Element,
|
||||
Layout,
|
||||
AdvanceRank,
|
||||
ThreadMap,
|
||||
AccessSize,
|
||||
Gather>> {
|
||||
using Iterator = PredicatedTileIteratorResidualLast<
|
||||
Shape,
|
||||
Element,
|
||||
Layout,
|
||||
AdvanceRank,
|
||||
ThreadMap,
|
||||
AccessSize,
|
||||
Gather>;
|
||||
};
|
||||
|
||||
template <
|
||||
typename Shape,
|
||||
typename Element,
|
||||
typename Layout,
|
||||
int AdvanceRank,
|
||||
typename ThreadMap,
|
||||
typename AccessType,
|
||||
bool Gather>
|
||||
struct MakeIteratorResidualLast<PredicatedTileAccessIterator<
|
||||
Shape,
|
||||
Element,
|
||||
Layout,
|
||||
AdvanceRank,
|
||||
ThreadMap,
|
||||
AccessType,
|
||||
Gather>> {
|
||||
using Iterator = PredicatedTileAccessIteratorResidualLast<
|
||||
Shape,
|
||||
Element,
|
||||
Layout,
|
||||
AdvanceRank,
|
||||
ThreadMap,
|
||||
AccessType,
|
||||
Gather>;
|
||||
};
|
||||
} // namespace threadblock
|
||||
} // namespace transform
|
||||
} // namespace cutlass
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
916
examples/42_fused_multi_head_attention/kernel_forward.h
Normal file
916
examples/42_fused_multi_head_attention/kernel_forward.h
Normal file
@ -0,0 +1,916 @@
|
||||
#ifdef HAS_PYTORCH
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/library.h>
|
||||
#endif
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/layout/vector.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "attention_scaling_coefs_updater.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/platform/platform.h"
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||
#include "debug_utils.h"
|
||||
#include "epilogue_pipelined.h"
|
||||
#include "epilogue_rescale_output.h"
|
||||
#include "find_default_mma.h"
|
||||
#include "gemm_kernel_utils.h"
|
||||
#include "mma_from_smem.h"
|
||||
|
||||
#include <inttypes.h>
|
||||
|
||||
using namespace gemm_kernel_utils;
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t, typename Arch>
|
||||
constexpr int getWarpsPerSm() {
|
||||
return (
|
||||
Arch::kMinComputeCapability >= 80 &&
|
||||
!cutlass::platform::is_same<scalar_t, float>::value
|
||||
? 16
|
||||
: 12);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <
|
||||
// The datatype of Q/K/V
|
||||
typename scalar_t_,
|
||||
// Architecture we are targeting (eg `cutlass::arch::Sm80`)
|
||||
typename ArchTag,
|
||||
// If Q/K/V are correctly aligned in memory and we can run a fast kernel
|
||||
bool isAligned_,
|
||||
int kQueriesPerBlock,
|
||||
int kKeysPerBlock,
|
||||
bool kSingleValueIteration // = `value.shape[-1] <= kKeysPerBlock`
|
||||
>
|
||||
struct AttentionKernel {
|
||||
using scalar_t = scalar_t_;
|
||||
using accum_t = float;
|
||||
using lse_scalar_t = float;
|
||||
using output_t = scalar_t;
|
||||
// Accumulator between 2 iterations
|
||||
// Using `accum_t` improves perf on f16 at the cost of
|
||||
// numerical errors
|
||||
using output_accum_t = accum_t;
|
||||
static constexpr bool kIsAligned = isAligned_;
|
||||
static constexpr int32_t kAlignLSE = 32; // block size of backward
|
||||
static constexpr bool kPreloadV = ArchTag::kMinComputeCapability >= 80 &&
|
||||
cutlass::sizeof_bits<scalar_t>::value == 16;
|
||||
static constexpr bool kKeepOutputInRF = kSingleValueIteration;
|
||||
static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF &&
|
||||
!cutlass::platform::is_same<output_accum_t, output_t>::value;
|
||||
|
||||
static_assert(kQueriesPerBlock % 32 == 0, "");
|
||||
static_assert(kKeysPerBlock % 32 == 0, "");
|
||||
static constexpr int kNumWarpsPerBlock =
|
||||
kQueriesPerBlock * kKeysPerBlock / (32 * 32);
|
||||
static constexpr int kWarpSize = 32;
|
||||
|
||||
// Launch bounds
|
||||
static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock;
|
||||
static constexpr int kMinBlocksPerSm =
|
||||
getWarpsPerSm<scalar_t, ArchTag>() / kNumWarpsPerBlock;
|
||||
|
||||
struct Params {
|
||||
// Input tensors
|
||||
scalar_t* query_ptr; // [num_queries, num_heads, head_dim]
|
||||
scalar_t* key_ptr; // [num_keys, num_heads, head_dim]
|
||||
scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value]
|
||||
int32_t* cu_seqlens_q_ptr = nullptr;
|
||||
int32_t* cu_seqlens_k_ptr = nullptr;
|
||||
|
||||
// Output tensors
|
||||
output_t* output_ptr; // [num_queries, num_heads, head_dim_value]
|
||||
output_accum_t*
|
||||
output_accum_ptr; // [num_queries, num_heads, head_dim_value]
|
||||
lse_scalar_t* logsumexp_ptr; // [num_heads, num_queries] - can be null
|
||||
|
||||
// Dimensions/strides
|
||||
int32_t head_dim;
|
||||
int32_t head_dim_value;
|
||||
int32_t num_queries;
|
||||
int32_t num_keys;
|
||||
|
||||
bool causal;
|
||||
|
||||
int32_t q_strideM;
|
||||
int32_t k_strideM;
|
||||
int32_t v_strideM;
|
||||
|
||||
// Everything below is only used in `advance_to_block`
|
||||
// and shouldn't use registers
|
||||
int32_t q_strideH;
|
||||
int32_t k_strideH;
|
||||
int32_t v_strideH;
|
||||
int32_t o_strideH;
|
||||
int64_t q_strideB;
|
||||
int64_t k_strideB;
|
||||
int64_t v_strideB;
|
||||
int64_t o_strideB;
|
||||
int32_t num_batches;
|
||||
int32_t num_heads;
|
||||
|
||||
CUTLASS_HOST_DEVICE int32_t o_strideM() const {
|
||||
return head_dim_value;
|
||||
}
|
||||
// Moves pointers to what we should process
|
||||
// Returns "false" if there is no work to do
|
||||
CUTLASS_DEVICE bool advance_to_block() {
|
||||
auto batch_id = blockIdx.z;
|
||||
auto head_id = blockIdx.y;
|
||||
auto query_start = blockIdx.x * kQueriesPerBlock;
|
||||
|
||||
auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE;
|
||||
|
||||
int64_t q_start, k_start;
|
||||
// Advance to current batch - in case of different sequence lengths
|
||||
if (cu_seqlens_q_ptr != nullptr) {
|
||||
assert(cu_seqlens_k_ptr != nullptr);
|
||||
cu_seqlens_q_ptr += batch_id;
|
||||
cu_seqlens_k_ptr += batch_id;
|
||||
q_start = cu_seqlens_q_ptr[0];
|
||||
k_start = cu_seqlens_k_ptr[0];
|
||||
int64_t q_next_start = cu_seqlens_q_ptr[1];
|
||||
int64_t k_next_start = cu_seqlens_k_ptr[1];
|
||||
num_queries = q_next_start - q_start;
|
||||
num_keys = k_next_start - k_start;
|
||||
|
||||
if (query_start >= num_queries) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
query_ptr += batch_id * q_strideB;
|
||||
key_ptr += batch_id * k_strideB;
|
||||
value_ptr += batch_id * v_strideB;
|
||||
output_ptr += batch_id * o_strideB;
|
||||
if (output_accum_ptr != nullptr) {
|
||||
output_accum_ptr += batch_id * o_strideB;
|
||||
}
|
||||
q_start = 0;
|
||||
k_start = 0;
|
||||
}
|
||||
|
||||
// Advance to the current batch / head / query_start
|
||||
query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH;
|
||||
key_ptr += k_start * k_strideM + head_id * k_strideH;
|
||||
value_ptr += k_start * v_strideM + head_id * v_strideH;
|
||||
output_ptr += int64_t(q_start + query_start) * o_strideM() +
|
||||
head_id * o_strideH;
|
||||
|
||||
if (output_accum_ptr != nullptr) {
|
||||
output_accum_ptr += int64_t(q_start + query_start) * o_strideM() +
|
||||
head_id * o_strideH;
|
||||
} else {
|
||||
// Accumulate directly in the destination buffer (eg for f32)
|
||||
output_accum_ptr = (accum_t*)output_ptr;
|
||||
}
|
||||
if (logsumexp_ptr != nullptr) {
|
||||
// lse[batch_id, head_id, query_start]
|
||||
logsumexp_ptr +=
|
||||
batch_id * lse_dim * num_heads + head_id * lse_dim + query_start;
|
||||
}
|
||||
|
||||
num_queries -= query_start;
|
||||
if (causal) {
|
||||
num_keys = cutlass::fast_min(
|
||||
int32_t(query_start + kQueriesPerBlock), num_keys);
|
||||
}
|
||||
num_batches = 0; // no longer used after
|
||||
|
||||
// Make sure the compiler knows these variables are the same on all
|
||||
// the threads of the warp.
|
||||
query_ptr = warp_uniform(query_ptr);
|
||||
key_ptr = warp_uniform(key_ptr);
|
||||
value_ptr = warp_uniform(value_ptr);
|
||||
output_ptr = warp_uniform(output_ptr);
|
||||
output_accum_ptr = warp_uniform(output_accum_ptr);
|
||||
logsumexp_ptr = warp_uniform(logsumexp_ptr);
|
||||
num_queries = warp_uniform(num_queries);
|
||||
num_keys = warp_uniform(num_keys);
|
||||
head_dim = warp_uniform(head_dim);
|
||||
head_dim_value = warp_uniform(head_dim_value);
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ dim3 getBlocksGrid() const {
|
||||
return dim3(
|
||||
ceil_div(num_queries, (int32_t)kQueriesPerBlock),
|
||||
num_heads,
|
||||
num_batches);
|
||||
}
|
||||
__host__ dim3 getThreadsGrid() const {
|
||||
return dim3(kWarpSize, kNumWarpsPerBlock, 1);
|
||||
}
|
||||
};
|
||||
|
||||
struct MM0 {
|
||||
/*
|
||||
In this first matmul, we compute a block of `Q @ K.T`.
|
||||
While the calculation result is still hot in registers, we update
|
||||
`mi`, `m_prime`, `s_prime` in shared-memory, and then store this value
|
||||
into a shared-memory ("AccumulatorSharedStorage") that is used later as
|
||||
operand A for the second matmul (see MM1)
|
||||
*/
|
||||
using GemmType = DefaultGemmType<ArchTag, scalar_t>;
|
||||
|
||||
using OpClass = typename GemmType::OpClass;
|
||||
using DefaultConfig =
|
||||
typename cutlass::gemm::device::DefaultGemmConfiguration<
|
||||
OpClass,
|
||||
ArchTag,
|
||||
scalar_t,
|
||||
scalar_t,
|
||||
scalar_t, // ElementC
|
||||
accum_t // ElementAccumulator
|
||||
>;
|
||||
static constexpr int kAlignmentA =
|
||||
kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment;
|
||||
static constexpr int kAlignmentB =
|
||||
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment;
|
||||
using ThreadblockShape = cutlass::gemm::
|
||||
GemmShape<kQueriesPerBlock, kKeysPerBlock, GemmType::ThreadK>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
|
||||
using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma<
|
||||
scalar_t, // ElementA,
|
||||
cutlass::layout::RowMajor, // LayoutA,
|
||||
kAlignmentA,
|
||||
scalar_t, // ElementB,
|
||||
cutlass::layout::ColumnMajor, // LayoutB,
|
||||
kAlignmentB,
|
||||
accum_t,
|
||||
cutlass::layout::RowMajor, // LayoutC,
|
||||
OpClass,
|
||||
ArchTag, // ArchTag
|
||||
ThreadblockShape, // ThreadblockShape
|
||||
WarpShape, // WarpShape
|
||||
typename GemmType::InstructionShape, // InstructionShape
|
||||
DefaultConfig::kStages, // Should use `DefaultConfig::kStages`, but that
|
||||
// uses too much smem
|
||||
typename GemmType::Operator // Operator
|
||||
>::DefaultMma;
|
||||
using MmaCore = typename DefaultMma::MmaCore;
|
||||
using IteratorA = typename DefaultMma::IteratorA;
|
||||
using IteratorB = typename DefaultMma::IteratorB;
|
||||
using Mma = typename DefaultMma::ThreadblockMma;
|
||||
using ScalingCoefsUpdater = typename DefaultAttentionScalingCoefsUpdater<
|
||||
typename Mma::Operator::IteratorC,
|
||||
accum_t,
|
||||
kWarpSize>::Updater;
|
||||
static_assert(
|
||||
MmaCore::WarpCount::kM * MmaCore::WarpCount::kN *
|
||||
MmaCore::WarpCount::kK ==
|
||||
kNumWarpsPerBlock,
|
||||
"");
|
||||
|
||||
// Epilogue to store to shared-memory in a format that we can use later for
|
||||
// the second matmul
|
||||
using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm<
|
||||
typename Mma::Operator::IteratorC,
|
||||
typename Mma::Operator,
|
||||
scalar_t,
|
||||
WarpShape,
|
||||
ThreadblockShape>;
|
||||
using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage;
|
||||
};
|
||||
|
||||
struct MM1 {
|
||||
/**
|
||||
Second matmul: perform `attn @ V` where `attn` is the attention (not
|
||||
normalized) and stored in shared memory
|
||||
*/
|
||||
using GemmType = DefaultGemmType<ArchTag, scalar_t>;
|
||||
|
||||
using OpClass = typename GemmType::OpClass;
|
||||
using DefaultConfig =
|
||||
typename cutlass::gemm::device::DefaultGemmConfiguration<
|
||||
OpClass,
|
||||
ArchTag,
|
||||
scalar_t,
|
||||
scalar_t,
|
||||
output_accum_t, // ElementC
|
||||
accum_t // ElementAccumulator
|
||||
>;
|
||||
static constexpr int kAlignmentA = DefaultConfig::kAlignmentA; // from smem
|
||||
static constexpr int kAlignmentB =
|
||||
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment;
|
||||
using ThreadblockShape = cutlass::gemm::
|
||||
GemmShape<kQueriesPerBlock, kKeysPerBlock, GemmType::ThreadK>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
|
||||
using InstructionShape = typename GemmType::InstructionShape;
|
||||
|
||||
using LayoutB = cutlass::layout::RowMajor;
|
||||
using DefaultGemm = cutlass::gemm::kernel::DefaultGemm<
|
||||
scalar_t, // ElementA,
|
||||
cutlass::layout::RowMajor, // LayoutA,
|
||||
kAlignmentA,
|
||||
scalar_t, // ElementB,
|
||||
LayoutB, // LayoutB,
|
||||
kAlignmentB,
|
||||
output_accum_t,
|
||||
cutlass::layout::RowMajor, // LayoutC,
|
||||
accum_t,
|
||||
OpClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
typename GemmType::InstructionShape,
|
||||
typename DefaultConfig::EpilogueOutputOp,
|
||||
void, // ThreadblockSwizzle - not used
|
||||
DefaultConfig::kStages,
|
||||
false, // SplitKSerial
|
||||
typename GemmType::Operator>;
|
||||
|
||||
using DefaultMmaFromSmem =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
|
||||
typename DefaultGemm::Mma,
|
||||
typename MM0::AccumulatorSharedStorage>;
|
||||
using Mma = typename DefaultMmaFromSmem::Mma;
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static_assert(
|
||||
WarpCount::kM * WarpCount::kN * WarpCount::kK == kNumWarpsPerBlock,
|
||||
"");
|
||||
|
||||
using DefaultEpilogue = typename DefaultGemm::Epilogue;
|
||||
using OutputTileIterator =
|
||||
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
|
||||
output_t>;
|
||||
using OutputTileIteratorAccum =
|
||||
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
|
||||
output_accum_t>;
|
||||
|
||||
struct SharedStorageMM1 {
|
||||
typename Mma::SharedStorage mm;
|
||||
};
|
||||
};
|
||||
|
||||
static constexpr int64_t kAlignmentQ = MM0::kAlignmentA;
|
||||
static constexpr int64_t kAlignmentK = MM0::kAlignmentB;
|
||||
static constexpr int64_t kAlignmentV = 1;
|
||||
|
||||
// Shared storage - depends on kernel params
|
||||
struct ScalingCoefs {
|
||||
cutlass::Array<accum_t, kQueriesPerBlock> m_prime;
|
||||
cutlass::Array<accum_t, kQueriesPerBlock> s_prime;
|
||||
cutlass::Array<accum_t, kQueriesPerBlock> mi;
|
||||
};
|
||||
|
||||
struct SharedStorageEpilogueAtEnd : ScalingCoefs {
|
||||
struct SharedStorageAfterMM0 {
|
||||
// Everything here might be overwritten during MM0
|
||||
typename MM0::AccumulatorSharedStorage si;
|
||||
typename MM1::SharedStorageMM1 mm1;
|
||||
};
|
||||
|
||||
union {
|
||||
typename MM0::Mma::SharedStorage mm0;
|
||||
SharedStorageAfterMM0 after_mm0;
|
||||
typename MM1::DefaultEpilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
|
||||
epilogue_shared_storage() {
|
||||
return epilogue;
|
||||
}
|
||||
};
|
||||
|
||||
struct SharedStorageEpilogueInLoop : ScalingCoefs {
|
||||
struct SharedStorageAfterMM0 {
|
||||
// Everything here might be overwritten during MM0
|
||||
typename MM0::AccumulatorSharedStorage si;
|
||||
typename MM1::SharedStorageMM1 mm1;
|
||||
typename MM1::DefaultEpilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
union {
|
||||
typename MM0::Mma::SharedStorage mm0;
|
||||
SharedStorageAfterMM0 after_mm0;
|
||||
};
|
||||
|
||||
CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
|
||||
epilogue_shared_storage() {
|
||||
return after_mm0.epilogue;
|
||||
}
|
||||
};
|
||||
|
||||
using SharedStorage = typename cutlass::platform::conditional<
|
||||
kSingleValueIteration || kKeepOutputInRF,
|
||||
SharedStorageEpilogueAtEnd,
|
||||
SharedStorageEpilogueInLoop>::type;
|
||||
|
||||
static bool __host__ check_supported(Params const& p) {
|
||||
CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ);
|
||||
CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK);
|
||||
CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV);
|
||||
XFORMERS_CHECK(
|
||||
p.q_strideM % kAlignmentQ == 0, "query is not correctly aligned");
|
||||
XFORMERS_CHECK(
|
||||
p.k_strideM % kAlignmentK == 0, "key is not correctly aligned");
|
||||
XFORMERS_CHECK(
|
||||
p.v_strideM % kAlignmentV == 0, "value is not correctly aligned");
|
||||
XFORMERS_CHECK(
|
||||
p.q_strideH % kAlignmentQ == 0, "query is not correctly aligned");
|
||||
XFORMERS_CHECK(
|
||||
p.k_strideH % kAlignmentK == 0, "key is not correctly aligned");
|
||||
XFORMERS_CHECK(
|
||||
p.v_strideH % kAlignmentV == 0, "value is not correctly aligned");
|
||||
return true;
|
||||
}
|
||||
|
||||
static void CUTLASS_DEVICE attention_kernel(Params& p) {
|
||||
// In this block, we will only ever:
|
||||
// - read query[query_start:query_end, :]
|
||||
// - write to output[query_start:query_end, :]
|
||||
|
||||
extern __shared__ char smem_buffer[];
|
||||
SharedStorage& shared_storage = *((SharedStorage*)smem_buffer);
|
||||
auto& m_prime = shared_storage.m_prime;
|
||||
auto& s_prime = shared_storage.s_prime;
|
||||
auto& si = shared_storage.after_mm0.si;
|
||||
auto& mi = shared_storage.mi;
|
||||
|
||||
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
|
||||
if (thread_id() < kQueriesPerBlock) {
|
||||
s_prime[thread_id()] = accum_t(0);
|
||||
m_prime[thread_id()] =
|
||||
-cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
mi[thread_id()] = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
}
|
||||
typename MM1::Mma::FragmentC accum_o;
|
||||
accum_o.clear();
|
||||
|
||||
auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator {
|
||||
using OutputTileIterator = typename MM1::OutputTileIterator;
|
||||
return OutputTileIterator(
|
||||
typename OutputTileIterator::Params{(int32_t)p.o_strideM()},
|
||||
p.output_ptr,
|
||||
typename OutputTileIterator::TensorCoord{
|
||||
p.num_queries, p.head_dim_value},
|
||||
thread_id(),
|
||||
{0, col});
|
||||
};
|
||||
|
||||
auto createOutputAccumIter = [&](int col) ->
|
||||
typename MM1::OutputTileIteratorAccum {
|
||||
using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum;
|
||||
return OutputTileIteratorAccum(
|
||||
typename OutputTileIteratorAccum::Params{(int32_t)p.o_strideM()},
|
||||
p.output_accum_ptr,
|
||||
typename OutputTileIteratorAccum::TensorCoord{
|
||||
p.num_queries, p.head_dim_value},
|
||||
thread_id(),
|
||||
{0, col});
|
||||
};
|
||||
|
||||
// Iterate through keys
|
||||
for (int32_t iter_key_start = 0; iter_key_start < p.num_keys;
|
||||
iter_key_start += kKeysPerBlock) {
|
||||
int32_t problem_size_0_m =
|
||||
cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries);
|
||||
int32_t problem_size_0_n = cutlass::fast_min(
|
||||
int32_t(kKeysPerBlock), p.num_keys - iter_key_start);
|
||||
int32_t const& problem_size_0_k = p.head_dim;
|
||||
int32_t const& problem_size_1_n = p.head_dim_value;
|
||||
int32_t const& problem_size_1_k = problem_size_0_n;
|
||||
|
||||
auto prologueV = [&](int blockN) {
|
||||
typename MM1::Mma::IteratorB iterator_V(
|
||||
typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)},
|
||||
p.value_ptr + iter_key_start * p.v_strideM,
|
||||
{problem_size_1_k, problem_size_1_n},
|
||||
thread_id(),
|
||||
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
|
||||
MM1::Mma::prologue(
|
||||
shared_storage.after_mm0.mm1.mm,
|
||||
iterator_V,
|
||||
thread_id(),
|
||||
problem_size_1_k);
|
||||
};
|
||||
|
||||
__syncthreads(); // Need to have shared memory initialized, and `m_prime`
|
||||
// updated from end of prev iter
|
||||
//
|
||||
// MATMUL: Q.K_t
|
||||
//
|
||||
// Computes the block-matrix product of:
|
||||
// (a) query[query_start:query_end, :]
|
||||
// with
|
||||
// (b) key[iter_key_start:iter_key_start + kKeysPerBlock]
|
||||
// and stores that into `shared_storage.si`
|
||||
//
|
||||
|
||||
// Compute threadblock location
|
||||
cutlass::gemm::GemmCoord tb_tile_offset = {0, 0, 0};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
tb_tile_offset.m() * MM0::Mma::Shape::kM, tb_tile_offset.k()};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{
|
||||
tb_tile_offset.k(), tb_tile_offset.n() * MM0::Mma::Shape::kN};
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename MM0::IteratorA iterator_A(
|
||||
typename MM0::IteratorA::Params(
|
||||
typename MM0::MmaCore::LayoutA(p.q_strideM)),
|
||||
p.query_ptr,
|
||||
{problem_size_0_m, problem_size_0_k},
|
||||
thread_id(),
|
||||
tb_offset_A);
|
||||
|
||||
typename MM0::IteratorB iterator_B(
|
||||
typename MM0::IteratorB::Params(
|
||||
typename MM0::MmaCore::LayoutB(p.k_strideM)),
|
||||
p.key_ptr + iter_key_start * p.k_strideM,
|
||||
{problem_size_0_k, problem_size_0_n},
|
||||
thread_id(),
|
||||
tb_offset_B);
|
||||
|
||||
auto my_warp_id = warp_id();
|
||||
auto my_lane_id = lane_id();
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
typename MM0::Mma mma(
|
||||
shared_storage.mm0, thread_id(), my_warp_id, my_lane_id);
|
||||
|
||||
typename MM0::Mma::FragmentC accum;
|
||||
|
||||
accum.clear();
|
||||
|
||||
auto gemm_k_iterations =
|
||||
(problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK;
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
|
||||
__syncthreads();
|
||||
|
||||
if (kPreloadV) {
|
||||
prologueV(0);
|
||||
}
|
||||
|
||||
typename MM0::Mma::Operator::IteratorC::TensorCoord
|
||||
iteratorC_tile_offset = {
|
||||
(tb_tile_offset.m() * MM0::Mma::WarpCount::kM) +
|
||||
(my_warp_id % MM0::Mma::WarpCount::kM),
|
||||
(tb_tile_offset.n() * MM0::Mma::WarpCount::kN) +
|
||||
(my_warp_id / MM0::Mma::WarpCount::kM)};
|
||||
|
||||
// Mask out last if causal
|
||||
if (p.causal && p.num_keys - iter_key_start <= kKeysPerBlock) {
|
||||
auto query_start = blockIdx.x * kQueriesPerBlock;
|
||||
auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset(
|
||||
lane_id(), warp_id(), iteratorC_tile_offset);
|
||||
int32_t last_col;
|
||||
MM0::ScalingCoefsUpdater::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) {
|
||||
last_col = query_start + accum_m - iter_key_start;
|
||||
},
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
if (accum_n > last_col) {
|
||||
accum[idx] =
|
||||
-cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
}
|
||||
},
|
||||
[&](int accum_m) {});
|
||||
}
|
||||
DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
|
||||
DISPATCH_BOOL(
|
||||
p.num_keys - iter_key_start >= kKeysPerBlock,
|
||||
kFullColumns,
|
||||
([&] {
|
||||
// Update `mi` from accum stored in registers
|
||||
// Also updates `accum` with accum[i] <-
|
||||
// exp(accum[i] * scale
|
||||
// - mi)
|
||||
MM0::ScalingCoefsUpdater::update<
|
||||
kQueriesPerBlock,
|
||||
kFullColumns,
|
||||
kIsFirst,
|
||||
kKeepOutputInRF>(
|
||||
accum_o,
|
||||
accum,
|
||||
mi,
|
||||
m_prime,
|
||||
s_prime,
|
||||
lane_id(),
|
||||
thread_id(),
|
||||
warp_id(),
|
||||
p.num_keys - iter_key_start,
|
||||
iteratorC_tile_offset,
|
||||
1.0f / cutlass::fast_sqrt(float(p.head_dim)));
|
||||
}));
|
||||
}));
|
||||
|
||||
// Output results to shared-memory
|
||||
int warp_idx_mn_0 = my_warp_id %
|
||||
(MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN);
|
||||
auto output_tile_coords = cutlass::MatrixCoord{
|
||||
warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM,
|
||||
warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM};
|
||||
|
||||
MM0::B2bGemm::accumToSmem(
|
||||
shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
// MATMUL: Attn . V
|
||||
// Run the matmul `attn @ V` for a block of attn and V.
|
||||
// `attn` is read from shared memory (in `shared_storage_si`)
|
||||
// `V` is read from global memory (with iterator_B)
|
||||
//
|
||||
|
||||
const int64_t nBlockN = kSingleValueIteration
|
||||
? 1
|
||||
: ceil_div(
|
||||
(int64_t)problem_size_1_n, int64_t(MM1::ThreadblockShape::kN));
|
||||
for (int blockN = 0; blockN < nBlockN; ++blockN) {
|
||||
int gemm_k_iterations =
|
||||
(problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK;
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add and store it in accum
|
||||
// (in registers)
|
||||
if (!kPreloadV) {
|
||||
__syncthreads(); // we share shmem between mma and epilogue
|
||||
}
|
||||
|
||||
typename MM1::Mma::IteratorB iterator_V(
|
||||
typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)},
|
||||
p.value_ptr + iter_key_start * p.v_strideM,
|
||||
{problem_size_1_k, problem_size_1_n},
|
||||
thread_id(),
|
||||
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
|
||||
typename MM1::Mma mma_pv(
|
||||
shared_storage.after_mm0.mm1.mm,
|
||||
shared_storage.after_mm0.si,
|
||||
(int)thread_id(),
|
||||
(int)warp_id(),
|
||||
(int)lane_id(),
|
||||
(int)problem_size_1_k);
|
||||
mma_pv.set_prologue_done(kPreloadV);
|
||||
if (!kKeepOutputInRF) {
|
||||
accum_o.clear();
|
||||
}
|
||||
mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o);
|
||||
__syncthreads();
|
||||
|
||||
if (kPreloadV && !kSingleValueIteration && blockN + 1 < nBlockN) {
|
||||
prologueV(blockN + 1);
|
||||
}
|
||||
|
||||
if (!kKeepOutputInRF) {
|
||||
DISPATCH_BOOL(
|
||||
iter_key_start == 0, kIsFirst, ([&] {
|
||||
DISPATCH_BOOL(
|
||||
(iter_key_start + kKeysPerBlock) >= p.num_keys,
|
||||
kIsLast,
|
||||
([&] {
|
||||
using DefaultEpilogue = typename MM1::DefaultEpilogue;
|
||||
using DefaultOp =
|
||||
typename MM1::DefaultConfig::EpilogueOutputOp;
|
||||
using ElementCompute = typename DefaultOp::ElementCompute;
|
||||
using EpilogueOutputOp = typename cutlass::epilogue::
|
||||
thread::MemoryEfficientAttentionNormalize<
|
||||
typename cutlass::platform::conditional<
|
||||
kIsLast,
|
||||
output_t,
|
||||
output_accum_t>::type,
|
||||
output_accum_t,
|
||||
DefaultOp::kCount,
|
||||
typename DefaultOp::ElementAccumulator,
|
||||
ElementCompute,
|
||||
kIsFirst,
|
||||
kIsLast,
|
||||
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::
|
||||
EpiloguePipelined<
|
||||
typename DefaultEpilogue::Shape,
|
||||
typename MM1::Mma::Operator,
|
||||
DefaultEpilogue::kPartitionsK,
|
||||
typename cutlass::platform::conditional<
|
||||
kIsLast,
|
||||
typename MM1::OutputTileIterator,
|
||||
typename MM1::OutputTileIteratorAccum>::type,
|
||||
typename DefaultEpilogue::
|
||||
AccumulatorFragmentIterator,
|
||||
typename DefaultEpilogue::WarpTileIterator,
|
||||
typename DefaultEpilogue::SharedLoadIterator,
|
||||
EpilogueOutputOp,
|
||||
typename DefaultEpilogue::Padding,
|
||||
DefaultEpilogue::kFragmentsPerIteration,
|
||||
true, // IterationsUnroll
|
||||
typename MM1::OutputTileIteratorAccum // Read
|
||||
// iterator
|
||||
>;
|
||||
|
||||
int col = blockN * MM1::Mma::Shape::kN;
|
||||
auto source_iter = createOutputAccumIter(col);
|
||||
auto dest_iter = call_conditional<
|
||||
kIsLast,
|
||||
decltype(createOutputIter),
|
||||
decltype(createOutputAccumIter)>::
|
||||
apply(createOutputIter, createOutputAccumIter, col);
|
||||
EpilogueOutputOp rescale(s_prime, m_prime);
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue_shared_storage(),
|
||||
thread_id(),
|
||||
warp_id(),
|
||||
lane_id());
|
||||
epilogue(rescale, dest_iter, accum_o, source_iter);
|
||||
}));
|
||||
}));
|
||||
if (!kSingleValueIteration) {
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads(); // we modify `m_prime` after
|
||||
}
|
||||
|
||||
if (kKeepOutputInRF) {
|
||||
constexpr bool kIsFirst = true;
|
||||
constexpr bool kIsLast = true;
|
||||
using DefaultEpilogue = typename MM1::DefaultEpilogue;
|
||||
using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp;
|
||||
using ElementCompute = typename DefaultOp::ElementCompute;
|
||||
using EpilogueOutputOp =
|
||||
typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize<
|
||||
output_t, // output
|
||||
output_accum_t, // source
|
||||
DefaultOp::kCount,
|
||||
typename DefaultOp::ElementAccumulator, // accum
|
||||
output_accum_t, // compute
|
||||
kIsFirst,
|
||||
kIsLast,
|
||||
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
|
||||
using Epilogue =
|
||||
typename cutlass::epilogue::threadblock::EpiloguePipelined<
|
||||
typename DefaultEpilogue::Shape,
|
||||
typename MM1::Mma::Operator,
|
||||
DefaultEpilogue::kPartitionsK,
|
||||
typename MM1::OutputTileIterator, // destination
|
||||
typename DefaultEpilogue::AccumulatorFragmentIterator,
|
||||
typename DefaultEpilogue::WarpTileIterator,
|
||||
typename DefaultEpilogue::SharedLoadIterator,
|
||||
EpilogueOutputOp,
|
||||
typename DefaultEpilogue::Padding,
|
||||
DefaultEpilogue::kFragmentsPerIteration,
|
||||
true, // IterationsUnroll
|
||||
typename MM1::OutputTileIteratorAccum // source tile
|
||||
>;
|
||||
auto dest_iter = createOutputIter(0);
|
||||
EpilogueOutputOp rescale(s_prime, m_prime);
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue_shared_storage(),
|
||||
thread_id(),
|
||||
warp_id(),
|
||||
lane_id());
|
||||
epilogue(rescale, dest_iter, accum_o);
|
||||
}
|
||||
|
||||
// 7. Calculate logsumexp
|
||||
// To make the backward easier, we pad logsumexp with `inf`
|
||||
// this avoids a few bound checks, and is not more expensive during fwd
|
||||
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
|
||||
if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) {
|
||||
auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE;
|
||||
if (thread_id() < p.num_queries) {
|
||||
p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()]) +
|
||||
cutlass::fast_log(accum_t(s_prime[thread_id()]));
|
||||
} else if (thread_id() < lse_dim) {
|
||||
p.logsumexp_ptr[thread_id()] =
|
||||
cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static CUTLASS_DEVICE int8_t lane_id() {
|
||||
return threadIdx.x;
|
||||
}
|
||||
static CUTLASS_DEVICE int8_t warp_id() {
|
||||
return threadIdx.y;
|
||||
}
|
||||
static CUTLASS_DEVICE int16_t thread_id() {
|
||||
return threadIdx.x + threadIdx.y * blockDim.x;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AK>
|
||||
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
|
||||
attention_kernel_batched_impl(typename AK::Params p) {
|
||||
if (!p.advance_to_block()) {
|
||||
return;
|
||||
}
|
||||
AK::attention_kernel(p);
|
||||
}
|
||||
|
||||
template <typename AK>
|
||||
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
|
||||
attention_kernel_batched(typename AK::Params params);
|
||||
|
||||
#define _ATTENTION_KERNEL_FORWARD_BEGIN(...) \
|
||||
template <> \
|
||||
__global__ void __launch_bounds__( \
|
||||
__VA_ARGS__::kNumThreads, __VA_ARGS__::kMinBlocksPerSm) \
|
||||
attention_kernel_batched<__VA_ARGS__>(typename __VA_ARGS__::Params p) { \
|
||||
using Kernel = __VA_ARGS__;
|
||||
#define _ATTENTION_KERNEL_FORWARD_END() }
|
||||
|
||||
#ifdef __CUDA_ARCH__
|
||||
#define __CUDA_ARCH_OR_ZERO__ __CUDA_ARCH__
|
||||
#else
|
||||
#define __CUDA_ARCH_OR_ZERO__ 0
|
||||
#endif
|
||||
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD( \
|
||||
ARCH, \
|
||||
SCALAR_T, \
|
||||
IS_ALIGNED, \
|
||||
QUERIES_PER_BLOCK, \
|
||||
KEYS_PER_BLOCK, \
|
||||
SINGLE_VALUE_ITER) \
|
||||
_ATTENTION_KERNEL_FORWARD_BEGIN(AttentionKernel< \
|
||||
SCALAR_T, \
|
||||
cutlass::arch::Sm##ARCH, \
|
||||
IS_ALIGNED, \
|
||||
QUERIES_PER_BLOCK, \
|
||||
KEYS_PER_BLOCK, \
|
||||
SINGLE_VALUE_ITER>) \
|
||||
if (!p.advance_to_block()) { \
|
||||
return; \
|
||||
} \
|
||||
Kernel::attention_kernel(p); \
|
||||
_ATTENTION_KERNEL_FORWARD_END();
|
||||
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED( \
|
||||
ARCH, \
|
||||
SCALAR_T, \
|
||||
IS_ALIGNED, \
|
||||
QUERIES_PER_BLOCK, \
|
||||
KEYS_PER_BLOCK, \
|
||||
SINGLE_VALUE_ITER) \
|
||||
_ATTENTION_KERNEL_FORWARD_BEGIN(AttentionKernel< \
|
||||
SCALAR_T, \
|
||||
cutlass::arch::Sm##ARCH, \
|
||||
IS_ALIGNED, \
|
||||
QUERIES_PER_BLOCK, \
|
||||
KEYS_PER_BLOCK, \
|
||||
SINGLE_VALUE_ITER>) \
|
||||
printf( \
|
||||
"FATAL: this function is for sm%d, but was built for sm%d\n", \
|
||||
int(ARCH), \
|
||||
int(__CUDA_ARCH_OR_ZERO__)); \
|
||||
_ATTENTION_KERNEL_FORWARD_END();
|
||||
|
||||
// All kernels are disabled by default
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(50, __VA_ARGS__)
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(70, __VA_ARGS__)
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(75, __VA_ARGS__)
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(80, __VA_ARGS__)
|
||||
|
||||
// Enable the right one based on __CUDA_ARCH__
|
||||
#ifndef __CUDA_ARCH__
|
||||
#elif __CUDA_ARCH__ < 500
|
||||
#error "Need cuda arch at least 5.0"
|
||||
#elif __CUDA_ARCH__ < 700
|
||||
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD(50, __VA_ARGS__)
|
||||
#elif __CUDA_ARCH__ < 750
|
||||
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD(70, __VA_ARGS__)
|
||||
#elif __CUDA_ARCH__ < 800
|
||||
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD(75, __VA_ARGS__)
|
||||
#elif __CUDA_ARCH__ >= 800
|
||||
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD(80, __VA_ARGS__)
|
||||
#endif
|
||||
1780
examples/42_fused_multi_head_attention/mma_from_smem.h
Normal file
1780
examples/42_fused_multi_head_attention/mma_from_smem.h
Normal file
File diff suppressed because it is too large
Load Diff
@ -120,6 +120,7 @@ foreach(EXAMPLE
|
||||
38_syr2k_grouped
|
||||
39_gemm_permute
|
||||
41_multi_head_attention
|
||||
42_fused_multi_head_attention
|
||||
)
|
||||
|
||||
add_subdirectory(${EXAMPLE})
|
||||
|
||||
@ -574,6 +574,21 @@ using std::is_trivially_copyable;
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// bit_cast <bit>
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
template< class To, class From >
|
||||
constexpr To CUTLASS_HOST_DEVICE bit_cast(const From& from ) noexcept;
|
||||
|
||||
template <class To, class From>
|
||||
constexpr To CUTLASS_HOST_DEVICE bit_cast(const From& src) noexcept
|
||||
{
|
||||
static_assert(sizeof(To) == sizeof(From), "sizes must match");
|
||||
return reinterpret_cast<To const &>(src);
|
||||
}
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Alignment and layout utilities
|
||||
//-----------------------------------------------------------------------------
|
||||
@ -865,5 +880,13 @@ struct numeric_limits<uint8_t> {
|
||||
static constexpr bool is_integer = true;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<float> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
static constexpr float infinity() noexcept { return bit_cast<float, int32_t>(0x7f800000);}
|
||||
static constexpr bool is_integer = false;
|
||||
static constexpr bool has_infinity = true;
|
||||
};
|
||||
|
||||
} // namespace platform
|
||||
} // namespace cutlass
|
||||
|
||||
Reference in New Issue
Block a user