ex77 backwards GQA (#2556)

* bwd GQA init

* Update examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu

* ref kernel type conversion fix

---------

Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
This commit is contained in:
Richard Cai
2025-09-09 09:53:28 -07:00
committed by GitHub
parent 76c96b0be3
commit 56f0718a97
6 changed files with 341 additions and 251 deletions

View File

@ -183,6 +183,9 @@ struct Options {
cmd.get_cmd_line_argument("h", h, -1);
if (h == -1) h = 2048 / d;
cmd.get_cmd_line_argument("h_k", h_k, -1);
if (h_k == -1) h_k = h;
varlen = cmd.check_cmd_line_flag("varlen");
cmd.get_cmd_line_argument("q", q, -1);
@ -298,6 +301,7 @@ struct Options {
<< " --help If specified, displays this usage statement\n\n"
<< " --b=<int> Sets the B extent\n"
<< " --h=<int> Sets the H extent\n"
<< " --h_k=<int> Sets the H_K/V extent (for GQA/MQA)\n"
<< " --q=<int> Sets the Q extent\n"
<< " --k=<int> Sets the K extent\n"
<< " --varlen-q=<int>:<int...> Sets the variable Q extent per batch (colon separated)\n"
@ -405,25 +409,24 @@ struct BwdRunner {
#endif
using ElementAccumulator = float;
// Q K D (H B)
// Q K D D_VO ((H_R, H_K) B)
using ProblemShape = std::conditional_t<
kIsVarlen,
cute::tuple<VariableLength, VariableLength, int, int, cute::tuple<int, int>>,
cute::tuple<int, int, int, int, cute::tuple<int, int>>
cute::tuple<VariableLength, VariableLength, int, int, cute::tuple<cute::tuple<int, int>, int>>,
cute::tuple<int, int, int, int, cute::tuple<cute::tuple<int, int>, int>>
>;
using TensorStride = Stride<int, _1, Stride<int, int>>; // Seq D (H B)
using StrideQ = TensorStride;
using StrideK = TensorStride;
using StrideV = TensorStride;
using StrideO = TensorStride;
using StrideLSE = Stride<_1, Stride<int, int>>; // Seq (H B)
using StrideQ = Stride<int, _1, Stride<Stride<int, int>, int>>; // Q D ((H_R, H_K), B)
using StrideK = Stride<int, _1, Stride<Stride<_0, int>, int>>; // K D ((H_R, H_K), B)
using StrideV = StrideK; // K D_VO ((H_R, H_K), B)
using StrideO = StrideQ; // Q D_VO ((H_R, H_K), B)
using StrideLSE = Stride<_1, Stride<Stride<int, int>, int>>; // Q ((H_R, H_K), B)
// Backwards specific
using StrideDQ = TensorStride;
using StrideDK = TensorStride;
using StrideDV = TensorStride;
using StrideDO = TensorStride;
using StrideDQ = StrideQ;
using StrideDK = StrideK;
using StrideDV = StrideV;
using StrideDO = StrideO;
//
// Data members
@ -468,43 +471,15 @@ struct BwdRunner {
auto [Q, K, D, D_VO, HB] = problem_shape;
auto [H, B] = HB;
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
select<0,2,4>(problem_shape),
stride_Q);
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()),
select<1,2,4>(problem_shape),
stride_K);
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()),
select<1,3,4>(problem_shape),
stride_V);
Tensor mO = make_tensor(make_gmem_ptr(block_O.get()),
select<0,3,4>(problem_shape),
stride_O);
// keep going here! (this might be better in cursor)
Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()),
select<0,4>(problem_shape),
stride_LSE);
Tensor mDQ = make_tensor(make_gmem_ptr(block_ref_dQ.get()),
select<0,2,4>(problem_shape),
stride_dQ);
Tensor mDK = make_tensor(make_gmem_ptr(block_ref_dK.get()),
select<1,2,4>(problem_shape),
stride_dK);
Tensor mDV = make_tensor(make_gmem_ptr(block_ref_dV.get()),
select<1,3,4>(problem_shape),
stride_dV);
Tensor mDO = make_tensor(make_gmem_ptr(block_dO.get()),
select<0,3,4>(problem_shape),
stride_dO);
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()), make_shape(Q, D, HB), stride_Q);
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()), make_shape(K, D, HB), stride_K);
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()), make_shape(K, D_VO, HB), stride_V);
Tensor mO = make_tensor(make_gmem_ptr(block_O.get()), make_shape(Q, D_VO, HB), stride_O);
Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()), make_shape(Q, HB), stride_LSE);
Tensor mDQ = make_tensor(make_gmem_ptr(block_ref_dQ.get()), make_shape(Q, D, HB), stride_dQ);
Tensor mDK = make_tensor(make_gmem_ptr(block_ref_dK.get()), make_shape(K, D, HB), stride_dK);
Tensor mDV = make_tensor(make_gmem_ptr(block_ref_dV.get()), make_shape(K, D_VO, HB), stride_dV);
Tensor mDO = make_tensor(make_gmem_ptr(block_dO.get()), make_shape(Q, D_VO, HB), stride_dO);
fmha_bwd_reference(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, mDK, mDV, ActiveMask{});
@ -549,6 +524,9 @@ struct BwdRunner {
}
auto initialize_problem_shape(Options const& options) {
int h_r = options.h / options.h_k;
assert(options.h % options.h_k == 0);
if constexpr (kIsVarlen) {
int num_batches = options.b;
@ -599,14 +577,14 @@ struct BwdRunner {
ProblemShape problem_shape{
{max_seqlen_q, block_cumulative_seqlen_q.get(), total_seqlen_q},
{max_seqlen_kv, block_cumulative_seqlen_kv.get(), total_seqlen_kv},
options.d, options.d_vo, {options.h, options.b}
options.d, options.d_vo, {{h_r, options.h_k}, options.b}
};
auto tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, options.d, options.d_vo, make_shape(options.h, 1));
auto tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, options.d, options.d_vo, make_shape(make_shape(h_r, options.h_k), 1));
return cute::make_tuple(problem_shape, tensor_shape);
}
else {
ProblemShape problem_shape{options.q, options.k, options.d, options.d_vo, {options.h, options.b}};
ProblemShape problem_shape{options.q, options.k, options.d, options.d_vo, {{h_r, options.h_k}, options.b}};
return cute::make_tuple(problem_shape, problem_shape);
}
}
@ -616,22 +594,23 @@ struct BwdRunner {
auto [problem_shape, tensor_shape] = initialize_problem_shape(options);
auto [Q, K, D, D_VO, HB] = tensor_shape;
auto [H, B] = HB;
auto [H_R, H_K] = H;
D = cutlass::round_up(D, 8); // Alignment
// for varlen, Q == total_Q, K == total_K, B = 1
// but in problem_shape, they've got to be max_Q/max_K, and B = B
auto shape_Q = make_shape(Q, D, make_shape(H, B));
auto shape_O = make_shape(Q, D_VO, make_shape(H, B));
auto shape_K = make_shape(K, D, make_shape(H, B));
auto shape_V = make_shape(K, D_VO, make_shape(H, B));
auto shape_LSE = make_shape(Q, make_shape(H, B));
auto shape_Q = make_shape(Q, D, HB);
auto shape_K = make_shape(K, D, HB);
auto shape_V = make_shape(K, D_VO, HB);
auto shape_O = make_shape(Q, D_VO, HB);
auto shape_LSE = make_shape(Q, HB);
stride_Q = make_stride(D, _1{}, make_stride(D*Q, B == 1 ? 0 : D*Q*H));
stride_K = make_stride(D, _1{}, make_stride(D*K, B == 1 ? 0 : D*K*H));
stride_V = make_stride(D_VO, _1{}, make_stride(D_VO*K, B == 1 ? 0 : D_VO*K*H));
stride_O = make_stride(D_VO, _1{}, make_stride(D_VO*Q, B == 1 ? 0 : D_VO*Q*H));
stride_LSE = make_stride(_1{}, make_stride(Q, B == 1 ? 0 : Q*H));
stride_Q = make_stride(D, _1{}, make_stride(make_stride(D*Q, D*Q*H_R), B == 1 ? 0 : D*Q*H_R*H_K));
stride_K = make_stride(D, _1{}, make_stride(make_stride(_0{}, D*K), B == 1 ? 0 : D*K*H_K));
stride_V = make_stride(D_VO, _1{}, make_stride(make_stride(_0{},D_VO*K), B == 1 ? 0 : D_VO*K*H_K));
stride_O = make_stride(D_VO, _1{}, make_stride(make_stride(D_VO*Q, D_VO*Q*H_R), B == 1 ? 0 : D_VO*Q*H_R*H_K));
stride_LSE = make_stride(_1{}, make_stride(make_stride(Q, Q*H_R), B == 1 ? 0 : Q*H_R*H_K));
stride_dQ = stride_Q;
stride_dK = stride_K;
@ -642,20 +621,23 @@ struct BwdRunner {
return size(make_shape(1ull, shape));
};
auto size_K = lsize(K * D * H_K * B);
auto size_V = lsize(K * D_VO * H_K * B);
block_Q.reset(lsize(shape_Q));
block_K.reset(lsize(shape_K));
block_V.reset(lsize(shape_V));
block_K.reset(size_K);
block_V.reset(size_V);
block_O.reset(lsize(shape_O));
block_LSE.reset(lsize(shape_LSE));
block_dQ.reset(lsize(shape_Q));
block_dK.reset(lsize(shape_K));
block_dV.reset(lsize(shape_V));
block_dK.reset(size_K);
block_dV.reset(size_V);
block_dO.reset(lsize(shape_O));
block_ref_dQ.reset(lsize(shape_Q));
block_ref_dK.reset(lsize(shape_K));
block_ref_dV.reset(lsize(shape_V));
block_ref_dK.reset(size_K);
block_ref_dV.reset(size_V);
initialize_block(block_Q, seed + 2023, options.init_style_q);
initialize_block(block_K, seed + 2022, options.init_style_k);
@ -689,7 +671,7 @@ struct BwdRunner {
select<0,4>(problem_shape),
stride_LSE);
if (! options.skip_reference) {
if (not options.skip_reference) {
fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{});
}
@ -820,7 +802,8 @@ struct BwdRunner {
flops *= static_cast<double>(get<0>(problem_shape));
flops *= static_cast<double>(get<1>(problem_shape));
flops *= (3 * static_cast<double>(get<2>(problem_shape)) + 2 * static_cast<double>(get<3>(problem_shape)));
flops *= static_cast<double>(get<4,0>(problem_shape));
flops *= static_cast<double>(get<4,0,0>(problem_shape));
flops *= static_cast<double>(get<4,0,1>(problem_shape));
flops *= static_cast<double>(get<4,1>(problem_shape));
double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);
example_result.tflops_tc_s = tflops_s;
@ -1001,7 +984,7 @@ int main_single(int argc, char const **args) {
hw_info.sm_count = options.sm_count;
}
std::cout << "###### B " << options.b << " H " << options.h << " Q " << options.q << " K " << options.k << " D " << options.d << " D_VO " << options.d_vo << " ";
std::cout << "###### B " << options.b << " H " << options.h << " H_K " << options.h_k << " Q " << options.q << " K " << options.k << " D " << options.d << " D_VO " << options.d_vo << " ";
std::cout << "Backward" << " " << (options.causal ? "Causal" : "Full") << " ";
std::cout << "#SM " << hw_info.sm_count << std::endl;

View File

@ -60,6 +60,38 @@ template<
class Mask
>
class Sm100FmhaBwd {
private:
template <typename T>
constexpr static auto to_bwd_shape(T shape) {
if constexpr (IsMla) { // remove GQA mode
constexpr int R = decltype(rank(shape))::value;
auto HB = get<R-1>(shape);
auto rest = take<0,R-1>(shape);
return append(rest, make_shape(size<0>(HB), get<1>(HB)));
}
else {
return shape;
}
}
template <typename T>
constexpr static auto to_bwd_stride(T stride) {
if constexpr (IsMla) { // remove GQA mode
constexpr int R = decltype(rank(stride))::value;
auto HB = get<R-1>(stride);
auto rest = take<0,R-1>(stride);
if constexpr (is_same_v<remove_cv_t<decltype(get<0,0>(HB))>, _0>) {
return append(rest, make_stride(get<0,1>(HB), get<1>(HB)));
}
else {
return append(rest, make_stride(get<0,0>(HB), get<1>(HB)));
}
}
else {
return stride;
}
}
public:
/// Argument structure: User API
struct Arguments {
@ -67,26 +99,26 @@ public:
ProblemShape problem_shape;
const Element* ptr_Q;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_Q;
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int,int>, int>> stride_Q;
const Element* ptr_K;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_K;
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<cute::_0,int>, int>> stride_K;
const Element* ptr_V;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_V;
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<cute::_0,int>, int>> stride_V;
const Element* ptr_O;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_O;
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int,int>, int>> stride_O;
const ElementAccumulator* ptr_LSE;
cute::tuple<cute::_1, cute::tuple<int, int>> stride_LSE;
cute::tuple<cute::_1, cute::tuple<cute::tuple<int,int>, int>> stride_LSE;
const Element* ptr_dO;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dO;
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int,int>, int>> stride_dO;
Element* ptr_dQ;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dQ;
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int,int>, int>> stride_dQ;
Element* ptr_dK;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dK;
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<cute::_0,int>, int>> stride_dK;
Element* ptr_dV;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dV;
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<cute::_0,int>, int>> stride_dV;
ElementAccumulator softmax_scale;
@ -106,9 +138,10 @@ public:
>
>;
using ProblemShapeMLA = decltype(to_bwd_shape(ProblemShape{}));
using OperationMla = cutlass::fmha::device::FMHA<
cutlass::fmha::kernel::Sm100FmhaBwdMlaKernelTmaWarpSpecialized<
ProblemShape, Element, ElementAccumulator, TileShape, Mask
ProblemShapeMLA, Element, ElementAccumulator, TileShape, Mask
>
>;
@ -134,10 +167,11 @@ private:
using namespace cute;
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
auto [H_R, H_K] = H;
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
auto stride_sum_OdO = make_stride(_1{}, make_stride(Q, Q*H));
auto stride_scaled_lse = make_stride(_1{}, make_stride(Q, Q*H));
auto stride_sum_OdO = make_stride(_1{}, make_stride(make_stride(Q, Q*H_R), B == 1 ? 0 : Q*H_R*H_K));
auto stride_scaled_lse = make_stride(_1{}, make_stride(make_stride(Q, Q*H_R), B == 1 ? 0 : Q*H_R*H_K));
auto log2_e = log2f(expf(1.0f));
return typename OperationSumOdO::Arguments {
args.problem_shape,
@ -154,14 +188,15 @@ private:
using namespace cute;
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
auto [H_R, H_K] = H;
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
auto stride_src_dQ = make_stride(D, _1{}, make_stride(D*Q, D*Q*H));
auto stride_src_dQ = make_stride(D, _1{}, make_stride(make_stride(D*Q, D*Q*H_R), B == 1 ? 0 : D*Q*H_R*H_K));
return typename OperationConvert::Arguments {
args.problem_shape,
src, stride_src_dQ,
nullptr, stride_src_dQ,
nullptr, stride_src_dQ,
nullptr, args.stride_dK,
nullptr, args.stride_dV,
args.ptr_dQ, args.stride_dQ,
nullptr, args.stride_dK,
nullptr, args.stride_dV,
@ -171,22 +206,22 @@ private:
static typename Operation::Arguments to_bwd_arguments(
Arguments const& args,
ElementAccumulator* sum_OdO = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_sum_OdO = {},
ElementAccumulator* scaled_lse = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_scaled_lse = {},
ElementAccumulator* dQ_acc = nullptr, cute::tuple<int, cute::_1, cute::tuple<int, int>> const& stride_dQ = {}) {
ElementAccumulator* sum_OdO = nullptr, cute::tuple<cute::_1, cute::tuple<cute::tuple<int, int>, int>> const& stride_sum_OdO = {},
ElementAccumulator* scaled_lse = nullptr, cute::tuple<cute::_1, cute::tuple<cute::tuple<int, int>, int>> const& stride_scaled_lse = {},
ElementAccumulator* dQ_acc = nullptr, cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int, int>, int>> const& stride_dQ = {}) {
return typename Operation::Arguments{
args.problem_shape,
{ args.ptr_Q, args.stride_Q,
args.ptr_K, args.stride_K,
args.ptr_V, args.stride_V,
args.ptr_dO, args.stride_dO,
scaled_lse, stride_scaled_lse,
sum_OdO, stride_sum_OdO,
dQ_acc, stride_dQ,
to_bwd_shape(args.problem_shape),
{ args.ptr_Q, to_bwd_stride(args.stride_Q),
args.ptr_K, to_bwd_stride(args.stride_K),
args.ptr_V, to_bwd_stride(args.stride_V),
args.ptr_dO, to_bwd_stride(args.stride_dO),
scaled_lse, to_bwd_stride(stride_scaled_lse),
sum_OdO, to_bwd_stride(stride_sum_OdO),
dQ_acc, to_bwd_stride(stride_dQ),
args.softmax_scale },
{ args.ptr_dK, args.stride_dK,
args.ptr_dV, args.stride_dV },
{ args.ptr_dK, to_bwd_stride(args.stride_dK),
args.ptr_dV, to_bwd_stride(args.stride_dV) },
args.hw_info
};
}
@ -220,7 +255,7 @@ public:
static size_t
get_workspace_size(Arguments const& args) {
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
auto [H, B] = product_each(HB);
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
size_t workspace_bytes = 0;
@ -240,7 +275,7 @@ public:
<< workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null"));
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
auto [H, B] = product_each(HB);
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_sum_OdO);
@ -269,7 +304,7 @@ public:
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
auto [H, B] = product_each(HB);
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
char* workspace_chr = reinterpret_cast<char*>(workspace);

View File

@ -46,18 +46,18 @@ struct FmhaKernelBwdConvert {
ProblemShape problem_shape;
const ElementAcc* ptr_src_dQ;
tuple<int, _1, tuple<int, int>> stride_src_dQ;
tuple<int, _1, tuple<tuple<int, int>, int>> stride_src_dQ;
const ElementAcc* ptr_src_dK;
tuple<int, _1, tuple<int, int>> stride_src_dK;
tuple<int, _1, tuple<tuple<_0, int>, int>> stride_src_dK;
const ElementAcc* ptr_src_dV;
tuple<int, _1, tuple<int, int>> stride_src_dV;
tuple<int, _1, tuple<tuple<_0, int>, int>> stride_src_dV;
Element* ptr_dest_dQ;
tuple<int, _1, tuple<int, int>> stride_dest_dQ;
tuple<int, _1, tuple<tuple<int, int>, int>> stride_dest_dQ;
Element* ptr_dest_dK;
tuple<int, _1, tuple<int, int>> stride_dest_dK;
tuple<int, _1, tuple<tuple<_0, int>, int>> stride_dest_dK;
Element* ptr_dest_dV;
tuple<int, _1, tuple<int, int>> stride_dest_dV;
tuple<int, _1, tuple<tuple<_0, int>, int>> stride_dest_dV;
ElementAcc scale = 1.0;
};
@ -104,8 +104,8 @@ struct FmhaKernelBwdConvert {
template<class StrideSrc, class StrideDest, class Count>
CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, Count const& count, int d_dim) {
auto ptr_src_bh = ptr_src + get<2,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y;
auto ptr_dest_bh = ptr_dest + get<2,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y;
auto ptr_src_bh = ptr_src + get<2,0,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y;
auto ptr_dest_bh = ptr_dest + get<2,0,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y;
int seqlen = count;
if constexpr (is_variable_length_v<decltype(count)>) {

View File

@ -46,18 +46,18 @@ struct FmhaKernelBwdSumOdO {
ProblemShape problem_shape;
const Element* ptr_O;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_O;
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int, int>, int>> stride_O;
const Element* ptr_dO;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dO;
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int, int>, int>> stride_dO;
ElementAcc* ptr_sum_OdO;
cute::tuple<cute::_1, cute::tuple<int, int>> stride_sum_OdO;
cute::tuple<cute::_1, cute::tuple<cute::tuple<int, int>, int>> stride_sum_OdO;
const ElementAcc* ptr_lse = nullptr;
cute::tuple<cute::_1, cute::tuple<int, int>> stride_lse;
cute::tuple<cute::_1, cute::tuple<cute::tuple<int, int>, int>> stride_lse;
ElementAcc* ptr_scaled_lse = nullptr;
cute::tuple<cute::_1, cute::tuple<int, int>> stride_scaled_lse;
cute::tuple<cute::_1, cute::tuple<cute::tuple<int, int>, int>> stride_scaled_lse;
ElementAcc sum_odo_scale = 1.0;
ElementAcc lse_scale = 1.0;
@ -104,11 +104,11 @@ struct FmhaKernelBwdSumOdO {
}
CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
auto ptr_O_bh = params.ptr_O + blockIdx.y * get<2,0>(params.stride_O) + blockIdx.z * get<2,1>(params.stride_O);
auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<2,0>(params.stride_dO) + blockIdx.z * get<2,1>(params.stride_dO);
auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1,0>(params.stride_sum_OdO) + blockIdx.z * get<1,1>(params.stride_sum_OdO);
auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse);
auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse);
auto ptr_O_bh = params.ptr_O + blockIdx.y * get<2,0,0>(params.stride_O) + blockIdx.z * get<2,1>(params.stride_O);
auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<2,0,0>(params.stride_dO) + blockIdx.z * get<2,1>(params.stride_dO);
auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1,0,0>(params.stride_sum_OdO) + blockIdx.z * get<1,1>(params.stride_sum_OdO);
auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse);
auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse);
auto problem_q = get<0>(params.problem_shape);
int seqlen_q = problem_q;

View File

@ -119,13 +119,15 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
static constexpr int Alignment = 128 / sizeof_bits_v<Element>;
static constexpr int kStages = 2;
using TensorStrideContiguousK = Stride<int, _1, Stride<int, int>>;
using TensorStrideContiguousMN = Stride<_1, int, Stride<int, int>>;
using TensorStrideContiguousK = Stride<int, _1, Stride<Stride<int,int>, int>>;
using TensorStrideContiguousMN = Stride<_1, int, Stride<Stride<int,int>, int>>;
using TensorStrideContiguousK_GQA = Stride<int, _1, Stride<Stride<_0,int>, int>>;
using TensorStrideContiguousMN_GQA = Stride<_1, int, Stride<Stride<_0,int>, int>>;
// compute S
using CollectiveMmaKQ = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
Element, TensorStrideContiguousK, Alignment,
Element, TensorStrideContiguousK_GQA, Alignment,
Element, TensorStrideContiguousK, Alignment,
ElementAcc,
Shape<TileShapeK, TileShapeQ, TileShapeDQK>,
@ -137,7 +139,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
// compute dP
using CollectiveMmaVDO = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
Element, TensorStrideContiguousK, Alignment,
Element, TensorStrideContiguousK_GQA, Alignment,
Element, TensorStrideContiguousK, Alignment,
ElementAcc,
Shape<TileShapeK, TileShapeQ, TileShapeDVO>,
@ -177,7 +179,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
// somewhat arbitrary since we dump to smem, need to agree with the previous one
Element, TensorStrideContiguousMN, Alignment,
Element, TensorStrideContiguousMN, Alignment,
Element, TensorStrideContiguousMN_GQA, Alignment,
ElementAcc,
Shape<TileShapeQ, TileShapeDQK, TileShapeK>,
ClusterShape, cutlass::gemm::collective::StageCount<kStages>,
@ -278,15 +280,16 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem");
using TensorStride = TensorStrideContiguousK; // S D (H B)
using RowTensorStride = Stride<_1, Stride<int, int>>; // S (H B)
using TensorStride_GQA = TensorStrideContiguousK_GQA;
using RowTensorStride = Stride<_1, Stride<Stride<int, int>, int>>; // S (H B)
struct MainloopArguments {
const Element* ptr_q;
TensorStride stride_q;
const Element* ptr_k;
TensorStride stride_k;
TensorStride_GQA stride_k;
const Element* ptr_v;
TensorStride stride_v;
TensorStride_GQA stride_v;
const Element* ptr_do;
TensorStride stride_do;
@ -308,7 +311,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
using TMA_DO = typename CollectiveMmaVDO::Params::TMA_B;
using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{},
make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(1, 1)), TensorStride{}),
make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(make_shape(1,1), 1)), TensorStride{}),
SmemLayoutDQ{}(_, _, _0{})
));
@ -322,9 +325,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
struct EpilogueArguments {
Element* ptr_dk;
TensorStride stride_dk;
TensorStride_GQA stride_dk;
Element* ptr_dv;
TensorStride stride_dv;
TensorStride_GQA stride_dv;
};
struct Arguments {
@ -346,7 +349,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
static bool can_implement(Arguments const& args) {
auto [Q, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
if (Q <= 0 || K <= 0 || D <= 0 || D_VO <= 0 || H <= 0 || B <= 0) {
auto [H_R, H_K] = H;
if (Q <= 0 || K <= 0 || D <= 0 || D_VO <= 0 || H_R <= 0 || H_K <= 0 || B <= 0) {
return false;
}
if (D % Alignment != 0 || D_VO % Alignment != 0) {
@ -432,7 +436,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
BlkCoord const& blk_coord,
BlkOffset const& blk_offset,
ProblemShape_ const& problem_shape,
int iter_index,
int iter_start,
int iter_end,
int iter_count,
MainloopArguments const& mainloop_args,
MainloopParams const& mainloop_params,
@ -447,6 +452,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_producer_state) {
auto [Q, K, D, D_VO, HB] = problem_shape;
int iter_index = iter_start;
using X = Underscore;
@ -590,6 +596,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
iter_index += 1;
while (iter_count > 0) {
if (iter_index == iter_end) {
iter_index = iter_start;
get<0,0>(blk_coord_batch) += 1;
}
pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state);
tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);
@ -660,7 +671,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
CUTLASS_DEVICE void mma(
BlkCoord const& blk_coord,
ProblemShape_ const& problem_shape,
int iter_index,
int iter_start,
int iter_end,
int iter_count,
MainloopArguments const& mainloop_args,
TensorStorage& shared_tensors,
@ -1119,7 +1131,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
BlkCoord const& blk_coord,
BlkOffset const& blk_offset,
ProblemShape_ const& problem_shape,
int iter_index,
int iter_start,
int iter_end,
int iter_count,
MainloopArguments const& mainloop_args,
EpilogueArguments const& epilogue_args,
@ -1141,6 +1154,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
auto [Q, K, D, D_VO, HB] = problem_shape;
int iter_index = iter_start;
// in tmem, S & P overlap
// and dP and dQ overlap
@ -1224,8 +1238,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
auto tRT_cST_p = thread_r2t.partition_S(tDVcST);
auto tRT_cST = split_wg(tRT_cST_p);
bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} >= get<1>(problem_shape);
int last_iter = iter_count - 1 + iter_index;
bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} > get<1>(problem_shape);
CUTLASS_PRAGMA_NO_UNROLL
while (iter_count > 0) {
@ -1246,7 +1259,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
bool leading_causal_masking = false;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {
leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord));
leading_causal_masking = warp_uniform(iter_index == iter_start);
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
int offset = get<1>(problem_shape) - get<0>(problem_shape);
int kv_left = get<1>(blk_coord) * TileShapeK{};
@ -1258,7 +1271,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
}
bool trailing_residual_masking = false;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::ResidualMaskForBackward, Mask>) {
trailing_residual_masking = warp_uniform((iter_index == last_iter) || is_residual_k);
trailing_residual_masking = warp_uniform((iter_index == iter_end - 1) || is_residual_k);
}
dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) {
@ -1379,6 +1392,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
iter_count -= 1;
iter_index += 1;
if (iter_index == iter_end) {
iter_index = iter_start;
}
}
epilogue(
@ -1391,7 +1407,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
CUTLASS_DEVICE void reduce(
BlkCoord const& blk_coord,
ProblemShape_ const& problem_shape,
int iter_index,
int iter_start,
int iter_end,
int iter_count,
MainloopArguments const& mainloop_args,
MainloopParams const& mainloop_params,
@ -1404,6 +1421,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
using X = Underscore;
auto [Q, K, D, D_VO, HB] = problem_shape;
int iter_index = iter_start;
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;
@ -1415,7 +1433,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB));
auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step<X, _1, _1>{})
(_, _, _, _0{}, blk_coord_batch);
(_, _, _, _0{}, _);
Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{}));
@ -1426,7 +1444,6 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
auto thread_t2r = tiled_t2r.get_slice(thread_idx);
Tensor tTR_cDQ = thread_t2r.partition_D(cDQ);
Tensor tTR_gDQ = thread_t2r.partition_D(gDQ);
Tensor tTR_sDQ = thread_t2r.partition_D(sDQ);
Tensor tTR_tDQ = thread_t2r.partition_S(tDQtDQ);
@ -1472,7 +1489,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
).arrive_and_wait();
if (lane_predicate) {
// launch tma store
copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index));
copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index,blk_coord_batch));
pipeline_reduce_tma_store.producer_commit(pipeline_reduce_tma_store_producer_state);
}
@ -1481,6 +1498,10 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
iter_count -= 1;
iter_index += 1;
if (iter_index == iter_end) {
iter_index = iter_start;
get<0,0>(blk_coord_batch) += 1;
}
}
}
@ -1683,12 +1704,12 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
pipeline_init_wait(size(ClusterShape{}));
auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(blockIdx.y, blockIdx.z));
auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(make_coord(0, blockIdx.y), blockIdx.z));
auto [problem_shape, blk_offset] = apply_variable_length_offset(
params.problem_shape,
blk_coord
);
int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{});
int iter_end = ceil_div(get<0>(problem_shape), TileShapeQ{});
int iter_start = 0;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {
iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{};
@ -1699,7 +1720,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) {
return;
}
iter_count -= iter_start;
int iter_count = (iter_end - iter_start) * get<4,0,0>(problem_shape);
if (iter_count <= 0) {
epilogue_clear(
@ -1720,6 +1741,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
blk_offset,
problem_shape,
iter_start,
iter_end,
iter_count,
params.mainloop,
params.mainloop_params,
@ -1741,6 +1763,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
blk_coord,
problem_shape,
iter_start,
iter_end,
iter_count,
params.mainloop,
shared_storage.tensors,
@ -1763,6 +1786,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
blk_offset,
problem_shape,
iter_start,
iter_end,
iter_count,
params.mainloop,
params.epilogue,
@ -1794,6 +1818,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
blk_coord,
problem_shape,
iter_start,
iter_end,
iter_count,
params.mainloop,
params.mainloop_params,
@ -1820,7 +1845,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
static dim3 get_grid_shape(Params const& params) {
auto [Q, K, D, D_VO, HB] = params.problem_shape;
auto [H, B] = HB;
dim3 grid(ceil_div(K, TileShapeK{}), H, B);
auto [H_R, H_K] = H;
dim3 grid(ceil_div(K, TileShapeK{}), H_K, B);
return grid;
}
};

View File

@ -56,12 +56,12 @@ void __global__ fmha_bwd_reference_dQ_kernel(
using namespace cutlass::fmha::collective;
using Element = typename TensorO::value_type;
using ElementAccumulator = typename TensorLSE::value_type;
using ElementAcc = typename TensorLSE::value_type;
extern __shared__ char mS_mem[];
ElementAccumulator* mS = reinterpret_cast<ElementAccumulator*>(mS_mem);
Element* mS = reinterpret_cast<Element*>(mS_mem);
ElementAccumulator softmax_scale = 1.0 / sqrt(ElementAccumulator(size<2>(problem_shape_in)));
ElementAcc softmax_scale = 1.0f / sqrtf(size<2>(problem_shape_in));
for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) {
auto [problem_shape, offset] = apply_variable_length_offset(
@ -79,9 +79,9 @@ void __global__ fmha_bwd_reference_dQ_kernel(
auto mDQ = domain_offset(select<0,2,4>(offset), mDQ_in);
for (int idx_Q = blockIdx.x; idx_Q < size<0>(problem_shape); idx_Q += gridDim.x) {
for (int idx_K = threadIdx.x; idx_K < size<1>(problem_shape); idx_K += blockDim.x) {
ElementAccumulator acc_qk = 0;
ElementAccumulator acc_dov = 0;
ElementAccumulator acc_doo = 0;
ElementAcc acc_qk = 0;
ElementAcc acc_dov = 0;
ElementAcc acc_doo = 0;
for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) {
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
// acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
@ -94,20 +94,22 @@ void __global__ fmha_bwd_reference_dQ_kernel(
}
auto id = make_identity_tensor(make_shape(1, 1));
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
auto frag = make_tensor<ElementAcc>(Shape<_1, _1>{});
frag(0) = acc_qk;
fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
acc_qk = frag(0);
mS[idx_K] = static_cast<ElementAccumulator>(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo));
mS[idx_K] = static_cast<Element>(expf(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo));
} // for idx_K
__syncthreads();
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
ElementAccumulator acc = 0;
ElementAcc acc = 0;
for (int idx_K = 0; idx_K < size<1>(problem_shape); idx_K++) {
acc += mS[idx_K] * ElementAccumulator(mK(idx_K, idx_D, idx_L));
ElementAcc rK = mK(idx_K, idx_D, idx_L);
ElementAcc rDS = mS[idx_K];
acc += rDS * rK;
}
mDQ(idx_Q, idx_D, idx_L) = static_cast<typename TensorDQ::value_type>(acc);
} // for idx_D
@ -135,62 +137,83 @@ void __global__ fmha_bwd_reference_dK_kernel(
using namespace cutlass::fmha::collective;
using Element = typename TensorO::value_type;
using ElementAccumulator = typename TensorLSE::value_type;
using ElementAcc = typename TensorLSE::value_type;
extern __shared__ char mS_mem[];
ElementAccumulator* mS = reinterpret_cast<ElementAccumulator*>(mS_mem);
Element* mS = reinterpret_cast<Element*>(mS_mem);
ElementAccumulator softmax_scale = 1.0 / sqrt(ElementAccumulator(size<2>(problem_shape_in)));
ElementAcc softmax_scale = 1.0f / sqrtf(size<2>(problem_shape_in));
for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) {
auto [H, B] = get<4>(problem_shape_in);
auto [H_R, H_K] = H;
for (int idx_HB = blockIdx.y; idx_HB < H_K * B; idx_HB += gridDim.y) {
auto [idx_H_K, idx_B] = idx2crd(idx_HB, make_shape(H_K, B));
auto [problem_shape, offset] = apply_variable_length_offset(
problem_shape_in,
make_coord(_0{}, _0{}, _0{}, _0{}, idx2crd(idx_L, get<4>(problem_shape_in)))
problem_shape_in,
make_coord(_0{}, _0{}, _0{}, _0{}, make_coord(make_coord(_0{}, idx_H_K), idx_B))
);
// problem_shape = problem_shape_in;
// offset = repeat_like(problem_shape_in, _0{});
auto mQ = domain_offset(select<0,2,4>(offset), mQ_in);
auto mK = domain_offset(select<1,2,4>(offset), mK_in);
auto mV = domain_offset(select<1,3,4>(offset), mV_in);
auto mO = domain_offset(select<0,3,4>(offset), mO_in);
auto mLSE = domain_offset(select<0,4>(offset), mLSE_in);
auto mDO = domain_offset(select<0,3,4>(offset), mDO_in);
auto mDK = domain_offset(select<1,2,4>(offset), mDK_in);
for (int idx_K = blockIdx.x; idx_K < size<1>(problem_shape); idx_K += gridDim.x) {
for (int idx_Q = threadIdx.x; idx_Q < size<0>(problem_shape); idx_Q += blockDim.x) {
ElementAccumulator acc_qk = 0;
ElementAccumulator acc_dov = 0;
ElementAccumulator acc_doo = 0;
for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) {
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
// acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
// acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
} // for idx_D0
auto [Q, K, D, D_VO, HB] = problem_shape;
auto [offset_Q, offset_K, offset_D, offset_D_VO, offset_HB] = offset;
for (int idx_D1 = 0; idx_D1 < size<3>(problem_shape); idx_D1++) {
acc_dov += mDO(idx_Q, idx_D1, idx_L) * mV(idx_K, idx_D1, idx_L);
acc_doo += mDO(idx_Q, idx_D1, idx_L) * mO(idx_Q, idx_D1, idx_L);
auto mQ = domain_offset(make_coord(offset_Q, offset_D, offset_HB), mQ_in);
auto mK = domain_offset(make_coord(offset_K, offset_D, offset_HB), mK_in);
auto mV = domain_offset(make_coord(offset_K, offset_D_VO, offset_HB), mV_in);
auto mO = domain_offset(make_coord(offset_Q, offset_D_VO, offset_HB), mO_in);
auto mLSE = domain_offset(make_coord(offset_Q, offset_HB), mLSE_in);
auto mDO = domain_offset(make_coord(offset_Q, offset_D_VO, offset_HB), mDO_in);
auto mDK = domain_offset(make_coord(offset_K, offset_D, offset_HB), mDK_in);
for (int idx_K = blockIdx.x; idx_K < K; idx_K += gridDim.x) {
ElementAcc acc_dk = 0;
for (int idx_H_R = 0; idx_H_R < H_R; idx_H_R++) {
auto coord_HB = make_coord(make_coord(idx_H_R, idx_H_K), idx_B);
for (int idx_Q = threadIdx.x; idx_Q < Q; idx_Q += blockDim.x) {
ElementAcc acc_qk = 0;
ElementAcc acc_dov = 0;
ElementAcc acc_doo = 0;
for (int idx_D0 = 0; idx_D0 < D; idx_D0++) {
ElementAcc rQ = mQ(idx_Q, idx_D0, coord_HB);
ElementAcc rK = mK(idx_K, idx_D0, coord_HB);
acc_qk += rQ * rK;
} // for idx_D0
for (int idx_D1 = 0; idx_D1 < D_VO; idx_D1++) {
ElementAcc rDO = mDO(idx_Q, idx_D1, coord_HB);
ElementAcc rV = mV(idx_K, idx_D1, coord_HB);
ElementAcc rO = mO(idx_Q, idx_D1, coord_HB);
acc_dov += rDO * rV;
acc_doo += rDO * rO ;
}
auto id = make_identity_tensor(make_shape(1, 1));
auto frag = make_tensor<ElementAcc>(Shape<_1, _1>{});
frag(0) = acc_qk;
fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
acc_qk = frag(0);
mS[idx_Q] = static_cast<Element>(expf(softmax_scale * acc_qk - mLSE(idx_Q, coord_HB)) * softmax_scale * (acc_dov - acc_doo));
} // for idx_Q
__syncthreads();
int idx_D = threadIdx.x;
if (idx_D < D) {
for (int idx_Q = 0; idx_Q < Q; idx_Q++) {
ElementAcc rQ = mQ(idx_Q, idx_D, coord_HB);
ElementAcc rDS = mS[idx_Q];
acc_dk += rDS * rQ;
}
}
auto id = make_identity_tensor(make_shape(1, 1));
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
frag(0) = acc_qk;
fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
acc_qk = frag(0);
__syncthreads();
} // for idx_H_R
mS[idx_Q] = static_cast<ElementAccumulator>(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo));
} // for idx_Q
__syncthreads();
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
ElementAccumulator acc = 0;
for (int idx_Q = 0; idx_Q < size<0>(problem_shape); idx_Q++) {
acc += mS[idx_Q] * ElementAccumulator(mQ(idx_Q, idx_D, idx_L));
}
mDK(idx_K, idx_D, idx_L) = static_cast<typename TensorDK::value_type>(acc);
} // for idx_D
int idx_D = threadIdx.x;
if (idx_D < D) {
auto coord_HB = make_coord(make_coord(0, idx_H_K), idx_B);
mDK(idx_K, idx_D, coord_HB) = static_cast<typename TensorDK::value_type>(acc_dk);
}
} // for idx_K
} // for idx_L
} // for idx_HB
}
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -216,54 +239,71 @@ void __global__ fmha_bwd_reference_dV_kernel(
using ElementAcc = typename TensorLSE::value_type;
extern __shared__ char mS_mem[];
ElementAcc* mS = reinterpret_cast<ElementAcc*>(mS_mem);
Element* mS = reinterpret_cast<Element*>(mS_mem);
ElementAcc softmax_scale = 1.0 / sqrt(ElementAcc(size<2>(problem_shape_in)));
ElementAcc softmax_scale = 1.0f / sqrtf(size<2>(problem_shape_in));
for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) {
auto [H, B] = get<4>(problem_shape_in);
auto [H_R, H_K] = H;
for (int idx_HB = blockIdx.y; idx_HB < H_K * B; idx_HB += gridDim.y) {
auto [idx_H_K, idx_B] = idx2crd(idx_HB, make_shape(H_K, B));
auto [problem_shape, offset] = apply_variable_length_offset(
problem_shape_in,
make_coord(_0{}, _0{}, _0{}, _0{}, idx2crd(idx_L, get<4>(problem_shape_in)))
problem_shape_in,
make_coord(_0{}, _0{}, _0{}, _0{}, make_coord(make_coord(_0{}, idx_H_K), idx_B))
);
// problem_shape = problem_shape_in;
// offset = repeat_like(problem_shape_in, _0{});
auto mQ = domain_offset(select<0,2,4>(offset), mQ_in);
auto mK = domain_offset(select<1,2,4>(offset), mK_in);
auto mV = domain_offset(select<1,3,4>(offset), mV_in);
auto mO = domain_offset(select<0,3,4>(offset), mO_in);
auto mLSE = domain_offset(select<0,4>(offset), mLSE_in);
auto mDO = domain_offset(select<0,3,4>(offset), mDO_in);
auto mDV = domain_offset(select<1,3,4>(offset), mDV_in);
for (int idx_K = blockIdx.x; idx_K < size<1>(problem_shape); idx_K += gridDim.x) {
for (int idx_Q = threadIdx.x; idx_Q < size<0>(problem_shape); idx_Q += blockDim.x) {
ElementAcc acc_qk = 0;
auto [Q, K, D, D_VO, HB] = problem_shape;
auto [offset_Q, offset_K, offset_D, offset_D_VO, offset_HB] = offset;
for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) {
ElementAcc rQ = mQ(idx_Q, idx_D0, idx_L);
ElementAcc rK = mK(idx_K, idx_D0, idx_L);
acc_qk += rQ * rK;
} // for idx_D0
auto mQ = domain_offset(make_coord(offset_Q, offset_D, offset_HB), mQ_in);
auto mK = domain_offset(make_coord(offset_K, offset_D, offset_HB), mK_in);
auto mV = domain_offset(make_coord(offset_K, offset_D_VO, offset_HB), mV_in);
auto mO = domain_offset(make_coord(offset_Q, offset_D_VO, offset_HB), mO_in);
auto mLSE = domain_offset(make_coord(offset_Q, offset_HB), mLSE_in);
auto mDO = domain_offset(make_coord(offset_Q, offset_D_VO, offset_HB), mDO_in);
auto mDV = domain_offset(make_coord(offset_K, offset_D_VO, offset_HB), mDV_in);
auto id = make_identity_tensor(make_shape(1, 1));
auto frag = make_tensor<ElementAcc>(Shape<_1, _1>{});
frag(0) = acc_qk;
fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
acc_qk = frag(0);
for (int idx_K = blockIdx.x; idx_K < K; idx_K += gridDim.x) {
ElementAcc acc_dv = 0;
for (int idx_H_R = 0; idx_H_R < H_R; idx_H_R++) {
auto coord_HB = make_coord(make_coord(idx_H_R, idx_H_K), idx_B);
for (int idx_Q = threadIdx.x; idx_Q < Q; idx_Q += blockDim.x) {
ElementAcc acc_qk = 0;
mS[idx_Q] = expf(softmax_scale * acc_qk - mLSE(idx_Q, idx_L));
} // for idx_Q
for (int idx_D0 = 0; idx_D0 < D; idx_D0++) {
ElementAcc rQ = mQ(idx_Q, idx_D0, coord_HB);
ElementAcc rK = mK(idx_K, idx_D0, coord_HB);
acc_qk += rQ * rK;
} // for idx_D0
__syncthreads();
auto id = make_identity_tensor(make_shape(1, 1));
auto frag = make_tensor<ElementAcc>(Shape<_1, _1>{});
frag(0) = acc_qk;
fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
acc_qk = frag(0);
for (int idx_D = threadIdx.x; idx_D < size<3>(problem_shape); idx_D += blockDim.x) {
ElementAcc acc = 0;
for (int idx_Q = 0; idx_Q < size<0>(problem_shape); idx_Q++) {
ElementAcc rS = static_cast<Element>(mS[idx_Q]);
ElementAcc rDO = mDO(idx_Q, idx_D, idx_L);
acc += rS * rDO;
}
mDV(idx_K, idx_D, idx_L) = static_cast<typename TensorDV::value_type>(acc);
} // for idx_D
mS[idx_Q] = static_cast<Element>(expf(softmax_scale * acc_qk - mLSE(idx_Q, coord_HB)));
} // for idx_Q
__syncthreads();
int idx_D_VO = threadIdx.x;
if (idx_D_VO < D_VO) {
for (int idx_Q = 0; idx_Q < Q; idx_Q++) {
ElementAcc rDO = mDO(idx_Q, idx_D_VO, coord_HB);
ElementAcc rP = mS[idx_Q];
acc_dv += rP * rDO;
}
} // for idx_D
__syncthreads();
} // for idx_H_R
int idx_D_VO = threadIdx.x;
if (idx_D_VO < D_VO) {
auto coord_HB = make_coord(make_coord(0, idx_H_K), idx_B);
mDV(idx_K, idx_D_VO, coord_HB) = static_cast<typename TensorDV::value_type>(acc_dv);
}
} // for idx_K
} // for idx_L
}
@ -288,7 +328,7 @@ void fmha_bwd_reference_dQ(
dim3 grid(size<0>(mDQ), size<2>(mDQ), 1);
dim3 block(256);
int shared_mem = size<0>(mK) * sizeof(typename TensorLSE::value_type);
int shared_mem = size<0>(mK) * sizeof(typename TensorDQ::value_type);
fmha_bwd_reference_dQ_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion);
}
@ -310,9 +350,12 @@ void fmha_bwd_reference_dK(
using namespace cute;
dim3 grid(size<0>(mDK), size<2>(mDK), 1);
dim3 block(256);
int shared_mem = size<0>(mDO) * sizeof(typename TensorLSE::value_type);
auto [K, D, HB] = mDK.shape();
auto [H, B] = HB;
auto [H_R, H_K] = H;
dim3 grid(K, H_K * B, 1);
dim3 block(std::max(D, 256));
int shared_mem = size<0>(mDO) * sizeof(typename TensorDK::value_type);
fmha_bwd_reference_dK_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion);
}
@ -334,9 +377,12 @@ void fmha_bwd_reference_dV(
using namespace cute;
dim3 grid(size<0>(mDV), size<2>(mDV), 1);
dim3 block(256);
int shared_mem = size<0>(mDO) * sizeof(typename TensorLSE::value_type);
auto [K, D_VO, HB] = mDV.shape();
auto [H, B] = HB;
auto [H_R, H_K] = H;
dim3 grid(K, H_K * B, 1);
dim3 block(std::max(D_VO, 256));
int shared_mem = size<0>(mDO) * sizeof(typename TensorDV::value_type);
fmha_bwd_reference_dV_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion);
}