fMHA: Sync FW with xFormers (#828)
* fMHA: Add support for bias+dropout in FW * Remove 'getMaximumSharedMemoryPerBlockKb' * fix comments --------- Co-authored-by: danthe3rd <danthe3rd> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
@ -50,12 +50,17 @@
|
||||
#if 1
|
||||
#define PRINT_WARP_ID 0
|
||||
#define PRINT_LANE_ID 0
|
||||
#define PRINT_T0_L0(msg, ...) \
|
||||
#define PRINT_B0_T0(msg, ...) \
|
||||
if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && \
|
||||
threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \
|
||||
threadIdx.z == 0) { \
|
||||
printf(msg "\n", ##__VA_ARGS__); \
|
||||
}
|
||||
#define PRINT_T0(msg, ...) \
|
||||
if (threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \
|
||||
threadIdx.z == 0) { \
|
||||
printf(msg "\n", ##__VA_ARGS__); \
|
||||
}
|
||||
#define PRINT_TX_LX(msg, ...) \
|
||||
for (int bx = 0; bx < gridDim.x; ++bx) { \
|
||||
for (int by = 0; by < gridDim.y; ++by) { \
|
||||
@ -84,7 +89,7 @@
|
||||
} \
|
||||
}
|
||||
#else
|
||||
#define PRINT_T0_L0
|
||||
#define PRINT_B0_T0
|
||||
#define PRINT_TX_LX
|
||||
#endif
|
||||
|
||||
@ -124,7 +129,7 @@ constexpr __string_view __get_type_name() {
|
||||
|
||||
// Print a given array
|
||||
#define PRINT_ACCUM8_T0_L0_START(name, accum, start) \
|
||||
PRINT_T0_L0( \
|
||||
PRINT_B0_T0( \
|
||||
"%s[%d:%d] - {%f, %f, %f, %f, %f, %f, %f, %f}", \
|
||||
name, \
|
||||
int(start), \
|
||||
@ -141,7 +146,7 @@ constexpr __string_view __get_type_name() {
|
||||
#define PRINT_FRAG_T0_L0(name, frag) \
|
||||
{ \
|
||||
auto typeStr = __get_type_name<decltype(frag)>(); \
|
||||
PRINT_T0_L0("printing %s (%s)", name, typeStr.data); \
|
||||
PRINT_B0_T0("printing %s (%s)", name, typeStr.data); \
|
||||
for (int _start = 0; _start < frag.size(); _start += 8) { \
|
||||
PRINT_ACCUM8_T0_L0_START(" ", frag, _start); \
|
||||
} \
|
||||
@ -150,7 +155,7 @@ constexpr __string_view __get_type_name() {
|
||||
}
|
||||
#define PRINT_ARRAY_T0_L0_INCR(name, array, length, incr) \
|
||||
{ \
|
||||
PRINT_T0_L0("printing %s (len=%d)", name, int(length)); \
|
||||
PRINT_B0_T0("printing %s (len=%d)", name, int(length)); \
|
||||
for (int _start = 0; _start < length; _start += incr) { \
|
||||
PRINT_ACCUM8_T0_L0_START(" ", array, _start); \
|
||||
} \
|
||||
@ -160,7 +165,7 @@ constexpr __string_view __get_type_name() {
|
||||
|
||||
// Print a 4x4 matrix
|
||||
#define PRINT_TENSOR4x4_T0_L0_START(name, ref, start_x, start_y) \
|
||||
PRINT_T0_L0( \
|
||||
PRINT_B0_T0( \
|
||||
"%s[%d:%d, %d:%d]:\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f", \
|
||||
name, \
|
||||
int(start_x), \
|
||||
@ -187,9 +192,43 @@ constexpr __string_view __get_type_name() {
|
||||
PRINT_TENSOR4x4_T0_L0_START(name, ref, 0, 0)
|
||||
|
||||
#define PRINT_PROBLEM_SIZE(name, ps) \
|
||||
PRINT_T0_L0( \
|
||||
PRINT_B0_T0( \
|
||||
"%s.problem_size: {.m=%d, .n=%d, .k=%d}", \
|
||||
name, \
|
||||
int(ps.m()), \
|
||||
int(ps.n()), \
|
||||
int(ps.k()))
|
||||
|
||||
template <typename LambdaIterator, typename LaneOffsetT, typename AccumT>
|
||||
CUTLASS_DEVICE void print_warp_accum(
|
||||
AccumT accum,
|
||||
LaneOffsetT lane_offset,
|
||||
int32_t num_rows,
|
||||
int32_t num_cols) {
|
||||
bool is_main = blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 &&
|
||||
threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0;
|
||||
for (int row = 0; row < num_rows; ++row) {
|
||||
for (int col = 0; col < num_cols; ++col) {
|
||||
if (col % 32 == 0) {
|
||||
if (is_main) {
|
||||
printf("\nmat[%3d, %3d:%3d]", row, col, col + 32);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
LambdaIterator::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) {},
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
if (row == accum_m && col == accum_n &&
|
||||
(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)) {
|
||||
printf(" %6.1f", float(accum[idx]));
|
||||
}
|
||||
},
|
||||
[&](int accum_m) {});
|
||||
__syncthreads();
|
||||
}
|
||||
if (is_main) {
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -50,9 +50,8 @@
|
||||
|
||||
#include "fmha_grouped.h"
|
||||
#include "gemm_kernel_utils.h"
|
||||
#include "find_default_mma.h"
|
||||
#include "attention_scaling_coefs_updater.h"
|
||||
#include "mma_from_smem.h"
|
||||
#include "gemm/find_default_mma.h"
|
||||
#include "gemm/mma_from_smem.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -154,10 +153,10 @@ struct DefaultFMHAGrouped {
|
||||
using IteratorA = typename DefaultMma::IteratorA;
|
||||
using IteratorB = typename DefaultMma::IteratorB;
|
||||
using Mma = typename DefaultMma::ThreadblockMma;
|
||||
using ScalingCoefsUpdater = typename DefaultAttentionScalingCoefsUpdater<
|
||||
using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
|
||||
typename Mma::Operator::IteratorC,
|
||||
ElementAccumulator,
|
||||
kWarpSize>::Updater;
|
||||
kWarpSize>::Iterator;
|
||||
|
||||
static_assert(MmaCore::WarpCount::kCount == kNumWarpsPerBlock, "");
|
||||
|
||||
@ -240,7 +239,8 @@ struct DefaultFMHAGrouped {
|
||||
using DefaultMmaFromSmem =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
|
||||
typename DefaultGemm::Mma,
|
||||
typename MM0::AccumulatorSharedStorage>;
|
||||
typename MM0::AccumulatorSharedStorage,
|
||||
false>; // kScaleOperandA
|
||||
|
||||
using Mma = typename DefaultMmaFromSmem::Mma;
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
@ -48,7 +48,18 @@
|
||||
|
||||
#include "fmha_grouped_problem_visitor.h"
|
||||
#include "gemm_kernel_utils.h"
|
||||
#include "epilogue_rescale_output.h"
|
||||
#include "gemm/mma_accum_lambda_iterator.h"
|
||||
#include "epilogue/epilogue_rescale_output.h"
|
||||
|
||||
|
||||
namespace {
|
||||
static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) {
|
||||
// source: https://stackoverflow.com/a/51549250
|
||||
return (value >= 0)
|
||||
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
|
||||
: __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -128,6 +139,9 @@ public:
|
||||
static int const kQueriesPerBlock = ThreadblockShape::kM;
|
||||
static int const kKeysPerBlock = ThreadblockShape::kN;
|
||||
|
||||
static constexpr bool kSupportsDropout = false;
|
||||
static constexpr bool kSupportsBias = false;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename MM1::WarpCount;
|
||||
static int const kThreadsPerWarp = 32;
|
||||
@ -619,10 +633,10 @@ public:
|
||||
|
||||
// Mask out last if causal
|
||||
if (params.causal && num_keys - iter_key_start <= kKeysPerBlock) {
|
||||
auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset(
|
||||
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
|
||||
lane_id(), warp_id(), iteratorC_tile_offset);
|
||||
int32_t last_col;
|
||||
MM0::ScalingCoefsUpdater::iterateRows(
|
||||
MM0::AccumLambdaIterator::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) {
|
||||
last_col = TileParams::query_start(threadblock_idx) + accum_m - iter_key_start;
|
||||
@ -641,14 +655,11 @@ public:
|
||||
kFullColumns,
|
||||
([&] {
|
||||
// Update `mi` from accum stored in registers
|
||||
// Also updates `accum` with accum[i] <-
|
||||
// exp(accum[i] * scale
|
||||
// - mi)
|
||||
MM0::ScalingCoefsUpdater::update<
|
||||
kQueriesPerBlock,
|
||||
// Also does accum[i] <- exp(accum[i] - mi)
|
||||
iterative_softmax<
|
||||
typename MM0::Mma::Operator::IteratorC,
|
||||
kFullColumns,
|
||||
kIsFirst,
|
||||
kKeepOutputInRF>(
|
||||
kIsFirst>(
|
||||
accum_o,
|
||||
accum,
|
||||
mi,
|
||||
@ -659,7 +670,7 @@ public:
|
||||
warp_id(),
|
||||
num_keys - iter_key_start,
|
||||
iteratorC_tile_offset,
|
||||
params.scale);
|
||||
kSupportsBias ? 1.0f : params.scale);
|
||||
}));
|
||||
}));
|
||||
|
||||
@ -838,6 +849,116 @@ public:
|
||||
problem_visitor.advance(gridDim.x);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename WarpIteratorC,
|
||||
bool kFullColumns,
|
||||
bool kIsFirst>
|
||||
CUTLASS_DEVICE static void iterative_softmax(
|
||||
typename WarpIteratorC::Fragment& frag_o, // output so far
|
||||
typename WarpIteratorC::Fragment& frag,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock>& mi,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
|
||||
int8_t lane_id,
|
||||
int8_t thread_id,
|
||||
int8_t warp_id,
|
||||
int16_t max_col,
|
||||
typename WarpIteratorC::TensorCoord const& tile_offset,
|
||||
float scaling) {
|
||||
/* Iterates on the accumulator and corresponding position on result matrix
|
||||
|
||||
(1) Update `mi[r]` to the max value of the row `r`
|
||||
(2) In a second iteration do the following:
|
||||
(a) accum <- exp(accum - mi)
|
||||
(b) m_prime <- exp(m_prime - mi)
|
||||
(c) s_prime <- s_prime * m_prime + sum(accum)
|
||||
|
||||
All of this is done on registers, before we store all of this
|
||||
on shared memory for the next matmul with Value.
|
||||
*/
|
||||
using Fragment = typename WarpIteratorC::Fragment;
|
||||
using LambdaIterator = typename DefaultMmaAccumLambdaIterator<
|
||||
WarpIteratorC,
|
||||
accum_t,
|
||||
kThreadsPerWarp>::Iterator;
|
||||
// Convert to `accum_t` (rather than double)
|
||||
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
|
||||
if (!kIsFirst) {
|
||||
if (thread_id < kQueriesPerBlock) {
|
||||
m_prime[thread_id] = mi[thread_id];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
auto lane_offset =
|
||||
LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset);
|
||||
|
||||
// First update `mi` to the max per-row
|
||||
{
|
||||
accum_t max;
|
||||
LambdaIterator::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) {
|
||||
max = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
},
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
if (kFullColumns || accum_n < max_col) {
|
||||
max = cutlass::fast_max(max, frag[idx]);
|
||||
}
|
||||
},
|
||||
[&](int accum_m) {
|
||||
// Having 4x atomicMax seems faster than reduce within warp
|
||||
// first...
|
||||
atomicMaxFloat(&mi[accum_m], max * scaling);
|
||||
});
|
||||
}
|
||||
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
|
||||
|
||||
// Make sure we all share the update values for `mi`
|
||||
__syncthreads();
|
||||
|
||||
if (thread_id < kQueriesPerBlock) {
|
||||
auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id]));
|
||||
m_prime[thread_id] = m_prime_exp;
|
||||
s_prime[thread_id] *= m_prime_exp;
|
||||
}
|
||||
__syncthreads(); // Update output fragments
|
||||
if (kKeepOutputInRF && !kIsFirst) {
|
||||
accum_t mp;
|
||||
LambdaIterator::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) { mp = m_prime[accum_m]; },
|
||||
[&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; },
|
||||
[&](int accum_m) {});
|
||||
__syncthreads();
|
||||
}
|
||||
// Update accum_m, accum_n, ...
|
||||
{
|
||||
accum_t mi_row, total_row;
|
||||
LambdaIterator::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) { mi_row = kLog2e * mi[accum_m]; },
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
frag[idx] = (kFullColumns || accum_n < max_col)
|
||||
? exp2f(frag[idx] - mi_row)
|
||||
: accum_t(0.0);
|
||||
},
|
||||
[&](int accum_m) {});
|
||||
LambdaIterator::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) { total_row = 0.0; },
|
||||
[&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; },
|
||||
[&](int accum_m) {
|
||||
if (LambdaIterator::reduceSameRow(
|
||||
lane_id, total_row, [](accum_t a, accum_t b) {
|
||||
return a + b;
|
||||
})) {
|
||||
atomicAdd(&s_prime[accum_m], total_row);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -856,7 +856,9 @@ public:
|
||||
p.head_dim_value = options.head_size_v;
|
||||
p.num_queries = options.seq_length;
|
||||
p.num_keys = options.seq_length_kv;
|
||||
p.causal = options.causal;
|
||||
if (options.causal) {
|
||||
p.custom_mask_type = Attention::CausalFromTopLeft;
|
||||
}
|
||||
|
||||
// All tensors are in BMHK shapes
|
||||
p.q_strideH = options.head_size;
|
||||
@ -868,6 +870,7 @@ public:
|
||||
p.q_strideB = p.q_strideM * options.seq_length;
|
||||
p.k_strideB = p.k_strideM * options.seq_length_kv;
|
||||
p.v_strideB = p.v_strideM * options.seq_length_kv;
|
||||
p.o_strideM = p.head_dim_value * p.num_heads;
|
||||
}
|
||||
|
||||
// launch kernel :)
|
||||
@ -1005,7 +1008,9 @@ int run_attention(Options& options) {
|
||||
true, // Memory is aligned
|
||||
kQueriesPerBlock,
|
||||
kKeysPerBlock,
|
||||
kSingleValueIteration
|
||||
kSingleValueIteration,
|
||||
false, // Supports dropout
|
||||
false // Supports bias
|
||||
>;
|
||||
|
||||
//
|
||||
|
||||
@ -42,6 +42,8 @@
|
||||
This is really only for the FastF32 case - aka using TensorCores with fp32.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/threadblock/default_mma.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
||||
@ -36,137 +36,15 @@
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "gemm_kernel_utils.h"
|
||||
|
||||
namespace {
|
||||
|
||||
static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) {
|
||||
// source: https://stackoverflow.com/a/51549250
|
||||
return (value >= 0)
|
||||
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
|
||||
: __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/* Iterates on the accumulator and corresponding position on result matrix
|
||||
|
||||
(1) Update `mi[r]` to the max value of the row `r`
|
||||
(2) In a second iteration do the following:
|
||||
(a) accum <- exp(accum - mi)
|
||||
(b) m_prime <- exp(m_prime - mi)
|
||||
(c) s_prime <- s_prime * m_prime + sum(accum)
|
||||
|
||||
All of this is done on registers, before we store all of this
|
||||
on shared memory for the next matmul with Value.
|
||||
|
||||
We have multiple implementations, because each configuration has a different way
|
||||
of iterating in the accumulators.
|
||||
/*
|
||||
TensorCores have different accumulator layouts.
|
||||
This file provides a class to easily map the accumulator
|
||||
i-th element with the corresponding matrix row/col.
|
||||
*/
|
||||
|
||||
template <typename BASE, typename T, typename accum_t, int kWarpSize>
|
||||
struct RegisterOps {
|
||||
template <
|
||||
int kQueriesPerBlock,
|
||||
bool kFullColumns,
|
||||
bool kIsFirst,
|
||||
bool kKeepOutputInRF>
|
||||
CUTLASS_DEVICE static void update(
|
||||
typename T::Fragment& frag_o, // output so far
|
||||
typename T::Fragment& frag,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock>& mi,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
|
||||
int8_t lane_id,
|
||||
int8_t thread_id,
|
||||
int8_t warp_id,
|
||||
int16_t max_col,
|
||||
typename T::TensorCoord const& tile_offset,
|
||||
float scaling) {
|
||||
// Convert to `accum_t` (rather than double)
|
||||
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
|
||||
if (!kIsFirst) {
|
||||
if (thread_id < kQueriesPerBlock) {
|
||||
m_prime[thread_id] = mi[thread_id];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
auto lane_offset = BASE::get_lane_offset(lane_id, warp_id, tile_offset);
|
||||
|
||||
// First update `mi` to the max per-row
|
||||
{
|
||||
accum_t max;
|
||||
BASE::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) {
|
||||
max = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
},
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
if (kFullColumns || accum_n < max_col) {
|
||||
max = cutlass::fast_max(max, frag[idx]);
|
||||
}
|
||||
},
|
||||
[&](int accum_m) {
|
||||
// Having 4x atomicMax seems faster than reduce within warp
|
||||
// first...
|
||||
atomicMaxFloat(&mi[accum_m], max * scaling);
|
||||
});
|
||||
}
|
||||
frag = cutlass::multiplies<typename T::Fragment>()(scaling * kLog2e, frag);
|
||||
|
||||
// Make sure we all share the update values for `mi`
|
||||
__syncthreads();
|
||||
|
||||
if (thread_id < kQueriesPerBlock) {
|
||||
auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id]));
|
||||
m_prime[thread_id] = m_prime_exp;
|
||||
s_prime[thread_id] *= m_prime_exp;
|
||||
}
|
||||
__syncthreads(); // Update output fragments
|
||||
if (kKeepOutputInRF && !kIsFirst) {
|
||||
accum_t mp;
|
||||
BASE::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) { mp = m_prime[accum_m]; },
|
||||
[&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; },
|
||||
[&](int accum_m) {});
|
||||
__syncthreads();
|
||||
}
|
||||
// Update accum_m, accum_n, ...
|
||||
{
|
||||
accum_t mi_row, total_row;
|
||||
BASE::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) { mi_row = kLog2e * mi[accum_m]; },
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
frag[idx] = (kFullColumns || accum_n < max_col)
|
||||
? exp2f(frag[idx] - mi_row)
|
||||
: accum_t(0.0);
|
||||
},
|
||||
[&](int accum_m) {});
|
||||
BASE::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) { total_row = 0.0; },
|
||||
[&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; },
|
||||
[&](int accum_m) {
|
||||
if (BASE::reduceSameRow(
|
||||
lane_id, total_row, [](accum_t a, accum_t b) {
|
||||
return a + b;
|
||||
})) {
|
||||
atomicAdd(&s_prime[accum_m], total_row);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename accum_t, int kWarpSize>
|
||||
struct AttentionScalingCoefsUpdaterSm80
|
||||
: RegisterOps<
|
||||
AttentionScalingCoefsUpdaterSm80<T, accum_t, kWarpSize>,
|
||||
T,
|
||||
accum_t,
|
||||
kWarpSize> {
|
||||
struct AccumLambdaIteratorSm80 {
|
||||
static_assert(
|
||||
cutlass::platform::
|
||||
is_same<typename T::Layout, cutlass::layout::RowMajor>::value,
|
||||
@ -239,12 +117,7 @@ struct AttentionScalingCoefsUpdaterSm80
|
||||
};
|
||||
|
||||
template <typename T, typename accum_t, int kWarpSize>
|
||||
struct AttentionScalingCoefsUpdaterVolta
|
||||
: RegisterOps<
|
||||
AttentionScalingCoefsUpdaterVolta<T, accum_t, kWarpSize>,
|
||||
T,
|
||||
accum_t,
|
||||
kWarpSize> {
|
||||
struct AccumLambdaIteratorSm70 {
|
||||
static_assert(
|
||||
cutlass::platform::
|
||||
is_same<typename T::Layout, cutlass::layout::RowMajor>::value,
|
||||
@ -357,12 +230,7 @@ struct AttentionScalingCoefsUpdaterVolta
|
||||
};
|
||||
|
||||
template <typename T, typename accum_t, int kWarpSize>
|
||||
struct AttentionScalingCoefsUpdaterSimt
|
||||
: RegisterOps<
|
||||
AttentionScalingCoefsUpdaterSimt<T, accum_t, kWarpSize>,
|
||||
T,
|
||||
accum_t,
|
||||
kWarpSize> {
|
||||
struct AccumLambdaIteratorSimt {
|
||||
using Policy = typename T::Policy;
|
||||
using Iterations = typename T::Iterations;
|
||||
using Element = typename T::Element;
|
||||
@ -436,11 +304,11 @@ struct AttentionScalingCoefsUpdaterSimt
|
||||
};
|
||||
|
||||
template <typename T, typename accum_t, int kWarpSize>
|
||||
struct DefaultAttentionScalingCoefsUpdater;
|
||||
struct DefaultMmaAccumLambdaIterator;
|
||||
|
||||
// Simt
|
||||
template <typename S, typename P, typename accum_t, int kWarpSize>
|
||||
struct DefaultAttentionScalingCoefsUpdater<
|
||||
struct DefaultMmaAccumLambdaIterator<
|
||||
cutlass::gemm::warp::MmaSimtTileIterator<
|
||||
S,
|
||||
cutlass::gemm::Operand::kC,
|
||||
@ -451,7 +319,7 @@ struct DefaultAttentionScalingCoefsUpdater<
|
||||
1>,
|
||||
accum_t,
|
||||
kWarpSize> {
|
||||
using Iterator = typename cutlass::gemm::warp::MmaSimtTileIterator<
|
||||
using WarpIterator = typename cutlass::gemm::warp::MmaSimtTileIterator<
|
||||
S,
|
||||
cutlass::gemm::Operand::kC,
|
||||
accum_t,
|
||||
@ -459,13 +327,12 @@ struct DefaultAttentionScalingCoefsUpdater<
|
||||
P,
|
||||
1,
|
||||
1>;
|
||||
using Updater =
|
||||
AttentionScalingCoefsUpdaterSimt<Iterator, accum_t, kWarpSize>;
|
||||
using Iterator = AccumLambdaIteratorSimt<WarpIterator, accum_t, kWarpSize>;
|
||||
};
|
||||
|
||||
// TensorOp - Volta
|
||||
template <typename S1, typename S2, typename accum_t, int kWarpSize>
|
||||
struct DefaultAttentionScalingCoefsUpdater<
|
||||
struct DefaultMmaAccumLambdaIterator<
|
||||
cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator<
|
||||
S1,
|
||||
accum_t,
|
||||
@ -474,15 +341,14 @@ struct DefaultAttentionScalingCoefsUpdater<
|
||||
cutlass::MatrixShape<1, 1>>,
|
||||
accum_t,
|
||||
kWarpSize> {
|
||||
using Iterator =
|
||||
using WarpIterator =
|
||||
typename cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator<
|
||||
S1,
|
||||
accum_t,
|
||||
cutlass::layout::RowMajor,
|
||||
S2,
|
||||
cutlass::MatrixShape<1, 1>>;
|
||||
using Updater =
|
||||
AttentionScalingCoefsUpdaterVolta<Iterator, accum_t, kWarpSize>;
|
||||
using Iterator = AccumLambdaIteratorSm70<WarpIterator, accum_t, kWarpSize>;
|
||||
};
|
||||
|
||||
// TensorOp - Sm75+
|
||||
@ -492,7 +358,7 @@ template <
|
||||
typename S3,
|
||||
typename accum_t,
|
||||
int kWarpSize>
|
||||
struct DefaultAttentionScalingCoefsUpdater<
|
||||
struct DefaultMmaAccumLambdaIterator<
|
||||
cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator<
|
||||
S1,
|
||||
accum_t,
|
||||
@ -501,13 +367,12 @@ struct DefaultAttentionScalingCoefsUpdater<
|
||||
S3>,
|
||||
accum_t,
|
||||
kWarpSize> {
|
||||
using Iterator =
|
||||
using WarpIterator =
|
||||
typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator<
|
||||
S1,
|
||||
accum_t,
|
||||
cutlass::layout::RowMajor,
|
||||
S2,
|
||||
S3>;
|
||||
using Updater =
|
||||
AttentionScalingCoefsUpdaterSm80<Iterator, accum_t, kWarpSize>;
|
||||
using Iterator = AccumLambdaIteratorSm80<WarpIterator, accum_t, kWarpSize>;
|
||||
};
|
||||
@ -43,22 +43,26 @@
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/platform/platform.h"
|
||||
#include "cutlass/transform/threadblock/vector_iterator.h"
|
||||
|
||||
#include "attention_scaling_coefs_updater.h"
|
||||
#include "../epilogue/epilogue_thread_apply_logsumexp.h"
|
||||
#include "../gemm/mma_accum_lambda_iterator.h"
|
||||
#include "../gemm_kernel_utils.h"
|
||||
#include "../iterators/make_residual_last.h"
|
||||
#include "../iterators/transpose_warp_iterator.h"
|
||||
#include "../iterators/warp_iterator_from_smem.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h"
|
||||
#include "cutlass/gemm/threadblock/mma_base.h"
|
||||
#include "cutlass/gemm/threadblock/mma_multistage.h"
|
||||
#include "cutlass/gemm/threadblock/mma_pipelined.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h"
|
||||
#include "epilogue_thread_apply_logsumexp.h"
|
||||
#include "gemm_kernel_utils.h"
|
||||
#include "iterators/make_residual_last.h"
|
||||
#include "iterators/transpose_warp_iterator.h"
|
||||
#include "iterators/warp_iterator_from_smem.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
@ -246,6 +250,78 @@ class MmaBaseFromSharedMemory {
|
||||
: warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {}
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
||||
// has necessary trait compliance with WarpIteratorFromSmem but doesn't do
|
||||
// anything, can be default initialized, and uses fragment that takes up
|
||||
// (almost) no space. this warp iterator is selected at compile time when
|
||||
// elementwise on-the-fly scaling for operand A is disabled, in which case
|
||||
// operations related to loading scale factors for operand A get wiped out by
|
||||
// the compiler.
|
||||
template <typename TensorRef>
|
||||
class NoOpWarpIteratorScale {
|
||||
public:
|
||||
// in pipelined+multistage MMA implementations we keep an array of fragments.
|
||||
// if we aren't using scaling we don't want to waste registers on fragments
|
||||
// of scale elements, so ideally this would be sized 0.
|
||||
// using size 1 is kind of a hack to get around arrays of zero-sized objects
|
||||
// not being allowed. the compiler is probably smart enough to wipe it out
|
||||
// anyways.
|
||||
using Fragment = cutlass::Array<char, 1>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
NoOpWarpIteratorScale() {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
NoOpWarpIteratorScale(TensorRef const&, int) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
NoOpWarpIteratorScale& add_tile_offset(
|
||||
typename TensorRef::TensorCoord const&) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
NoOpWarpIteratorScale& operator++() {
|
||||
return *this;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment&) const {}
|
||||
};
|
||||
|
||||
// if scaling is enabled, performs fragment elementwise multiplication between
|
||||
// fragment and its scaling factor.
|
||||
template <typename Fragment, typename FragmentScale, bool ScalingEnabled>
|
||||
class FragmentElementwiseScaler;
|
||||
|
||||
// specialization for scaling being enabled.
|
||||
template <typename Fragment, typename FragmentScale>
|
||||
class FragmentElementwiseScaler<Fragment, FragmentScale, true> {
|
||||
public:
|
||||
// cast scale_frag to correct type then apply elementwise to fragment
|
||||
CUTLASS_DEVICE
|
||||
static Fragment apply(Fragment frag, FragmentScale const& scale_frag) {
|
||||
Fragment converted_scale_frag = cutlass::NumericArrayConverter<
|
||||
typename Fragment::Element,
|
||||
typename FragmentScale::Element,
|
||||
FragmentScale::kElements>()(scale_frag);
|
||||
return cutlass::multiplies<Fragment>()(frag, converted_scale_frag);
|
||||
}
|
||||
};
|
||||
|
||||
// specialization for scaling being disabled. doesn't do anything and should
|
||||
// just get wiped out by the compiler.
|
||||
template <typename Fragment, typename FragmentScale>
|
||||
class FragmentElementwiseScaler<Fragment, FragmentScale, false> {
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static Fragment apply(Fragment frag, FragmentScale const&) {
|
||||
return frag;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Taken from
|
||||
// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h
|
||||
@ -259,6 +335,10 @@ template <
|
||||
// BEGIN smem
|
||||
/// Iterates over the intermediate accumulator tile in shared memory
|
||||
typename WarpIteratorA,
|
||||
/// whether or not to perform elementwise multiplication of A
|
||||
// by another matrix (A_scale) that is also kept in shared memory prior
|
||||
// to matmul A @ B
|
||||
bool ScaleOperandA_,
|
||||
// Accumulator type
|
||||
typename AccumulatorSharedStorage,
|
||||
// END smem
|
||||
@ -297,6 +377,15 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
|
||||
|
||||
using Shape =
|
||||
Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
static constexpr bool ScaleOperandA = ScaleOperandA_;
|
||||
|
||||
///< loads fragments of A_scale from shared memory if operand A scaling is
|
||||
///< enabled. otherwise no-op.
|
||||
using WarpIteratorAScale = typename cutlass::platform::conditional<
|
||||
ScaleOperandA,
|
||||
WarpIteratorA,
|
||||
NoOpWarpIteratorScale<typename WarpIteratorA::TensorRef>>::type;
|
||||
|
||||
using IteratorB =
|
||||
IteratorB_; ///< Iterates over tiles of B operand in global memory
|
||||
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
||||
@ -333,8 +422,20 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
|
||||
|
||||
private:
|
||||
using WarpFragmentA = typename Operator::FragmentA;
|
||||
|
||||
/// fragment type of OperandA elementwise scaling matrix. (almost) empty
|
||||
/// if operand A scaling is disabled.
|
||||
using WarpFragmentAScale = typename WarpIteratorAScale::Fragment;
|
||||
|
||||
using WarpFragmentB = typename Operator::FragmentB;
|
||||
|
||||
/// applies scaling factor to operand A fragment if operand A scaling is
|
||||
/// enabled. otherwise no-op.
|
||||
using FragmentAScaler = FragmentElementwiseScaler<
|
||||
WarpFragmentA,
|
||||
WarpFragmentAScale,
|
||||
ScaleOperandA>;
|
||||
|
||||
protected:
|
||||
// /// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
// SmemIteratorA smem_iterator_A_;
|
||||
@ -346,7 +447,46 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
|
||||
/// accumulator tile
|
||||
WarpIteratorA warp_tile_iterator_A_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A_scale from intermediate
|
||||
/// accumulator tile (only used if ScaleOperandA_ is true)
|
||||
WarpIteratorAScale warp_tile_iterator_A_scale_;
|
||||
|
||||
public:
|
||||
/// constructor for MMA with operand A scaling enabled.
|
||||
CUTLASS_DEVICE
|
||||
MmaPipelinedFromSharedMemory(
|
||||
// shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::SharedStorage& shared_storage,
|
||||
// warp iterator over A tile held in shared memory
|
||||
WarpIteratorA warp_iter_a,
|
||||
// warp iterator over A_scale tile held in shared memory
|
||||
WarpIteratorAScale warp_iter_a_scale,
|
||||
int thread_idx,
|
||||
int warp_idx,
|
||||
int lane_idx)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
warp_tile_iterator_A_(warp_iter_a),
|
||||
warp_tile_iterator_A_scale_(warp_iter_a_scale),
|
||||
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_A_scale_.add_tile_offset(
|
||||
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
MmaPipelinedFromSharedMemory(
|
||||
@ -429,19 +569,26 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// remember that WarpFragmentAScale and WarpIteratorAScale are empty/no-op
|
||||
// if scaling is disabled.
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpFragmentA warp_frag_A[2];
|
||||
WarpFragmentAScale warp_frag_A_scale[2];
|
||||
WarpFragmentB warp_frag_B[2];
|
||||
warp_frag_A[0].clear();
|
||||
warp_frag_A_scale[0].clear();
|
||||
warp_frag_B[0].clear();
|
||||
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
|
||||
this->warp_tile_iterator_A_scale_.load(warp_frag_A_scale[0]);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_A_scale_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
Operator warp_mma;
|
||||
@ -503,9 +650,12 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
|
||||
(warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_A_scale_.load(
|
||||
warp_frag_A_scale[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_A_scale_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
@ -521,7 +671,8 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
|
||||
|
||||
warp_mma(
|
||||
accum,
|
||||
warp_frag_A[warp_mma_k % 2],
|
||||
FragmentAScaler::apply(
|
||||
warp_frag_A[warp_mma_k % 2], warp_frag_A_scale[warp_mma_k % 2]),
|
||||
warp_frag_B[warp_mma_k % 2],
|
||||
accum);
|
||||
}
|
||||
@ -541,6 +692,10 @@ template <
|
||||
typename Shape1_,
|
||||
/// Iterates over the intermediate accumulator tile in shared memory
|
||||
typename WarpIteratorA1_,
|
||||
/// whether or not to perform elementwise multiplication of A
|
||||
// by another matrix (A_scale) that is also kept in shared memory prior
|
||||
// to matmul A @ B
|
||||
bool ScaleOperandA_,
|
||||
// Accumulator type
|
||||
typename AccumulatorSharedStorage,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
@ -580,7 +735,14 @@ class MmaMultistageFromSharedMemory
|
||||
using SmemIteratorB1 = SmemIteratorB1_;
|
||||
using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate
|
||||
///< accumulator tile in shared memory
|
||||
static constexpr bool ScaleOperandA = ScaleOperandA_;
|
||||
|
||||
///< warp level iterator over A_scale matrix tile kept in shared memory.
|
||||
///< if elementwise A scaling is disabled then everything this does is no-op.
|
||||
using WarpIteratorAScale = typename cutlass::platform::conditional<
|
||||
ScaleOperandA,
|
||||
WarpIteratorA1,
|
||||
NoOpWarpIteratorScale<typename WarpIteratorA1::TensorRef>>::type;
|
||||
///< Data type of accumulator matrix
|
||||
using ElementC = ElementC_;
|
||||
///< Layout of accumulator matrix
|
||||
@ -628,10 +790,20 @@ class MmaMultistageFromSharedMemory
|
||||
|
||||
private:
|
||||
using WarpLoadedFragmentA1 = typename Operator1::FragmentA;
|
||||
/// fragment of OperandA scale matrix. if operand A scaling is disabled this
|
||||
/// is (almost) empty.
|
||||
using WarpLoadedFragmentA1Scale = typename WarpIteratorAScale::Fragment;
|
||||
using WarpLoadedFragmentB1 = typename Operator1::FragmentB;
|
||||
using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB;
|
||||
|
||||
/// applies elementwise scaling to fragment of A. if operand A scaling is
|
||||
/// disabled this is a no-op.
|
||||
using FragmentAScaler = FragmentElementwiseScaler<
|
||||
WarpLoadedFragmentA1,
|
||||
WarpLoadedFragmentA1Scale,
|
||||
ScaleOperandA>;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
@ -641,12 +813,54 @@ class MmaMultistageFromSharedMemory
|
||||
/// accumulator tile
|
||||
WarpIteratorA1 warp_tile_iterator_A1_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A1_scale operand from shared memory
|
||||
/// if operand A scaling is disabled everything this does is a no-op.
|
||||
WarpIteratorAScale warp_tile_iterator_A1_scale_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB1 smem_iterator_B1_;
|
||||
|
||||
bool prologue_done_;
|
||||
|
||||
public:
|
||||
/// constructor for MMA with operand A scaling enabled.
|
||||
CUTLASS_DEVICE
|
||||
MmaMultistageFromSharedMemory(
|
||||
// shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::SharedStorage& shared_storage,
|
||||
// warp level iterator over operand A tile kept in shared memory
|
||||
WarpIteratorA1 warp_tile_iterator_A1,
|
||||
// warp level iterator over operand A elementwise scale tile kept in
|
||||
// shared memory.
|
||||
WarpIteratorAScale warp_tile_iterator_A1_scale,
|
||||
int thread_idx,
|
||||
int warp_idx,
|
||||
int lane_idx)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
warp_tile_iterator_A1_(warp_tile_iterator_A1),
|
||||
warp_tile_iterator_A1_scale_(warp_tile_iterator_A1_scale),
|
||||
smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx),
|
||||
prologue_done_(false) {
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
int warp_idx_mn_1 =
|
||||
warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN);
|
||||
int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN);
|
||||
int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM;
|
||||
int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
warp_tile_iterator_A1_.add_tile_offset(
|
||||
{warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1});
|
||||
warp_tile_iterator_A1_scale_.add_tile_offset(
|
||||
{warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1});
|
||||
}
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
MmaMultistageFromSharedMemory(
|
||||
@ -842,9 +1056,13 @@ class MmaMultistageFromSharedMemory
|
||||
cutlass::arch::cp_async_wait<kNumStagesConcurrentLoad - 1>();
|
||||
__syncthreads();
|
||||
|
||||
// remember that WarpFragmentAScale and WarpIteratorAScale are no-op/empty
|
||||
// if scaling is disabled.
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpLoadedFragmentA1 warp_loaded_frag_A1[2];
|
||||
WarpLoadedFragmentA1Scale warp_loaded_frag_A1_scale[2];
|
||||
WarpLoadedFragmentB1 warp_loaded_frag_B1[2];
|
||||
WarpTransformedFragmentA1 warp_transformed_frag_A1[2];
|
||||
WarpTransformedFragmentB1 warp_transformed_frag_B1[2];
|
||||
@ -854,6 +1072,9 @@ class MmaMultistageFromSharedMemory
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]);
|
||||
++warp_tile_iterator_A1_;
|
||||
|
||||
warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[0]);
|
||||
++warp_tile_iterator_A1_scale_;
|
||||
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.load(warp_loaded_frag_B1[0]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
@ -864,7 +1085,8 @@ class MmaMultistageFromSharedMemory
|
||||
warp_mma1.transform(
|
||||
warp_transformed_frag_A1[0],
|
||||
warp_transformed_frag_B1[0],
|
||||
warp_loaded_frag_A1[0],
|
||||
FragmentAScaler::apply(
|
||||
warp_loaded_frag_A1[0], warp_loaded_frag_A1_scale[0]),
|
||||
warp_loaded_frag_B1[0]);
|
||||
|
||||
// tf32x3 kernels use staging accumulation. warp_mma uses a temporary
|
||||
@ -909,17 +1131,22 @@ class MmaMultistageFromSharedMemory
|
||||
warp_mma_k < Base::kWarpGemmIterations1 - 1) {
|
||||
warp_tile_iterator_A1_.load(
|
||||
warp_loaded_frag_A1[(warp_mma_k + 1) % 2]);
|
||||
warp_tile_iterator_A1_scale_.load(
|
||||
warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B_.load(
|
||||
warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
}
|
||||
++warp_tile_iterator_A1_;
|
||||
++warp_tile_iterator_A1_scale_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
if (warp_mma_k > 0)
|
||||
warp_mma1.transform(
|
||||
warp_transformed_frag_A1[warp_mma_k % 2],
|
||||
warp_transformed_frag_B1[warp_mma_k % 2],
|
||||
warp_loaded_frag_A1[warp_mma_k % 2],
|
||||
FragmentAScaler::apply(
|
||||
warp_loaded_frag_A1[warp_mma_k % 2],
|
||||
warp_loaded_frag_A1_scale[warp_mma_k % 2]),
|
||||
warp_loaded_frag_B1[warp_mma_k % 2]);
|
||||
|
||||
if (platform::is_same<
|
||||
@ -1009,7 +1236,9 @@ class MmaMultistageFromSharedMemory
|
||||
warp_mma1.transform(
|
||||
warp_transformed_frag_A1[(warp_mma_k + 1) % 2],
|
||||
warp_transformed_frag_B1[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A1[(warp_mma_k + 1) % 2],
|
||||
FragmentAScaler::apply(
|
||||
warp_loaded_frag_A1[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]),
|
||||
warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
}
|
||||
}
|
||||
@ -1119,6 +1348,9 @@ struct DefaultWarpIteratorAFromSharedMemory<
|
||||
template <
|
||||
typename Mma_,
|
||||
typename AccumulatorSharedStorage,
|
||||
/// whether or not to apply elementwise multiplication of operand A by
|
||||
/// another matrix in shared memory before usage in A @ B
|
||||
bool kScaleOperandA,
|
||||
bool kTransposeA = false>
|
||||
struct DefaultMmaFromSharedMemory;
|
||||
|
||||
@ -1151,6 +1383,9 @@ template <
|
||||
/// Transformation applied to B operand
|
||||
typename TransformB_,
|
||||
typename AccumulatorSharedStorage_,
|
||||
/// whether or not to apply elementwise multiplication of operand A by
|
||||
/// another matrix in shared memory before usage in A @ B
|
||||
bool kScaleOperandA,
|
||||
bool kTransposeA>
|
||||
struct DefaultMmaFromSharedMemory<
|
||||
MmaPipelined<
|
||||
@ -1165,6 +1400,7 @@ struct DefaultMmaFromSharedMemory<
|
||||
TransformA_,
|
||||
TransformB_>,
|
||||
AccumulatorSharedStorage_,
|
||||
kScaleOperandA,
|
||||
kTransposeA> {
|
||||
static constexpr int kWarpSize = 32;
|
||||
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
||||
@ -1198,6 +1434,7 @@ struct DefaultMmaFromSharedMemory<
|
||||
using Mma = typename cutlass::gemm::threadblock::MmaPipelinedFromSharedMemory<
|
||||
Shape_,
|
||||
WarpIteratorA,
|
||||
kScaleOperandA,
|
||||
AccumulatorSharedStorage_,
|
||||
IteratorB,
|
||||
SmemIteratorB_,
|
||||
@ -1238,6 +1475,9 @@ template <
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear,
|
||||
typename AccumulatorSharedStorage_,
|
||||
/// whether or not to apply elementwise multiplication of operand A by
|
||||
/// another matrix in shared memory before usage in A @ B
|
||||
bool kScaleOperandA,
|
||||
bool kTransposeA>
|
||||
struct DefaultMmaFromSharedMemory<
|
||||
MmaMultistage<
|
||||
@ -1254,6 +1494,7 @@ struct DefaultMmaFromSharedMemory<
|
||||
Stages,
|
||||
SharedMemoryClear>,
|
||||
AccumulatorSharedStorage_,
|
||||
kScaleOperandA,
|
||||
kTransposeA> {
|
||||
static constexpr int kWarpSize = 32;
|
||||
|
||||
@ -1301,6 +1542,7 @@ struct DefaultMmaFromSharedMemory<
|
||||
typename cutlass::gemm::threadblock::MmaMultistageFromSharedMemory<
|
||||
Shape_,
|
||||
WarpIteratorA,
|
||||
kScaleOperandA,
|
||||
AccumulatorSharedStorage_,
|
||||
IteratorB,
|
||||
SmemIteratorB_,
|
||||
@ -1637,18 +1879,17 @@ struct B2bGemm<
|
||||
// NOTE: accum is attn.T
|
||||
// TODO: Optimize for each architecture
|
||||
static constexpr int WarpSize = 32;
|
||||
using RegistersIter = typename DefaultAttentionScalingCoefsUpdater<
|
||||
IteratorC,
|
||||
accum_t,
|
||||
WarpSize>::Updater;
|
||||
using AccumLambdaIterator =
|
||||
typename DefaultMmaAccumLambdaIterator<IteratorC, accum_t, WarpSize>::
|
||||
Iterator;
|
||||
auto lane_offset =
|
||||
RegistersIter::get_lane_offset(lane_id, warp_id, tile_coords);
|
||||
AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords);
|
||||
|
||||
cutlass::Array<lse_scalar_t, IteratorC::Fragment::kElements> lse_prefetched;
|
||||
lse_prefetched.clear();
|
||||
int rowIdx = 0;
|
||||
int colIdx = 0;
|
||||
RegistersIter::iterateRows(
|
||||
AccumLambdaIterator::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) {
|
||||
++rowIdx;
|
||||
@ -1777,18 +2018,17 @@ struct B2bGemm<
|
||||
// NOTE: accum is attn.T
|
||||
// TODO: Optimize for each architecture
|
||||
static constexpr int WarpSize = 32;
|
||||
using RegistersIter = typename DefaultAttentionScalingCoefsUpdater<
|
||||
IteratorC,
|
||||
accum_t,
|
||||
WarpSize>::Updater;
|
||||
using AccumLambdaIterator =
|
||||
typename DefaultMmaAccumLambdaIterator<IteratorC, accum_t, WarpSize>::
|
||||
Iterator;
|
||||
auto lane_offset =
|
||||
RegistersIter::get_lane_offset(lane_id, warp_id, tile_coords);
|
||||
AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords);
|
||||
|
||||
cutlass::Array<lse_scalar_t, IteratorC::Fragment::kElements> lse_prefetched;
|
||||
lse_prefetched.clear();
|
||||
int rowIdx = 0;
|
||||
int colIdx = 0;
|
||||
RegistersIter::iterateRows(
|
||||
AccumLambdaIterator::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) {
|
||||
++rowIdx;
|
||||
@ -29,16 +29,26 @@
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef HAS_PYTORCH
|
||||
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||
#include <ATen/cuda/CUDAGraphsUtils.cuh>
|
||||
#endif
|
||||
|
||||
#include <curand_kernel.h>
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/layout/vector.h"
|
||||
#include "cutlass/matrix.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
#include "attention_scaling_coefs_updater.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
||||
@ -54,11 +64,12 @@
|
||||
#include "cutlass/platform/platform.h"
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||
#include "debug_utils.h"
|
||||
#include "epilogue_pipelined.h"
|
||||
#include "epilogue_rescale_output.h"
|
||||
#include "find_default_mma.h"
|
||||
#include "epilogue/epilogue_pipelined.h"
|
||||
#include "epilogue/epilogue_rescale_output.h"
|
||||
#include "gemm/find_default_mma.h"
|
||||
#include "gemm/mma_from_smem.h"
|
||||
#include "gemm_kernel_utils.h"
|
||||
#include "mma_from_smem.h"
|
||||
#include "transform/tile_smem_loader.h"
|
||||
|
||||
#include <inttypes.h>
|
||||
|
||||
@ -73,6 +84,12 @@ constexpr int getWarpsPerSm() {
|
||||
? 16
|
||||
: 12);
|
||||
}
|
||||
static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) {
|
||||
// source: https://stackoverflow.com/a/51549250
|
||||
return (value >= 0)
|
||||
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
|
||||
: __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <
|
||||
@ -83,10 +100,20 @@ template <
|
||||
// If Q/K/V are correctly aligned in memory and we can run a fast kernel
|
||||
bool isAligned_,
|
||||
int kQueriesPerBlock,
|
||||
int kKeysPerBlock,
|
||||
bool kSingleValueIteration // = `value.shape[-1] <= kKeysPerBlock`
|
||||
>
|
||||
int kKeysPerBlock_,
|
||||
bool kSingleValueIteration_, // = `value.shape[-1] <= kKeysPerBlock`
|
||||
// This is quite slower on V100 for some reason
|
||||
// Set to false if you know at compile-time you will never need dropout
|
||||
bool kSupportsDropout_ = true,
|
||||
bool kSupportsBias_ = true>
|
||||
struct AttentionKernel {
|
||||
enum CustomMaskType {
|
||||
NoCustomMask = 0,
|
||||
CausalFromTopLeft = 1,
|
||||
CausalFromBottomRight = 2,
|
||||
NumCustomMaskTypes,
|
||||
};
|
||||
|
||||
using scalar_t = scalar_t_;
|
||||
using accum_t = float;
|
||||
using lse_scalar_t = float;
|
||||
@ -95,7 +122,11 @@ struct AttentionKernel {
|
||||
// Using `accum_t` improves perf on f16 at the cost of
|
||||
// numerical errors
|
||||
using output_accum_t = accum_t;
|
||||
static constexpr bool kSupportsDropout = kSupportsDropout_;
|
||||
static constexpr bool kSupportsBias = kSupportsBias_;
|
||||
static constexpr int kKeysPerBlock = kKeysPerBlock_;
|
||||
static constexpr bool kIsAligned = isAligned_;
|
||||
static constexpr bool kSingleValueIteration = kSingleValueIteration_;
|
||||
static constexpr int32_t kAlignLSE = 32; // block size of backward
|
||||
static constexpr bool kPreloadV = ArchTag::kMinComputeCapability >= 80 &&
|
||||
cutlass::sizeof_bits<scalar_t>::value == 16;
|
||||
@ -117,10 +148,15 @@ struct AttentionKernel {
|
||||
struct Params {
|
||||
// Input tensors
|
||||
scalar_t* query_ptr; // [num_queries, num_heads, head_dim]
|
||||
scalar_t* key_ptr; // [num_keys, num_heads, head_dim]
|
||||
scalar_t* key_ptr; // [num_keys, num_heads, head_dim]
|
||||
scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value]
|
||||
int32_t* cu_seqlens_q_ptr = nullptr;
|
||||
int32_t* cu_seqlens_k_ptr = nullptr;
|
||||
scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys]
|
||||
int32_t* seqstart_q_ptr = nullptr;
|
||||
int32_t* seqstart_k_ptr = nullptr;
|
||||
|
||||
int32_t* causal_diagonal_ptr = nullptr;
|
||||
int32_t* seqlen_k_ptr = nullptr;
|
||||
uint32_t causal_diagonal_offset = 0;
|
||||
|
||||
// Output tensors
|
||||
output_t* output_ptr; // [num_queries, num_heads, head_dim_value]
|
||||
@ -137,26 +173,38 @@ struct AttentionKernel {
|
||||
int32_t num_queries;
|
||||
int32_t num_keys;
|
||||
|
||||
bool causal;
|
||||
uint8_t custom_mask_type = NoCustomMask;
|
||||
|
||||
int32_t q_strideM;
|
||||
int32_t k_strideM;
|
||||
int32_t v_strideM;
|
||||
int32_t bias_strideM = 0;
|
||||
|
||||
int32_t o_strideM = 0;
|
||||
|
||||
// Everything below is only used in `advance_to_block`
|
||||
// and shouldn't use registers
|
||||
int32_t q_strideH;
|
||||
int32_t k_strideH;
|
||||
int32_t v_strideH;
|
||||
int32_t bias_strideH = 0;
|
||||
|
||||
int64_t q_strideB;
|
||||
int64_t k_strideB;
|
||||
int64_t v_strideB;
|
||||
int32_t bias_strideB = 0;
|
||||
|
||||
int32_t num_batches;
|
||||
int32_t num_heads;
|
||||
|
||||
CUTLASS_HOST_DEVICE int32_t o_strideM() const {
|
||||
return head_dim_value * num_heads;
|
||||
}
|
||||
// dropout
|
||||
bool use_dropout;
|
||||
unsigned long long dropout_batch_head_rng_offset;
|
||||
float dropout_prob;
|
||||
#ifdef HAS_PYTORCH
|
||||
at::PhiloxCudaState rng_engine_inputs;
|
||||
#endif
|
||||
|
||||
// Moves pointers to what we should process
|
||||
// Returns "false" if there is no work to do
|
||||
CUTLASS_DEVICE bool advance_to_block() {
|
||||
@ -166,18 +214,33 @@ struct AttentionKernel {
|
||||
|
||||
auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE;
|
||||
|
||||
if (kSupportsDropout) {
|
||||
dropout_batch_head_rng_offset =
|
||||
batch_id * num_heads * num_queries * num_keys +
|
||||
head_id * num_queries * num_keys;
|
||||
}
|
||||
|
||||
int64_t q_start, k_start;
|
||||
// Advance to current batch - in case of different sequence lengths
|
||||
if (cu_seqlens_q_ptr != nullptr) {
|
||||
assert(cu_seqlens_k_ptr != nullptr);
|
||||
cu_seqlens_q_ptr += batch_id;
|
||||
cu_seqlens_k_ptr += batch_id;
|
||||
q_start = cu_seqlens_q_ptr[0];
|
||||
k_start = cu_seqlens_k_ptr[0];
|
||||
int64_t q_next_start = cu_seqlens_q_ptr[1];
|
||||
int64_t k_next_start = cu_seqlens_k_ptr[1];
|
||||
if (seqstart_q_ptr != nullptr) {
|
||||
assert(seqstart_k_ptr != nullptr);
|
||||
seqstart_q_ptr += batch_id;
|
||||
|
||||
q_start = seqstart_q_ptr[0];
|
||||
int64_t q_next_start = seqstart_q_ptr[1];
|
||||
int64_t k_end;
|
||||
seqstart_k_ptr += batch_id;
|
||||
|
||||
if (seqlen_k_ptr) {
|
||||
k_start = seqstart_k_ptr[0];
|
||||
k_end = k_start + seqlen_k_ptr[batch_id];
|
||||
} else {
|
||||
k_start = seqstart_k_ptr[0];
|
||||
k_end = seqstart_k_ptr[1];
|
||||
}
|
||||
|
||||
num_queries = q_next_start - q_start;
|
||||
num_keys = k_next_start - k_start;
|
||||
num_keys = k_end - k_start;
|
||||
|
||||
if (query_start >= num_queries) {
|
||||
return false;
|
||||
@ -186,9 +249,10 @@ struct AttentionKernel {
|
||||
query_ptr += batch_id * q_strideB;
|
||||
key_ptr += batch_id * k_strideB;
|
||||
value_ptr += batch_id * v_strideB;
|
||||
output_ptr += int64_t(batch_id * num_queries) * o_strideM();
|
||||
output_ptr += int64_t(batch_id * num_queries) * o_strideM;
|
||||
if (output_accum_ptr != nullptr) {
|
||||
output_accum_ptr += int64_t(batch_id * num_queries) * o_strideM();
|
||||
output_accum_ptr +=
|
||||
int64_t(batch_id * num_queries) * (head_dim_value * num_heads);
|
||||
}
|
||||
q_start = 0;
|
||||
k_start = 0;
|
||||
@ -197,42 +261,84 @@ struct AttentionKernel {
|
||||
// Advance to the current batch / head / query_start
|
||||
query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH;
|
||||
key_ptr += k_start * k_strideM + head_id * k_strideH;
|
||||
value_ptr += k_start * v_strideM + head_id * v_strideH;
|
||||
output_ptr += int64_t(q_start + query_start) * o_strideM() +
|
||||
head_id * head_dim_value;
|
||||
|
||||
value_ptr += k_start * v_strideM + head_id * v_strideH;
|
||||
output_ptr +=
|
||||
int64_t(q_start + query_start) * o_strideM + head_id * head_dim_value;
|
||||
|
||||
if (kSupportsBias && attn_bias_ptr != nullptr) {
|
||||
attn_bias_ptr += (batch_id * bias_strideB) + (head_id * bias_strideH);
|
||||
}
|
||||
if (output_accum_ptr != nullptr) {
|
||||
output_accum_ptr += int64_t(q_start + query_start) * o_strideM() +
|
||||
output_accum_ptr +=
|
||||
int64_t(q_start + query_start) * (head_dim_value * num_heads) +
|
||||
head_id * head_dim_value;
|
||||
} else {
|
||||
// Accumulate directly in the destination buffer (eg for f32)
|
||||
output_accum_ptr = (accum_t*)output_ptr;
|
||||
}
|
||||
|
||||
if (logsumexp_ptr != nullptr) {
|
||||
// lse[batch_id, head_id, query_start]
|
||||
logsumexp_ptr +=
|
||||
batch_id * lse_dim * num_heads + head_id * lse_dim + query_start;
|
||||
}
|
||||
|
||||
num_queries -= query_start;
|
||||
if (causal) {
|
||||
num_keys = cutlass::fast_min(
|
||||
int32_t(query_start + kQueriesPerBlock), num_keys);
|
||||
// Custom masking
|
||||
if (causal_diagonal_ptr) {
|
||||
causal_diagonal_offset = causal_diagonal_ptr[batch_id];
|
||||
}
|
||||
if (custom_mask_type == CausalFromBottomRight) {
|
||||
causal_diagonal_offset += num_keys - num_queries;
|
||||
}
|
||||
if (custom_mask_type == CausalFromTopLeft ||
|
||||
custom_mask_type == CausalFromBottomRight) {
|
||||
// the bottom row of the current block is query_start + kQueriesPerBlock
|
||||
// the last active key is then query_start + causal_diagonal_offset +
|
||||
// kQueriesPerBlock so num_keys is the min between actual num_keys and
|
||||
// this to avoid extra computations
|
||||
num_keys = cutlass::fast_min(
|
||||
int32_t(query_start + causal_diagonal_offset + kQueriesPerBlock),
|
||||
num_keys);
|
||||
}
|
||||
|
||||
num_queries -= query_start;
|
||||
num_batches = 0; // no longer used after
|
||||
|
||||
// If num_queries == 1, and there is only one key head we're wasting
|
||||
// 15/16th of tensor core compute In that case :
|
||||
// - we only launch kernels for head_id % kQueriesPerBlock == 0
|
||||
// - we iterate over heads instead of queries (strideM = strideH)
|
||||
if (num_queries == 1 && k_strideH == 0 && v_strideH == 0) {
|
||||
if (head_id % kQueriesPerBlock != 0)
|
||||
return false;
|
||||
q_strideM = q_strideH;
|
||||
num_queries = num_heads;
|
||||
num_heads = 1; // unused but here for intent
|
||||
// remove causal since n_query = 1
|
||||
// otherwise, offset would change with head !
|
||||
custom_mask_type = NoCustomMask;
|
||||
o_strideM = head_dim_value;
|
||||
}
|
||||
|
||||
// Make sure the compiler knows these variables are the same on all
|
||||
// the threads of the warp.
|
||||
query_ptr = warp_uniform(query_ptr);
|
||||
key_ptr = warp_uniform(key_ptr);
|
||||
value_ptr = warp_uniform(value_ptr);
|
||||
if (kSupportsBias) {
|
||||
attn_bias_ptr = warp_uniform(attn_bias_ptr);
|
||||
}
|
||||
output_ptr = warp_uniform(output_ptr);
|
||||
output_accum_ptr = warp_uniform(output_accum_ptr);
|
||||
logsumexp_ptr = warp_uniform(logsumexp_ptr);
|
||||
num_queries = warp_uniform(num_queries);
|
||||
num_keys = warp_uniform(num_keys);
|
||||
num_heads = warp_uniform(num_heads);
|
||||
head_dim = warp_uniform(head_dim);
|
||||
head_dim_value = warp_uniform(head_dim_value);
|
||||
o_strideM = warp_uniform(o_strideM);
|
||||
custom_mask_type = warp_uniform(custom_mask_type);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -242,6 +348,7 @@ struct AttentionKernel {
|
||||
num_heads,
|
||||
num_batches);
|
||||
}
|
||||
|
||||
__host__ dim3 getThreadsGrid() const {
|
||||
return dim3(kWarpSize, kNumWarpsPerBlock, 1);
|
||||
}
|
||||
@ -296,16 +403,24 @@ struct AttentionKernel {
|
||||
using IteratorA = typename DefaultMma::IteratorA;
|
||||
using IteratorB = typename DefaultMma::IteratorB;
|
||||
using Mma = typename DefaultMma::ThreadblockMma;
|
||||
using ScalingCoefsUpdater = typename DefaultAttentionScalingCoefsUpdater<
|
||||
using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
|
||||
typename Mma::Operator::IteratorC,
|
||||
accum_t,
|
||||
kWarpSize>::Updater;
|
||||
kWarpSize>::Iterator;
|
||||
static_assert(
|
||||
MmaCore::WarpCount::kM * MmaCore::WarpCount::kN *
|
||||
MmaCore::WarpCount::kK ==
|
||||
kNumWarpsPerBlock,
|
||||
"");
|
||||
|
||||
// used for efficient load of bias tile Bij from global to shared memory
|
||||
using BiasLoader = TileSmemLoader<
|
||||
scalar_t,
|
||||
cutlass::MatrixShape<kQueriesPerBlock, kKeysPerBlock>,
|
||||
MmaCore::kThreads,
|
||||
// input restriction: kv_len has to be a multiple of this value
|
||||
128 / cutlass::sizeof_bits<scalar_t>::value>;
|
||||
|
||||
// Epilogue to store to shared-memory in a format that we can use later for
|
||||
// the second matmul
|
||||
using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm<
|
||||
@ -367,7 +482,8 @@ struct AttentionKernel {
|
||||
using DefaultMmaFromSmem =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
|
||||
typename DefaultGemm::Mma,
|
||||
typename MM0::AccumulatorSharedStorage>;
|
||||
typename MM0::AccumulatorSharedStorage,
|
||||
false>; // kScaleOperandA
|
||||
using Mma = typename DefaultMmaFromSmem::Mma;
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
@ -404,7 +520,10 @@ struct AttentionKernel {
|
||||
struct SharedStorageEpilogueAtEnd : ScalingCoefs {
|
||||
struct SharedStorageAfterMM0 {
|
||||
// Everything here might be overwritten during MM0
|
||||
typename MM0::AccumulatorSharedStorage si;
|
||||
union {
|
||||
typename MM0::BiasLoader::SmemTile bias;
|
||||
typename MM0::AccumulatorSharedStorage si;
|
||||
};
|
||||
typename MM1::SharedStorageMM1 mm1;
|
||||
};
|
||||
|
||||
@ -423,7 +542,10 @@ struct AttentionKernel {
|
||||
struct SharedStorageEpilogueInLoop : ScalingCoefs {
|
||||
struct SharedStorageAfterMM0 {
|
||||
// Everything here might be overwritten during MM0
|
||||
typename MM0::AccumulatorSharedStorage si;
|
||||
union {
|
||||
typename MM0::BiasLoader::SmemTile bias;
|
||||
typename MM0::AccumulatorSharedStorage si;
|
||||
};
|
||||
typename MM1::SharedStorageMM1 mm1;
|
||||
typename MM1::DefaultEpilogue::SharedStorage epilogue;
|
||||
};
|
||||
@ -448,6 +570,18 @@ struct AttentionKernel {
|
||||
CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ);
|
||||
CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK);
|
||||
CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV);
|
||||
if (kSupportsBias) {
|
||||
CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ);
|
||||
XFORMERS_CHECK(
|
||||
p.bias_strideB % kAlignmentQ == 0,
|
||||
"attn_bias is not correctly aligned");
|
||||
XFORMERS_CHECK(
|
||||
p.bias_strideH % kAlignmentQ == 0,
|
||||
"attn_bias is not correctly aligned");
|
||||
XFORMERS_CHECK(
|
||||
p.bias_strideM % kAlignmentQ == 0,
|
||||
"attn_bias is not correctly aligned");
|
||||
}
|
||||
XFORMERS_CHECK(
|
||||
p.q_strideM % kAlignmentQ == 0, "query is not correctly aligned");
|
||||
XFORMERS_CHECK(
|
||||
@ -460,6 +594,12 @@ struct AttentionKernel {
|
||||
p.k_strideH % kAlignmentK == 0, "key is not correctly aligned");
|
||||
XFORMERS_CHECK(
|
||||
p.v_strideH % kAlignmentV == 0, "value is not correctly aligned");
|
||||
XFORMERS_CHECK(
|
||||
p.causal_diagonal_ptr == nullptr || p.custom_mask_type != NoCustomMask,
|
||||
"`causal_diagonal_ptr` is only useful when `custom_mask_type` is causal");
|
||||
XFORMERS_CHECK(
|
||||
p.custom_mask_type < NumCustomMaskTypes,
|
||||
"invalid value for `custom_mask_type`");
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -472,8 +612,8 @@ struct AttentionKernel {
|
||||
SharedStorage& shared_storage = *((SharedStorage*)smem_buffer);
|
||||
auto& m_prime = shared_storage.m_prime;
|
||||
auto& s_prime = shared_storage.s_prime;
|
||||
[[maybe_unused]] auto& si = shared_storage.after_mm0.si;
|
||||
auto& mi = shared_storage.mi;
|
||||
const uint32_t query_start = blockIdx.x * kQueriesPerBlock;
|
||||
|
||||
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
|
||||
if (thread_id() < kQueriesPerBlock) {
|
||||
@ -488,7 +628,7 @@ struct AttentionKernel {
|
||||
auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator {
|
||||
using OutputTileIterator = typename MM1::OutputTileIterator;
|
||||
return OutputTileIterator(
|
||||
typename OutputTileIterator::Params{(int32_t)p.o_strideM()},
|
||||
typename OutputTileIterator::Params{(int32_t)p.o_strideM},
|
||||
p.output_ptr,
|
||||
typename OutputTileIterator::TensorCoord{
|
||||
p.num_queries, p.head_dim_value},
|
||||
@ -500,7 +640,8 @@ struct AttentionKernel {
|
||||
typename MM1::OutputTileIteratorAccum {
|
||||
using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum;
|
||||
return OutputTileIteratorAccum(
|
||||
typename OutputTileIteratorAccum::Params{(int32_t)p.o_strideM()},
|
||||
typename OutputTileIteratorAccum::Params{
|
||||
(int32_t)(p.head_dim_value * p.num_heads)},
|
||||
p.output_accum_ptr,
|
||||
typename OutputTileIteratorAccum::TensorCoord{
|
||||
p.num_queries, p.head_dim_value},
|
||||
@ -508,6 +649,27 @@ struct AttentionKernel {
|
||||
{0, col});
|
||||
};
|
||||
|
||||
#ifdef HAS_PYTORCH
|
||||
curandStatePhilox4_32_10_t curand_state_init;
|
||||
if (kSupportsDropout && p.use_dropout) {
|
||||
const auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs);
|
||||
|
||||
// each element of the attention matrix P with shape
|
||||
// (batch_sz, n_heads, n_queries, n_keys) is associated with a single
|
||||
// offset in RNG sequence. we initialize the RNG state with offset that
|
||||
// starts at the beginning of a (n_queries, n_keys) matrix for this
|
||||
// block's batch_id and head_id
|
||||
// initializing rng state is very expensive, so we run once per kernel,
|
||||
// rather than once per iteration. each iteration takes a copy of the
|
||||
// initialized RNG state and offsets it as needed.
|
||||
curand_init(
|
||||
std::get<0>(seeds),
|
||||
0,
|
||||
std::get<1>(seeds) + p.dropout_batch_head_rng_offset,
|
||||
&curand_state_init);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Iterate through keys
|
||||
for (int32_t iter_key_start = 0; iter_key_start < p.num_keys;
|
||||
iter_key_start += kKeysPerBlock) {
|
||||
@ -600,16 +762,65 @@ struct AttentionKernel {
|
||||
(tb_tile_offset.n() * MM0::Mma::WarpCount::kN) +
|
||||
(my_warp_id / MM0::Mma::WarpCount::kM)};
|
||||
|
||||
// multiply by scaling factor
|
||||
if (kSupportsBias) {
|
||||
accum =
|
||||
cutlass::multiplies<typename MM0::Mma::FragmentC>()(p.scale, accum);
|
||||
}
|
||||
|
||||
// apply attention bias if applicable
|
||||
if (kSupportsBias && p.attn_bias_ptr != nullptr) {
|
||||
// load bias tile Bij into shared memory
|
||||
typename MM0::BiasLoader::GmemTileIterator bias_iter(
|
||||
{cutlass::layout::RowMajor(p.bias_strideM)},
|
||||
// attn_bias_pointer points to matrix of size (n_queries, n_keys)
|
||||
// for the relevant batch_id and head_id
|
||||
p.attn_bias_ptr + query_start * p.bias_strideM + iter_key_start,
|
||||
{problem_size_0_m, problem_size_0_n},
|
||||
thread_id());
|
||||
cutlass::TensorRef<scalar_t, cutlass::layout::RowMajor> bias_tensor_ref(
|
||||
shared_storage.after_mm0.bias.data(),
|
||||
cutlass::layout::RowMajor(MM0::ThreadblockShape::kN));
|
||||
typename MM0::BiasLoader::SmemTileIterator smem_tile_iter(
|
||||
bias_tensor_ref, thread_id());
|
||||
MM0::BiasLoader::load(bias_iter, smem_tile_iter);
|
||||
|
||||
// Pij += Bij, Pij is in register fragment and Bij is in shared memory
|
||||
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
|
||||
lane_id(), warp_id(), iteratorC_tile_offset);
|
||||
MM0::AccumLambdaIterator::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) {},
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
if (accum_m < problem_size_0_m && accum_n < problem_size_0_n) {
|
||||
accum[idx] += bias_tensor_ref.at({accum_m, accum_n});
|
||||
}
|
||||
},
|
||||
[&](int accum_m) {});
|
||||
}
|
||||
|
||||
// Mask out last if causal
|
||||
if (p.causal && p.num_keys - iter_key_start <= kKeysPerBlock) {
|
||||
// This is only needed if upper-right corner of current query / key block
|
||||
// intersects the mask Coordinates of upper-right corner of current block
|
||||
// is y=query_start x=min(iter_key_start + kKeysPerBlock, num_keys)) The
|
||||
// first masked element is x = y + offset -> query_start + offset There is
|
||||
// intersection (and we need to mask) if min(iter_key_start +
|
||||
// kKeysPerBlock, num_keys)) >= query_start + offset
|
||||
if (p.custom_mask_type &&
|
||||
cutlass::fast_min(iter_key_start + kKeysPerBlock, p.num_keys) >=
|
||||
(query_start + p.causal_diagonal_offset)) {
|
||||
auto query_start = blockIdx.x * kQueriesPerBlock;
|
||||
auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset(
|
||||
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
|
||||
lane_id(), warp_id(), iteratorC_tile_offset);
|
||||
int32_t last_col;
|
||||
MM0::ScalingCoefsUpdater::iterateRows(
|
||||
MM0::AccumLambdaIterator::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) {
|
||||
last_col = query_start + accum_m - iter_key_start;
|
||||
// last absolute col is (last absolute query + offset)
|
||||
// last local col is (last absolute query + offset -
|
||||
// iter_key_start)
|
||||
last_col = query_start + accum_m + p.causal_diagonal_offset -
|
||||
iter_key_start;
|
||||
},
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
if (accum_n > last_col) {
|
||||
@ -625,14 +836,11 @@ struct AttentionKernel {
|
||||
kFullColumns,
|
||||
([&] {
|
||||
// Update `mi` from accum stored in registers
|
||||
// Also updates `accum` with accum[i] <-
|
||||
// exp(accum[i] * scale
|
||||
// - mi)
|
||||
MM0::ScalingCoefsUpdater::update<
|
||||
kQueriesPerBlock,
|
||||
// Also does accum[i] <- exp(accum[i] - mi)
|
||||
iterative_softmax<
|
||||
typename MM0::Mma::Operator::IteratorC,
|
||||
kFullColumns,
|
||||
kIsFirst,
|
||||
kKeepOutputInRF>(
|
||||
kIsFirst>(
|
||||
accum_o,
|
||||
accum,
|
||||
mi,
|
||||
@ -643,7 +851,7 @@ struct AttentionKernel {
|
||||
warp_id(),
|
||||
p.num_keys - iter_key_start,
|
||||
iteratorC_tile_offset,
|
||||
p.scale);
|
||||
kSupportsBias ? 1.0f : p.scale);
|
||||
}));
|
||||
}));
|
||||
|
||||
@ -659,6 +867,69 @@ struct AttentionKernel {
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#ifdef HAS_PYTORCH
|
||||
// apply dropout (if applicable) after we've written Pij to smem.
|
||||
// dropout is applied by multiplying each element of Pij by:
|
||||
// - 0 with probability dropout_p
|
||||
// - 1 / (1 - dropout_p) with probability 1 - dropout_p
|
||||
//
|
||||
// for backward purposes we want to be able to map each element of the
|
||||
// attention matrix to the same random uniform number as the one we used
|
||||
// in forward, without needing to use the same iteration order or having
|
||||
// to store the dropout matrix. its possible to do this in registers but
|
||||
// it ends up being very slow because each thread having noncontiguous
|
||||
// strips of the Pij tile means we have to skip around a lot, and also
|
||||
// have to generate a single random number at a time
|
||||
if (kSupportsDropout && p.use_dropout) {
|
||||
auto si = shared_storage.after_mm0.si.accum_ref();
|
||||
// each thread handles a contiguous sequence of elements from Sij, all
|
||||
// coming from the same row. the reason they have to come from the same
|
||||
// row is that the sampling random numbers from a contiguous random
|
||||
// number sequence is much more efficient than jumping around, and the
|
||||
// linear offset of each element of S (the global matrix) maps to an
|
||||
// offset in a random number sequence. for S, the end of a row and the
|
||||
// beginning of the next have adjacent offsets, but for Sij, this is not
|
||||
// necessarily the case.
|
||||
const int num_threads = blockDim.x * blockDim.y * blockDim.z;
|
||||
const int threads_per_row =
|
||||
cutlass::fast_min(num_threads / problem_size_0_m, problem_size_0_n);
|
||||
const int elts_per_thread = cutlass::round_nearest(
|
||||
cutlass::ceil_div(problem_size_0_n, threads_per_row), 4);
|
||||
|
||||
const int thread_i = thread_id() / threads_per_row;
|
||||
const int thread_start_j =
|
||||
(thread_id() % threads_per_row) * elts_per_thread;
|
||||
|
||||
if (thread_i < problem_size_0_m && thread_start_j < problem_size_0_n) {
|
||||
curandStatePhilox4_32_10_t curand_state = curand_state_init;
|
||||
skipahead(
|
||||
static_cast<unsigned long long>(
|
||||
(query_start + thread_i) * p.num_keys +
|
||||
(iter_key_start + thread_start_j)),
|
||||
&curand_state);
|
||||
const float dropout_scale = 1.0 / (1.0 - p.dropout_prob);
|
||||
|
||||
// apply dropout scaling to elements this thread is responsible for,
|
||||
// in chunks of 4
|
||||
for (int sij_start_col_idx = thread_start_j; sij_start_col_idx <
|
||||
cutlass::fast_min(thread_start_j + elts_per_thread,
|
||||
problem_size_0_n);
|
||||
sij_start_col_idx += 4) {
|
||||
const float4 rand_uniform_quad = curand_uniform4(&curand_state);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int quad_idx = 0; quad_idx < 4; ++quad_idx) {
|
||||
si.at({thread_i, sij_start_col_idx + quad_idx}) *=
|
||||
static_cast<scalar_t>(
|
||||
dropout_scale *
|
||||
((&rand_uniform_quad.x)[quad_idx] > p.dropout_prob));
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads(); // p.use_dropout should have same value kernel-wide
|
||||
}
|
||||
#endif
|
||||
|
||||
//
|
||||
// MATMUL: Attn . V
|
||||
// Run the matmul `attn @ V` for a block of attn and V.
|
||||
@ -830,6 +1101,116 @@ struct AttentionKernel {
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename WarpIteratorC,
|
||||
bool kFullColumns,
|
||||
bool kIsFirst>
|
||||
CUTLASS_DEVICE static void iterative_softmax(
|
||||
typename WarpIteratorC::Fragment& frag_o, // output so far
|
||||
typename WarpIteratorC::Fragment& frag,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock>& mi,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
|
||||
int8_t lane_id,
|
||||
int8_t thread_id,
|
||||
int8_t warp_id,
|
||||
int16_t max_col,
|
||||
typename WarpIteratorC::TensorCoord const& tile_offset,
|
||||
float scaling) {
|
||||
/* Iterates on the accumulator and corresponding position on result matrix
|
||||
|
||||
(1) Update `mi[r]` to the max value of the row `r`
|
||||
(2) In a second iteration do the following:
|
||||
(a) accum <- exp(accum - mi)
|
||||
(b) m_prime <- exp(m_prime - mi)
|
||||
(c) s_prime <- s_prime * m_prime + sum(accum)
|
||||
|
||||
All of this is done on registers, before we store all of this
|
||||
on shared memory for the next matmul with Value.
|
||||
*/
|
||||
using Fragment = typename WarpIteratorC::Fragment;
|
||||
using LambdaIterator = typename DefaultMmaAccumLambdaIterator<
|
||||
WarpIteratorC,
|
||||
accum_t,
|
||||
kWarpSize>::Iterator;
|
||||
// Convert to `accum_t` (rather than double)
|
||||
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
|
||||
if (!kIsFirst) {
|
||||
if (thread_id < kQueriesPerBlock) {
|
||||
m_prime[thread_id] = mi[thread_id];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
auto lane_offset =
|
||||
LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset);
|
||||
|
||||
// First update `mi` to the max per-row
|
||||
{
|
||||
accum_t max;
|
||||
LambdaIterator::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) {
|
||||
max = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
},
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
if (kFullColumns || accum_n < max_col) {
|
||||
max = cutlass::fast_max(max, frag[idx]);
|
||||
}
|
||||
},
|
||||
[&](int accum_m) {
|
||||
// Having 4x atomicMax seems faster than reduce within warp
|
||||
// first...
|
||||
atomicMaxFloat(&mi[accum_m], max * scaling);
|
||||
});
|
||||
}
|
||||
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
|
||||
|
||||
// Make sure we all share the update values for `mi`
|
||||
__syncthreads();
|
||||
|
||||
if (thread_id < kQueriesPerBlock) {
|
||||
auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id]));
|
||||
m_prime[thread_id] = m_prime_exp;
|
||||
s_prime[thread_id] *= m_prime_exp;
|
||||
}
|
||||
__syncthreads(); // Update output fragments
|
||||
if (kKeepOutputInRF && !kIsFirst) {
|
||||
accum_t mp;
|
||||
LambdaIterator::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) { mp = m_prime[accum_m]; },
|
||||
[&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; },
|
||||
[&](int accum_m) {});
|
||||
__syncthreads();
|
||||
}
|
||||
// Update accum_m, accum_n, ...
|
||||
{
|
||||
accum_t mi_row, total_row;
|
||||
LambdaIterator::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) { mi_row = kLog2e * mi[accum_m]; },
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
frag[idx] = (kFullColumns || accum_n < max_col)
|
||||
? exp2f(frag[idx] - mi_row)
|
||||
: accum_t(0.0);
|
||||
},
|
||||
[&](int accum_m) {});
|
||||
LambdaIterator::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) { total_row = 0.0; },
|
||||
[&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; },
|
||||
[&](int accum_m) {
|
||||
if (LambdaIterator::reduceSameRow(
|
||||
lane_id, total_row, [](accum_t a, accum_t b) {
|
||||
return a + b;
|
||||
})) {
|
||||
atomicAdd(&s_prime[accum_m], total_row);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
static CUTLASS_DEVICE int8_t lane_id() {
|
||||
return threadIdx.x;
|
||||
}
|
||||
@ -849,3 +1230,7 @@ __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
|
||||
}
|
||||
AK::attention_kernel(p);
|
||||
}
|
||||
|
||||
template <typename AK>
|
||||
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
|
||||
attention_kernel_batched(typename AK::Params params);
|
||||
|
||||
@ -0,0 +1,88 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holdvr nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/transform/pitch_linear_thread_map.h"
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
|
||||
|
||||
template <
|
||||
typename scalar_t, // scalar type
|
||||
typename ThreadblockTileShape, // size of tile to load
|
||||
int Threads, // number of participating threads
|
||||
int ElementsPerAccess> // thread access width in elements
|
||||
class TileSmemLoader {
|
||||
public:
|
||||
using SmemTile =
|
||||
cutlass::AlignedBuffer<scalar_t, ThreadblockTileShape::kCount>;
|
||||
|
||||
using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap<
|
||||
cutlass::layout::PitchLinearShape<
|
||||
ThreadblockTileShape::kColumn, // contiguous
|
||||
ThreadblockTileShape::kRow>, // strided
|
||||
Threads, // Threads
|
||||
ElementsPerAccess>; // ElementsPerAccess
|
||||
|
||||
using GmemTileIterator =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
ThreadblockTileShape, // Shape
|
||||
scalar_t, // Element
|
||||
cutlass::layout::RowMajor, // Layout
|
||||
0, // AdvanceRank
|
||||
ThreadMap>; // ThreadMap
|
||||
|
||||
using SmemTileIterator = cutlass::transform::threadblock::RegularTileIterator<
|
||||
ThreadblockTileShape, // Shape
|
||||
scalar_t, // Element
|
||||
cutlass::layout::RowMajor, // Layout
|
||||
0, // AdvanceRank
|
||||
ThreadMap>; // ThreadMap
|
||||
|
||||
using Fragment = typename GmemTileIterator::Fragment;
|
||||
|
||||
/// load a tile from global memory into shared memory
|
||||
CUTLASS_DEVICE
|
||||
static void load(
|
||||
GmemTileIterator tile_load_iter,
|
||||
SmemTileIterator tile_store_iter) {
|
||||
Fragment tb_frag;
|
||||
tb_frag.clear();
|
||||
tile_load_iter.load(tb_frag);
|
||||
tile_store_iter.store(tb_frag);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
};
|
||||
Reference in New Issue
Block a user