diff --git a/examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu b/examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu index 12d2e3f6..4521d87f 100644 --- a/examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu +++ b/examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu @@ -183,6 +183,9 @@ struct Options { cmd.get_cmd_line_argument("h", h, -1); if (h == -1) h = 2048 / d; + cmd.get_cmd_line_argument("h_k", h_k, -1); + if (h_k == -1) h_k = h; + varlen = cmd.check_cmd_line_flag("varlen"); cmd.get_cmd_line_argument("q", q, -1); @@ -298,6 +301,7 @@ struct Options { << " --help If specified, displays this usage statement\n\n" << " --b= Sets the B extent\n" << " --h= Sets the H extent\n" + << " --h_k= Sets the H_K/V extent (for GQA/MQA)\n" << " --q= Sets the Q extent\n" << " --k= Sets the K extent\n" << " --varlen-q=: Sets the variable Q extent per batch (colon separated)\n" @@ -405,25 +409,24 @@ struct BwdRunner { #endif using ElementAccumulator = float; - // Q K D (H B) + // Q K D D_VO ((H_R, H_K) B) using ProblemShape = std::conditional_t< kIsVarlen, - cute::tuple>, - cute::tuple> + cute::tuple, int>>, + cute::tuple, int>> >; - using TensorStride = Stride>; // Seq D (H B) - using StrideQ = TensorStride; - using StrideK = TensorStride; - using StrideV = TensorStride; - using StrideO = TensorStride; - using StrideLSE = Stride<_1, Stride>; // Seq (H B) + using StrideQ = Stride, int>>; // Q D ((H_R, H_K), B) + using StrideK = Stride, int>>; // K D ((H_R, H_K), B) + using StrideV = StrideK; // K D_VO ((H_R, H_K), B) + using StrideO = StrideQ; // Q D_VO ((H_R, H_K), B) + using StrideLSE = Stride<_1, Stride, int>>; // Q ((H_R, H_K), B) // Backwards specific - using StrideDQ = TensorStride; - using StrideDK = TensorStride; - using StrideDV = TensorStride; - using StrideDO = TensorStride; + using StrideDQ = StrideQ; + using StrideDK = StrideK; + using StrideDV = StrideV; + using StrideDO = StrideO; // // Data members @@ -468,43 +471,15 @@ struct BwdRunner { auto [Q, K, D, D_VO, HB] = problem_shape; auto [H, B] = HB; - Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()), - select<0,2,4>(problem_shape), - stride_Q); - - Tensor mK = make_tensor(make_gmem_ptr(block_K.get()), - select<1,2,4>(problem_shape), - stride_K); - - Tensor mV = make_tensor(make_gmem_ptr(block_V.get()), - select<1,3,4>(problem_shape), - stride_V); - - Tensor mO = make_tensor(make_gmem_ptr(block_O.get()), - select<0,3,4>(problem_shape), - stride_O); - - // keep going here! (this might be better in cursor) - - Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()), - select<0,4>(problem_shape), - stride_LSE); - - Tensor mDQ = make_tensor(make_gmem_ptr(block_ref_dQ.get()), - select<0,2,4>(problem_shape), - stride_dQ); - - Tensor mDK = make_tensor(make_gmem_ptr(block_ref_dK.get()), - select<1,2,4>(problem_shape), - stride_dK); - - Tensor mDV = make_tensor(make_gmem_ptr(block_ref_dV.get()), - select<1,3,4>(problem_shape), - stride_dV); - - Tensor mDO = make_tensor(make_gmem_ptr(block_dO.get()), - select<0,3,4>(problem_shape), - stride_dO); + Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()), make_shape(Q, D, HB), stride_Q); + Tensor mK = make_tensor(make_gmem_ptr(block_K.get()), make_shape(K, D, HB), stride_K); + Tensor mV = make_tensor(make_gmem_ptr(block_V.get()), make_shape(K, D_VO, HB), stride_V); + Tensor mO = make_tensor(make_gmem_ptr(block_O.get()), make_shape(Q, D_VO, HB), stride_O); + Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()), make_shape(Q, HB), stride_LSE); + Tensor mDQ = make_tensor(make_gmem_ptr(block_ref_dQ.get()), make_shape(Q, D, HB), stride_dQ); + Tensor mDK = make_tensor(make_gmem_ptr(block_ref_dK.get()), make_shape(K, D, HB), stride_dK); + Tensor mDV = make_tensor(make_gmem_ptr(block_ref_dV.get()), make_shape(K, D_VO, HB), stride_dV); + Tensor mDO = make_tensor(make_gmem_ptr(block_dO.get()), make_shape(Q, D_VO, HB), stride_dO); fmha_bwd_reference(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, mDK, mDV, ActiveMask{}); @@ -549,6 +524,9 @@ struct BwdRunner { } auto initialize_problem_shape(Options const& options) { + int h_r = options.h / options.h_k; + assert(options.h % options.h_k == 0); + if constexpr (kIsVarlen) { int num_batches = options.b; @@ -599,14 +577,14 @@ struct BwdRunner { ProblemShape problem_shape{ {max_seqlen_q, block_cumulative_seqlen_q.get(), total_seqlen_q}, {max_seqlen_kv, block_cumulative_seqlen_kv.get(), total_seqlen_kv}, - options.d, options.d_vo, {options.h, options.b} + options.d, options.d_vo, {{h_r, options.h_k}, options.b} }; - auto tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, options.d, options.d_vo, make_shape(options.h, 1)); + auto tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, options.d, options.d_vo, make_shape(make_shape(h_r, options.h_k), 1)); return cute::make_tuple(problem_shape, tensor_shape); } else { - ProblemShape problem_shape{options.q, options.k, options.d, options.d_vo, {options.h, options.b}}; + ProblemShape problem_shape{options.q, options.k, options.d, options.d_vo, {{h_r, options.h_k}, options.b}}; return cute::make_tuple(problem_shape, problem_shape); } } @@ -616,22 +594,23 @@ struct BwdRunner { auto [problem_shape, tensor_shape] = initialize_problem_shape(options); auto [Q, K, D, D_VO, HB] = tensor_shape; auto [H, B] = HB; + auto [H_R, H_K] = H; D = cutlass::round_up(D, 8); // Alignment // for varlen, Q == total_Q, K == total_K, B = 1 // but in problem_shape, they've got to be max_Q/max_K, and B = B - auto shape_Q = make_shape(Q, D, make_shape(H, B)); - auto shape_O = make_shape(Q, D_VO, make_shape(H, B)); - auto shape_K = make_shape(K, D, make_shape(H, B)); - auto shape_V = make_shape(K, D_VO, make_shape(H, B)); - auto shape_LSE = make_shape(Q, make_shape(H, B)); + auto shape_Q = make_shape(Q, D, HB); + auto shape_K = make_shape(K, D, HB); + auto shape_V = make_shape(K, D_VO, HB); + auto shape_O = make_shape(Q, D_VO, HB); + auto shape_LSE = make_shape(Q, HB); - stride_Q = make_stride(D, _1{}, make_stride(D*Q, B == 1 ? 0 : D*Q*H)); - stride_K = make_stride(D, _1{}, make_stride(D*K, B == 1 ? 0 : D*K*H)); - stride_V = make_stride(D_VO, _1{}, make_stride(D_VO*K, B == 1 ? 0 : D_VO*K*H)); - stride_O = make_stride(D_VO, _1{}, make_stride(D_VO*Q, B == 1 ? 0 : D_VO*Q*H)); - stride_LSE = make_stride(_1{}, make_stride(Q, B == 1 ? 0 : Q*H)); + stride_Q = make_stride(D, _1{}, make_stride(make_stride(D*Q, D*Q*H_R), B == 1 ? 0 : D*Q*H_R*H_K)); + stride_K = make_stride(D, _1{}, make_stride(make_stride(_0{}, D*K), B == 1 ? 0 : D*K*H_K)); + stride_V = make_stride(D_VO, _1{}, make_stride(make_stride(_0{},D_VO*K), B == 1 ? 0 : D_VO*K*H_K)); + stride_O = make_stride(D_VO, _1{}, make_stride(make_stride(D_VO*Q, D_VO*Q*H_R), B == 1 ? 0 : D_VO*Q*H_R*H_K)); + stride_LSE = make_stride(_1{}, make_stride(make_stride(Q, Q*H_R), B == 1 ? 0 : Q*H_R*H_K)); stride_dQ = stride_Q; stride_dK = stride_K; @@ -642,20 +621,23 @@ struct BwdRunner { return size(make_shape(1ull, shape)); }; + auto size_K = lsize(K * D * H_K * B); + auto size_V = lsize(K * D_VO * H_K * B); + block_Q.reset(lsize(shape_Q)); - block_K.reset(lsize(shape_K)); - block_V.reset(lsize(shape_V)); + block_K.reset(size_K); + block_V.reset(size_V); block_O.reset(lsize(shape_O)); block_LSE.reset(lsize(shape_LSE)); block_dQ.reset(lsize(shape_Q)); - block_dK.reset(lsize(shape_K)); - block_dV.reset(lsize(shape_V)); + block_dK.reset(size_K); + block_dV.reset(size_V); block_dO.reset(lsize(shape_O)); block_ref_dQ.reset(lsize(shape_Q)); - block_ref_dK.reset(lsize(shape_K)); - block_ref_dV.reset(lsize(shape_V)); + block_ref_dK.reset(size_K); + block_ref_dV.reset(size_V); initialize_block(block_Q, seed + 2023, options.init_style_q); initialize_block(block_K, seed + 2022, options.init_style_k); @@ -689,7 +671,7 @@ struct BwdRunner { select<0,4>(problem_shape), stride_LSE); - if (! options.skip_reference) { + if (not options.skip_reference) { fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{}); } @@ -820,7 +802,8 @@ struct BwdRunner { flops *= static_cast(get<0>(problem_shape)); flops *= static_cast(get<1>(problem_shape)); flops *= (3 * static_cast(get<2>(problem_shape)) + 2 * static_cast(get<3>(problem_shape))); - flops *= static_cast(get<4,0>(problem_shape)); + flops *= static_cast(get<4,0,0>(problem_shape)); + flops *= static_cast(get<4,0,1>(problem_shape)); flops *= static_cast(get<4,1>(problem_shape)); double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/); example_result.tflops_tc_s = tflops_s; @@ -1001,7 +984,7 @@ int main_single(int argc, char const **args) { hw_info.sm_count = options.sm_count; } - std::cout << "###### B " << options.b << " H " << options.h << " Q " << options.q << " K " << options.k << " D " << options.d << " D_VO " << options.d_vo << " "; + std::cout << "###### B " << options.b << " H " << options.h << " H_K " << options.h_k << " Q " << options.q << " K " << options.k << " D " << options.d << " D_VO " << options.d_vo << " "; std::cout << "Backward" << " " << (options.causal ? "Causal" : "Full") << " "; std::cout << "#SM " << hw_info.sm_count << std::endl; diff --git a/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp b/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp index 5c5de849..9e4efb34 100644 --- a/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp +++ b/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp @@ -60,6 +60,38 @@ template< class Mask > class Sm100FmhaBwd { +private: + template + constexpr static auto to_bwd_shape(T shape) { + if constexpr (IsMla) { // remove GQA mode + constexpr int R = decltype(rank(shape))::value; + auto HB = get(shape); + auto rest = take<0,R-1>(shape); + return append(rest, make_shape(size<0>(HB), get<1>(HB))); + } + else { + return shape; + } + } + + template + constexpr static auto to_bwd_stride(T stride) { + if constexpr (IsMla) { // remove GQA mode + constexpr int R = decltype(rank(stride))::value; + auto HB = get(stride); + auto rest = take<0,R-1>(stride); + if constexpr (is_same_v(HB))>, _0>) { + return append(rest, make_stride(get<0,1>(HB), get<1>(HB))); + } + else { + return append(rest, make_stride(get<0,0>(HB), get<1>(HB))); + } + } + else { + return stride; + } + } + public: /// Argument structure: User API struct Arguments { @@ -67,26 +99,26 @@ public: ProblemShape problem_shape; const Element* ptr_Q; - cute::tuple> stride_Q; + cute::tuple, int>> stride_Q; const Element* ptr_K; - cute::tuple> stride_K; + cute::tuple, int>> stride_K; const Element* ptr_V; - cute::tuple> stride_V; + cute::tuple, int>> stride_V; const Element* ptr_O; - cute::tuple> stride_O; + cute::tuple, int>> stride_O; const ElementAccumulator* ptr_LSE; - cute::tuple> stride_LSE; + cute::tuple, int>> stride_LSE; const Element* ptr_dO; - cute::tuple> stride_dO; + cute::tuple, int>> stride_dO; Element* ptr_dQ; - cute::tuple> stride_dQ; + cute::tuple, int>> stride_dQ; Element* ptr_dK; - cute::tuple> stride_dK; + cute::tuple, int>> stride_dK; Element* ptr_dV; - cute::tuple> stride_dV; + cute::tuple, int>> stride_dV; ElementAccumulator softmax_scale; @@ -106,9 +138,10 @@ public: > >; + using ProblemShapeMLA = decltype(to_bwd_shape(ProblemShape{})); using OperationMla = cutlass::fmha::device::FMHA< cutlass::fmha::kernel::Sm100FmhaBwdMlaKernelTmaWarpSpecialized< - ProblemShape, Element, ElementAccumulator, TileShape, Mask + ProblemShapeMLA, Element, ElementAccumulator, TileShape, Mask > >; @@ -134,10 +167,11 @@ private: using namespace cute; auto [Q_, K, D, D_VO, HB] = args.problem_shape; auto [H, B] = HB; + auto [H_R, H_K] = H; D = cutlass::round_up(D, 8); // Alignment int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment - auto stride_sum_OdO = make_stride(_1{}, make_stride(Q, Q*H)); - auto stride_scaled_lse = make_stride(_1{}, make_stride(Q, Q*H)); + auto stride_sum_OdO = make_stride(_1{}, make_stride(make_stride(Q, Q*H_R), B == 1 ? 0 : Q*H_R*H_K)); + auto stride_scaled_lse = make_stride(_1{}, make_stride(make_stride(Q, Q*H_R), B == 1 ? 0 : Q*H_R*H_K)); auto log2_e = log2f(expf(1.0f)); return typename OperationSumOdO::Arguments { args.problem_shape, @@ -154,14 +188,15 @@ private: using namespace cute; auto [Q_, K, D, D_VO, HB] = args.problem_shape; auto [H, B] = HB; + auto [H_R, H_K] = H; D = cutlass::round_up(D, 8); // Alignment int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment - auto stride_src_dQ = make_stride(D, _1{}, make_stride(D*Q, D*Q*H)); + auto stride_src_dQ = make_stride(D, _1{}, make_stride(make_stride(D*Q, D*Q*H_R), B == 1 ? 0 : D*Q*H_R*H_K)); return typename OperationConvert::Arguments { args.problem_shape, src, stride_src_dQ, - nullptr, stride_src_dQ, - nullptr, stride_src_dQ, + nullptr, args.stride_dK, + nullptr, args.stride_dV, args.ptr_dQ, args.stride_dQ, nullptr, args.stride_dK, nullptr, args.stride_dV, @@ -171,22 +206,22 @@ private: static typename Operation::Arguments to_bwd_arguments( Arguments const& args, - ElementAccumulator* sum_OdO = nullptr, cute::tuple> const& stride_sum_OdO = {}, - ElementAccumulator* scaled_lse = nullptr, cute::tuple> const& stride_scaled_lse = {}, - ElementAccumulator* dQ_acc = nullptr, cute::tuple> const& stride_dQ = {}) { - + ElementAccumulator* sum_OdO = nullptr, cute::tuple, int>> const& stride_sum_OdO = {}, + ElementAccumulator* scaled_lse = nullptr, cute::tuple, int>> const& stride_scaled_lse = {}, + ElementAccumulator* dQ_acc = nullptr, cute::tuple, int>> const& stride_dQ = {}) { + return typename Operation::Arguments{ - args.problem_shape, - { args.ptr_Q, args.stride_Q, - args.ptr_K, args.stride_K, - args.ptr_V, args.stride_V, - args.ptr_dO, args.stride_dO, - scaled_lse, stride_scaled_lse, - sum_OdO, stride_sum_OdO, - dQ_acc, stride_dQ, + to_bwd_shape(args.problem_shape), + { args.ptr_Q, to_bwd_stride(args.stride_Q), + args.ptr_K, to_bwd_stride(args.stride_K), + args.ptr_V, to_bwd_stride(args.stride_V), + args.ptr_dO, to_bwd_stride(args.stride_dO), + scaled_lse, to_bwd_stride(stride_scaled_lse), + sum_OdO, to_bwd_stride(stride_sum_OdO), + dQ_acc, to_bwd_stride(stride_dQ), args.softmax_scale }, - { args.ptr_dK, args.stride_dK, - args.ptr_dV, args.stride_dV }, + { args.ptr_dK, to_bwd_stride(args.stride_dK), + args.ptr_dV, to_bwd_stride(args.stride_dV) }, args.hw_info }; } @@ -220,7 +255,7 @@ public: static size_t get_workspace_size(Arguments const& args) { auto [Q_, K, D, D_VO, HB] = args.problem_shape; - auto [H, B] = HB; + auto [H, B] = product_each(HB); D = cutlass::round_up(D, 8); // Alignment int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment size_t workspace_bytes = 0; @@ -240,7 +275,7 @@ public: << workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null")); auto [Q_, K, D, D_VO, HB] = args.problem_shape; - auto [H, B] = HB; + auto [H, B] = product_each(HB); D = cutlass::round_up(D, 8); // Alignment int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment ElementAccumulator* sum_OdO = reinterpret_cast(workspace_sum_OdO); @@ -269,7 +304,7 @@ public: << workspace << ", stream: " << (stream ? "non-null" : "null")); auto [Q_, K, D, D_VO, HB] = args.problem_shape; - auto [H, B] = HB; + auto [H, B] = product_each(HB); D = cutlass::round_up(D, 8); // Alignment int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment char* workspace_chr = reinterpret_cast(workspace); diff --git a/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_convert.hpp b/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_convert.hpp index f18fa3bf..7bee3d5f 100644 --- a/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_convert.hpp +++ b/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_convert.hpp @@ -46,18 +46,18 @@ struct FmhaKernelBwdConvert { ProblemShape problem_shape; const ElementAcc* ptr_src_dQ; - tuple> stride_src_dQ; + tuple, int>> stride_src_dQ; const ElementAcc* ptr_src_dK; - tuple> stride_src_dK; + tuple, int>> stride_src_dK; const ElementAcc* ptr_src_dV; - tuple> stride_src_dV; + tuple, int>> stride_src_dV; Element* ptr_dest_dQ; - tuple> stride_dest_dQ; + tuple, int>> stride_dest_dQ; Element* ptr_dest_dK; - tuple> stride_dest_dK; + tuple, int>> stride_dest_dK; Element* ptr_dest_dV; - tuple> stride_dest_dV; + tuple, int>> stride_dest_dV; ElementAcc scale = 1.0; }; @@ -104,8 +104,8 @@ struct FmhaKernelBwdConvert { template CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, Count const& count, int d_dim) { - auto ptr_src_bh = ptr_src + get<2,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y; - auto ptr_dest_bh = ptr_dest + get<2,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y; + auto ptr_src_bh = ptr_src + get<2,0,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y; + auto ptr_dest_bh = ptr_dest + get<2,0,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y; int seqlen = count; if constexpr (is_variable_length_v) { diff --git a/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp b/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp index 4a26d768..66780f35 100644 --- a/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp +++ b/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp @@ -46,18 +46,18 @@ struct FmhaKernelBwdSumOdO { ProblemShape problem_shape; const Element* ptr_O; - cute::tuple> stride_O; + cute::tuple, int>> stride_O; const Element* ptr_dO; - cute::tuple> stride_dO; + cute::tuple, int>> stride_dO; ElementAcc* ptr_sum_OdO; - cute::tuple> stride_sum_OdO; + cute::tuple, int>> stride_sum_OdO; const ElementAcc* ptr_lse = nullptr; - cute::tuple> stride_lse; + cute::tuple, int>> stride_lse; ElementAcc* ptr_scaled_lse = nullptr; - cute::tuple> stride_scaled_lse; + cute::tuple, int>> stride_scaled_lse; ElementAcc sum_odo_scale = 1.0; ElementAcc lse_scale = 1.0; @@ -104,11 +104,11 @@ struct FmhaKernelBwdSumOdO { } CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { - auto ptr_O_bh = params.ptr_O + blockIdx.y * get<2,0>(params.stride_O) + blockIdx.z * get<2,1>(params.stride_O); - auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<2,0>(params.stride_dO) + blockIdx.z * get<2,1>(params.stride_dO); - auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1,0>(params.stride_sum_OdO) + blockIdx.z * get<1,1>(params.stride_sum_OdO); - auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse); - auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse); + auto ptr_O_bh = params.ptr_O + blockIdx.y * get<2,0,0>(params.stride_O) + blockIdx.z * get<2,1>(params.stride_O); + auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<2,0,0>(params.stride_dO) + blockIdx.z * get<2,1>(params.stride_dO); + auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1,0,0>(params.stride_sum_OdO) + blockIdx.z * get<1,1>(params.stride_sum_OdO); + auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse); + auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse); auto problem_q = get<0>(params.problem_shape); int seqlen_q = problem_q; diff --git a/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp index fce00fd9..4cc42dc4 100644 --- a/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp @@ -119,13 +119,15 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { static constexpr int Alignment = 128 / sizeof_bits_v; static constexpr int kStages = 2; - using TensorStrideContiguousK = Stride>; - using TensorStrideContiguousMN = Stride<_1, int, Stride>; + using TensorStrideContiguousK = Stride, int>>; + using TensorStrideContiguousMN = Stride<_1, int, Stride, int>>; + using TensorStrideContiguousK_GQA = Stride, int>>; + using TensorStrideContiguousMN_GQA = Stride<_1, int, Stride, int>>; // compute S using CollectiveMmaKQ = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousK_GQA, Alignment, Element, TensorStrideContiguousK, Alignment, ElementAcc, Shape, @@ -137,7 +139,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { // compute dP using CollectiveMmaVDO = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousK_GQA, Alignment, Element, TensorStrideContiguousK, Alignment, ElementAcc, Shape, @@ -177,7 +179,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // somewhat arbitrary since we dump to smem, need to agree with the previous one Element, TensorStrideContiguousMN, Alignment, - Element, TensorStrideContiguousMN, Alignment, + Element, TensorStrideContiguousMN_GQA, Alignment, ElementAcc, Shape, ClusterShape, cutlass::gemm::collective::StageCount, @@ -278,15 +280,16 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem"); using TensorStride = TensorStrideContiguousK; // S D (H B) - using RowTensorStride = Stride<_1, Stride>; // S (H B) + using TensorStride_GQA = TensorStrideContiguousK_GQA; + using RowTensorStride = Stride<_1, Stride, int>>; // S (H B) struct MainloopArguments { const Element* ptr_q; TensorStride stride_q; const Element* ptr_k; - TensorStride stride_k; + TensorStride_GQA stride_k; const Element* ptr_v; - TensorStride stride_v; + TensorStride_GQA stride_v; const Element* ptr_do; TensorStride stride_do; @@ -308,7 +311,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { using TMA_DO = typename CollectiveMmaVDO::Params::TMA_B; using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{}, - make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(1, 1)), TensorStride{}), + make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(make_shape(1,1), 1)), TensorStride{}), SmemLayoutDQ{}(_, _, _0{}) )); @@ -322,9 +325,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { struct EpilogueArguments { Element* ptr_dk; - TensorStride stride_dk; + TensorStride_GQA stride_dk; Element* ptr_dv; - TensorStride stride_dv; + TensorStride_GQA stride_dv; }; struct Arguments { @@ -346,7 +349,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { static bool can_implement(Arguments const& args) { auto [Q, K, D, D_VO, HB] = args.problem_shape; auto [H, B] = HB; - if (Q <= 0 || K <= 0 || D <= 0 || D_VO <= 0 || H <= 0 || B <= 0) { + auto [H_R, H_K] = H; + if (Q <= 0 || K <= 0 || D <= 0 || D_VO <= 0 || H_R <= 0 || H_K <= 0 || B <= 0) { return false; } if (D % Alignment != 0 || D_VO % Alignment != 0) { @@ -432,7 +436,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { BlkCoord const& blk_coord, BlkOffset const& blk_offset, ProblemShape_ const& problem_shape, - int iter_index, + int iter_start, + int iter_end, int iter_count, MainloopArguments const& mainloop_args, MainloopParams const& mainloop_params, @@ -447,6 +452,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_producer_state) { auto [Q, K, D, D_VO, HB] = problem_shape; + int iter_index = iter_start; using X = Underscore; @@ -590,6 +596,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { iter_index += 1; while (iter_count > 0) { + if (iter_index == iter_end) { + iter_index = iter_start; + get<0,0>(blk_coord_batch) += 1; + } + pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); @@ -660,7 +671,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { CUTLASS_DEVICE void mma( BlkCoord const& blk_coord, ProblemShape_ const& problem_shape, - int iter_index, + int iter_start, + int iter_end, int iter_count, MainloopArguments const& mainloop_args, TensorStorage& shared_tensors, @@ -1119,7 +1131,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { BlkCoord const& blk_coord, BlkOffset const& blk_offset, ProblemShape_ const& problem_shape, - int iter_index, + int iter_start, + int iter_end, int iter_count, MainloopArguments const& mainloop_args, EpilogueArguments const& epilogue_args, @@ -1141,6 +1154,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { auto [Q, K, D, D_VO, HB] = problem_shape; + int iter_index = iter_start; // in tmem, S & P overlap // and dP and dQ overlap @@ -1224,8 +1238,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { auto tRT_cST_p = thread_r2t.partition_S(tDVcST); auto tRT_cST = split_wg(tRT_cST_p); - bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} >= get<1>(problem_shape); - int last_iter = iter_count - 1 + iter_index; + bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} > get<1>(problem_shape); CUTLASS_PRAGMA_NO_UNROLL while (iter_count > 0) { @@ -1246,7 +1259,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { bool leading_causal_masking = false; if constexpr (std::is_base_of_v, Mask>) { - leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord)); + leading_causal_masking = warp_uniform(iter_index == iter_start); } else if constexpr (std::is_base_of_v, Mask>) { int offset = get<1>(problem_shape) - get<0>(problem_shape); int kv_left = get<1>(blk_coord) * TileShapeK{}; @@ -1258,7 +1271,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { } bool trailing_residual_masking = false; if constexpr (std::is_base_of_v) { - trailing_residual_masking = warp_uniform((iter_index == last_iter) || is_residual_k); + trailing_residual_masking = warp_uniform((iter_index == iter_end - 1) || is_residual_k); } dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) { @@ -1379,6 +1392,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { iter_count -= 1; iter_index += 1; + if (iter_index == iter_end) { + iter_index = iter_start; + } } epilogue( @@ -1391,7 +1407,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { CUTLASS_DEVICE void reduce( BlkCoord const& blk_coord, ProblemShape_ const& problem_shape, - int iter_index, + int iter_start, + int iter_end, int iter_count, MainloopArguments const& mainloop_args, MainloopParams const& mainloop_params, @@ -1404,6 +1421,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { using X = Underscore; auto [Q, K, D, D_VO, HB] = problem_shape; + int iter_index = iter_start; auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; @@ -1415,7 +1433,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB)); auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step{}) - (_, _, _, _0{}, blk_coord_batch); + (_, _, _, _0{}, _); Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{})); @@ -1426,7 +1444,6 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { auto thread_t2r = tiled_t2r.get_slice(thread_idx); Tensor tTR_cDQ = thread_t2r.partition_D(cDQ); - Tensor tTR_gDQ = thread_t2r.partition_D(gDQ); Tensor tTR_sDQ = thread_t2r.partition_D(sDQ); Tensor tTR_tDQ = thread_t2r.partition_S(tDQtDQ); @@ -1472,7 +1489,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { ).arrive_and_wait(); if (lane_predicate) { // launch tma store - copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index)); + copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index,blk_coord_batch)); pipeline_reduce_tma_store.producer_commit(pipeline_reduce_tma_store_producer_state); } @@ -1481,6 +1498,10 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { iter_count -= 1; iter_index += 1; + if (iter_index == iter_end) { + iter_index = iter_start; + get<0,0>(blk_coord_batch) += 1; + } } } @@ -1683,12 +1704,12 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { pipeline_init_wait(size(ClusterShape{})); - auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(blockIdx.y, blockIdx.z)); + auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(make_coord(0, blockIdx.y), blockIdx.z)); auto [problem_shape, blk_offset] = apply_variable_length_offset( params.problem_shape, blk_coord ); - int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{}); + int iter_end = ceil_div(get<0>(problem_shape), TileShapeQ{}); int iter_start = 0; if constexpr (std::is_base_of_v, Mask>) { iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{}; @@ -1699,7 +1720,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) { return; } - iter_count -= iter_start; + int iter_count = (iter_end - iter_start) * get<4,0,0>(problem_shape); if (iter_count <= 0) { epilogue_clear( @@ -1720,6 +1741,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { blk_offset, problem_shape, iter_start, + iter_end, iter_count, params.mainloop, params.mainloop_params, @@ -1741,6 +1763,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { blk_coord, problem_shape, iter_start, + iter_end, iter_count, params.mainloop, shared_storage.tensors, @@ -1763,6 +1786,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { blk_offset, problem_shape, iter_start, + iter_end, iter_count, params.mainloop, params.epilogue, @@ -1794,6 +1818,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { blk_coord, problem_shape, iter_start, + iter_end, iter_count, params.mainloop, params.mainloop_params, @@ -1820,7 +1845,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { static dim3 get_grid_shape(Params const& params) { auto [Q, K, D, D_VO, HB] = params.problem_shape; auto [H, B] = HB; - dim3 grid(ceil_div(K, TileShapeK{}), H, B); + auto [H_R, H_K] = H; + dim3 grid(ceil_div(K, TileShapeK{}), H_K, B); return grid; } }; diff --git a/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp b/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp index 6f4a0c7a..465a9871 100644 --- a/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp +++ b/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp @@ -56,12 +56,12 @@ void __global__ fmha_bwd_reference_dQ_kernel( using namespace cutlass::fmha::collective; using Element = typename TensorO::value_type; - using ElementAccumulator = typename TensorLSE::value_type; + using ElementAcc = typename TensorLSE::value_type; extern __shared__ char mS_mem[]; - ElementAccumulator* mS = reinterpret_cast(mS_mem); + Element* mS = reinterpret_cast(mS_mem); - ElementAccumulator softmax_scale = 1.0 / sqrt(ElementAccumulator(size<2>(problem_shape_in))); + ElementAcc softmax_scale = 1.0f / sqrtf(size<2>(problem_shape_in)); for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) { auto [problem_shape, offset] = apply_variable_length_offset( @@ -79,9 +79,9 @@ void __global__ fmha_bwd_reference_dQ_kernel( auto mDQ = domain_offset(select<0,2,4>(offset), mDQ_in); for (int idx_Q = blockIdx.x; idx_Q < size<0>(problem_shape); idx_Q += gridDim.x) { for (int idx_K = threadIdx.x; idx_K < size<1>(problem_shape); idx_K += blockDim.x) { - ElementAccumulator acc_qk = 0; - ElementAccumulator acc_dov = 0; - ElementAccumulator acc_doo = 0; + ElementAcc acc_qk = 0; + ElementAcc acc_dov = 0; + ElementAcc acc_doo = 0; for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) { acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L); // acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L); @@ -94,20 +94,22 @@ void __global__ fmha_bwd_reference_dQ_kernel( } auto id = make_identity_tensor(make_shape(1, 1)); - auto frag = make_tensor(Shape<_1, _1>{}); + auto frag = make_tensor(Shape<_1, _1>{}); frag(0) = acc_qk; fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); acc_qk = frag(0); - mS[idx_K] = static_cast(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo)); + mS[idx_K] = static_cast(expf(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo)); } // for idx_K __syncthreads(); for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) { - ElementAccumulator acc = 0; + ElementAcc acc = 0; for (int idx_K = 0; idx_K < size<1>(problem_shape); idx_K++) { - acc += mS[idx_K] * ElementAccumulator(mK(idx_K, idx_D, idx_L)); + ElementAcc rK = mK(idx_K, idx_D, idx_L); + ElementAcc rDS = mS[idx_K]; + acc += rDS * rK; } mDQ(idx_Q, idx_D, idx_L) = static_cast(acc); } // for idx_D @@ -135,62 +137,83 @@ void __global__ fmha_bwd_reference_dK_kernel( using namespace cutlass::fmha::collective; using Element = typename TensorO::value_type; - using ElementAccumulator = typename TensorLSE::value_type; + using ElementAcc = typename TensorLSE::value_type; extern __shared__ char mS_mem[]; - ElementAccumulator* mS = reinterpret_cast(mS_mem); + Element* mS = reinterpret_cast(mS_mem); - ElementAccumulator softmax_scale = 1.0 / sqrt(ElementAccumulator(size<2>(problem_shape_in))); + ElementAcc softmax_scale = 1.0f / sqrtf(size<2>(problem_shape_in)); - for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) { + auto [H, B] = get<4>(problem_shape_in); + auto [H_R, H_K] = H; + + for (int idx_HB = blockIdx.y; idx_HB < H_K * B; idx_HB += gridDim.y) { + auto [idx_H_K, idx_B] = idx2crd(idx_HB, make_shape(H_K, B)); auto [problem_shape, offset] = apply_variable_length_offset( - problem_shape_in, - make_coord(_0{}, _0{}, _0{}, _0{}, idx2crd(idx_L, get<4>(problem_shape_in))) + problem_shape_in, + make_coord(_0{}, _0{}, _0{}, _0{}, make_coord(make_coord(_0{}, idx_H_K), idx_B)) ); - // problem_shape = problem_shape_in; - // offset = repeat_like(problem_shape_in, _0{}); - auto mQ = domain_offset(select<0,2,4>(offset), mQ_in); - auto mK = domain_offset(select<1,2,4>(offset), mK_in); - auto mV = domain_offset(select<1,3,4>(offset), mV_in); - auto mO = domain_offset(select<0,3,4>(offset), mO_in); - auto mLSE = domain_offset(select<0,4>(offset), mLSE_in); - auto mDO = domain_offset(select<0,3,4>(offset), mDO_in); - auto mDK = domain_offset(select<1,2,4>(offset), mDK_in); - for (int idx_K = blockIdx.x; idx_K < size<1>(problem_shape); idx_K += gridDim.x) { - for (int idx_Q = threadIdx.x; idx_Q < size<0>(problem_shape); idx_Q += blockDim.x) { - ElementAccumulator acc_qk = 0; - ElementAccumulator acc_dov = 0; - ElementAccumulator acc_doo = 0; - for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) { - acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L); - // acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L); - // acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L); - } // for idx_D0 + auto [Q, K, D, D_VO, HB] = problem_shape; + auto [offset_Q, offset_K, offset_D, offset_D_VO, offset_HB] = offset; - for (int idx_D1 = 0; idx_D1 < size<3>(problem_shape); idx_D1++) { - acc_dov += mDO(idx_Q, idx_D1, idx_L) * mV(idx_K, idx_D1, idx_L); - acc_doo += mDO(idx_Q, idx_D1, idx_L) * mO(idx_Q, idx_D1, idx_L); + auto mQ = domain_offset(make_coord(offset_Q, offset_D, offset_HB), mQ_in); + auto mK = domain_offset(make_coord(offset_K, offset_D, offset_HB), mK_in); + auto mV = domain_offset(make_coord(offset_K, offset_D_VO, offset_HB), mV_in); + auto mO = domain_offset(make_coord(offset_Q, offset_D_VO, offset_HB), mO_in); + auto mLSE = domain_offset(make_coord(offset_Q, offset_HB), mLSE_in); + auto mDO = domain_offset(make_coord(offset_Q, offset_D_VO, offset_HB), mDO_in); + auto mDK = domain_offset(make_coord(offset_K, offset_D, offset_HB), mDK_in); + + for (int idx_K = blockIdx.x; idx_K < K; idx_K += gridDim.x) { + ElementAcc acc_dk = 0; + for (int idx_H_R = 0; idx_H_R < H_R; idx_H_R++) { + auto coord_HB = make_coord(make_coord(idx_H_R, idx_H_K), idx_B); + for (int idx_Q = threadIdx.x; idx_Q < Q; idx_Q += blockDim.x) { + ElementAcc acc_qk = 0; + ElementAcc acc_dov = 0; + ElementAcc acc_doo = 0; + for (int idx_D0 = 0; idx_D0 < D; idx_D0++) { + ElementAcc rQ = mQ(idx_Q, idx_D0, coord_HB); + ElementAcc rK = mK(idx_K, idx_D0, coord_HB); + acc_qk += rQ * rK; + } // for idx_D0 + + for (int idx_D1 = 0; idx_D1 < D_VO; idx_D1++) { + ElementAcc rDO = mDO(idx_Q, idx_D1, coord_HB); + ElementAcc rV = mV(idx_K, idx_D1, coord_HB); + ElementAcc rO = mO(idx_Q, idx_D1, coord_HB); + acc_dov += rDO * rV; + acc_doo += rDO * rO ; + } + auto id = make_identity_tensor(make_shape(1, 1)); + auto frag = make_tensor(Shape<_1, _1>{}); + frag(0) = acc_qk; + fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); + acc_qk = frag(0); + + mS[idx_Q] = static_cast(expf(softmax_scale * acc_qk - mLSE(idx_Q, coord_HB)) * softmax_scale * (acc_dov - acc_doo)); + } // for idx_Q + + __syncthreads(); + + int idx_D = threadIdx.x; + if (idx_D < D) { + for (int idx_Q = 0; idx_Q < Q; idx_Q++) { + ElementAcc rQ = mQ(idx_Q, idx_D, coord_HB); + ElementAcc rDS = mS[idx_Q]; + acc_dk += rDS * rQ; + } } - auto id = make_identity_tensor(make_shape(1, 1)); - auto frag = make_tensor(Shape<_1, _1>{}); - frag(0) = acc_qk; - fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); - acc_qk = frag(0); + __syncthreads(); + } // for idx_H_R - mS[idx_Q] = static_cast(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo)); - } // for idx_Q - - __syncthreads(); - - for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) { - ElementAccumulator acc = 0; - for (int idx_Q = 0; idx_Q < size<0>(problem_shape); idx_Q++) { - acc += mS[idx_Q] * ElementAccumulator(mQ(idx_Q, idx_D, idx_L)); - } - mDK(idx_K, idx_D, idx_L) = static_cast(acc); - } // for idx_D + int idx_D = threadIdx.x; + if (idx_D < D) { + auto coord_HB = make_coord(make_coord(0, idx_H_K), idx_B); + mDK(idx_K, idx_D, coord_HB) = static_cast(acc_dk); + } } // for idx_K - } // for idx_L + } // for idx_HB } ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -216,54 +239,71 @@ void __global__ fmha_bwd_reference_dV_kernel( using ElementAcc = typename TensorLSE::value_type; extern __shared__ char mS_mem[]; - ElementAcc* mS = reinterpret_cast(mS_mem); + Element* mS = reinterpret_cast(mS_mem); - ElementAcc softmax_scale = 1.0 / sqrt(ElementAcc(size<2>(problem_shape_in))); + ElementAcc softmax_scale = 1.0f / sqrtf(size<2>(problem_shape_in)); - for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) { + auto [H, B] = get<4>(problem_shape_in); + auto [H_R, H_K] = H; + + for (int idx_HB = blockIdx.y; idx_HB < H_K * B; idx_HB += gridDim.y) { + auto [idx_H_K, idx_B] = idx2crd(idx_HB, make_shape(H_K, B)); auto [problem_shape, offset] = apply_variable_length_offset( - problem_shape_in, - make_coord(_0{}, _0{}, _0{}, _0{}, idx2crd(idx_L, get<4>(problem_shape_in))) + problem_shape_in, + make_coord(_0{}, _0{}, _0{}, _0{}, make_coord(make_coord(_0{}, idx_H_K), idx_B)) ); - // problem_shape = problem_shape_in; - // offset = repeat_like(problem_shape_in, _0{}); - auto mQ = domain_offset(select<0,2,4>(offset), mQ_in); - auto mK = domain_offset(select<1,2,4>(offset), mK_in); - auto mV = domain_offset(select<1,3,4>(offset), mV_in); - auto mO = domain_offset(select<0,3,4>(offset), mO_in); - auto mLSE = domain_offset(select<0,4>(offset), mLSE_in); - auto mDO = domain_offset(select<0,3,4>(offset), mDO_in); - auto mDV = domain_offset(select<1,3,4>(offset), mDV_in); - for (int idx_K = blockIdx.x; idx_K < size<1>(problem_shape); idx_K += gridDim.x) { - for (int idx_Q = threadIdx.x; idx_Q < size<0>(problem_shape); idx_Q += blockDim.x) { - ElementAcc acc_qk = 0; + auto [Q, K, D, D_VO, HB] = problem_shape; + auto [offset_Q, offset_K, offset_D, offset_D_VO, offset_HB] = offset; - for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) { - ElementAcc rQ = mQ(idx_Q, idx_D0, idx_L); - ElementAcc rK = mK(idx_K, idx_D0, idx_L); - acc_qk += rQ * rK; - } // for idx_D0 + auto mQ = domain_offset(make_coord(offset_Q, offset_D, offset_HB), mQ_in); + auto mK = domain_offset(make_coord(offset_K, offset_D, offset_HB), mK_in); + auto mV = domain_offset(make_coord(offset_K, offset_D_VO, offset_HB), mV_in); + auto mO = domain_offset(make_coord(offset_Q, offset_D_VO, offset_HB), mO_in); + auto mLSE = domain_offset(make_coord(offset_Q, offset_HB), mLSE_in); + auto mDO = domain_offset(make_coord(offset_Q, offset_D_VO, offset_HB), mDO_in); + auto mDV = domain_offset(make_coord(offset_K, offset_D_VO, offset_HB), mDV_in); - auto id = make_identity_tensor(make_shape(1, 1)); - auto frag = make_tensor(Shape<_1, _1>{}); - frag(0) = acc_qk; - fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); - acc_qk = frag(0); + for (int idx_K = blockIdx.x; idx_K < K; idx_K += gridDim.x) { + ElementAcc acc_dv = 0; + for (int idx_H_R = 0; idx_H_R < H_R; idx_H_R++) { + auto coord_HB = make_coord(make_coord(idx_H_R, idx_H_K), idx_B); + for (int idx_Q = threadIdx.x; idx_Q < Q; idx_Q += blockDim.x) { + ElementAcc acc_qk = 0; - mS[idx_Q] = expf(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)); - } // for idx_Q + for (int idx_D0 = 0; idx_D0 < D; idx_D0++) { + ElementAcc rQ = mQ(idx_Q, idx_D0, coord_HB); + ElementAcc rK = mK(idx_K, idx_D0, coord_HB); + acc_qk += rQ * rK; + } // for idx_D0 - __syncthreads(); + auto id = make_identity_tensor(make_shape(1, 1)); + auto frag = make_tensor(Shape<_1, _1>{}); + frag(0) = acc_qk; + fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); + acc_qk = frag(0); - for (int idx_D = threadIdx.x; idx_D < size<3>(problem_shape); idx_D += blockDim.x) { - ElementAcc acc = 0; - for (int idx_Q = 0; idx_Q < size<0>(problem_shape); idx_Q++) { - ElementAcc rS = static_cast(mS[idx_Q]); - ElementAcc rDO = mDO(idx_Q, idx_D, idx_L); - acc += rS * rDO; - } - mDV(idx_K, idx_D, idx_L) = static_cast(acc); - } // for idx_D + mS[idx_Q] = static_cast(expf(softmax_scale * acc_qk - mLSE(idx_Q, coord_HB))); + } // for idx_Q + + __syncthreads(); + + int idx_D_VO = threadIdx.x; + if (idx_D_VO < D_VO) { + for (int idx_Q = 0; idx_Q < Q; idx_Q++) { + ElementAcc rDO = mDO(idx_Q, idx_D_VO, coord_HB); + ElementAcc rP = mS[idx_Q]; + acc_dv += rP * rDO; + } + } // for idx_D + + __syncthreads(); + } // for idx_H_R + + int idx_D_VO = threadIdx.x; + if (idx_D_VO < D_VO) { + auto coord_HB = make_coord(make_coord(0, idx_H_K), idx_B); + mDV(idx_K, idx_D_VO, coord_HB) = static_cast(acc_dv); + } } // for idx_K } // for idx_L } @@ -288,7 +328,7 @@ void fmha_bwd_reference_dQ( dim3 grid(size<0>(mDQ), size<2>(mDQ), 1); dim3 block(256); - int shared_mem = size<0>(mK) * sizeof(typename TensorLSE::value_type); + int shared_mem = size<0>(mK) * sizeof(typename TensorDQ::value_type); fmha_bwd_reference_dQ_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion); } @@ -310,9 +350,12 @@ void fmha_bwd_reference_dK( using namespace cute; - dim3 grid(size<0>(mDK), size<2>(mDK), 1); - dim3 block(256); - int shared_mem = size<0>(mDO) * sizeof(typename TensorLSE::value_type); + auto [K, D, HB] = mDK.shape(); + auto [H, B] = HB; + auto [H_R, H_K] = H; + dim3 grid(K, H_K * B, 1); + dim3 block(std::max(D, 256)); + int shared_mem = size<0>(mDO) * sizeof(typename TensorDK::value_type); fmha_bwd_reference_dK_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion); } @@ -334,9 +377,12 @@ void fmha_bwd_reference_dV( using namespace cute; - dim3 grid(size<0>(mDV), size<2>(mDV), 1); - dim3 block(256); - int shared_mem = size<0>(mDO) * sizeof(typename TensorLSE::value_type); + auto [K, D_VO, HB] = mDV.shape(); + auto [H, B] = HB; + auto [H_R, H_K] = H; + dim3 grid(K, H_K * B, 1); + dim3 block(std::max(D_VO, 256)); + int shared_mem = size<0>(mDO) * sizeof(typename TensorDV::value_type); fmha_bwd_reference_dV_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion); }