fMHA: Sync FW with xFormers (#828)

* fMHA: Add support for bias+dropout in FW

* Remove 'getMaximumSharedMemoryPerBlockKb'

* fix comments

---------

Co-authored-by: danthe3rd <danthe3rd>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
dan_the_3rd
2023-02-23 05:25:31 +01:00
committed by GitHub
parent 9cdbe33570
commit f303889ed9
12 changed files with 999 additions and 254 deletions

View File

@ -50,12 +50,17 @@
#if 1
#define PRINT_WARP_ID 0
#define PRINT_LANE_ID 0
#define PRINT_T0_L0(msg, ...) \
#define PRINT_B0_T0(msg, ...) \
if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && \
threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \
threadIdx.z == 0) { \
printf(msg "\n", ##__VA_ARGS__); \
}
#define PRINT_T0(msg, ...) \
if (threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \
threadIdx.z == 0) { \
printf(msg "\n", ##__VA_ARGS__); \
}
#define PRINT_TX_LX(msg, ...) \
for (int bx = 0; bx < gridDim.x; ++bx) { \
for (int by = 0; by < gridDim.y; ++by) { \
@ -84,7 +89,7 @@
} \
}
#else
#define PRINT_T0_L0
#define PRINT_B0_T0
#define PRINT_TX_LX
#endif
@ -124,7 +129,7 @@ constexpr __string_view __get_type_name() {
// Print a given array
#define PRINT_ACCUM8_T0_L0_START(name, accum, start) \
PRINT_T0_L0( \
PRINT_B0_T0( \
"%s[%d:%d] - {%f, %f, %f, %f, %f, %f, %f, %f}", \
name, \
int(start), \
@ -141,7 +146,7 @@ constexpr __string_view __get_type_name() {
#define PRINT_FRAG_T0_L0(name, frag) \
{ \
auto typeStr = __get_type_name<decltype(frag)>(); \
PRINT_T0_L0("printing %s (%s)", name, typeStr.data); \
PRINT_B0_T0("printing %s (%s)", name, typeStr.data); \
for (int _start = 0; _start < frag.size(); _start += 8) { \
PRINT_ACCUM8_T0_L0_START(" ", frag, _start); \
} \
@ -150,7 +155,7 @@ constexpr __string_view __get_type_name() {
}
#define PRINT_ARRAY_T0_L0_INCR(name, array, length, incr) \
{ \
PRINT_T0_L0("printing %s (len=%d)", name, int(length)); \
PRINT_B0_T0("printing %s (len=%d)", name, int(length)); \
for (int _start = 0; _start < length; _start += incr) { \
PRINT_ACCUM8_T0_L0_START(" ", array, _start); \
} \
@ -160,7 +165,7 @@ constexpr __string_view __get_type_name() {
// Print a 4x4 matrix
#define PRINT_TENSOR4x4_T0_L0_START(name, ref, start_x, start_y) \
PRINT_T0_L0( \
PRINT_B0_T0( \
"%s[%d:%d, %d:%d]:\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f", \
name, \
int(start_x), \
@ -187,9 +192,43 @@ constexpr __string_view __get_type_name() {
PRINT_TENSOR4x4_T0_L0_START(name, ref, 0, 0)
#define PRINT_PROBLEM_SIZE(name, ps) \
PRINT_T0_L0( \
PRINT_B0_T0( \
"%s.problem_size: {.m=%d, .n=%d, .k=%d}", \
name, \
int(ps.m()), \
int(ps.n()), \
int(ps.k()))
template <typename LambdaIterator, typename LaneOffsetT, typename AccumT>
CUTLASS_DEVICE void print_warp_accum(
AccumT accum,
LaneOffsetT lane_offset,
int32_t num_rows,
int32_t num_cols) {
bool is_main = blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 &&
threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0;
for (int row = 0; row < num_rows; ++row) {
for (int col = 0; col < num_cols; ++col) {
if (col % 32 == 0) {
if (is_main) {
printf("\nmat[%3d, %3d:%3d]", row, col, col + 32);
}
__syncthreads();
}
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {},
[&](int accum_m, int accum_n, int idx) {
if (row == accum_m && col == accum_n &&
(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)) {
printf(" %6.1f", float(accum[idx]));
}
},
[&](int accum_m) {});
__syncthreads();
}
if (is_main) {
printf("\n");
}
}
}

View File

@ -50,9 +50,8 @@
#include "fmha_grouped.h"
#include "gemm_kernel_utils.h"
#include "find_default_mma.h"
#include "attention_scaling_coefs_updater.h"
#include "mma_from_smem.h"
#include "gemm/find_default_mma.h"
#include "gemm/mma_from_smem.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -154,10 +153,10 @@ struct DefaultFMHAGrouped {
using IteratorA = typename DefaultMma::IteratorA;
using IteratorB = typename DefaultMma::IteratorB;
using Mma = typename DefaultMma::ThreadblockMma;
using ScalingCoefsUpdater = typename DefaultAttentionScalingCoefsUpdater<
using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
typename Mma::Operator::IteratorC,
ElementAccumulator,
kWarpSize>::Updater;
kWarpSize>::Iterator;
static_assert(MmaCore::WarpCount::kCount == kNumWarpsPerBlock, "");
@ -240,7 +239,8 @@ struct DefaultFMHAGrouped {
using DefaultMmaFromSmem =
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
typename DefaultGemm::Mma,
typename MM0::AccumulatorSharedStorage>;
typename MM0::AccumulatorSharedStorage,
false>; // kScaleOperandA
using Mma = typename DefaultMmaFromSmem::Mma;
using IteratorB = typename Mma::IteratorB;

View File

@ -48,7 +48,18 @@
#include "fmha_grouped_problem_visitor.h"
#include "gemm_kernel_utils.h"
#include "epilogue_rescale_output.h"
#include "gemm/mma_accum_lambda_iterator.h"
#include "epilogue/epilogue_rescale_output.h"
namespace {
static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) {
// source: https://stackoverflow.com/a/51549250
return (value >= 0)
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
: __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -128,6 +139,9 @@ public:
static int const kQueriesPerBlock = ThreadblockShape::kM;
static int const kKeysPerBlock = ThreadblockShape::kN;
static constexpr bool kSupportsDropout = false;
static constexpr bool kSupportsBias = false;
/// Warp count (concept: GemmShape)
using WarpCount = typename MM1::WarpCount;
static int const kThreadsPerWarp = 32;
@ -619,10 +633,10 @@ public:
// Mask out last if causal
if (params.causal && num_keys - iter_key_start <= kKeysPerBlock) {
auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset(
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
lane_id(), warp_id(), iteratorC_tile_offset);
int32_t last_col;
MM0::ScalingCoefsUpdater::iterateRows(
MM0::AccumLambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {
last_col = TileParams::query_start(threadblock_idx) + accum_m - iter_key_start;
@ -641,14 +655,11 @@ public:
kFullColumns,
([&] {
// Update `mi` from accum stored in registers
// Also updates `accum` with accum[i] <-
// exp(accum[i] * scale
// - mi)
MM0::ScalingCoefsUpdater::update<
kQueriesPerBlock,
// Also does accum[i] <- exp(accum[i] - mi)
iterative_softmax<
typename MM0::Mma::Operator::IteratorC,
kFullColumns,
kIsFirst,
kKeepOutputInRF>(
kIsFirst>(
accum_o,
accum,
mi,
@ -659,7 +670,7 @@ public:
warp_id(),
num_keys - iter_key_start,
iteratorC_tile_offset,
params.scale);
kSupportsBias ? 1.0f : params.scale);
}));
}));
@ -838,6 +849,116 @@ public:
problem_visitor.advance(gridDim.x);
}
}
template <
typename WarpIteratorC,
bool kFullColumns,
bool kIsFirst>
CUTLASS_DEVICE static void iterative_softmax(
typename WarpIteratorC::Fragment& frag_o, // output so far
typename WarpIteratorC::Fragment& frag,
cutlass::Array<accum_t, kQueriesPerBlock>& mi,
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
int8_t lane_id,
int8_t thread_id,
int8_t warp_id,
int16_t max_col,
typename WarpIteratorC::TensorCoord const& tile_offset,
float scaling) {
/* Iterates on the accumulator and corresponding position on result matrix
(1) Update `mi[r]` to the max value of the row `r`
(2) In a second iteration do the following:
(a) accum <- exp(accum - mi)
(b) m_prime <- exp(m_prime - mi)
(c) s_prime <- s_prime * m_prime + sum(accum)
All of this is done on registers, before we store all of this
on shared memory for the next matmul with Value.
*/
using Fragment = typename WarpIteratorC::Fragment;
using LambdaIterator = typename DefaultMmaAccumLambdaIterator<
WarpIteratorC,
accum_t,
kThreadsPerWarp>::Iterator;
// Convert to `accum_t` (rather than double)
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
if (!kIsFirst) {
if (thread_id < kQueriesPerBlock) {
m_prime[thread_id] = mi[thread_id];
}
__syncthreads();
}
auto lane_offset =
LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset);
// First update `mi` to the max per-row
{
accum_t max;
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {
max = -cutlass::platform::numeric_limits<accum_t>::infinity();
},
[&](int accum_m, int accum_n, int idx) {
if (kFullColumns || accum_n < max_col) {
max = cutlass::fast_max(max, frag[idx]);
}
},
[&](int accum_m) {
// Having 4x atomicMax seems faster than reduce within warp
// first...
atomicMaxFloat(&mi[accum_m], max * scaling);
});
}
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
// Make sure we all share the update values for `mi`
__syncthreads();
if (thread_id < kQueriesPerBlock) {
auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id]));
m_prime[thread_id] = m_prime_exp;
s_prime[thread_id] *= m_prime_exp;
}
__syncthreads(); // Update output fragments
if (kKeepOutputInRF && !kIsFirst) {
accum_t mp;
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) { mp = m_prime[accum_m]; },
[&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; },
[&](int accum_m) {});
__syncthreads();
}
// Update accum_m, accum_n, ...
{
accum_t mi_row, total_row;
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) { mi_row = kLog2e * mi[accum_m]; },
[&](int accum_m, int accum_n, int idx) {
frag[idx] = (kFullColumns || accum_n < max_col)
? exp2f(frag[idx] - mi_row)
: accum_t(0.0);
},
[&](int accum_m) {});
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) { total_row = 0.0; },
[&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; },
[&](int accum_m) {
if (LambdaIterator::reduceSameRow(
lane_id, total_row, [](accum_t a, accum_t b) {
return a + b;
})) {
atomicAdd(&s_prime[accum_m], total_row);
}
});
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -856,7 +856,9 @@ public:
p.head_dim_value = options.head_size_v;
p.num_queries = options.seq_length;
p.num_keys = options.seq_length_kv;
p.causal = options.causal;
if (options.causal) {
p.custom_mask_type = Attention::CausalFromTopLeft;
}
// All tensors are in BMHK shapes
p.q_strideH = options.head_size;
@ -868,6 +870,7 @@ public:
p.q_strideB = p.q_strideM * options.seq_length;
p.k_strideB = p.k_strideM * options.seq_length_kv;
p.v_strideB = p.v_strideM * options.seq_length_kv;
p.o_strideM = p.head_dim_value * p.num_heads;
}
// launch kernel :)
@ -1005,7 +1008,9 @@ int run_attention(Options& options) {
true, // Memory is aligned
kQueriesPerBlock,
kKeysPerBlock,
kSingleValueIteration
kSingleValueIteration,
false, // Supports dropout
false // Supports bias
>;
//

View File

@ -42,6 +42,8 @@
This is really only for the FastF32 case - aka using TensorCores with fp32.
*/
#pragma once
#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"

View File

@ -36,137 +36,15 @@
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h"
#include "cutlass/matrix_shape.h"
#include "gemm_kernel_utils.h"
namespace {
static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) {
// source: https://stackoverflow.com/a/51549250
return (value >= 0)
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
: __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
}
} // namespace
/* Iterates on the accumulator and corresponding position on result matrix
(1) Update `mi[r]` to the max value of the row `r`
(2) In a second iteration do the following:
(a) accum <- exp(accum - mi)
(b) m_prime <- exp(m_prime - mi)
(c) s_prime <- s_prime * m_prime + sum(accum)
All of this is done on registers, before we store all of this
on shared memory for the next matmul with Value.
We have multiple implementations, because each configuration has a different way
of iterating in the accumulators.
/*
TensorCores have different accumulator layouts.
This file provides a class to easily map the accumulator
i-th element with the corresponding matrix row/col.
*/
template <typename BASE, typename T, typename accum_t, int kWarpSize>
struct RegisterOps {
template <
int kQueriesPerBlock,
bool kFullColumns,
bool kIsFirst,
bool kKeepOutputInRF>
CUTLASS_DEVICE static void update(
typename T::Fragment& frag_o, // output so far
typename T::Fragment& frag,
cutlass::Array<accum_t, kQueriesPerBlock>& mi,
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
int8_t lane_id,
int8_t thread_id,
int8_t warp_id,
int16_t max_col,
typename T::TensorCoord const& tile_offset,
float scaling) {
// Convert to `accum_t` (rather than double)
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
if (!kIsFirst) {
if (thread_id < kQueriesPerBlock) {
m_prime[thread_id] = mi[thread_id];
}
__syncthreads();
}
auto lane_offset = BASE::get_lane_offset(lane_id, warp_id, tile_offset);
// First update `mi` to the max per-row
{
accum_t max;
BASE::iterateRows(
lane_offset,
[&](int accum_m) {
max = -cutlass::platform::numeric_limits<accum_t>::infinity();
},
[&](int accum_m, int accum_n, int idx) {
if (kFullColumns || accum_n < max_col) {
max = cutlass::fast_max(max, frag[idx]);
}
},
[&](int accum_m) {
// Having 4x atomicMax seems faster than reduce within warp
// first...
atomicMaxFloat(&mi[accum_m], max * scaling);
});
}
frag = cutlass::multiplies<typename T::Fragment>()(scaling * kLog2e, frag);
// Make sure we all share the update values for `mi`
__syncthreads();
if (thread_id < kQueriesPerBlock) {
auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id]));
m_prime[thread_id] = m_prime_exp;
s_prime[thread_id] *= m_prime_exp;
}
__syncthreads(); // Update output fragments
if (kKeepOutputInRF && !kIsFirst) {
accum_t mp;
BASE::iterateRows(
lane_offset,
[&](int accum_m) { mp = m_prime[accum_m]; },
[&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; },
[&](int accum_m) {});
__syncthreads();
}
// Update accum_m, accum_n, ...
{
accum_t mi_row, total_row;
BASE::iterateRows(
lane_offset,
[&](int accum_m) { mi_row = kLog2e * mi[accum_m]; },
[&](int accum_m, int accum_n, int idx) {
frag[idx] = (kFullColumns || accum_n < max_col)
? exp2f(frag[idx] - mi_row)
: accum_t(0.0);
},
[&](int accum_m) {});
BASE::iterateRows(
lane_offset,
[&](int accum_m) { total_row = 0.0; },
[&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; },
[&](int accum_m) {
if (BASE::reduceSameRow(
lane_id, total_row, [](accum_t a, accum_t b) {
return a + b;
})) {
atomicAdd(&s_prime[accum_m], total_row);
}
});
}
}
};
template <typename T, typename accum_t, int kWarpSize>
struct AttentionScalingCoefsUpdaterSm80
: RegisterOps<
AttentionScalingCoefsUpdaterSm80<T, accum_t, kWarpSize>,
T,
accum_t,
kWarpSize> {
struct AccumLambdaIteratorSm80 {
static_assert(
cutlass::platform::
is_same<typename T::Layout, cutlass::layout::RowMajor>::value,
@ -239,12 +117,7 @@ struct AttentionScalingCoefsUpdaterSm80
};
template <typename T, typename accum_t, int kWarpSize>
struct AttentionScalingCoefsUpdaterVolta
: RegisterOps<
AttentionScalingCoefsUpdaterVolta<T, accum_t, kWarpSize>,
T,
accum_t,
kWarpSize> {
struct AccumLambdaIteratorSm70 {
static_assert(
cutlass::platform::
is_same<typename T::Layout, cutlass::layout::RowMajor>::value,
@ -357,12 +230,7 @@ struct AttentionScalingCoefsUpdaterVolta
};
template <typename T, typename accum_t, int kWarpSize>
struct AttentionScalingCoefsUpdaterSimt
: RegisterOps<
AttentionScalingCoefsUpdaterSimt<T, accum_t, kWarpSize>,
T,
accum_t,
kWarpSize> {
struct AccumLambdaIteratorSimt {
using Policy = typename T::Policy;
using Iterations = typename T::Iterations;
using Element = typename T::Element;
@ -436,11 +304,11 @@ struct AttentionScalingCoefsUpdaterSimt
};
template <typename T, typename accum_t, int kWarpSize>
struct DefaultAttentionScalingCoefsUpdater;
struct DefaultMmaAccumLambdaIterator;
// Simt
template <typename S, typename P, typename accum_t, int kWarpSize>
struct DefaultAttentionScalingCoefsUpdater<
struct DefaultMmaAccumLambdaIterator<
cutlass::gemm::warp::MmaSimtTileIterator<
S,
cutlass::gemm::Operand::kC,
@ -451,7 +319,7 @@ struct DefaultAttentionScalingCoefsUpdater<
1>,
accum_t,
kWarpSize> {
using Iterator = typename cutlass::gemm::warp::MmaSimtTileIterator<
using WarpIterator = typename cutlass::gemm::warp::MmaSimtTileIterator<
S,
cutlass::gemm::Operand::kC,
accum_t,
@ -459,13 +327,12 @@ struct DefaultAttentionScalingCoefsUpdater<
P,
1,
1>;
using Updater =
AttentionScalingCoefsUpdaterSimt<Iterator, accum_t, kWarpSize>;
using Iterator = AccumLambdaIteratorSimt<WarpIterator, accum_t, kWarpSize>;
};
// TensorOp - Volta
template <typename S1, typename S2, typename accum_t, int kWarpSize>
struct DefaultAttentionScalingCoefsUpdater<
struct DefaultMmaAccumLambdaIterator<
cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator<
S1,
accum_t,
@ -474,15 +341,14 @@ struct DefaultAttentionScalingCoefsUpdater<
cutlass::MatrixShape<1, 1>>,
accum_t,
kWarpSize> {
using Iterator =
using WarpIterator =
typename cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator<
S1,
accum_t,
cutlass::layout::RowMajor,
S2,
cutlass::MatrixShape<1, 1>>;
using Updater =
AttentionScalingCoefsUpdaterVolta<Iterator, accum_t, kWarpSize>;
using Iterator = AccumLambdaIteratorSm70<WarpIterator, accum_t, kWarpSize>;
};
// TensorOp - Sm75+
@ -492,7 +358,7 @@ template <
typename S3,
typename accum_t,
int kWarpSize>
struct DefaultAttentionScalingCoefsUpdater<
struct DefaultMmaAccumLambdaIterator<
cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator<
S1,
accum_t,
@ -501,13 +367,12 @@ struct DefaultAttentionScalingCoefsUpdater<
S3>,
accum_t,
kWarpSize> {
using Iterator =
using WarpIterator =
typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator<
S1,
accum_t,
cutlass::layout::RowMajor,
S2,
S3>;
using Updater =
AttentionScalingCoefsUpdaterSm80<Iterator, accum_t, kWarpSize>;
using Iterator = AccumLambdaIteratorSm80<WarpIterator, accum_t, kWarpSize>;
};

View File

@ -43,22 +43,26 @@
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
#include "cutlass/functional.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
#include "cutlass/platform/platform.h"
#include "cutlass/transform/threadblock/vector_iterator.h"
#include "attention_scaling_coefs_updater.h"
#include "../epilogue/epilogue_thread_apply_logsumexp.h"
#include "../gemm/mma_accum_lambda_iterator.h"
#include "../gemm_kernel_utils.h"
#include "../iterators/make_residual_last.h"
#include "../iterators/transpose_warp_iterator.h"
#include "../iterators/warp_iterator_from_smem.h"
#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h"
#include "cutlass/gemm/threadblock/mma_base.h"
#include "cutlass/gemm/threadblock/mma_multistage.h"
#include "cutlass/gemm/threadblock/mma_pipelined.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h"
#include "epilogue_thread_apply_logsumexp.h"
#include "gemm_kernel_utils.h"
#include "iterators/make_residual_last.h"
#include "iterators/transpose_warp_iterator.h"
#include "iterators/warp_iterator_from_smem.h"
namespace cutlass {
namespace gemm {
@ -246,6 +250,78 @@ class MmaBaseFromSharedMemory {
: warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {}
};
namespace {
// has necessary trait compliance with WarpIteratorFromSmem but doesn't do
// anything, can be default initialized, and uses fragment that takes up
// (almost) no space. this warp iterator is selected at compile time when
// elementwise on-the-fly scaling for operand A is disabled, in which case
// operations related to loading scale factors for operand A get wiped out by
// the compiler.
template <typename TensorRef>
class NoOpWarpIteratorScale {
public:
// in pipelined+multistage MMA implementations we keep an array of fragments.
// if we aren't using scaling we don't want to waste registers on fragments
// of scale elements, so ideally this would be sized 0.
// using size 1 is kind of a hack to get around arrays of zero-sized objects
// not being allowed. the compiler is probably smart enough to wipe it out
// anyways.
using Fragment = cutlass::Array<char, 1>;
CUTLASS_HOST_DEVICE
NoOpWarpIteratorScale() {}
CUTLASS_HOST_DEVICE
NoOpWarpIteratorScale(TensorRef const&, int) {}
CUTLASS_HOST_DEVICE
NoOpWarpIteratorScale& add_tile_offset(
typename TensorRef::TensorCoord const&) {
return *this;
}
CUTLASS_HOST_DEVICE
NoOpWarpIteratorScale& operator++() {
return *this;
}
CUTLASS_DEVICE
void load(Fragment&) const {}
};
// if scaling is enabled, performs fragment elementwise multiplication between
// fragment and its scaling factor.
template <typename Fragment, typename FragmentScale, bool ScalingEnabled>
class FragmentElementwiseScaler;
// specialization for scaling being enabled.
template <typename Fragment, typename FragmentScale>
class FragmentElementwiseScaler<Fragment, FragmentScale, true> {
public:
// cast scale_frag to correct type then apply elementwise to fragment
CUTLASS_DEVICE
static Fragment apply(Fragment frag, FragmentScale const& scale_frag) {
Fragment converted_scale_frag = cutlass::NumericArrayConverter<
typename Fragment::Element,
typename FragmentScale::Element,
FragmentScale::kElements>()(scale_frag);
return cutlass::multiplies<Fragment>()(frag, converted_scale_frag);
}
};
// specialization for scaling being disabled. doesn't do anything and should
// just get wiped out by the compiler.
template <typename Fragment, typename FragmentScale>
class FragmentElementwiseScaler<Fragment, FragmentScale, false> {
public:
CUTLASS_DEVICE
static Fragment apply(Fragment frag, FragmentScale const&) {
return frag;
}
};
} // namespace
////////////////////////////////////////////////////////////////////////////////
// Taken from
// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h
@ -259,6 +335,10 @@ template <
// BEGIN smem
/// Iterates over the intermediate accumulator tile in shared memory
typename WarpIteratorA,
/// whether or not to perform elementwise multiplication of A
// by another matrix (A_scale) that is also kept in shared memory prior
// to matmul A @ B
bool ScaleOperandA_,
// Accumulator type
typename AccumulatorSharedStorage,
// END smem
@ -297,6 +377,15 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
using Shape =
Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
static constexpr bool ScaleOperandA = ScaleOperandA_;
///< loads fragments of A_scale from shared memory if operand A scaling is
///< enabled. otherwise no-op.
using WarpIteratorAScale = typename cutlass::platform::conditional<
ScaleOperandA,
WarpIteratorA,
NoOpWarpIteratorScale<typename WarpIteratorA::TensorRef>>::type;
using IteratorB =
IteratorB_; ///< Iterates over tiles of B operand in global memory
using ElementC = ElementC_; ///< Data type of accumulator matrix
@ -333,8 +422,20 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
private:
using WarpFragmentA = typename Operator::FragmentA;
/// fragment type of OperandA elementwise scaling matrix. (almost) empty
/// if operand A scaling is disabled.
using WarpFragmentAScale = typename WarpIteratorAScale::Fragment;
using WarpFragmentB = typename Operator::FragmentB;
/// applies scaling factor to operand A fragment if operand A scaling is
/// enabled. otherwise no-op.
using FragmentAScaler = FragmentElementwiseScaler<
WarpFragmentA,
WarpFragmentAScale,
ScaleOperandA>;
protected:
// /// Iterator to write threadblock-scoped tile of A operand to shared memory
// SmemIteratorA smem_iterator_A_;
@ -346,7 +447,46 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
/// accumulator tile
WarpIteratorA warp_tile_iterator_A_;
/// Iterator to load a warp-scoped tile of A_scale from intermediate
/// accumulator tile (only used if ScaleOperandA_ is true)
WarpIteratorAScale warp_tile_iterator_A_scale_;
public:
/// constructor for MMA with operand A scaling enabled.
CUTLASS_DEVICE
MmaPipelinedFromSharedMemory(
// shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorage& shared_storage,
// warp iterator over A tile held in shared memory
WarpIteratorA warp_iter_a,
// warp iterator over A_scale tile held in shared memory
WarpIteratorAScale warp_iter_a_scale,
int thread_idx,
int warp_idx,
int lane_idx)
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
warp_tile_iterator_A_(warp_iter_a),
warp_tile_iterator_A_scale_(warp_iter_a_scale),
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset(
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_A_scale_.add_tile_offset(
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset(
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
}
/// Construct from tensor references
CUTLASS_DEVICE
MmaPipelinedFromSharedMemory(
@ -429,19 +569,26 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
__syncthreads();
// remember that WarpFragmentAScale and WarpIteratorAScale are empty/no-op
// if scaling is disabled.
// Pair of fragments used to overlap shared memory loads and math
// instructions
WarpFragmentA warp_frag_A[2];
WarpFragmentAScale warp_frag_A_scale[2];
WarpFragmentB warp_frag_B[2];
warp_frag_A[0].clear();
warp_frag_A_scale[0].clear();
warp_frag_B[0].clear();
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
this->warp_tile_iterator_A_scale_.load(warp_frag_A_scale[0]);
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_A_scale_;
++this->warp_tile_iterator_B_;
Operator warp_mma;
@ -503,9 +650,12 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
(warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
this->warp_tile_iterator_A_scale_.load(
warp_frag_A_scale[(warp_mma_k + 1) % 2]);
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_A_scale_;
++this->warp_tile_iterator_B_;
if (warp_mma_k == 0) {
@ -521,7 +671,8 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
warp_mma(
accum,
warp_frag_A[warp_mma_k % 2],
FragmentAScaler::apply(
warp_frag_A[warp_mma_k % 2], warp_frag_A_scale[warp_mma_k % 2]),
warp_frag_B[warp_mma_k % 2],
accum);
}
@ -541,6 +692,10 @@ template <
typename Shape1_,
/// Iterates over the intermediate accumulator tile in shared memory
typename WarpIteratorA1_,
/// whether or not to perform elementwise multiplication of A
// by another matrix (A_scale) that is also kept in shared memory prior
// to matmul A @ B
bool ScaleOperandA_,
// Accumulator type
typename AccumulatorSharedStorage,
/// Iterates over tiles of B operand in global memory
@ -580,7 +735,14 @@ class MmaMultistageFromSharedMemory
using SmemIteratorB1 = SmemIteratorB1_;
using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate
///< accumulator tile in shared memory
static constexpr bool ScaleOperandA = ScaleOperandA_;
///< warp level iterator over A_scale matrix tile kept in shared memory.
///< if elementwise A scaling is disabled then everything this does is no-op.
using WarpIteratorAScale = typename cutlass::platform::conditional<
ScaleOperandA,
WarpIteratorA1,
NoOpWarpIteratorScale<typename WarpIteratorA1::TensorRef>>::type;
///< Data type of accumulator matrix
using ElementC = ElementC_;
///< Layout of accumulator matrix
@ -628,10 +790,20 @@ class MmaMultistageFromSharedMemory
private:
using WarpLoadedFragmentA1 = typename Operator1::FragmentA;
/// fragment of OperandA scale matrix. if operand A scaling is disabled this
/// is (almost) empty.
using WarpLoadedFragmentA1Scale = typename WarpIteratorAScale::Fragment;
using WarpLoadedFragmentB1 = typename Operator1::FragmentB;
using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA;
using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB;
/// applies elementwise scaling to fragment of A. if operand A scaling is
/// disabled this is a no-op.
using FragmentAScaler = FragmentElementwiseScaler<
WarpLoadedFragmentA1,
WarpLoadedFragmentA1Scale,
ScaleOperandA>;
private:
//
// Data members
@ -641,12 +813,54 @@ class MmaMultistageFromSharedMemory
/// accumulator tile
WarpIteratorA1 warp_tile_iterator_A1_;
/// Iterator to load a warp-scoped tile of A1_scale operand from shared memory
/// if operand A scaling is disabled everything this does is a no-op.
WarpIteratorAScale warp_tile_iterator_A1_scale_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB1 smem_iterator_B1_;
bool prologue_done_;
public:
/// constructor for MMA with operand A scaling enabled.
CUTLASS_DEVICE
MmaMultistageFromSharedMemory(
// shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorage& shared_storage,
// warp level iterator over operand A tile kept in shared memory
WarpIteratorA1 warp_tile_iterator_A1,
// warp level iterator over operand A elementwise scale tile kept in
// shared memory.
WarpIteratorAScale warp_tile_iterator_A1_scale,
int thread_idx,
int warp_idx,
int lane_idx)
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
warp_tile_iterator_A1_(warp_tile_iterator_A1),
warp_tile_iterator_A1_scale_(warp_tile_iterator_A1_scale),
smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx),
prologue_done_(false) {
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn_1 =
warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN);
int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN);
int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM;
int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM;
// Add per-warp offsets in units of warp-level tiles
warp_tile_iterator_A1_.add_tile_offset(
{warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1});
warp_tile_iterator_A1_scale_.add_tile_offset(
{warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1});
this->warp_tile_iterator_B_.add_tile_offset(
{Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1});
}
/// Construct from tensor references
CUTLASS_DEVICE
MmaMultistageFromSharedMemory(
@ -842,9 +1056,13 @@ class MmaMultistageFromSharedMemory
cutlass::arch::cp_async_wait<kNumStagesConcurrentLoad - 1>();
__syncthreads();
// remember that WarpFragmentAScale and WarpIteratorAScale are no-op/empty
// if scaling is disabled.
// Pair of fragments used to overlap shared memory loads and math
// instructions
WarpLoadedFragmentA1 warp_loaded_frag_A1[2];
WarpLoadedFragmentA1Scale warp_loaded_frag_A1_scale[2];
WarpLoadedFragmentB1 warp_loaded_frag_B1[2];
WarpTransformedFragmentA1 warp_transformed_frag_A1[2];
WarpTransformedFragmentB1 warp_transformed_frag_B1[2];
@ -854,6 +1072,9 @@ class MmaMultistageFromSharedMemory
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]);
++warp_tile_iterator_A1_;
warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[0]);
++warp_tile_iterator_A1_scale_;
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_B_.load(warp_loaded_frag_B1[0]);
++this->warp_tile_iterator_B_;
@ -864,7 +1085,8 @@ class MmaMultistageFromSharedMemory
warp_mma1.transform(
warp_transformed_frag_A1[0],
warp_transformed_frag_B1[0],
warp_loaded_frag_A1[0],
FragmentAScaler::apply(
warp_loaded_frag_A1[0], warp_loaded_frag_A1_scale[0]),
warp_loaded_frag_B1[0]);
// tf32x3 kernels use staging accumulation. warp_mma uses a temporary
@ -909,17 +1131,22 @@ class MmaMultistageFromSharedMemory
warp_mma_k < Base::kWarpGemmIterations1 - 1) {
warp_tile_iterator_A1_.load(
warp_loaded_frag_A1[(warp_mma_k + 1) % 2]);
warp_tile_iterator_A1_scale_.load(
warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]);
this->warp_tile_iterator_B_.load(
warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
}
++warp_tile_iterator_A1_;
++warp_tile_iterator_A1_scale_;
++this->warp_tile_iterator_B_;
if (warp_mma_k > 0)
warp_mma1.transform(
warp_transformed_frag_A1[warp_mma_k % 2],
warp_transformed_frag_B1[warp_mma_k % 2],
warp_loaded_frag_A1[warp_mma_k % 2],
FragmentAScaler::apply(
warp_loaded_frag_A1[warp_mma_k % 2],
warp_loaded_frag_A1_scale[warp_mma_k % 2]),
warp_loaded_frag_B1[warp_mma_k % 2]);
if (platform::is_same<
@ -1009,7 +1236,9 @@ class MmaMultistageFromSharedMemory
warp_mma1.transform(
warp_transformed_frag_A1[(warp_mma_k + 1) % 2],
warp_transformed_frag_B1[(warp_mma_k + 1) % 2],
warp_loaded_frag_A1[(warp_mma_k + 1) % 2],
FragmentAScaler::apply(
warp_loaded_frag_A1[(warp_mma_k + 1) % 2],
warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]),
warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
}
}
@ -1119,6 +1348,9 @@ struct DefaultWarpIteratorAFromSharedMemory<
template <
typename Mma_,
typename AccumulatorSharedStorage,
/// whether or not to apply elementwise multiplication of operand A by
/// another matrix in shared memory before usage in A @ B
bool kScaleOperandA,
bool kTransposeA = false>
struct DefaultMmaFromSharedMemory;
@ -1151,6 +1383,9 @@ template <
/// Transformation applied to B operand
typename TransformB_,
typename AccumulatorSharedStorage_,
/// whether or not to apply elementwise multiplication of operand A by
/// another matrix in shared memory before usage in A @ B
bool kScaleOperandA,
bool kTransposeA>
struct DefaultMmaFromSharedMemory<
MmaPipelined<
@ -1165,6 +1400,7 @@ struct DefaultMmaFromSharedMemory<
TransformA_,
TransformB_>,
AccumulatorSharedStorage_,
kScaleOperandA,
kTransposeA> {
static constexpr int kWarpSize = 32;
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
@ -1198,6 +1434,7 @@ struct DefaultMmaFromSharedMemory<
using Mma = typename cutlass::gemm::threadblock::MmaPipelinedFromSharedMemory<
Shape_,
WarpIteratorA,
kScaleOperandA,
AccumulatorSharedStorage_,
IteratorB,
SmemIteratorB_,
@ -1238,6 +1475,9 @@ template <
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear,
typename AccumulatorSharedStorage_,
/// whether or not to apply elementwise multiplication of operand A by
/// another matrix in shared memory before usage in A @ B
bool kScaleOperandA,
bool kTransposeA>
struct DefaultMmaFromSharedMemory<
MmaMultistage<
@ -1254,6 +1494,7 @@ struct DefaultMmaFromSharedMemory<
Stages,
SharedMemoryClear>,
AccumulatorSharedStorage_,
kScaleOperandA,
kTransposeA> {
static constexpr int kWarpSize = 32;
@ -1301,6 +1542,7 @@ struct DefaultMmaFromSharedMemory<
typename cutlass::gemm::threadblock::MmaMultistageFromSharedMemory<
Shape_,
WarpIteratorA,
kScaleOperandA,
AccumulatorSharedStorage_,
IteratorB,
SmemIteratorB_,
@ -1637,18 +1879,17 @@ struct B2bGemm<
// NOTE: accum is attn.T
// TODO: Optimize for each architecture
static constexpr int WarpSize = 32;
using RegistersIter = typename DefaultAttentionScalingCoefsUpdater<
IteratorC,
accum_t,
WarpSize>::Updater;
using AccumLambdaIterator =
typename DefaultMmaAccumLambdaIterator<IteratorC, accum_t, WarpSize>::
Iterator;
auto lane_offset =
RegistersIter::get_lane_offset(lane_id, warp_id, tile_coords);
AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords);
cutlass::Array<lse_scalar_t, IteratorC::Fragment::kElements> lse_prefetched;
lse_prefetched.clear();
int rowIdx = 0;
int colIdx = 0;
RegistersIter::iterateRows(
AccumLambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {
++rowIdx;
@ -1777,18 +2018,17 @@ struct B2bGemm<
// NOTE: accum is attn.T
// TODO: Optimize for each architecture
static constexpr int WarpSize = 32;
using RegistersIter = typename DefaultAttentionScalingCoefsUpdater<
IteratorC,
accum_t,
WarpSize>::Updater;
using AccumLambdaIterator =
typename DefaultMmaAccumLambdaIterator<IteratorC, accum_t, WarpSize>::
Iterator;
auto lane_offset =
RegistersIter::get_lane_offset(lane_id, warp_id, tile_coords);
AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords);
cutlass::Array<lse_scalar_t, IteratorC::Fragment::kElements> lse_prefetched;
lse_prefetched.clear();
int rowIdx = 0;
int colIdx = 0;
RegistersIter::iterateRows(
AccumLambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {
++rowIdx;

View File

@ -29,16 +29,26 @@
*
**************************************************************************************************/
#pragma once
#ifdef HAS_PYTORCH
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#endif
#include <curand_kernel.h>
#include <cmath>
#include <vector>
#include "cutlass/bfloat16.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/vector.h"
#include "cutlass/matrix.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "attention_scaling_coefs_updater.h"
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
@ -54,11 +64,12 @@
#include "cutlass/platform/platform.h"
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
#include "debug_utils.h"
#include "epilogue_pipelined.h"
#include "epilogue_rescale_output.h"
#include "find_default_mma.h"
#include "epilogue/epilogue_pipelined.h"
#include "epilogue/epilogue_rescale_output.h"
#include "gemm/find_default_mma.h"
#include "gemm/mma_from_smem.h"
#include "gemm_kernel_utils.h"
#include "mma_from_smem.h"
#include "transform/tile_smem_loader.h"
#include <inttypes.h>
@ -73,6 +84,12 @@ constexpr int getWarpsPerSm() {
? 16
: 12);
}
static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) {
// source: https://stackoverflow.com/a/51549250
return (value >= 0)
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
: __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
}
} // namespace
template <
@ -83,10 +100,20 @@ template <
// If Q/K/V are correctly aligned in memory and we can run a fast kernel
bool isAligned_,
int kQueriesPerBlock,
int kKeysPerBlock,
bool kSingleValueIteration // = `value.shape[-1] <= kKeysPerBlock`
>
int kKeysPerBlock_,
bool kSingleValueIteration_, // = `value.shape[-1] <= kKeysPerBlock`
// This is quite slower on V100 for some reason
// Set to false if you know at compile-time you will never need dropout
bool kSupportsDropout_ = true,
bool kSupportsBias_ = true>
struct AttentionKernel {
enum CustomMaskType {
NoCustomMask = 0,
CausalFromTopLeft = 1,
CausalFromBottomRight = 2,
NumCustomMaskTypes,
};
using scalar_t = scalar_t_;
using accum_t = float;
using lse_scalar_t = float;
@ -95,7 +122,11 @@ struct AttentionKernel {
// Using `accum_t` improves perf on f16 at the cost of
// numerical errors
using output_accum_t = accum_t;
static constexpr bool kSupportsDropout = kSupportsDropout_;
static constexpr bool kSupportsBias = kSupportsBias_;
static constexpr int kKeysPerBlock = kKeysPerBlock_;
static constexpr bool kIsAligned = isAligned_;
static constexpr bool kSingleValueIteration = kSingleValueIteration_;
static constexpr int32_t kAlignLSE = 32; // block size of backward
static constexpr bool kPreloadV = ArchTag::kMinComputeCapability >= 80 &&
cutlass::sizeof_bits<scalar_t>::value == 16;
@ -117,10 +148,15 @@ struct AttentionKernel {
struct Params {
// Input tensors
scalar_t* query_ptr; // [num_queries, num_heads, head_dim]
scalar_t* key_ptr; // [num_keys, num_heads, head_dim]
scalar_t* key_ptr; // [num_keys, num_heads, head_dim]
scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value]
int32_t* cu_seqlens_q_ptr = nullptr;
int32_t* cu_seqlens_k_ptr = nullptr;
scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys]
int32_t* seqstart_q_ptr = nullptr;
int32_t* seqstart_k_ptr = nullptr;
int32_t* causal_diagonal_ptr = nullptr;
int32_t* seqlen_k_ptr = nullptr;
uint32_t causal_diagonal_offset = 0;
// Output tensors
output_t* output_ptr; // [num_queries, num_heads, head_dim_value]
@ -137,26 +173,38 @@ struct AttentionKernel {
int32_t num_queries;
int32_t num_keys;
bool causal;
uint8_t custom_mask_type = NoCustomMask;
int32_t q_strideM;
int32_t k_strideM;
int32_t v_strideM;
int32_t bias_strideM = 0;
int32_t o_strideM = 0;
// Everything below is only used in `advance_to_block`
// and shouldn't use registers
int32_t q_strideH;
int32_t k_strideH;
int32_t v_strideH;
int32_t bias_strideH = 0;
int64_t q_strideB;
int64_t k_strideB;
int64_t v_strideB;
int32_t bias_strideB = 0;
int32_t num_batches;
int32_t num_heads;
CUTLASS_HOST_DEVICE int32_t o_strideM() const {
return head_dim_value * num_heads;
}
// dropout
bool use_dropout;
unsigned long long dropout_batch_head_rng_offset;
float dropout_prob;
#ifdef HAS_PYTORCH
at::PhiloxCudaState rng_engine_inputs;
#endif
// Moves pointers to what we should process
// Returns "false" if there is no work to do
CUTLASS_DEVICE bool advance_to_block() {
@ -166,18 +214,33 @@ struct AttentionKernel {
auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE;
if (kSupportsDropout) {
dropout_batch_head_rng_offset =
batch_id * num_heads * num_queries * num_keys +
head_id * num_queries * num_keys;
}
int64_t q_start, k_start;
// Advance to current batch - in case of different sequence lengths
if (cu_seqlens_q_ptr != nullptr) {
assert(cu_seqlens_k_ptr != nullptr);
cu_seqlens_q_ptr += batch_id;
cu_seqlens_k_ptr += batch_id;
q_start = cu_seqlens_q_ptr[0];
k_start = cu_seqlens_k_ptr[0];
int64_t q_next_start = cu_seqlens_q_ptr[1];
int64_t k_next_start = cu_seqlens_k_ptr[1];
if (seqstart_q_ptr != nullptr) {
assert(seqstart_k_ptr != nullptr);
seqstart_q_ptr += batch_id;
q_start = seqstart_q_ptr[0];
int64_t q_next_start = seqstart_q_ptr[1];
int64_t k_end;
seqstart_k_ptr += batch_id;
if (seqlen_k_ptr) {
k_start = seqstart_k_ptr[0];
k_end = k_start + seqlen_k_ptr[batch_id];
} else {
k_start = seqstart_k_ptr[0];
k_end = seqstart_k_ptr[1];
}
num_queries = q_next_start - q_start;
num_keys = k_next_start - k_start;
num_keys = k_end - k_start;
if (query_start >= num_queries) {
return false;
@ -186,9 +249,10 @@ struct AttentionKernel {
query_ptr += batch_id * q_strideB;
key_ptr += batch_id * k_strideB;
value_ptr += batch_id * v_strideB;
output_ptr += int64_t(batch_id * num_queries) * o_strideM();
output_ptr += int64_t(batch_id * num_queries) * o_strideM;
if (output_accum_ptr != nullptr) {
output_accum_ptr += int64_t(batch_id * num_queries) * o_strideM();
output_accum_ptr +=
int64_t(batch_id * num_queries) * (head_dim_value * num_heads);
}
q_start = 0;
k_start = 0;
@ -197,42 +261,84 @@ struct AttentionKernel {
// Advance to the current batch / head / query_start
query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH;
key_ptr += k_start * k_strideM + head_id * k_strideH;
value_ptr += k_start * v_strideM + head_id * v_strideH;
output_ptr += int64_t(q_start + query_start) * o_strideM() +
head_id * head_dim_value;
value_ptr += k_start * v_strideM + head_id * v_strideH;
output_ptr +=
int64_t(q_start + query_start) * o_strideM + head_id * head_dim_value;
if (kSupportsBias && attn_bias_ptr != nullptr) {
attn_bias_ptr += (batch_id * bias_strideB) + (head_id * bias_strideH);
}
if (output_accum_ptr != nullptr) {
output_accum_ptr += int64_t(q_start + query_start) * o_strideM() +
output_accum_ptr +=
int64_t(q_start + query_start) * (head_dim_value * num_heads) +
head_id * head_dim_value;
} else {
// Accumulate directly in the destination buffer (eg for f32)
output_accum_ptr = (accum_t*)output_ptr;
}
if (logsumexp_ptr != nullptr) {
// lse[batch_id, head_id, query_start]
logsumexp_ptr +=
batch_id * lse_dim * num_heads + head_id * lse_dim + query_start;
}
num_queries -= query_start;
if (causal) {
num_keys = cutlass::fast_min(
int32_t(query_start + kQueriesPerBlock), num_keys);
// Custom masking
if (causal_diagonal_ptr) {
causal_diagonal_offset = causal_diagonal_ptr[batch_id];
}
if (custom_mask_type == CausalFromBottomRight) {
causal_diagonal_offset += num_keys - num_queries;
}
if (custom_mask_type == CausalFromTopLeft ||
custom_mask_type == CausalFromBottomRight) {
// the bottom row of the current block is query_start + kQueriesPerBlock
// the last active key is then query_start + causal_diagonal_offset +
// kQueriesPerBlock so num_keys is the min between actual num_keys and
// this to avoid extra computations
num_keys = cutlass::fast_min(
int32_t(query_start + causal_diagonal_offset + kQueriesPerBlock),
num_keys);
}
num_queries -= query_start;
num_batches = 0; // no longer used after
// If num_queries == 1, and there is only one key head we're wasting
// 15/16th of tensor core compute In that case :
// - we only launch kernels for head_id % kQueriesPerBlock == 0
// - we iterate over heads instead of queries (strideM = strideH)
if (num_queries == 1 && k_strideH == 0 && v_strideH == 0) {
if (head_id % kQueriesPerBlock != 0)
return false;
q_strideM = q_strideH;
num_queries = num_heads;
num_heads = 1; // unused but here for intent
// remove causal since n_query = 1
// otherwise, offset would change with head !
custom_mask_type = NoCustomMask;
o_strideM = head_dim_value;
}
// Make sure the compiler knows these variables are the same on all
// the threads of the warp.
query_ptr = warp_uniform(query_ptr);
key_ptr = warp_uniform(key_ptr);
value_ptr = warp_uniform(value_ptr);
if (kSupportsBias) {
attn_bias_ptr = warp_uniform(attn_bias_ptr);
}
output_ptr = warp_uniform(output_ptr);
output_accum_ptr = warp_uniform(output_accum_ptr);
logsumexp_ptr = warp_uniform(logsumexp_ptr);
num_queries = warp_uniform(num_queries);
num_keys = warp_uniform(num_keys);
num_heads = warp_uniform(num_heads);
head_dim = warp_uniform(head_dim);
head_dim_value = warp_uniform(head_dim_value);
o_strideM = warp_uniform(o_strideM);
custom_mask_type = warp_uniform(custom_mask_type);
return true;
}
@ -242,6 +348,7 @@ struct AttentionKernel {
num_heads,
num_batches);
}
__host__ dim3 getThreadsGrid() const {
return dim3(kWarpSize, kNumWarpsPerBlock, 1);
}
@ -296,16 +403,24 @@ struct AttentionKernel {
using IteratorA = typename DefaultMma::IteratorA;
using IteratorB = typename DefaultMma::IteratorB;
using Mma = typename DefaultMma::ThreadblockMma;
using ScalingCoefsUpdater = typename DefaultAttentionScalingCoefsUpdater<
using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
typename Mma::Operator::IteratorC,
accum_t,
kWarpSize>::Updater;
kWarpSize>::Iterator;
static_assert(
MmaCore::WarpCount::kM * MmaCore::WarpCount::kN *
MmaCore::WarpCount::kK ==
kNumWarpsPerBlock,
"");
// used for efficient load of bias tile Bij from global to shared memory
using BiasLoader = TileSmemLoader<
scalar_t,
cutlass::MatrixShape<kQueriesPerBlock, kKeysPerBlock>,
MmaCore::kThreads,
// input restriction: kv_len has to be a multiple of this value
128 / cutlass::sizeof_bits<scalar_t>::value>;
// Epilogue to store to shared-memory in a format that we can use later for
// the second matmul
using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm<
@ -367,7 +482,8 @@ struct AttentionKernel {
using DefaultMmaFromSmem =
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
typename DefaultGemm::Mma,
typename MM0::AccumulatorSharedStorage>;
typename MM0::AccumulatorSharedStorage,
false>; // kScaleOperandA
using Mma = typename DefaultMmaFromSmem::Mma;
using IteratorB = typename Mma::IteratorB;
using WarpCount = typename Mma::WarpCount;
@ -404,7 +520,10 @@ struct AttentionKernel {
struct SharedStorageEpilogueAtEnd : ScalingCoefs {
struct SharedStorageAfterMM0 {
// Everything here might be overwritten during MM0
typename MM0::AccumulatorSharedStorage si;
union {
typename MM0::BiasLoader::SmemTile bias;
typename MM0::AccumulatorSharedStorage si;
};
typename MM1::SharedStorageMM1 mm1;
};
@ -423,7 +542,10 @@ struct AttentionKernel {
struct SharedStorageEpilogueInLoop : ScalingCoefs {
struct SharedStorageAfterMM0 {
// Everything here might be overwritten during MM0
typename MM0::AccumulatorSharedStorage si;
union {
typename MM0::BiasLoader::SmemTile bias;
typename MM0::AccumulatorSharedStorage si;
};
typename MM1::SharedStorageMM1 mm1;
typename MM1::DefaultEpilogue::SharedStorage epilogue;
};
@ -448,6 +570,18 @@ struct AttentionKernel {
CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ);
CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK);
CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV);
if (kSupportsBias) {
CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ);
XFORMERS_CHECK(
p.bias_strideB % kAlignmentQ == 0,
"attn_bias is not correctly aligned");
XFORMERS_CHECK(
p.bias_strideH % kAlignmentQ == 0,
"attn_bias is not correctly aligned");
XFORMERS_CHECK(
p.bias_strideM % kAlignmentQ == 0,
"attn_bias is not correctly aligned");
}
XFORMERS_CHECK(
p.q_strideM % kAlignmentQ == 0, "query is not correctly aligned");
XFORMERS_CHECK(
@ -460,6 +594,12 @@ struct AttentionKernel {
p.k_strideH % kAlignmentK == 0, "key is not correctly aligned");
XFORMERS_CHECK(
p.v_strideH % kAlignmentV == 0, "value is not correctly aligned");
XFORMERS_CHECK(
p.causal_diagonal_ptr == nullptr || p.custom_mask_type != NoCustomMask,
"`causal_diagonal_ptr` is only useful when `custom_mask_type` is causal");
XFORMERS_CHECK(
p.custom_mask_type < NumCustomMaskTypes,
"invalid value for `custom_mask_type`");
return true;
}
@ -472,8 +612,8 @@ struct AttentionKernel {
SharedStorage& shared_storage = *((SharedStorage*)smem_buffer);
auto& m_prime = shared_storage.m_prime;
auto& s_prime = shared_storage.s_prime;
[[maybe_unused]] auto& si = shared_storage.after_mm0.si;
auto& mi = shared_storage.mi;
const uint32_t query_start = blockIdx.x * kQueriesPerBlock;
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
if (thread_id() < kQueriesPerBlock) {
@ -488,7 +628,7 @@ struct AttentionKernel {
auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator {
using OutputTileIterator = typename MM1::OutputTileIterator;
return OutputTileIterator(
typename OutputTileIterator::Params{(int32_t)p.o_strideM()},
typename OutputTileIterator::Params{(int32_t)p.o_strideM},
p.output_ptr,
typename OutputTileIterator::TensorCoord{
p.num_queries, p.head_dim_value},
@ -500,7 +640,8 @@ struct AttentionKernel {
typename MM1::OutputTileIteratorAccum {
using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum;
return OutputTileIteratorAccum(
typename OutputTileIteratorAccum::Params{(int32_t)p.o_strideM()},
typename OutputTileIteratorAccum::Params{
(int32_t)(p.head_dim_value * p.num_heads)},
p.output_accum_ptr,
typename OutputTileIteratorAccum::TensorCoord{
p.num_queries, p.head_dim_value},
@ -508,6 +649,27 @@ struct AttentionKernel {
{0, col});
};
#ifdef HAS_PYTORCH
curandStatePhilox4_32_10_t curand_state_init;
if (kSupportsDropout && p.use_dropout) {
const auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs);
// each element of the attention matrix P with shape
// (batch_sz, n_heads, n_queries, n_keys) is associated with a single
// offset in RNG sequence. we initialize the RNG state with offset that
// starts at the beginning of a (n_queries, n_keys) matrix for this
// block's batch_id and head_id
// initializing rng state is very expensive, so we run once per kernel,
// rather than once per iteration. each iteration takes a copy of the
// initialized RNG state and offsets it as needed.
curand_init(
std::get<0>(seeds),
0,
std::get<1>(seeds) + p.dropout_batch_head_rng_offset,
&curand_state_init);
}
#endif
// Iterate through keys
for (int32_t iter_key_start = 0; iter_key_start < p.num_keys;
iter_key_start += kKeysPerBlock) {
@ -600,16 +762,65 @@ struct AttentionKernel {
(tb_tile_offset.n() * MM0::Mma::WarpCount::kN) +
(my_warp_id / MM0::Mma::WarpCount::kM)};
// multiply by scaling factor
if (kSupportsBias) {
accum =
cutlass::multiplies<typename MM0::Mma::FragmentC>()(p.scale, accum);
}
// apply attention bias if applicable
if (kSupportsBias && p.attn_bias_ptr != nullptr) {
// load bias tile Bij into shared memory
typename MM0::BiasLoader::GmemTileIterator bias_iter(
{cutlass::layout::RowMajor(p.bias_strideM)},
// attn_bias_pointer points to matrix of size (n_queries, n_keys)
// for the relevant batch_id and head_id
p.attn_bias_ptr + query_start * p.bias_strideM + iter_key_start,
{problem_size_0_m, problem_size_0_n},
thread_id());
cutlass::TensorRef<scalar_t, cutlass::layout::RowMajor> bias_tensor_ref(
shared_storage.after_mm0.bias.data(),
cutlass::layout::RowMajor(MM0::ThreadblockShape::kN));
typename MM0::BiasLoader::SmemTileIterator smem_tile_iter(
bias_tensor_ref, thread_id());
MM0::BiasLoader::load(bias_iter, smem_tile_iter);
// Pij += Bij, Pij is in register fragment and Bij is in shared memory
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
lane_id(), warp_id(), iteratorC_tile_offset);
MM0::AccumLambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {},
[&](int accum_m, int accum_n, int idx) {
if (accum_m < problem_size_0_m && accum_n < problem_size_0_n) {
accum[idx] += bias_tensor_ref.at({accum_m, accum_n});
}
},
[&](int accum_m) {});
}
// Mask out last if causal
if (p.causal && p.num_keys - iter_key_start <= kKeysPerBlock) {
// This is only needed if upper-right corner of current query / key block
// intersects the mask Coordinates of upper-right corner of current block
// is y=query_start x=min(iter_key_start + kKeysPerBlock, num_keys)) The
// first masked element is x = y + offset -> query_start + offset There is
// intersection (and we need to mask) if min(iter_key_start +
// kKeysPerBlock, num_keys)) >= query_start + offset
if (p.custom_mask_type &&
cutlass::fast_min(iter_key_start + kKeysPerBlock, p.num_keys) >=
(query_start + p.causal_diagonal_offset)) {
auto query_start = blockIdx.x * kQueriesPerBlock;
auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset(
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
lane_id(), warp_id(), iteratorC_tile_offset);
int32_t last_col;
MM0::ScalingCoefsUpdater::iterateRows(
MM0::AccumLambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {
last_col = query_start + accum_m - iter_key_start;
// last absolute col is (last absolute query + offset)
// last local col is (last absolute query + offset -
// iter_key_start)
last_col = query_start + accum_m + p.causal_diagonal_offset -
iter_key_start;
},
[&](int accum_m, int accum_n, int idx) {
if (accum_n > last_col) {
@ -625,14 +836,11 @@ struct AttentionKernel {
kFullColumns,
([&] {
// Update `mi` from accum stored in registers
// Also updates `accum` with accum[i] <-
// exp(accum[i] * scale
// - mi)
MM0::ScalingCoefsUpdater::update<
kQueriesPerBlock,
// Also does accum[i] <- exp(accum[i] - mi)
iterative_softmax<
typename MM0::Mma::Operator::IteratorC,
kFullColumns,
kIsFirst,
kKeepOutputInRF>(
kIsFirst>(
accum_o,
accum,
mi,
@ -643,7 +851,7 @@ struct AttentionKernel {
warp_id(),
p.num_keys - iter_key_start,
iteratorC_tile_offset,
p.scale);
kSupportsBias ? 1.0f : p.scale);
}));
}));
@ -659,6 +867,69 @@ struct AttentionKernel {
__syncthreads();
#ifdef HAS_PYTORCH
// apply dropout (if applicable) after we've written Pij to smem.
// dropout is applied by multiplying each element of Pij by:
// - 0 with probability dropout_p
// - 1 / (1 - dropout_p) with probability 1 - dropout_p
//
// for backward purposes we want to be able to map each element of the
// attention matrix to the same random uniform number as the one we used
// in forward, without needing to use the same iteration order or having
// to store the dropout matrix. its possible to do this in registers but
// it ends up being very slow because each thread having noncontiguous
// strips of the Pij tile means we have to skip around a lot, and also
// have to generate a single random number at a time
if (kSupportsDropout && p.use_dropout) {
auto si = shared_storage.after_mm0.si.accum_ref();
// each thread handles a contiguous sequence of elements from Sij, all
// coming from the same row. the reason they have to come from the same
// row is that the sampling random numbers from a contiguous random
// number sequence is much more efficient than jumping around, and the
// linear offset of each element of S (the global matrix) maps to an
// offset in a random number sequence. for S, the end of a row and the
// beginning of the next have adjacent offsets, but for Sij, this is not
// necessarily the case.
const int num_threads = blockDim.x * blockDim.y * blockDim.z;
const int threads_per_row =
cutlass::fast_min(num_threads / problem_size_0_m, problem_size_0_n);
const int elts_per_thread = cutlass::round_nearest(
cutlass::ceil_div(problem_size_0_n, threads_per_row), 4);
const int thread_i = thread_id() / threads_per_row;
const int thread_start_j =
(thread_id() % threads_per_row) * elts_per_thread;
if (thread_i < problem_size_0_m && thread_start_j < problem_size_0_n) {
curandStatePhilox4_32_10_t curand_state = curand_state_init;
skipahead(
static_cast<unsigned long long>(
(query_start + thread_i) * p.num_keys +
(iter_key_start + thread_start_j)),
&curand_state);
const float dropout_scale = 1.0 / (1.0 - p.dropout_prob);
// apply dropout scaling to elements this thread is responsible for,
// in chunks of 4
for (int sij_start_col_idx = thread_start_j; sij_start_col_idx <
cutlass::fast_min(thread_start_j + elts_per_thread,
problem_size_0_n);
sij_start_col_idx += 4) {
const float4 rand_uniform_quad = curand_uniform4(&curand_state);
CUTLASS_PRAGMA_UNROLL
for (int quad_idx = 0; quad_idx < 4; ++quad_idx) {
si.at({thread_i, sij_start_col_idx + quad_idx}) *=
static_cast<scalar_t>(
dropout_scale *
((&rand_uniform_quad.x)[quad_idx] > p.dropout_prob));
}
}
}
__syncthreads(); // p.use_dropout should have same value kernel-wide
}
#endif
//
// MATMUL: Attn . V
// Run the matmul `attn @ V` for a block of attn and V.
@ -830,6 +1101,116 @@ struct AttentionKernel {
}
}
template <
typename WarpIteratorC,
bool kFullColumns,
bool kIsFirst>
CUTLASS_DEVICE static void iterative_softmax(
typename WarpIteratorC::Fragment& frag_o, // output so far
typename WarpIteratorC::Fragment& frag,
cutlass::Array<accum_t, kQueriesPerBlock>& mi,
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
int8_t lane_id,
int8_t thread_id,
int8_t warp_id,
int16_t max_col,
typename WarpIteratorC::TensorCoord const& tile_offset,
float scaling) {
/* Iterates on the accumulator and corresponding position on result matrix
(1) Update `mi[r]` to the max value of the row `r`
(2) In a second iteration do the following:
(a) accum <- exp(accum - mi)
(b) m_prime <- exp(m_prime - mi)
(c) s_prime <- s_prime * m_prime + sum(accum)
All of this is done on registers, before we store all of this
on shared memory for the next matmul with Value.
*/
using Fragment = typename WarpIteratorC::Fragment;
using LambdaIterator = typename DefaultMmaAccumLambdaIterator<
WarpIteratorC,
accum_t,
kWarpSize>::Iterator;
// Convert to `accum_t` (rather than double)
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
if (!kIsFirst) {
if (thread_id < kQueriesPerBlock) {
m_prime[thread_id] = mi[thread_id];
}
__syncthreads();
}
auto lane_offset =
LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset);
// First update `mi` to the max per-row
{
accum_t max;
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {
max = -cutlass::platform::numeric_limits<accum_t>::infinity();
},
[&](int accum_m, int accum_n, int idx) {
if (kFullColumns || accum_n < max_col) {
max = cutlass::fast_max(max, frag[idx]);
}
},
[&](int accum_m) {
// Having 4x atomicMax seems faster than reduce within warp
// first...
atomicMaxFloat(&mi[accum_m], max * scaling);
});
}
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
// Make sure we all share the update values for `mi`
__syncthreads();
if (thread_id < kQueriesPerBlock) {
auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id]));
m_prime[thread_id] = m_prime_exp;
s_prime[thread_id] *= m_prime_exp;
}
__syncthreads(); // Update output fragments
if (kKeepOutputInRF && !kIsFirst) {
accum_t mp;
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) { mp = m_prime[accum_m]; },
[&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; },
[&](int accum_m) {});
__syncthreads();
}
// Update accum_m, accum_n, ...
{
accum_t mi_row, total_row;
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) { mi_row = kLog2e * mi[accum_m]; },
[&](int accum_m, int accum_n, int idx) {
frag[idx] = (kFullColumns || accum_n < max_col)
? exp2f(frag[idx] - mi_row)
: accum_t(0.0);
},
[&](int accum_m) {});
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) { total_row = 0.0; },
[&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; },
[&](int accum_m) {
if (LambdaIterator::reduceSameRow(
lane_id, total_row, [](accum_t a, accum_t b) {
return a + b;
})) {
atomicAdd(&s_prime[accum_m], total_row);
}
});
}
}
static CUTLASS_DEVICE int8_t lane_id() {
return threadIdx.x;
}
@ -849,3 +1230,7 @@ __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
}
AK::attention_kernel(p);
}
template <typename AK>
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
attention_kernel_batched(typename AK::Params params);

View File

@ -0,0 +1,88 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <cutlass/cutlass.h>
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/numeric_types.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
template <
typename scalar_t, // scalar type
typename ThreadblockTileShape, // size of tile to load
int Threads, // number of participating threads
int ElementsPerAccess> // thread access width in elements
class TileSmemLoader {
public:
using SmemTile =
cutlass::AlignedBuffer<scalar_t, ThreadblockTileShape::kCount>;
using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap<
cutlass::layout::PitchLinearShape<
ThreadblockTileShape::kColumn, // contiguous
ThreadblockTileShape::kRow>, // strided
Threads, // Threads
ElementsPerAccess>; // ElementsPerAccess
using GmemTileIterator =
cutlass::transform::threadblock::PredicatedTileIterator<
ThreadblockTileShape, // Shape
scalar_t, // Element
cutlass::layout::RowMajor, // Layout
0, // AdvanceRank
ThreadMap>; // ThreadMap
using SmemTileIterator = cutlass::transform::threadblock::RegularTileIterator<
ThreadblockTileShape, // Shape
scalar_t, // Element
cutlass::layout::RowMajor, // Layout
0, // AdvanceRank
ThreadMap>; // ThreadMap
using Fragment = typename GmemTileIterator::Fragment;
/// load a tile from global memory into shared memory
CUTLASS_DEVICE
static void load(
GmemTileIterator tile_load_iter,
SmemTileIterator tile_store_iter) {
Fragment tb_frag;
tb_frag.clear();
tile_load_iter.load(tb_frag);
tile_store_iter.store(tb_frag);
__syncthreads();
}
};