ex77 backwards GQA (#2556)
* bwd GQA init * Update examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu * ref kernel type conversion fix --------- Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
This commit is contained in:
@ -183,6 +183,9 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("h", h, -1);
|
||||
if (h == -1) h = 2048 / d;
|
||||
|
||||
cmd.get_cmd_line_argument("h_k", h_k, -1);
|
||||
if (h_k == -1) h_k = h;
|
||||
|
||||
varlen = cmd.check_cmd_line_flag("varlen");
|
||||
|
||||
cmd.get_cmd_line_argument("q", q, -1);
|
||||
@ -298,6 +301,7 @@ struct Options {
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --b=<int> Sets the B extent\n"
|
||||
<< " --h=<int> Sets the H extent\n"
|
||||
<< " --h_k=<int> Sets the H_K/V extent (for GQA/MQA)\n"
|
||||
<< " --q=<int> Sets the Q extent\n"
|
||||
<< " --k=<int> Sets the K extent\n"
|
||||
<< " --varlen-q=<int>:<int...> Sets the variable Q extent per batch (colon separated)\n"
|
||||
@ -405,25 +409,24 @@ struct BwdRunner {
|
||||
#endif
|
||||
using ElementAccumulator = float;
|
||||
|
||||
// Q K D (H B)
|
||||
// Q K D D_VO ((H_R, H_K) B)
|
||||
using ProblemShape = std::conditional_t<
|
||||
kIsVarlen,
|
||||
cute::tuple<VariableLength, VariableLength, int, int, cute::tuple<int, int>>,
|
||||
cute::tuple<int, int, int, int, cute::tuple<int, int>>
|
||||
cute::tuple<VariableLength, VariableLength, int, int, cute::tuple<cute::tuple<int, int>, int>>,
|
||||
cute::tuple<int, int, int, int, cute::tuple<cute::tuple<int, int>, int>>
|
||||
>;
|
||||
|
||||
using TensorStride = Stride<int, _1, Stride<int, int>>; // Seq D (H B)
|
||||
using StrideQ = TensorStride;
|
||||
using StrideK = TensorStride;
|
||||
using StrideV = TensorStride;
|
||||
using StrideO = TensorStride;
|
||||
using StrideLSE = Stride<_1, Stride<int, int>>; // Seq (H B)
|
||||
using StrideQ = Stride<int, _1, Stride<Stride<int, int>, int>>; // Q D ((H_R, H_K), B)
|
||||
using StrideK = Stride<int, _1, Stride<Stride<_0, int>, int>>; // K D ((H_R, H_K), B)
|
||||
using StrideV = StrideK; // K D_VO ((H_R, H_K), B)
|
||||
using StrideO = StrideQ; // Q D_VO ((H_R, H_K), B)
|
||||
using StrideLSE = Stride<_1, Stride<Stride<int, int>, int>>; // Q ((H_R, H_K), B)
|
||||
|
||||
// Backwards specific
|
||||
using StrideDQ = TensorStride;
|
||||
using StrideDK = TensorStride;
|
||||
using StrideDV = TensorStride;
|
||||
using StrideDO = TensorStride;
|
||||
using StrideDQ = StrideQ;
|
||||
using StrideDK = StrideK;
|
||||
using StrideDV = StrideV;
|
||||
using StrideDO = StrideO;
|
||||
|
||||
//
|
||||
// Data members
|
||||
@ -468,43 +471,15 @@ struct BwdRunner {
|
||||
auto [Q, K, D, D_VO, HB] = problem_shape;
|
||||
auto [H, B] = HB;
|
||||
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
|
||||
select<0,2,4>(problem_shape),
|
||||
stride_Q);
|
||||
|
||||
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()),
|
||||
select<1,2,4>(problem_shape),
|
||||
stride_K);
|
||||
|
||||
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()),
|
||||
select<1,3,4>(problem_shape),
|
||||
stride_V);
|
||||
|
||||
Tensor mO = make_tensor(make_gmem_ptr(block_O.get()),
|
||||
select<0,3,4>(problem_shape),
|
||||
stride_O);
|
||||
|
||||
// keep going here! (this might be better in cursor)
|
||||
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()),
|
||||
select<0,4>(problem_shape),
|
||||
stride_LSE);
|
||||
|
||||
Tensor mDQ = make_tensor(make_gmem_ptr(block_ref_dQ.get()),
|
||||
select<0,2,4>(problem_shape),
|
||||
stride_dQ);
|
||||
|
||||
Tensor mDK = make_tensor(make_gmem_ptr(block_ref_dK.get()),
|
||||
select<1,2,4>(problem_shape),
|
||||
stride_dK);
|
||||
|
||||
Tensor mDV = make_tensor(make_gmem_ptr(block_ref_dV.get()),
|
||||
select<1,3,4>(problem_shape),
|
||||
stride_dV);
|
||||
|
||||
Tensor mDO = make_tensor(make_gmem_ptr(block_dO.get()),
|
||||
select<0,3,4>(problem_shape),
|
||||
stride_dO);
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()), make_shape(Q, D, HB), stride_Q);
|
||||
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()), make_shape(K, D, HB), stride_K);
|
||||
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()), make_shape(K, D_VO, HB), stride_V);
|
||||
Tensor mO = make_tensor(make_gmem_ptr(block_O.get()), make_shape(Q, D_VO, HB), stride_O);
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()), make_shape(Q, HB), stride_LSE);
|
||||
Tensor mDQ = make_tensor(make_gmem_ptr(block_ref_dQ.get()), make_shape(Q, D, HB), stride_dQ);
|
||||
Tensor mDK = make_tensor(make_gmem_ptr(block_ref_dK.get()), make_shape(K, D, HB), stride_dK);
|
||||
Tensor mDV = make_tensor(make_gmem_ptr(block_ref_dV.get()), make_shape(K, D_VO, HB), stride_dV);
|
||||
Tensor mDO = make_tensor(make_gmem_ptr(block_dO.get()), make_shape(Q, D_VO, HB), stride_dO);
|
||||
|
||||
fmha_bwd_reference(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, mDK, mDV, ActiveMask{});
|
||||
|
||||
@ -549,6 +524,9 @@ struct BwdRunner {
|
||||
}
|
||||
|
||||
auto initialize_problem_shape(Options const& options) {
|
||||
int h_r = options.h / options.h_k;
|
||||
assert(options.h % options.h_k == 0);
|
||||
|
||||
if constexpr (kIsVarlen) {
|
||||
int num_batches = options.b;
|
||||
|
||||
@ -599,14 +577,14 @@ struct BwdRunner {
|
||||
ProblemShape problem_shape{
|
||||
{max_seqlen_q, block_cumulative_seqlen_q.get(), total_seqlen_q},
|
||||
{max_seqlen_kv, block_cumulative_seqlen_kv.get(), total_seqlen_kv},
|
||||
options.d, options.d_vo, {options.h, options.b}
|
||||
options.d, options.d_vo, {{h_r, options.h_k}, options.b}
|
||||
};
|
||||
auto tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, options.d, options.d_vo, make_shape(options.h, 1));
|
||||
auto tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, options.d, options.d_vo, make_shape(make_shape(h_r, options.h_k), 1));
|
||||
|
||||
return cute::make_tuple(problem_shape, tensor_shape);
|
||||
}
|
||||
else {
|
||||
ProblemShape problem_shape{options.q, options.k, options.d, options.d_vo, {options.h, options.b}};
|
||||
ProblemShape problem_shape{options.q, options.k, options.d, options.d_vo, {{h_r, options.h_k}, options.b}};
|
||||
return cute::make_tuple(problem_shape, problem_shape);
|
||||
}
|
||||
}
|
||||
@ -616,22 +594,23 @@ struct BwdRunner {
|
||||
auto [problem_shape, tensor_shape] = initialize_problem_shape(options);
|
||||
auto [Q, K, D, D_VO, HB] = tensor_shape;
|
||||
auto [H, B] = HB;
|
||||
auto [H_R, H_K] = H;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
|
||||
// for varlen, Q == total_Q, K == total_K, B = 1
|
||||
// but in problem_shape, they've got to be max_Q/max_K, and B = B
|
||||
|
||||
auto shape_Q = make_shape(Q, D, make_shape(H, B));
|
||||
auto shape_O = make_shape(Q, D_VO, make_shape(H, B));
|
||||
auto shape_K = make_shape(K, D, make_shape(H, B));
|
||||
auto shape_V = make_shape(K, D_VO, make_shape(H, B));
|
||||
auto shape_LSE = make_shape(Q, make_shape(H, B));
|
||||
auto shape_Q = make_shape(Q, D, HB);
|
||||
auto shape_K = make_shape(K, D, HB);
|
||||
auto shape_V = make_shape(K, D_VO, HB);
|
||||
auto shape_O = make_shape(Q, D_VO, HB);
|
||||
auto shape_LSE = make_shape(Q, HB);
|
||||
|
||||
stride_Q = make_stride(D, _1{}, make_stride(D*Q, B == 1 ? 0 : D*Q*H));
|
||||
stride_K = make_stride(D, _1{}, make_stride(D*K, B == 1 ? 0 : D*K*H));
|
||||
stride_V = make_stride(D_VO, _1{}, make_stride(D_VO*K, B == 1 ? 0 : D_VO*K*H));
|
||||
stride_O = make_stride(D_VO, _1{}, make_stride(D_VO*Q, B == 1 ? 0 : D_VO*Q*H));
|
||||
stride_LSE = make_stride(_1{}, make_stride(Q, B == 1 ? 0 : Q*H));
|
||||
stride_Q = make_stride(D, _1{}, make_stride(make_stride(D*Q, D*Q*H_R), B == 1 ? 0 : D*Q*H_R*H_K));
|
||||
stride_K = make_stride(D, _1{}, make_stride(make_stride(_0{}, D*K), B == 1 ? 0 : D*K*H_K));
|
||||
stride_V = make_stride(D_VO, _1{}, make_stride(make_stride(_0{},D_VO*K), B == 1 ? 0 : D_VO*K*H_K));
|
||||
stride_O = make_stride(D_VO, _1{}, make_stride(make_stride(D_VO*Q, D_VO*Q*H_R), B == 1 ? 0 : D_VO*Q*H_R*H_K));
|
||||
stride_LSE = make_stride(_1{}, make_stride(make_stride(Q, Q*H_R), B == 1 ? 0 : Q*H_R*H_K));
|
||||
|
||||
stride_dQ = stride_Q;
|
||||
stride_dK = stride_K;
|
||||
@ -642,20 +621,23 @@ struct BwdRunner {
|
||||
return size(make_shape(1ull, shape));
|
||||
};
|
||||
|
||||
auto size_K = lsize(K * D * H_K * B);
|
||||
auto size_V = lsize(K * D_VO * H_K * B);
|
||||
|
||||
block_Q.reset(lsize(shape_Q));
|
||||
block_K.reset(lsize(shape_K));
|
||||
block_V.reset(lsize(shape_V));
|
||||
block_K.reset(size_K);
|
||||
block_V.reset(size_V);
|
||||
block_O.reset(lsize(shape_O));
|
||||
block_LSE.reset(lsize(shape_LSE));
|
||||
|
||||
block_dQ.reset(lsize(shape_Q));
|
||||
block_dK.reset(lsize(shape_K));
|
||||
block_dV.reset(lsize(shape_V));
|
||||
block_dK.reset(size_K);
|
||||
block_dV.reset(size_V);
|
||||
block_dO.reset(lsize(shape_O));
|
||||
|
||||
block_ref_dQ.reset(lsize(shape_Q));
|
||||
block_ref_dK.reset(lsize(shape_K));
|
||||
block_ref_dV.reset(lsize(shape_V));
|
||||
block_ref_dK.reset(size_K);
|
||||
block_ref_dV.reset(size_V);
|
||||
|
||||
initialize_block(block_Q, seed + 2023, options.init_style_q);
|
||||
initialize_block(block_K, seed + 2022, options.init_style_k);
|
||||
@ -689,7 +671,7 @@ struct BwdRunner {
|
||||
select<0,4>(problem_shape),
|
||||
stride_LSE);
|
||||
|
||||
if (! options.skip_reference) {
|
||||
if (not options.skip_reference) {
|
||||
fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{});
|
||||
}
|
||||
|
||||
@ -820,7 +802,8 @@ struct BwdRunner {
|
||||
flops *= static_cast<double>(get<0>(problem_shape));
|
||||
flops *= static_cast<double>(get<1>(problem_shape));
|
||||
flops *= (3 * static_cast<double>(get<2>(problem_shape)) + 2 * static_cast<double>(get<3>(problem_shape)));
|
||||
flops *= static_cast<double>(get<4,0>(problem_shape));
|
||||
flops *= static_cast<double>(get<4,0,0>(problem_shape));
|
||||
flops *= static_cast<double>(get<4,0,1>(problem_shape));
|
||||
flops *= static_cast<double>(get<4,1>(problem_shape));
|
||||
double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);
|
||||
example_result.tflops_tc_s = tflops_s;
|
||||
@ -1001,7 +984,7 @@ int main_single(int argc, char const **args) {
|
||||
hw_info.sm_count = options.sm_count;
|
||||
}
|
||||
|
||||
std::cout << "###### B " << options.b << " H " << options.h << " Q " << options.q << " K " << options.k << " D " << options.d << " D_VO " << options.d_vo << " ";
|
||||
std::cout << "###### B " << options.b << " H " << options.h << " H_K " << options.h_k << " Q " << options.q << " K " << options.k << " D " << options.d << " D_VO " << options.d_vo << " ";
|
||||
std::cout << "Backward" << " " << (options.causal ? "Causal" : "Full") << " ";
|
||||
std::cout << "#SM " << hw_info.sm_count << std::endl;
|
||||
|
||||
|
||||
@ -60,6 +60,38 @@ template<
|
||||
class Mask
|
||||
>
|
||||
class Sm100FmhaBwd {
|
||||
private:
|
||||
template <typename T>
|
||||
constexpr static auto to_bwd_shape(T shape) {
|
||||
if constexpr (IsMla) { // remove GQA mode
|
||||
constexpr int R = decltype(rank(shape))::value;
|
||||
auto HB = get<R-1>(shape);
|
||||
auto rest = take<0,R-1>(shape);
|
||||
return append(rest, make_shape(size<0>(HB), get<1>(HB)));
|
||||
}
|
||||
else {
|
||||
return shape;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr static auto to_bwd_stride(T stride) {
|
||||
if constexpr (IsMla) { // remove GQA mode
|
||||
constexpr int R = decltype(rank(stride))::value;
|
||||
auto HB = get<R-1>(stride);
|
||||
auto rest = take<0,R-1>(stride);
|
||||
if constexpr (is_same_v<remove_cv_t<decltype(get<0,0>(HB))>, _0>) {
|
||||
return append(rest, make_stride(get<0,1>(HB), get<1>(HB)));
|
||||
}
|
||||
else {
|
||||
return append(rest, make_stride(get<0,0>(HB), get<1>(HB)));
|
||||
}
|
||||
}
|
||||
else {
|
||||
return stride;
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
/// Argument structure: User API
|
||||
struct Arguments {
|
||||
@ -67,26 +99,26 @@ public:
|
||||
ProblemShape problem_shape;
|
||||
|
||||
const Element* ptr_Q;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_Q;
|
||||
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int,int>, int>> stride_Q;
|
||||
const Element* ptr_K;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_K;
|
||||
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<cute::_0,int>, int>> stride_K;
|
||||
const Element* ptr_V;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_V;
|
||||
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<cute::_0,int>, int>> stride_V;
|
||||
|
||||
const Element* ptr_O;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_O;
|
||||
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int,int>, int>> stride_O;
|
||||
const ElementAccumulator* ptr_LSE;
|
||||
cute::tuple<cute::_1, cute::tuple<int, int>> stride_LSE;
|
||||
cute::tuple<cute::_1, cute::tuple<cute::tuple<int,int>, int>> stride_LSE;
|
||||
|
||||
const Element* ptr_dO;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dO;
|
||||
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int,int>, int>> stride_dO;
|
||||
|
||||
Element* ptr_dQ;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dQ;
|
||||
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int,int>, int>> stride_dQ;
|
||||
Element* ptr_dK;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dK;
|
||||
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<cute::_0,int>, int>> stride_dK;
|
||||
Element* ptr_dV;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dV;
|
||||
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<cute::_0,int>, int>> stride_dV;
|
||||
|
||||
ElementAccumulator softmax_scale;
|
||||
|
||||
@ -106,9 +138,10 @@ public:
|
||||
>
|
||||
>;
|
||||
|
||||
using ProblemShapeMLA = decltype(to_bwd_shape(ProblemShape{}));
|
||||
using OperationMla = cutlass::fmha::device::FMHA<
|
||||
cutlass::fmha::kernel::Sm100FmhaBwdMlaKernelTmaWarpSpecialized<
|
||||
ProblemShape, Element, ElementAccumulator, TileShape, Mask
|
||||
ProblemShapeMLA, Element, ElementAccumulator, TileShape, Mask
|
||||
>
|
||||
>;
|
||||
|
||||
@ -134,10 +167,11 @@ private:
|
||||
using namespace cute;
|
||||
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
|
||||
auto [H, B] = HB;
|
||||
auto [H_R, H_K] = H;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
|
||||
auto stride_sum_OdO = make_stride(_1{}, make_stride(Q, Q*H));
|
||||
auto stride_scaled_lse = make_stride(_1{}, make_stride(Q, Q*H));
|
||||
auto stride_sum_OdO = make_stride(_1{}, make_stride(make_stride(Q, Q*H_R), B == 1 ? 0 : Q*H_R*H_K));
|
||||
auto stride_scaled_lse = make_stride(_1{}, make_stride(make_stride(Q, Q*H_R), B == 1 ? 0 : Q*H_R*H_K));
|
||||
auto log2_e = log2f(expf(1.0f));
|
||||
return typename OperationSumOdO::Arguments {
|
||||
args.problem_shape,
|
||||
@ -154,14 +188,15 @@ private:
|
||||
using namespace cute;
|
||||
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
|
||||
auto [H, B] = HB;
|
||||
auto [H_R, H_K] = H;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
|
||||
auto stride_src_dQ = make_stride(D, _1{}, make_stride(D*Q, D*Q*H));
|
||||
auto stride_src_dQ = make_stride(D, _1{}, make_stride(make_stride(D*Q, D*Q*H_R), B == 1 ? 0 : D*Q*H_R*H_K));
|
||||
return typename OperationConvert::Arguments {
|
||||
args.problem_shape,
|
||||
src, stride_src_dQ,
|
||||
nullptr, stride_src_dQ,
|
||||
nullptr, stride_src_dQ,
|
||||
nullptr, args.stride_dK,
|
||||
nullptr, args.stride_dV,
|
||||
args.ptr_dQ, args.stride_dQ,
|
||||
nullptr, args.stride_dK,
|
||||
nullptr, args.stride_dV,
|
||||
@ -171,22 +206,22 @@ private:
|
||||
|
||||
static typename Operation::Arguments to_bwd_arguments(
|
||||
Arguments const& args,
|
||||
ElementAccumulator* sum_OdO = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_sum_OdO = {},
|
||||
ElementAccumulator* scaled_lse = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_scaled_lse = {},
|
||||
ElementAccumulator* dQ_acc = nullptr, cute::tuple<int, cute::_1, cute::tuple<int, int>> const& stride_dQ = {}) {
|
||||
|
||||
ElementAccumulator* sum_OdO = nullptr, cute::tuple<cute::_1, cute::tuple<cute::tuple<int, int>, int>> const& stride_sum_OdO = {},
|
||||
ElementAccumulator* scaled_lse = nullptr, cute::tuple<cute::_1, cute::tuple<cute::tuple<int, int>, int>> const& stride_scaled_lse = {},
|
||||
ElementAccumulator* dQ_acc = nullptr, cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int, int>, int>> const& stride_dQ = {}) {
|
||||
|
||||
return typename Operation::Arguments{
|
||||
args.problem_shape,
|
||||
{ args.ptr_Q, args.stride_Q,
|
||||
args.ptr_K, args.stride_K,
|
||||
args.ptr_V, args.stride_V,
|
||||
args.ptr_dO, args.stride_dO,
|
||||
scaled_lse, stride_scaled_lse,
|
||||
sum_OdO, stride_sum_OdO,
|
||||
dQ_acc, stride_dQ,
|
||||
to_bwd_shape(args.problem_shape),
|
||||
{ args.ptr_Q, to_bwd_stride(args.stride_Q),
|
||||
args.ptr_K, to_bwd_stride(args.stride_K),
|
||||
args.ptr_V, to_bwd_stride(args.stride_V),
|
||||
args.ptr_dO, to_bwd_stride(args.stride_dO),
|
||||
scaled_lse, to_bwd_stride(stride_scaled_lse),
|
||||
sum_OdO, to_bwd_stride(stride_sum_OdO),
|
||||
dQ_acc, to_bwd_stride(stride_dQ),
|
||||
args.softmax_scale },
|
||||
{ args.ptr_dK, args.stride_dK,
|
||||
args.ptr_dV, args.stride_dV },
|
||||
{ args.ptr_dK, to_bwd_stride(args.stride_dK),
|
||||
args.ptr_dV, to_bwd_stride(args.stride_dV) },
|
||||
args.hw_info
|
||||
};
|
||||
}
|
||||
@ -220,7 +255,7 @@ public:
|
||||
static size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
|
||||
auto [H, B] = HB;
|
||||
auto [H, B] = product_each(HB);
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
|
||||
size_t workspace_bytes = 0;
|
||||
@ -240,7 +275,7 @@ public:
|
||||
<< workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
|
||||
auto [H, B] = HB;
|
||||
auto [H, B] = product_each(HB);
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
|
||||
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_sum_OdO);
|
||||
@ -269,7 +304,7 @@ public:
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
|
||||
auto [H, B] = HB;
|
||||
auto [H, B] = product_each(HB);
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
|
||||
char* workspace_chr = reinterpret_cast<char*>(workspace);
|
||||
|
||||
@ -46,18 +46,18 @@ struct FmhaKernelBwdConvert {
|
||||
ProblemShape problem_shape;
|
||||
|
||||
const ElementAcc* ptr_src_dQ;
|
||||
tuple<int, _1, tuple<int, int>> stride_src_dQ;
|
||||
tuple<int, _1, tuple<tuple<int, int>, int>> stride_src_dQ;
|
||||
const ElementAcc* ptr_src_dK;
|
||||
tuple<int, _1, tuple<int, int>> stride_src_dK;
|
||||
tuple<int, _1, tuple<tuple<_0, int>, int>> stride_src_dK;
|
||||
const ElementAcc* ptr_src_dV;
|
||||
tuple<int, _1, tuple<int, int>> stride_src_dV;
|
||||
tuple<int, _1, tuple<tuple<_0, int>, int>> stride_src_dV;
|
||||
|
||||
Element* ptr_dest_dQ;
|
||||
tuple<int, _1, tuple<int, int>> stride_dest_dQ;
|
||||
tuple<int, _1, tuple<tuple<int, int>, int>> stride_dest_dQ;
|
||||
Element* ptr_dest_dK;
|
||||
tuple<int, _1, tuple<int, int>> stride_dest_dK;
|
||||
tuple<int, _1, tuple<tuple<_0, int>, int>> stride_dest_dK;
|
||||
Element* ptr_dest_dV;
|
||||
tuple<int, _1, tuple<int, int>> stride_dest_dV;
|
||||
tuple<int, _1, tuple<tuple<_0, int>, int>> stride_dest_dV;
|
||||
|
||||
ElementAcc scale = 1.0;
|
||||
};
|
||||
@ -104,8 +104,8 @@ struct FmhaKernelBwdConvert {
|
||||
|
||||
template<class StrideSrc, class StrideDest, class Count>
|
||||
CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, Count const& count, int d_dim) {
|
||||
auto ptr_src_bh = ptr_src + get<2,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y;
|
||||
auto ptr_dest_bh = ptr_dest + get<2,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y;
|
||||
auto ptr_src_bh = ptr_src + get<2,0,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y;
|
||||
auto ptr_dest_bh = ptr_dest + get<2,0,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y;
|
||||
|
||||
int seqlen = count;
|
||||
if constexpr (is_variable_length_v<decltype(count)>) {
|
||||
|
||||
@ -46,18 +46,18 @@ struct FmhaKernelBwdSumOdO {
|
||||
ProblemShape problem_shape;
|
||||
|
||||
const Element* ptr_O;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_O;
|
||||
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int, int>, int>> stride_O;
|
||||
const Element* ptr_dO;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dO;
|
||||
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int, int>, int>> stride_dO;
|
||||
|
||||
ElementAcc* ptr_sum_OdO;
|
||||
cute::tuple<cute::_1, cute::tuple<int, int>> stride_sum_OdO;
|
||||
cute::tuple<cute::_1, cute::tuple<cute::tuple<int, int>, int>> stride_sum_OdO;
|
||||
|
||||
const ElementAcc* ptr_lse = nullptr;
|
||||
cute::tuple<cute::_1, cute::tuple<int, int>> stride_lse;
|
||||
cute::tuple<cute::_1, cute::tuple<cute::tuple<int, int>, int>> stride_lse;
|
||||
|
||||
ElementAcc* ptr_scaled_lse = nullptr;
|
||||
cute::tuple<cute::_1, cute::tuple<int, int>> stride_scaled_lse;
|
||||
cute::tuple<cute::_1, cute::tuple<cute::tuple<int, int>, int>> stride_scaled_lse;
|
||||
|
||||
ElementAcc sum_odo_scale = 1.0;
|
||||
ElementAcc lse_scale = 1.0;
|
||||
@ -104,11 +104,11 @@ struct FmhaKernelBwdSumOdO {
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
|
||||
auto ptr_O_bh = params.ptr_O + blockIdx.y * get<2,0>(params.stride_O) + blockIdx.z * get<2,1>(params.stride_O);
|
||||
auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<2,0>(params.stride_dO) + blockIdx.z * get<2,1>(params.stride_dO);
|
||||
auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1,0>(params.stride_sum_OdO) + blockIdx.z * get<1,1>(params.stride_sum_OdO);
|
||||
auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse);
|
||||
auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse);
|
||||
auto ptr_O_bh = params.ptr_O + blockIdx.y * get<2,0,0>(params.stride_O) + blockIdx.z * get<2,1>(params.stride_O);
|
||||
auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<2,0,0>(params.stride_dO) + blockIdx.z * get<2,1>(params.stride_dO);
|
||||
auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1,0,0>(params.stride_sum_OdO) + blockIdx.z * get<1,1>(params.stride_sum_OdO);
|
||||
auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse);
|
||||
auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse);
|
||||
|
||||
auto problem_q = get<0>(params.problem_shape);
|
||||
int seqlen_q = problem_q;
|
||||
|
||||
@ -119,13 +119,15 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
static constexpr int Alignment = 128 / sizeof_bits_v<Element>;
|
||||
static constexpr int kStages = 2;
|
||||
|
||||
using TensorStrideContiguousK = Stride<int, _1, Stride<int, int>>;
|
||||
using TensorStrideContiguousMN = Stride<_1, int, Stride<int, int>>;
|
||||
using TensorStrideContiguousK = Stride<int, _1, Stride<Stride<int,int>, int>>;
|
||||
using TensorStrideContiguousMN = Stride<_1, int, Stride<Stride<int,int>, int>>;
|
||||
using TensorStrideContiguousK_GQA = Stride<int, _1, Stride<Stride<_0,int>, int>>;
|
||||
using TensorStrideContiguousMN_GQA = Stride<_1, int, Stride<Stride<_0,int>, int>>;
|
||||
|
||||
// compute S
|
||||
using CollectiveMmaKQ = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
||||
Element, TensorStrideContiguousK, Alignment,
|
||||
Element, TensorStrideContiguousK_GQA, Alignment,
|
||||
Element, TensorStrideContiguousK, Alignment,
|
||||
ElementAcc,
|
||||
Shape<TileShapeK, TileShapeQ, TileShapeDQK>,
|
||||
@ -137,7 +139,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
// compute dP
|
||||
using CollectiveMmaVDO = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
||||
Element, TensorStrideContiguousK, Alignment,
|
||||
Element, TensorStrideContiguousK_GQA, Alignment,
|
||||
Element, TensorStrideContiguousK, Alignment,
|
||||
ElementAcc,
|
||||
Shape<TileShapeK, TileShapeQ, TileShapeDVO>,
|
||||
@ -177,7 +179,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
||||
// somewhat arbitrary since we dump to smem, need to agree with the previous one
|
||||
Element, TensorStrideContiguousMN, Alignment,
|
||||
Element, TensorStrideContiguousMN, Alignment,
|
||||
Element, TensorStrideContiguousMN_GQA, Alignment,
|
||||
ElementAcc,
|
||||
Shape<TileShapeQ, TileShapeDQK, TileShapeK>,
|
||||
ClusterShape, cutlass::gemm::collective::StageCount<kStages>,
|
||||
@ -278,15 +280,16 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem");
|
||||
|
||||
using TensorStride = TensorStrideContiguousK; // S D (H B)
|
||||
using RowTensorStride = Stride<_1, Stride<int, int>>; // S (H B)
|
||||
using TensorStride_GQA = TensorStrideContiguousK_GQA;
|
||||
using RowTensorStride = Stride<_1, Stride<Stride<int, int>, int>>; // S (H B)
|
||||
|
||||
struct MainloopArguments {
|
||||
const Element* ptr_q;
|
||||
TensorStride stride_q;
|
||||
const Element* ptr_k;
|
||||
TensorStride stride_k;
|
||||
TensorStride_GQA stride_k;
|
||||
const Element* ptr_v;
|
||||
TensorStride stride_v;
|
||||
TensorStride_GQA stride_v;
|
||||
const Element* ptr_do;
|
||||
TensorStride stride_do;
|
||||
|
||||
@ -308,7 +311,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
using TMA_DO = typename CollectiveMmaVDO::Params::TMA_B;
|
||||
|
||||
using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{},
|
||||
make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(1, 1)), TensorStride{}),
|
||||
make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(make_shape(1,1), 1)), TensorStride{}),
|
||||
SmemLayoutDQ{}(_, _, _0{})
|
||||
));
|
||||
|
||||
@ -322,9 +325,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
struct EpilogueArguments {
|
||||
Element* ptr_dk;
|
||||
TensorStride stride_dk;
|
||||
TensorStride_GQA stride_dk;
|
||||
Element* ptr_dv;
|
||||
TensorStride stride_dv;
|
||||
TensorStride_GQA stride_dv;
|
||||
};
|
||||
|
||||
struct Arguments {
|
||||
@ -346,7 +349,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
static bool can_implement(Arguments const& args) {
|
||||
auto [Q, K, D, D_VO, HB] = args.problem_shape;
|
||||
auto [H, B] = HB;
|
||||
if (Q <= 0 || K <= 0 || D <= 0 || D_VO <= 0 || H <= 0 || B <= 0) {
|
||||
auto [H_R, H_K] = H;
|
||||
if (Q <= 0 || K <= 0 || D <= 0 || D_VO <= 0 || H_R <= 0 || H_K <= 0 || B <= 0) {
|
||||
return false;
|
||||
}
|
||||
if (D % Alignment != 0 || D_VO % Alignment != 0) {
|
||||
@ -432,7 +436,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
BlkCoord const& blk_coord,
|
||||
BlkOffset const& blk_offset,
|
||||
ProblemShape_ const& problem_shape,
|
||||
int iter_index,
|
||||
int iter_start,
|
||||
int iter_end,
|
||||
int iter_count,
|
||||
MainloopArguments const& mainloop_args,
|
||||
MainloopParams const& mainloop_params,
|
||||
@ -447,6 +452,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_producer_state) {
|
||||
|
||||
auto [Q, K, D, D_VO, HB] = problem_shape;
|
||||
int iter_index = iter_start;
|
||||
|
||||
using X = Underscore;
|
||||
|
||||
@ -590,6 +596,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
iter_index += 1;
|
||||
|
||||
while (iter_count > 0) {
|
||||
if (iter_index == iter_end) {
|
||||
iter_index = iter_start;
|
||||
get<0,0>(blk_coord_batch) += 1;
|
||||
}
|
||||
|
||||
pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state);
|
||||
tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);
|
||||
|
||||
@ -660,7 +671,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
CUTLASS_DEVICE void mma(
|
||||
BlkCoord const& blk_coord,
|
||||
ProblemShape_ const& problem_shape,
|
||||
int iter_index,
|
||||
int iter_start,
|
||||
int iter_end,
|
||||
int iter_count,
|
||||
MainloopArguments const& mainloop_args,
|
||||
TensorStorage& shared_tensors,
|
||||
@ -1119,7 +1131,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
BlkCoord const& blk_coord,
|
||||
BlkOffset const& blk_offset,
|
||||
ProblemShape_ const& problem_shape,
|
||||
int iter_index,
|
||||
int iter_start,
|
||||
int iter_end,
|
||||
int iter_count,
|
||||
MainloopArguments const& mainloop_args,
|
||||
EpilogueArguments const& epilogue_args,
|
||||
@ -1141,6 +1154,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
|
||||
auto [Q, K, D, D_VO, HB] = problem_shape;
|
||||
int iter_index = iter_start;
|
||||
|
||||
// in tmem, S & P overlap
|
||||
// and dP and dQ overlap
|
||||
@ -1224,8 +1238,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
auto tRT_cST_p = thread_r2t.partition_S(tDVcST);
|
||||
auto tRT_cST = split_wg(tRT_cST_p);
|
||||
|
||||
bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} >= get<1>(problem_shape);
|
||||
int last_iter = iter_count - 1 + iter_index;
|
||||
bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} > get<1>(problem_shape);
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
while (iter_count > 0) {
|
||||
@ -1246,7 +1259,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
bool leading_causal_masking = false;
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {
|
||||
leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord));
|
||||
leading_causal_masking = warp_uniform(iter_index == iter_start);
|
||||
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
|
||||
int offset = get<1>(problem_shape) - get<0>(problem_shape);
|
||||
int kv_left = get<1>(blk_coord) * TileShapeK{};
|
||||
@ -1258,7 +1271,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
}
|
||||
bool trailing_residual_masking = false;
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::ResidualMaskForBackward, Mask>) {
|
||||
trailing_residual_masking = warp_uniform((iter_index == last_iter) || is_residual_k);
|
||||
trailing_residual_masking = warp_uniform((iter_index == iter_end - 1) || is_residual_k);
|
||||
}
|
||||
|
||||
dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) {
|
||||
@ -1379,6 +1392,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
iter_count -= 1;
|
||||
iter_index += 1;
|
||||
if (iter_index == iter_end) {
|
||||
iter_index = iter_start;
|
||||
}
|
||||
}
|
||||
|
||||
epilogue(
|
||||
@ -1391,7 +1407,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
CUTLASS_DEVICE void reduce(
|
||||
BlkCoord const& blk_coord,
|
||||
ProblemShape_ const& problem_shape,
|
||||
int iter_index,
|
||||
int iter_start,
|
||||
int iter_end,
|
||||
int iter_count,
|
||||
MainloopArguments const& mainloop_args,
|
||||
MainloopParams const& mainloop_params,
|
||||
@ -1404,6 +1421,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
using X = Underscore;
|
||||
|
||||
auto [Q, K, D, D_VO, HB] = problem_shape;
|
||||
int iter_index = iter_start;
|
||||
|
||||
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;
|
||||
|
||||
@ -1415,7 +1433,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB));
|
||||
auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step<X, _1, _1>{})
|
||||
(_, _, _, _0{}, blk_coord_batch);
|
||||
(_, _, _, _0{}, _);
|
||||
|
||||
Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{}));
|
||||
|
||||
@ -1426,7 +1444,6 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
auto thread_t2r = tiled_t2r.get_slice(thread_idx);
|
||||
|
||||
Tensor tTR_cDQ = thread_t2r.partition_D(cDQ);
|
||||
Tensor tTR_gDQ = thread_t2r.partition_D(gDQ);
|
||||
Tensor tTR_sDQ = thread_t2r.partition_D(sDQ);
|
||||
Tensor tTR_tDQ = thread_t2r.partition_S(tDQtDQ);
|
||||
|
||||
@ -1472,7 +1489,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
).arrive_and_wait();
|
||||
if (lane_predicate) {
|
||||
// launch tma store
|
||||
copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index));
|
||||
copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index,blk_coord_batch));
|
||||
pipeline_reduce_tma_store.producer_commit(pipeline_reduce_tma_store_producer_state);
|
||||
}
|
||||
|
||||
@ -1481,6 +1498,10 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
iter_count -= 1;
|
||||
iter_index += 1;
|
||||
if (iter_index == iter_end) {
|
||||
iter_index = iter_start;
|
||||
get<0,0>(blk_coord_batch) += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1683,12 +1704,12 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
pipeline_init_wait(size(ClusterShape{}));
|
||||
|
||||
auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(blockIdx.y, blockIdx.z));
|
||||
auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(make_coord(0, blockIdx.y), blockIdx.z));
|
||||
auto [problem_shape, blk_offset] = apply_variable_length_offset(
|
||||
params.problem_shape,
|
||||
blk_coord
|
||||
);
|
||||
int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{});
|
||||
int iter_end = ceil_div(get<0>(problem_shape), TileShapeQ{});
|
||||
int iter_start = 0;
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {
|
||||
iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{};
|
||||
@ -1699,7 +1720,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) {
|
||||
return;
|
||||
}
|
||||
iter_count -= iter_start;
|
||||
int iter_count = (iter_end - iter_start) * get<4,0,0>(problem_shape);
|
||||
|
||||
if (iter_count <= 0) {
|
||||
epilogue_clear(
|
||||
@ -1720,6 +1741,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
blk_offset,
|
||||
problem_shape,
|
||||
iter_start,
|
||||
iter_end,
|
||||
iter_count,
|
||||
params.mainloop,
|
||||
params.mainloop_params,
|
||||
@ -1741,6 +1763,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
blk_coord,
|
||||
problem_shape,
|
||||
iter_start,
|
||||
iter_end,
|
||||
iter_count,
|
||||
params.mainloop,
|
||||
shared_storage.tensors,
|
||||
@ -1763,6 +1786,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
blk_offset,
|
||||
problem_shape,
|
||||
iter_start,
|
||||
iter_end,
|
||||
iter_count,
|
||||
params.mainloop,
|
||||
params.epilogue,
|
||||
@ -1794,6 +1818,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
blk_coord,
|
||||
problem_shape,
|
||||
iter_start,
|
||||
iter_end,
|
||||
iter_count,
|
||||
params.mainloop,
|
||||
params.mainloop_params,
|
||||
@ -1820,7 +1845,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
auto [Q, K, D, D_VO, HB] = params.problem_shape;
|
||||
auto [H, B] = HB;
|
||||
dim3 grid(ceil_div(K, TileShapeK{}), H, B);
|
||||
auto [H_R, H_K] = H;
|
||||
dim3 grid(ceil_div(K, TileShapeK{}), H_K, B);
|
||||
return grid;
|
||||
}
|
||||
};
|
||||
|
||||
@ -56,12 +56,12 @@ void __global__ fmha_bwd_reference_dQ_kernel(
|
||||
using namespace cutlass::fmha::collective;
|
||||
|
||||
using Element = typename TensorO::value_type;
|
||||
using ElementAccumulator = typename TensorLSE::value_type;
|
||||
using ElementAcc = typename TensorLSE::value_type;
|
||||
|
||||
extern __shared__ char mS_mem[];
|
||||
ElementAccumulator* mS = reinterpret_cast<ElementAccumulator*>(mS_mem);
|
||||
Element* mS = reinterpret_cast<Element*>(mS_mem);
|
||||
|
||||
ElementAccumulator softmax_scale = 1.0 / sqrt(ElementAccumulator(size<2>(problem_shape_in)));
|
||||
ElementAcc softmax_scale = 1.0f / sqrtf(size<2>(problem_shape_in));
|
||||
|
||||
for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) {
|
||||
auto [problem_shape, offset] = apply_variable_length_offset(
|
||||
@ -79,9 +79,9 @@ void __global__ fmha_bwd_reference_dQ_kernel(
|
||||
auto mDQ = domain_offset(select<0,2,4>(offset), mDQ_in);
|
||||
for (int idx_Q = blockIdx.x; idx_Q < size<0>(problem_shape); idx_Q += gridDim.x) {
|
||||
for (int idx_K = threadIdx.x; idx_K < size<1>(problem_shape); idx_K += blockDim.x) {
|
||||
ElementAccumulator acc_qk = 0;
|
||||
ElementAccumulator acc_dov = 0;
|
||||
ElementAccumulator acc_doo = 0;
|
||||
ElementAcc acc_qk = 0;
|
||||
ElementAcc acc_dov = 0;
|
||||
ElementAcc acc_doo = 0;
|
||||
for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) {
|
||||
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
|
||||
// acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
|
||||
@ -94,20 +94,22 @@ void __global__ fmha_bwd_reference_dQ_kernel(
|
||||
}
|
||||
|
||||
auto id = make_identity_tensor(make_shape(1, 1));
|
||||
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
|
||||
auto frag = make_tensor<ElementAcc>(Shape<_1, _1>{});
|
||||
frag(0) = acc_qk;
|
||||
fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
|
||||
acc_qk = frag(0);
|
||||
|
||||
mS[idx_K] = static_cast<ElementAccumulator>(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo));
|
||||
mS[idx_K] = static_cast<Element>(expf(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo));
|
||||
} // for idx_K
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
|
||||
ElementAccumulator acc = 0;
|
||||
ElementAcc acc = 0;
|
||||
for (int idx_K = 0; idx_K < size<1>(problem_shape); idx_K++) {
|
||||
acc += mS[idx_K] * ElementAccumulator(mK(idx_K, idx_D, idx_L));
|
||||
ElementAcc rK = mK(idx_K, idx_D, idx_L);
|
||||
ElementAcc rDS = mS[idx_K];
|
||||
acc += rDS * rK;
|
||||
}
|
||||
mDQ(idx_Q, idx_D, idx_L) = static_cast<typename TensorDQ::value_type>(acc);
|
||||
} // for idx_D
|
||||
@ -135,62 +137,83 @@ void __global__ fmha_bwd_reference_dK_kernel(
|
||||
using namespace cutlass::fmha::collective;
|
||||
|
||||
using Element = typename TensorO::value_type;
|
||||
using ElementAccumulator = typename TensorLSE::value_type;
|
||||
using ElementAcc = typename TensorLSE::value_type;
|
||||
|
||||
extern __shared__ char mS_mem[];
|
||||
ElementAccumulator* mS = reinterpret_cast<ElementAccumulator*>(mS_mem);
|
||||
Element* mS = reinterpret_cast<Element*>(mS_mem);
|
||||
|
||||
ElementAccumulator softmax_scale = 1.0 / sqrt(ElementAccumulator(size<2>(problem_shape_in)));
|
||||
ElementAcc softmax_scale = 1.0f / sqrtf(size<2>(problem_shape_in));
|
||||
|
||||
for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) {
|
||||
auto [H, B] = get<4>(problem_shape_in);
|
||||
auto [H_R, H_K] = H;
|
||||
|
||||
for (int idx_HB = blockIdx.y; idx_HB < H_K * B; idx_HB += gridDim.y) {
|
||||
auto [idx_H_K, idx_B] = idx2crd(idx_HB, make_shape(H_K, B));
|
||||
auto [problem_shape, offset] = apply_variable_length_offset(
|
||||
problem_shape_in,
|
||||
make_coord(_0{}, _0{}, _0{}, _0{}, idx2crd(idx_L, get<4>(problem_shape_in)))
|
||||
problem_shape_in,
|
||||
make_coord(_0{}, _0{}, _0{}, _0{}, make_coord(make_coord(_0{}, idx_H_K), idx_B))
|
||||
);
|
||||
// problem_shape = problem_shape_in;
|
||||
// offset = repeat_like(problem_shape_in, _0{});
|
||||
auto mQ = domain_offset(select<0,2,4>(offset), mQ_in);
|
||||
auto mK = domain_offset(select<1,2,4>(offset), mK_in);
|
||||
auto mV = domain_offset(select<1,3,4>(offset), mV_in);
|
||||
auto mO = domain_offset(select<0,3,4>(offset), mO_in);
|
||||
auto mLSE = domain_offset(select<0,4>(offset), mLSE_in);
|
||||
auto mDO = domain_offset(select<0,3,4>(offset), mDO_in);
|
||||
auto mDK = domain_offset(select<1,2,4>(offset), mDK_in);
|
||||
for (int idx_K = blockIdx.x; idx_K < size<1>(problem_shape); idx_K += gridDim.x) {
|
||||
for (int idx_Q = threadIdx.x; idx_Q < size<0>(problem_shape); idx_Q += blockDim.x) {
|
||||
ElementAccumulator acc_qk = 0;
|
||||
ElementAccumulator acc_dov = 0;
|
||||
ElementAccumulator acc_doo = 0;
|
||||
for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) {
|
||||
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
|
||||
// acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
|
||||
// acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
|
||||
} // for idx_D0
|
||||
auto [Q, K, D, D_VO, HB] = problem_shape;
|
||||
auto [offset_Q, offset_K, offset_D, offset_D_VO, offset_HB] = offset;
|
||||
|
||||
for (int idx_D1 = 0; idx_D1 < size<3>(problem_shape); idx_D1++) {
|
||||
acc_dov += mDO(idx_Q, idx_D1, idx_L) * mV(idx_K, idx_D1, idx_L);
|
||||
acc_doo += mDO(idx_Q, idx_D1, idx_L) * mO(idx_Q, idx_D1, idx_L);
|
||||
auto mQ = domain_offset(make_coord(offset_Q, offset_D, offset_HB), mQ_in);
|
||||
auto mK = domain_offset(make_coord(offset_K, offset_D, offset_HB), mK_in);
|
||||
auto mV = domain_offset(make_coord(offset_K, offset_D_VO, offset_HB), mV_in);
|
||||
auto mO = domain_offset(make_coord(offset_Q, offset_D_VO, offset_HB), mO_in);
|
||||
auto mLSE = domain_offset(make_coord(offset_Q, offset_HB), mLSE_in);
|
||||
auto mDO = domain_offset(make_coord(offset_Q, offset_D_VO, offset_HB), mDO_in);
|
||||
auto mDK = domain_offset(make_coord(offset_K, offset_D, offset_HB), mDK_in);
|
||||
|
||||
for (int idx_K = blockIdx.x; idx_K < K; idx_K += gridDim.x) {
|
||||
ElementAcc acc_dk = 0;
|
||||
for (int idx_H_R = 0; idx_H_R < H_R; idx_H_R++) {
|
||||
auto coord_HB = make_coord(make_coord(idx_H_R, idx_H_K), idx_B);
|
||||
for (int idx_Q = threadIdx.x; idx_Q < Q; idx_Q += blockDim.x) {
|
||||
ElementAcc acc_qk = 0;
|
||||
ElementAcc acc_dov = 0;
|
||||
ElementAcc acc_doo = 0;
|
||||
for (int idx_D0 = 0; idx_D0 < D; idx_D0++) {
|
||||
ElementAcc rQ = mQ(idx_Q, idx_D0, coord_HB);
|
||||
ElementAcc rK = mK(idx_K, idx_D0, coord_HB);
|
||||
acc_qk += rQ * rK;
|
||||
} // for idx_D0
|
||||
|
||||
for (int idx_D1 = 0; idx_D1 < D_VO; idx_D1++) {
|
||||
ElementAcc rDO = mDO(idx_Q, idx_D1, coord_HB);
|
||||
ElementAcc rV = mV(idx_K, idx_D1, coord_HB);
|
||||
ElementAcc rO = mO(idx_Q, idx_D1, coord_HB);
|
||||
acc_dov += rDO * rV;
|
||||
acc_doo += rDO * rO ;
|
||||
}
|
||||
auto id = make_identity_tensor(make_shape(1, 1));
|
||||
auto frag = make_tensor<ElementAcc>(Shape<_1, _1>{});
|
||||
frag(0) = acc_qk;
|
||||
fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
|
||||
acc_qk = frag(0);
|
||||
|
||||
mS[idx_Q] = static_cast<Element>(expf(softmax_scale * acc_qk - mLSE(idx_Q, coord_HB)) * softmax_scale * (acc_dov - acc_doo));
|
||||
} // for idx_Q
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int idx_D = threadIdx.x;
|
||||
if (idx_D < D) {
|
||||
for (int idx_Q = 0; idx_Q < Q; idx_Q++) {
|
||||
ElementAcc rQ = mQ(idx_Q, idx_D, coord_HB);
|
||||
ElementAcc rDS = mS[idx_Q];
|
||||
acc_dk += rDS * rQ;
|
||||
}
|
||||
}
|
||||
auto id = make_identity_tensor(make_shape(1, 1));
|
||||
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
|
||||
frag(0) = acc_qk;
|
||||
fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
|
||||
acc_qk = frag(0);
|
||||
__syncthreads();
|
||||
} // for idx_H_R
|
||||
|
||||
mS[idx_Q] = static_cast<ElementAccumulator>(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo));
|
||||
} // for idx_Q
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
|
||||
ElementAccumulator acc = 0;
|
||||
for (int idx_Q = 0; idx_Q < size<0>(problem_shape); idx_Q++) {
|
||||
acc += mS[idx_Q] * ElementAccumulator(mQ(idx_Q, idx_D, idx_L));
|
||||
}
|
||||
mDK(idx_K, idx_D, idx_L) = static_cast<typename TensorDK::value_type>(acc);
|
||||
} // for idx_D
|
||||
int idx_D = threadIdx.x;
|
||||
if (idx_D < D) {
|
||||
auto coord_HB = make_coord(make_coord(0, idx_H_K), idx_B);
|
||||
mDK(idx_K, idx_D, coord_HB) = static_cast<typename TensorDK::value_type>(acc_dk);
|
||||
}
|
||||
} // for idx_K
|
||||
} // for idx_L
|
||||
} // for idx_HB
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -216,54 +239,71 @@ void __global__ fmha_bwd_reference_dV_kernel(
|
||||
using ElementAcc = typename TensorLSE::value_type;
|
||||
|
||||
extern __shared__ char mS_mem[];
|
||||
ElementAcc* mS = reinterpret_cast<ElementAcc*>(mS_mem);
|
||||
Element* mS = reinterpret_cast<Element*>(mS_mem);
|
||||
|
||||
ElementAcc softmax_scale = 1.0 / sqrt(ElementAcc(size<2>(problem_shape_in)));
|
||||
ElementAcc softmax_scale = 1.0f / sqrtf(size<2>(problem_shape_in));
|
||||
|
||||
for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) {
|
||||
auto [H, B] = get<4>(problem_shape_in);
|
||||
auto [H_R, H_K] = H;
|
||||
|
||||
for (int idx_HB = blockIdx.y; idx_HB < H_K * B; idx_HB += gridDim.y) {
|
||||
auto [idx_H_K, idx_B] = idx2crd(idx_HB, make_shape(H_K, B));
|
||||
auto [problem_shape, offset] = apply_variable_length_offset(
|
||||
problem_shape_in,
|
||||
make_coord(_0{}, _0{}, _0{}, _0{}, idx2crd(idx_L, get<4>(problem_shape_in)))
|
||||
problem_shape_in,
|
||||
make_coord(_0{}, _0{}, _0{}, _0{}, make_coord(make_coord(_0{}, idx_H_K), idx_B))
|
||||
);
|
||||
// problem_shape = problem_shape_in;
|
||||
// offset = repeat_like(problem_shape_in, _0{});
|
||||
auto mQ = domain_offset(select<0,2,4>(offset), mQ_in);
|
||||
auto mK = domain_offset(select<1,2,4>(offset), mK_in);
|
||||
auto mV = domain_offset(select<1,3,4>(offset), mV_in);
|
||||
auto mO = domain_offset(select<0,3,4>(offset), mO_in);
|
||||
auto mLSE = domain_offset(select<0,4>(offset), mLSE_in);
|
||||
auto mDO = domain_offset(select<0,3,4>(offset), mDO_in);
|
||||
auto mDV = domain_offset(select<1,3,4>(offset), mDV_in);
|
||||
for (int idx_K = blockIdx.x; idx_K < size<1>(problem_shape); idx_K += gridDim.x) {
|
||||
for (int idx_Q = threadIdx.x; idx_Q < size<0>(problem_shape); idx_Q += blockDim.x) {
|
||||
ElementAcc acc_qk = 0;
|
||||
auto [Q, K, D, D_VO, HB] = problem_shape;
|
||||
auto [offset_Q, offset_K, offset_D, offset_D_VO, offset_HB] = offset;
|
||||
|
||||
for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) {
|
||||
ElementAcc rQ = mQ(idx_Q, idx_D0, idx_L);
|
||||
ElementAcc rK = mK(idx_K, idx_D0, idx_L);
|
||||
acc_qk += rQ * rK;
|
||||
} // for idx_D0
|
||||
auto mQ = domain_offset(make_coord(offset_Q, offset_D, offset_HB), mQ_in);
|
||||
auto mK = domain_offset(make_coord(offset_K, offset_D, offset_HB), mK_in);
|
||||
auto mV = domain_offset(make_coord(offset_K, offset_D_VO, offset_HB), mV_in);
|
||||
auto mO = domain_offset(make_coord(offset_Q, offset_D_VO, offset_HB), mO_in);
|
||||
auto mLSE = domain_offset(make_coord(offset_Q, offset_HB), mLSE_in);
|
||||
auto mDO = domain_offset(make_coord(offset_Q, offset_D_VO, offset_HB), mDO_in);
|
||||
auto mDV = domain_offset(make_coord(offset_K, offset_D_VO, offset_HB), mDV_in);
|
||||
|
||||
auto id = make_identity_tensor(make_shape(1, 1));
|
||||
auto frag = make_tensor<ElementAcc>(Shape<_1, _1>{});
|
||||
frag(0) = acc_qk;
|
||||
fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
|
||||
acc_qk = frag(0);
|
||||
for (int idx_K = blockIdx.x; idx_K < K; idx_K += gridDim.x) {
|
||||
ElementAcc acc_dv = 0;
|
||||
for (int idx_H_R = 0; idx_H_R < H_R; idx_H_R++) {
|
||||
auto coord_HB = make_coord(make_coord(idx_H_R, idx_H_K), idx_B);
|
||||
for (int idx_Q = threadIdx.x; idx_Q < Q; idx_Q += blockDim.x) {
|
||||
ElementAcc acc_qk = 0;
|
||||
|
||||
mS[idx_Q] = expf(softmax_scale * acc_qk - mLSE(idx_Q, idx_L));
|
||||
} // for idx_Q
|
||||
for (int idx_D0 = 0; idx_D0 < D; idx_D0++) {
|
||||
ElementAcc rQ = mQ(idx_Q, idx_D0, coord_HB);
|
||||
ElementAcc rK = mK(idx_K, idx_D0, coord_HB);
|
||||
acc_qk += rQ * rK;
|
||||
} // for idx_D0
|
||||
|
||||
__syncthreads();
|
||||
auto id = make_identity_tensor(make_shape(1, 1));
|
||||
auto frag = make_tensor<ElementAcc>(Shape<_1, _1>{});
|
||||
frag(0) = acc_qk;
|
||||
fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
|
||||
acc_qk = frag(0);
|
||||
|
||||
for (int idx_D = threadIdx.x; idx_D < size<3>(problem_shape); idx_D += blockDim.x) {
|
||||
ElementAcc acc = 0;
|
||||
for (int idx_Q = 0; idx_Q < size<0>(problem_shape); idx_Q++) {
|
||||
ElementAcc rS = static_cast<Element>(mS[idx_Q]);
|
||||
ElementAcc rDO = mDO(idx_Q, idx_D, idx_L);
|
||||
acc += rS * rDO;
|
||||
}
|
||||
mDV(idx_K, idx_D, idx_L) = static_cast<typename TensorDV::value_type>(acc);
|
||||
} // for idx_D
|
||||
mS[idx_Q] = static_cast<Element>(expf(softmax_scale * acc_qk - mLSE(idx_Q, coord_HB)));
|
||||
} // for idx_Q
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int idx_D_VO = threadIdx.x;
|
||||
if (idx_D_VO < D_VO) {
|
||||
for (int idx_Q = 0; idx_Q < Q; idx_Q++) {
|
||||
ElementAcc rDO = mDO(idx_Q, idx_D_VO, coord_HB);
|
||||
ElementAcc rP = mS[idx_Q];
|
||||
acc_dv += rP * rDO;
|
||||
}
|
||||
} // for idx_D
|
||||
|
||||
__syncthreads();
|
||||
} // for idx_H_R
|
||||
|
||||
int idx_D_VO = threadIdx.x;
|
||||
if (idx_D_VO < D_VO) {
|
||||
auto coord_HB = make_coord(make_coord(0, idx_H_K), idx_B);
|
||||
mDV(idx_K, idx_D_VO, coord_HB) = static_cast<typename TensorDV::value_type>(acc_dv);
|
||||
}
|
||||
} // for idx_K
|
||||
} // for idx_L
|
||||
}
|
||||
@ -288,7 +328,7 @@ void fmha_bwd_reference_dQ(
|
||||
|
||||
dim3 grid(size<0>(mDQ), size<2>(mDQ), 1);
|
||||
dim3 block(256);
|
||||
int shared_mem = size<0>(mK) * sizeof(typename TensorLSE::value_type);
|
||||
int shared_mem = size<0>(mK) * sizeof(typename TensorDQ::value_type);
|
||||
fmha_bwd_reference_dQ_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion);
|
||||
}
|
||||
|
||||
@ -310,9 +350,12 @@ void fmha_bwd_reference_dK(
|
||||
|
||||
using namespace cute;
|
||||
|
||||
dim3 grid(size<0>(mDK), size<2>(mDK), 1);
|
||||
dim3 block(256);
|
||||
int shared_mem = size<0>(mDO) * sizeof(typename TensorLSE::value_type);
|
||||
auto [K, D, HB] = mDK.shape();
|
||||
auto [H, B] = HB;
|
||||
auto [H_R, H_K] = H;
|
||||
dim3 grid(K, H_K * B, 1);
|
||||
dim3 block(std::max(D, 256));
|
||||
int shared_mem = size<0>(mDO) * sizeof(typename TensorDK::value_type);
|
||||
fmha_bwd_reference_dK_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion);
|
||||
}
|
||||
|
||||
@ -334,9 +377,12 @@ void fmha_bwd_reference_dV(
|
||||
|
||||
using namespace cute;
|
||||
|
||||
dim3 grid(size<0>(mDV), size<2>(mDV), 1);
|
||||
dim3 block(256);
|
||||
int shared_mem = size<0>(mDO) * sizeof(typename TensorLSE::value_type);
|
||||
auto [K, D_VO, HB] = mDV.shape();
|
||||
auto [H, B] = HB;
|
||||
auto [H_R, H_K] = H;
|
||||
dim3 grid(K, H_K * B, 1);
|
||||
dim3 block(std::max(D_VO, 256));
|
||||
int shared_mem = size<0>(mDO) * sizeof(typename TensorDV::value_type);
|
||||
fmha_bwd_reference_dV_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user