From 2e2af190bd4e57b26d2be26fec8a6b34bc979ee5 Mon Sep 17 00:00:00 2001 From: Manish Gupta Date: Thu, 5 Jun 2025 20:14:57 -0700 Subject: [PATCH] Revert "[ex77] fix mla split; add fwd lse; add bwd varlen (#2366)" (#2370) This reverts commit f12b1d75c904c05b10650809af39511080a06ff3. --- .../77_blackwell_fmha/77_blackwell_fmha.cu | 224 ++++-------- .../77_blackwell_fmha_bwd.cu | 165 ++------- .../77_blackwell_fmha_gen.cu | 6 - .../77_blackwell_fmha/77_blackwell_mla.cu | 10 +- examples/77_blackwell_fmha/CMakeLists.txt | 38 +-- .../collective/fmha_fusion.hpp | 121 +------ ..._fmha_fwd_epilogue_tma_warpspecialized.hpp | 14 +- ..._fmha_fwd_mainloop_tma_warpspecialized.hpp | 55 +-- .../device/fmha_device_bwd.hpp | 44 +-- .../kernel/fmha_kernel_bwd_convert.hpp | 29 +- .../kernel/fmha_kernel_bwd_sum_OdO.hpp | 22 +- .../77_blackwell_fmha/kernel/fmha_options.hpp | 2 +- .../kernel/fmha_tile_scheduler.hpp | 4 +- ...00_fmha_bwd_kernel_tma_warpspecialized.hpp | 320 ++++++------------ ...00_fmha_fwd_kernel_tma_warpspecialized.hpp | 5 +- .../kernel/sm100_fmha_mla_reduction.hpp | 4 +- .../reference/fmha_bwd_reference.hpp | 105 ++---- .../reference/fmha_fwd_reference.hpp | 2 +- .../reference/reference_abs_error.hpp | 2 - 19 files changed, 326 insertions(+), 846 deletions(-) diff --git a/examples/77_blackwell_fmha/77_blackwell_fmha.cu b/examples/77_blackwell_fmha/77_blackwell_fmha.cu index 1753df7b..c8792122 100644 --- a/examples/77_blackwell_fmha/77_blackwell_fmha.cu +++ b/examples/77_blackwell_fmha/77_blackwell_fmha.cu @@ -117,17 +117,15 @@ struct Options { int q = 256; int k = 256; int d = 128; - int warmup_iterations = 1; int iterations = 3; - int tensor_ring_buffers = 1; bool verify = false; bool verbose = false; bool causal = false; bool residual = false; bool varlen = false; - bool persistent = false; int sm_count = 0; + std::string kernel_filter; InitStyle init_style_q = InitStyle::kRandom; @@ -191,15 +189,10 @@ struct Options { if (b == -1) b = 16384 / k; if (b == 0) b = 1; - cmd.get_cmd_line_argument("warmup_iterations", warmup_iterations, defaults.warmup_iterations); cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations); - cmd.get_cmd_line_argument("tensor_ring_buffers", tensor_ring_buffers, defaults.tensor_ring_buffers); - verify = cmd.check_cmd_line_flag("verify"); verbose = cmd.check_cmd_line_flag("verbose"); varlen = cmd.check_cmd_line_flag("varlen"); - persistent = cmd.check_cmd_line_flag("persistent"); - std::string mask; cmd.get_cmd_line_argument("mask", mask, ""); if (mask == "no" || mask == "") { @@ -217,7 +210,7 @@ struct Options { causal = false; } cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count); - + get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q); get_init_style_argument(cmd, "init-style", init_style_k, defaults.init_style_q); get_init_style_argument(cmd, "init-style", init_style_v, defaults.init_style_q); @@ -242,13 +235,10 @@ struct Options { << " --q= Sets the Q extent\n" << " --k= Sets the K extent\n" << " --d= Sets the D extentn" - << " --tensor_ring_buffers= Sets the number of tensor ring buffers\n" - << " --warmup_iterations= Sets the warmup iterations\n" << " --iterations= Benchmarking iterations\n" << " --verify Verify results\n" << " --verbose Print smem and execution time per kernel\n" << " --mask= Enables masking\n" - << " --persistent Enables persistent scheduler\n" << " --varlen Enables variable sequence length\n" << " B*Q and B*K become the total sequence length\n" << " and are split B-ways, alternatingly +10% and -10%\n" @@ -389,55 +379,40 @@ struct FwdRunner { StrideLSE stride_LSE; uint64_t seed = 0; - struct DeviceBuffer { - DeviceAllocation block_Q; - DeviceAllocation block_K; - DeviceAllocation block_V; - DeviceAllocation block_O; - DeviceAllocation block_LSE; - DeviceAllocation block_ref_O; - DeviceAllocation block_ref_LSE; - DeviceAllocation device_cumulative_seqlen_q; - DeviceAllocation device_cumulative_seqlen_kv; - - DeviceBuffer() = default; - DeviceBuffer(const DeviceBuffer&) = delete; - DeviceBuffer& operator=(const DeviceBuffer&) = delete; - - size_t get_storage_size() const { - return block_Q.get_storage_size() + block_K.get_storage_size() + block_V.get_storage_size() - + block_O.get_storage_size() + block_LSE.get_storage_size() + block_ref_O.get_storage_size() - + block_ref_LSE.get_storage_size() + device_cumulative_seqlen_q.get_storage_size() - + device_cumulative_seqlen_kv.get_storage_size(); - } - }; - - std::vector> buffers; + DeviceAllocation block_Q; + DeviceAllocation block_K; + DeviceAllocation block_V; + DeviceAllocation block_O; + DeviceAllocation block_LSE; + DeviceAllocation block_ref_O; + DeviceAllocation block_ref_LSE; std::vector cumulative_seqlen_q; std::vector cumulative_seqlen_kv; + DeviceAllocation device_cumulative_seqlen_q; + DeviceAllocation device_cumulative_seqlen_kv; // // Methods // - bool verify(const ProblemShapeType& problem_shape, DeviceBuffer& buffer) { - Tensor mQ = make_tensor(make_gmem_ptr(buffer.block_Q.get()), + bool verify(const ProblemShapeType& problem_shape) { + Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()), select<0,2,3>(problem_shape), stride_Q); - Tensor mK = make_tensor(make_gmem_ptr(buffer.block_K.get()), + Tensor mK = make_tensor(make_gmem_ptr(block_K.get()), select<1,2,3>(problem_shape), stride_K); - Tensor mV = make_tensor(make_gmem_ptr(buffer.block_V.get()), + Tensor mV = make_tensor(make_gmem_ptr(block_V.get()), select<1,2,3>(problem_shape), stride_V); - Tensor mO = make_tensor(make_gmem_ptr(buffer.block_ref_O.get()), + Tensor mO = make_tensor(make_gmem_ptr(block_ref_O.get()), select<0,2,3>(problem_shape), stride_O); - Tensor mLSE = make_tensor(make_gmem_ptr(buffer.block_ref_LSE.get()), + Tensor mLSE = make_tensor(make_gmem_ptr(block_ref_LSE.get()), select<0,3>(problem_shape), stride_LSE); @@ -456,7 +431,7 @@ struct FwdRunner { // Check if output from CUTLASS kernel and reference kernel are equal or not double max_diff = 0; double mean_diff = 0; - reference_abs_diff(buffer.block_O, buffer.block_ref_O, max_diff, mean_diff); + reference_abs_diff(block_O, block_ref_O, max_diff, mean_diff); bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); if (! passed_O) { @@ -464,13 +439,14 @@ struct FwdRunner { << " mean " << mean_diff << std::endl; } - reference_abs_diff(buffer.block_LSE, buffer.block_ref_LSE, max_diff, mean_diff); + // reference_abs_diff(block_LSE, block_ref_LSE, max_diff, mean_diff); - bool passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); - if ( ! passed_LSE) { - std::cerr << "failed LSE: max diff " << max_diff - << " mean " << mean_diff << std::endl; - } + bool passed_LSE = true; // future work + // bool passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); + // if ( ! passed_LSE) { + // std::cerr << "failed LSE: max diff " << max_diff + // << " mean " << mean_diff << std::endl; + // } return passed_O && passed_LSE; } @@ -583,71 +559,50 @@ struct FwdRunner { get<1,1>(stride_LSE) = 0; } - auto buffer_init_fn = [&](auto& buffer) { - buffer.block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0); - buffer.block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0); - buffer.block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0); - buffer.block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0); - buffer.block_LSE.reset(size(shape_LSE)); + block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0); + block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0); + block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0); + block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0); + block_LSE.reset(size(shape_LSE)); + block_ref_O.reset(size(shape_QO)); + block_ref_LSE.reset(size(shape_LSE)); - initialize_block(buffer.block_Q, seed + 2023, options.init_style_q); - initialize_block(buffer.block_K, seed + 2022, options.init_style_k); - initialize_block(buffer.block_V, seed + 2021, options.init_style_v); + initialize_block(block_Q, seed + 2023, options.init_style_q); + initialize_block(block_K, seed + 2022, options.init_style_k); + initialize_block(block_V, seed + 2021, options.init_style_v); - if ( ! cumulative_seqlen_q.empty()) { - buffer.device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size()); - buffer.device_cumulative_seqlen_q.copy_from_host( - cumulative_seqlen_q.data(), cumulative_seqlen_q.size()); - } - if ( ! cumulative_seqlen_kv.empty()) { - buffer.device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size()); - buffer.device_cumulative_seqlen_kv.copy_from_host( - cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size()); - } - }; - - buffers.push_back(std::make_unique()); - buffer_init_fn(*buffers.back()); - - int tensor_ring_buffers = options.tensor_ring_buffers; - - for (int i = 1; i < tensor_ring_buffers; i++) { - buffers.push_back(std::make_unique()); - buffer_init_fn(*buffers.back()); + if ( ! cumulative_seqlen_q.empty()) { + device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size()); + device_cumulative_seqlen_q.copy_from_host( + cumulative_seqlen_q.data(), cumulative_seqlen_q.size()); + } + if ( ! cumulative_seqlen_kv.empty()) { + device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size()); + device_cumulative_seqlen_kv.copy_from_host( + cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size()); } if constexpr (kIsVarlen) { - get<0>(problem_shape).cumulative_length = buffers[0]->device_cumulative_seqlen_q.get(); - get<1>(problem_shape).cumulative_length = buffers[0]->device_cumulative_seqlen_kv.get(); + get<0>(problem_shape).cumulative_length = device_cumulative_seqlen_q.get(); + get<1>(problem_shape).cumulative_length = device_cumulative_seqlen_kv.get(); } return problem_shape; } - auto get_arguments(const ProblemShapeType& problem_shape, const cutlass::KernelHardwareInfo& hw_info, int buffer_index) { - auto problem_shape_ = problem_shape; - if constexpr (kIsVarlen) { - get<0>(problem_shape_).cumulative_length = buffers[buffer_index]->device_cumulative_seqlen_q.get(); - get<1>(problem_shape_).cumulative_length = buffers[buffer_index]->device_cumulative_seqlen_kv.get(); - } - typename Operation::Arguments arguments{ - problem_shape_, - { buffers[buffer_index]->block_Q.get(), stride_Q, - buffers[buffer_index]->block_K.get(), stride_K, - buffers[buffer_index]->block_V.get(), stride_V }, - { buffers[buffer_index]->block_O.get(), stride_O, - buffers[buffer_index]->block_LSE.get(), stride_LSE }, - hw_info - }; - return arguments; - } - ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { ProblemShapeType problem_shape = initialize(options); - int buffer_index = 0; - typename Operation::Arguments arguments = get_arguments(problem_shape, hw_info, buffer_index); + typename Operation::Arguments arguments{ + problem_shape, + { block_Q.get(), stride_Q, + block_K.get(), stride_K, + block_V.get(), stride_V }, + { block_O.get(), stride_O, + block_LSE.get(), stride_LSE }, + hw_info + }; Operation op; @@ -675,21 +630,11 @@ struct FwdRunner { } // Run - for (int i = 0; i < options.warmup_iterations; i++) { - status = op.run(); - if (status != cutlass::Status::kSuccess) { - std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " - << cudaGetErrorString(cudaGetLastError()) << std::endl; - return example_result; - } - buffer_index = (buffer_index + 1) % buffers.size(); - arguments = get_arguments(problem_shape, hw_info, buffer_index); - status = op.update(arguments, workspace.get()); - if (status != cutlass::Status::kSuccess) { - std::cerr << "Failed to update the CUTLASS kernel's parameters. Last CUDA error is: " - << std::endl; - return example_result; - } + status = op.run(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; } cudaError_t result = cudaDeviceSynchronize(); @@ -727,14 +672,6 @@ struct FwdRunner { << cudaGetErrorString(cudaGetLastError()) << std::endl; return example_result; } - buffer_index = (buffer_index + 1) % buffers.size(); - arguments = get_arguments(problem_shape, hw_info, buffer_index); - status = op.update(arguments, workspace.get()); - if (status != cutlass::Status::kSuccess) { - std::cerr << "Failed to update the CUTLASS kernel's parameters. Last CUDA error is: " - << std::endl; - return example_result; - } } // @@ -797,10 +734,10 @@ struct FwdRunner { // Verify that the result is correct bool passed = true; if (options.verify) { - passed = verify(problem_shape, *buffers[0]); + passed = verify(problem_shape); if (passed) example_result.verified = true; } - + if (!passed) { std::cerr << "Reference check failed" << std::endl; return example_result; @@ -852,14 +789,10 @@ void run_fwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareIn using HeadDim = _128; - if (options.persistent) { - // Persistent Tile Scheduler - run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option{}); - } - else { - // Individual Tile Scheduler - run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option{}); - } + // Persistent Tile Scheduler + run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option{}); + // Individual Tile Scheduler + run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option{}); } /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -885,14 +818,10 @@ void run_fwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInf using HeadDim = _64; - if (options.persistent) { - // Persistent Tile Scheduler - run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option{}); - } - else { - // Individual Tile Scheduler - run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option{}); - } + // Persistent Tile Scheduler + run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option{}); + // Individual Tile Scheduler + run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option{}); } @@ -916,14 +845,10 @@ void run_fwd_32(Mask fusion, Options const & options, cutlass::KernelHardwareInf using HeadDim = _32; #ifdef FP8 - if (options.persistent) { - // Persistent Tile Scheduler - run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option{}); - } - else { - // Individual Tile Scheduler - run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option{}); - } + // Persistent Tile Scheduler + run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option{}); + // Individual Tile Scheduler + run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option{}); #endif } @@ -988,7 +913,6 @@ int main_single(int argc, char const **args) { hw_info.sm_count = options.sm_count; } - std::cout << "###### B " << options.b << " H " << options.h << " H_K " << options.h_k << " Q " << options.q << " K " << options.k << " D " << options.d << " "; std::cout << "Forward" << " " << (options.causal ? "Causal" : (options.residual ? "Residual" : "None")) << " "; std::cout << "#SM " << hw_info.sm_count << std::endl; diff --git a/examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu b/examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu index b8f998ca..1c02a29e 100644 --- a/examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu +++ b/examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu @@ -120,8 +120,6 @@ struct Options { bool verbose = false; bool causal = false; - bool residual = false; - bool varlen = false; int sm_count = 0; std::string kernel_filter; @@ -192,21 +190,14 @@ struct Options { cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations); verify = cmd.check_cmd_line_flag("verify"); verbose = cmd.check_cmd_line_flag("verbose"); - varlen = cmd.check_cmd_line_flag("varlen"); std::string mask; cmd.get_cmd_line_argument("mask", mask, ""); if (mask == "causal") { causal = true; } - else if (mask == "residual") { - residual = true; - } else { causal = defaults.causal; } - if (varlen) { - residual = true; - } skip_reference = cmd.check_cmd_line_flag("skip-reference"); cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count); @@ -239,12 +230,7 @@ struct Options { << " --iterations= Benchmarking iterations\n" << " --verify Verify results\n" << " --verbose Print smem and execution time per kernel\n" - << " --mask= Enables masking\n" - << " --varlen Enables variable sequence length\n" - << " B*Q and B*K become the total sequence length\n" - << " and are split B-ways, alternatingly +10% and -10%\n" - << " with the last batch sized to make it fit\n" - << " implies at least residual masking for correctness\n" + << " --mask= Enables masking\n" << " --sm-count Sets SM count rather than querying it\n" << " --kernel-filter= Sets regexp to match kernel against\n" << "\n"; @@ -321,7 +307,6 @@ struct ExampleResult { /////////////////////////////////////////////////////////////////////////////////////////////////// template< - bool kIsVarlen, class TileShape, class DispatchPolicy, class ActiveMask, @@ -337,11 +322,9 @@ struct BwdRunner { using ElementAccumulator = float; // Q K D (H B) - using ProblemShape = std::conditional_t< - kIsVarlen, - cute::tuple>, - cute::tuple> - >; + using ProblemShapeType = cute::tuple>; + + using Operation = cutlass::fmha::device::Sm100FmhaBwd; using TensorStride = Stride>; // Seq D (H B) using StrideQ = TensorStride; @@ -380,9 +363,6 @@ struct BwdRunner { DeviceAllocation block_O; DeviceAllocation block_LSE; - DeviceAllocation block_cumulative_seqlen_q; - DeviceAllocation block_cumulative_seqlen_kv; - DeviceAllocation block_dQ; DeviceAllocation block_dK; DeviceAllocation block_dV; @@ -395,7 +375,7 @@ struct BwdRunner { // // Methods // - bool verify(const ProblemShape& problem_shape) { + bool verify(const ProblemShapeType& problem_shape) { auto [Q, K, D, HB] = problem_shape; auto [H, B] = HB; @@ -479,85 +459,22 @@ struct BwdRunner { return passed_dQ && passed_dK && passed_dV; } - auto initialize_problem_shape(Options const& options) { - if constexpr (kIsVarlen) { - int num_batches = options.b; - - // generate Q as --b times - // gaussian (--Q, --Q / 2) sampled positive - // track cumulative - std::mt19937 rng(0x202305151552ull); - std::normal_distribution dist_q(options.q, options.q / 2); - std::normal_distribution dist_kv(options.k, options.k / 2); - - auto generate_positive_int = [](auto& dist, auto& gen) { - // "0" is a valid value we test here - return std::max(0, static_cast(dist(gen))); - }; - - std::vector cumulative_seqlen_q = {0}; - std::vector cumulative_seqlen_kv = {0}; - - int total_seqlen_q = 0; - int total_seqlen_kv = 0; - int max_seqlen_q = 0; - int max_seqlen_kv = 0; - - const bool kVarlenSame = false; - for (int i = 0; i < num_batches; i++) { - int seqlen_q = kVarlenSame ? options.q : generate_positive_int(dist_q, rng); - int seqlen_kv = kVarlenSame ? options.k : generate_positive_int(dist_kv, rng); - - total_seqlen_q += seqlen_q; - total_seqlen_kv += seqlen_kv; - - max_seqlen_q = std::max(max_seqlen_q, seqlen_q); - max_seqlen_kv = std::max(max_seqlen_kv, seqlen_kv); - - cumulative_seqlen_q.push_back(cumulative_seqlen_q.back() + seqlen_q); - cumulative_seqlen_kv.push_back(cumulative_seqlen_kv.back() + seqlen_kv); - } - - block_cumulative_seqlen_q.reset(cumulative_seqlen_q.size()); - block_cumulative_seqlen_q.copy_from_host(cumulative_seqlen_q.data(), cumulative_seqlen_q.size()); - block_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size()); - block_cumulative_seqlen_kv.copy_from_host(cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size()); - - ProblemShape problem_shape{ - {max_seqlen_q, block_cumulative_seqlen_q.get(), total_seqlen_q}, - {max_seqlen_kv, block_cumulative_seqlen_kv.get(), total_seqlen_kv}, - options.d, {options.h, options.b} - }; - auto tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, options.d, make_shape(options.h, 1)); - - return cute::make_tuple(problem_shape, tensor_shape); - } - else { - ProblemShape problem_shape{options.q, options.k, options.d, {options.h, options.b}}; - return cute::make_tuple(problem_shape, problem_shape); - } - } - /// Initialize operands to be used in the GEMM and reference GEMM - ProblemShape initialize(Options const& options) { - auto [problem_shape, tensor_shape] = initialize_problem_shape(options); - auto [Q, K, D, HB] = tensor_shape; + void initialize(const ProblemShapeType& problem_shape, Options const& options) { + auto [Q, K, D, HB] = problem_shape; auto [H, B] = HB; D = cutlass::round_up(D, 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment - // for varlen, Q == total_Q, K == total_K, B = 1 - // but in problem_shape, they've got to be max_Q/max_K, and B = B - - auto shape_QO = make_shape(Q, D, make_shape(H, B)); - auto shape_KV = make_shape(K, D, make_shape(H, B)); - auto shape_LSE = make_shape(Q, make_shape(H, B)); - - stride_Q = make_stride(D, _1{}, make_stride(D*Q, B == 1 ? 0 : D*Q*H)); - stride_K = make_stride(D, _1{}, make_stride(D*K, B == 1 ? 0 : D*K*H)); - stride_LSE = make_stride(_1{}, make_stride(Q, B == 1 ? 0 : Q*H)); + auto shape_QO = select<0,2,3>(problem_shape); + auto shape_KV = select<1,2,3>(problem_shape); + auto shape_LSE = select<0,3>(problem_shape); + stride_Q = make_stride(D, _1{}, make_stride(D*Q, D*Q*H)); + stride_K = make_stride(D, _1{}, make_stride(D*K, D*K*H)); stride_V = stride_K; stride_O = stride_Q; + stride_LSE = make_stride(_1{}, make_stride(Q, Q*H)); stride_dQ = stride_Q; stride_dK = stride_K; @@ -588,13 +505,6 @@ struct BwdRunner { initialize_block(block_V, seed + 2021, options.init_style_v); initialize_block(block_dO, seed + 2020, options.init_style_do); - initialize_block(block_dQ, seed + 2030, InitStyle::kOne); - initialize_block(block_dK, seed + 2031, InitStyle::kOne); - initialize_block(block_dV, seed + 2032, InitStyle::kOne); - initialize_block(block_ref_dQ, seed + 2033); - initialize_block(block_ref_dK, seed + 2034); - initialize_block(block_ref_dV, seed + 2035); - Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()), select<0,2,3>(problem_shape), stride_Q); @@ -618,19 +528,15 @@ struct BwdRunner { if (! options.skip_reference) { fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{}); } - - return problem_shape; } ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { - auto problem_shape = initialize(options); + auto problem_shape = make_shape(options.q, options.k, options.d, make_shape(options.h, options.b)); + + initialize(problem_shape, options); ElementAccumulator softmax_scale = 1.0f / sqrtf(options.d); - ExampleResult example_result; - - using Operation = cutlass::fmha::device::Sm100FmhaBwd; - typename Operation::Arguments arguments{ problem_shape, block_Q.get(), stride_Q, @@ -648,6 +554,8 @@ struct BwdRunner { Operation op; + ExampleResult example_result; + example_result.smem_size = Operation::Kernel::SharedStorageSize; size_t workspace_size = 0; @@ -742,7 +650,7 @@ struct BwdRunner { runtime_ms /= static_cast(options.iterations); - double flops = 10.0 * (std::is_same_v ? 0.5 : 1.0); + double flops = 10.0 * (std::is_same_v ? 0.5 : 1.0); flops *= static_cast(get<0>(problem_shape)); flops *= static_cast(get<1>(problem_shape)); flops *= static_cast(get<2>(problem_shape)); @@ -798,28 +706,14 @@ void print_result(const std::string& description, ExampleResult result, bool ver struct KernelCoop {}; -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -auto dispatch_bool(bool value, Fn fn) { - if (value) { - return fn(std::true_type{}); - } - else { - return fn(std::false_type{}); - } -} - ////////////////////////////////////////////////////////////////////////////////////////////////// template void run_bwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) { auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) { - dispatch_bool(options.varlen, [&](auto is_varlen) { - BwdRunner runner; - auto result = runner.run(options, hw_info); - print_result(name, result, options.verbose); - }); + BwdRunner runner; + auto result = runner.run(options, hw_info); + print_result(name, result, options.verbose); }; using HeadDim = _64; @@ -832,11 +726,9 @@ void run_bwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInf template void run_bwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) { auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) { - dispatch_bool(options.varlen, [&](auto is_varlen) { - BwdRunner runner; - auto result = runner.run(options, hw_info); - print_result(name, result, options.verbose); - }); + BwdRunner runner; + auto result = runner.run(options, hw_info); + print_result(name, result, options.verbose); }; using HeadDim = _128; @@ -911,10 +803,7 @@ int main_single(int argc, char const **args) { auto with_causal = [&](auto fn) { if (options.causal) { - fn(CausalForBackwardMask{}); - } - else if (options.residual) { - fn(ResidualMaskForBackward{}); + fn(CausalMask{}); } else { fn(NoMask{}); diff --git a/examples/77_blackwell_fmha/77_blackwell_fmha_gen.cu b/examples/77_blackwell_fmha/77_blackwell_fmha_gen.cu index 31432621..220f5fa8 100644 --- a/examples/77_blackwell_fmha/77_blackwell_fmha_gen.cu +++ b/examples/77_blackwell_fmha/77_blackwell_fmha_gen.cu @@ -394,7 +394,6 @@ struct ExampleRunner { fmha_fwd_gen_reference( problem_shape, block_seqlen_kv.get(), block_cache_batch_idx.get(), mQ, mNewK, mNewV, mCacheK, mCacheV, mO); - cudaError_t result = cudaDeviceSynchronize(); if (result != cudaSuccess) { std::cerr << "Reference kernel failed. Last CUDA error: " @@ -409,7 +408,6 @@ struct ExampleRunner { double max_diff = 0; double mean_diff = 0; reference_abs_diff(block_o, block_ref_o, max_diff, mean_diff); - bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); if (! passed_O) { std::cerr << "failed O: max diff " << max_diff @@ -417,7 +415,6 @@ struct ExampleRunner { } reference_abs_diff(block_cache_k, block_ref_cache_k, max_diff, mean_diff); - bool passed_K = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); if ( ! passed_K) { std::cerr << "failed Cache K: max diff " << max_diff @@ -425,7 +422,6 @@ struct ExampleRunner { } reference_abs_diff(block_cache_v, block_ref_cache_v, max_diff, mean_diff); - bool passed_V = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); if ( ! passed_V) { std::cerr << "failed Cache V: max diff " << max_diff @@ -507,7 +503,6 @@ struct ExampleRunner { block_ref_cache_k.copy_from_device(block_cache_k.get(), block_cache_k.size()); block_ref_cache_v.copy_from_device(block_cache_v.get(), block_cache_v.size()); - block_seqlen_kv.reset(seqlen_kv.size()); block_seqlen_kv.copy_from_host(seqlen_kv.data(), seqlen_kv.size()); @@ -726,7 +721,6 @@ int main_single(int argc, char const **args) { << "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n"; return 0; } - // // Parse options // diff --git a/examples/77_blackwell_fmha/77_blackwell_mla.cu b/examples/77_blackwell_fmha/77_blackwell_mla.cu index ca024623..baa70fce 100644 --- a/examples/77_blackwell_fmha/77_blackwell_mla.cu +++ b/examples/77_blackwell_fmha/77_blackwell_mla.cu @@ -59,7 +59,7 @@ using namespace cutlass::fmha::kernel; /////////////////////////////////////////////////////////////////////////////////////////////////// enum class InitStyle { - kOne, kLinearStride128, kLinearStride1, kRandom, kRandomLarge, kNone + kOne, kLinearStride128, kLinearStride1, kRandom, kNone }; /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -98,9 +98,6 @@ struct Options { if (s == "r") { dst = InitStyle::kRandom; } - else if (s == "l") { - dst = InitStyle::kRandomLarge; - } else if (s == "1") { dst = InitStyle::kOne; } @@ -206,11 +203,6 @@ void initialize_block( block.get(), block.size(), seed, (Element) -1, (Element) 1); break; } - case InitStyle::kRandomLarge: { - cutlass::reference::device::BlockFillRandomGaussian( - block.get(), block.size(), seed, (Element) -1, (Element) 100); - break; - } case InitStyle::kLinearStride1: { std::vector data(block.size()); for (size_t i = 0; i < block.size() / 128; i ++) { diff --git a/examples/77_blackwell_fmha/CMakeLists.txt b/examples/77_blackwell_fmha/CMakeLists.txt index df637e66..f04ebe41 100644 --- a/examples/77_blackwell_fmha/CMakeLists.txt +++ b/examples/77_blackwell_fmha/CMakeLists.txt @@ -63,7 +63,7 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC TEST_COMMAND_OPTIONS TEST_BASIC # TEST_CAUSAL - TEST_VARLEN + # TEST_VARLEN # TEST_HDIM64 # TEST_GQA) ) @@ -119,7 +119,7 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC 77_blackwell_fmha_bwd.cu TEST_COMMAND_OPTIONS TEST_BASIC - TEST_VARLEN + # TEST_GEN_VARLEN # TEST_GEN_HDIM64 # TEST_GEN_GQA # TEST_GEN_REMAP @@ -128,24 +128,20 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC target_include_directories(77_blackwell_fmha_bwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) target_compile_definitions(77_blackwell_fmha_bwd_${PREC} PRIVATE ${PREC_MACRO}) target_compile_options(77_blackwell_fmha_bwd_${PREC} PRIVATE -Xptxas -v) - endforeach() - # Add a target that builds all examples - add_custom_target(77_blackwell_fmha_all - DEPENDS - 77_blackwell_fmha_fp8 - 77_blackwell_fmha_fp16 - 77_blackwell_fmha_gen_fp8 - 77_blackwell_fmha_gen_fp16 - 77_blackwell_mla_2sm_fp8 - 77_blackwell_mla_2sm_fp16 - 77_blackwell_mla_2sm_cpasync_fp8 - 77_blackwell_mla_2sm_cpasync_fp16 - 77_blackwell_mla_b2b_2sm_fp8 - 77_blackwell_mla_b2b_2sm_fp16 - 77_blackwell_fmha_bwd_fp8 - 77_blackwell_fmha_bwd_fp16 - 77_blackwell_fmha_bwd_sat_fp8 - 77_blackwell_fmha_bwd_sat_fp16 - ) + cutlass_example_add_executable( + 77_blackwell_fmha_bwd_sat_${PREC} + 77_blackwell_fmha_bwd.cu + TEST_COMMAND_OPTIONS + TEST_BASIC + # TEST_GEN_VARLEN + TEST_GEN_HDIM64 + # TEST_GEN_GQA + # TEST_GEN_REMAP + # TEST_GEN_CACHEONLY) + ) + target_include_directories(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) + target_compile_definitions(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${PREC_MACRO} SKIP_ATOMIC) + target_compile_options(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE -Xptxas -v) + endforeach() endif() diff --git a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp index 445cc3f2..f31c8024 100644 --- a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp +++ b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp @@ -132,58 +132,6 @@ struct ResidualMask : NoMask { } }; -struct ResidualMaskForBackward : NoMask { - - using Base = NoMask; - - template - CUTLASS_DEVICE int get_masked_trip_count( - BlkCoord const& blk_coord, - TileShape const& tile_shape, - ProblemSize const& problem_size) { - - if (get<1>(problem_size) % get<1>(tile_shape) != 0) { - return 1; - } - return 0; - } - - template - CUTLASS_DEVICE - int get_unmasked_trip_count( - BlkCoord const& blk_coord, - TileShape const& tile_shape, - ProblemSize const& problem_size) { - - // if the sequence length does not divide the tile size evenly - if (get<1>(problem_size) % get<1>(tile_shape) != 0) { - return get_trip_count(blk_coord, tile_shape, problem_size) - 1; - } - return get_trip_count(blk_coord, tile_shape, problem_size); - } - - template - CUTLASS_DEVICE - void apply_mask( - AccQK& acc_qk, - IndexQK const& index_qk, - ProblemSize const& problem_size) { - - // This is useful is seqlen_k % kBlockN != 0 since it masks - // the remaining elements out from softmax. - // d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar - // issues as they are transparently taken care of by TMA and the - // epilogue, if it is instantiated with predication support. - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(acc_qk); i++) { - auto pos = index_qk(i); - if (! elem_less(pos, select<0,1>(problem_size))) { - acc_qk(i) = -INFINITY; - } - } - } -}; - struct CausalMask : NoMask { using Base = NoMask; @@ -209,7 +157,8 @@ struct CausalMask : NoMask { TileShape const& tile_shape, ProblemSize const& problem_size) { - return ceil_div(get<0>(tile_shape), get<1>(tile_shape)); + int trip_count = get_trip_count(blk_coord, tile_shape, problem_size); + return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape)))); } template @@ -248,57 +197,25 @@ struct CausalMask : NoMask { }; -struct CausalForBackwardMask : CausalMask, ResidualMaskForBackward { - - using Base = CausalMask; - - template - CUTLASS_DEVICE - void apply_mask( - AccQK& acc_qk, - IndexQK const& index_qk, - ProblemSize const& problem_size) { - - // There are two ways to do causal if N_Q != N_K - // (1) is to assume that the Q is at the beginning of the matrix - // - this is what we demonstrate here - // (2) is that it is at the end of the matrix - // - this is usually what we want for inference settings - // where we only compute the next row and use cache for the rest - // - if you'd like this, you only need to add an offset like so: - // get<0>(pos) + offset_q < get<1>(pos) - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(acc_qk); i++) { - auto pos = index_qk(i); - bool masked = (get<0>(pos) < get<1>(pos)) || !elem_less(pos, problem_size); - if (masked) { - acc_qk(i) = -INFINITY; - } - } - } - -}; - struct VariableLength { int max_length; int* cumulative_length = nullptr; - int total_length = -1; CUTE_HOST_DEVICE operator int() const { return max_length; } }; -template struct is_variable_length_impl : std::false_type {}; -template<> struct is_variable_length_impl : std::true_type {}; -template constexpr bool is_variable_length_v = is_variable_length_impl>::value; +template struct is_variable_length : std::false_type {}; +template<> struct is_variable_length : std::true_type {}; +template constexpr bool is_variable_length_v = is_variable_length::value; template CUTE_HOST_DEVICE constexpr auto apply_variable_length(Shape const& shape, Idx const& idx) { return transform_leaf(shape, [&](auto const& s) { - if constexpr (is_variable_length_v) { + if constexpr (is_variable_length_v>) { return s.cumulative_length[idx+1] - s.cumulative_length[idx]; } else { @@ -313,7 +230,7 @@ constexpr auto apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) { auto new_shape = apply_variable_length(shape, idx); auto new_coord = transform_leaf(shape, coord, [&](auto const& s, auto const& c) { - if constexpr (is_variable_length_v) { + if constexpr (is_variable_length_v>) { return cute::make_tuple(c, s.cumulative_length[idx]); } else { @@ -323,30 +240,6 @@ apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) { return cute::make_tuple(new_shape, new_coord); } -template -CUTE_HOST_DEVICE -constexpr auto -apply_variable_length_offset(Shape const& shape, Coord const& coord) { - auto idx = back(back(coord)); - auto result_shape = transform_leaf(shape, [&](auto const& s) { - if constexpr (is_variable_length_v) { - return s.cumulative_length[idx+1] - s.cumulative_length[idx]; - } - else { - return s; - } - }); - auto result_offset = transform_leaf(coord, shape, [&](auto const& c, auto const& s) { - if constexpr (is_variable_length_v) { - return s.cumulative_length[idx]; - } - else { - return _0{}; - } - }); - return cute::make_tuple(result_shape, result_offset); -} - } // namespace cutlass::fmha::collective namespace cute { diff --git a/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp index 2740c6b8..82400801 100644 --- a/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp @@ -42,7 +42,7 @@ template< class ElementAcc, class TileShape, // Q, D, _ class StrideO, // Q, D, B - class StrideLSE_ // Q, B + class StrideLSE // Q, B > struct Sm100FmhaFwdEpilogueTmaWarpspecialized { @@ -54,7 +54,6 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized { // using SmemLayoutAtomO = decltype(make_ordered_layout(select<0,1>(TileShape{}), Step<_1, _0>{})); using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{})); using SmemLayoutO_ = SmemLayoutO; - using StrideLSE = StrideLSE_; struct TensorStorage { @@ -80,9 +79,6 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized { struct Params { TMA_O tma_store_o; - - ElementAcc* ptr_LSE; - StrideLSE dLSE; }; template @@ -114,9 +110,7 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized { ); return { - tma_store_o, - args.ptr_LSE, - args.dLSE + tma_store_o }; } @@ -125,10 +119,6 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized { cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor()); } - const Params& params; - - CUTLASS_DEVICE Sm100FmhaFwdEpilogueTmaWarpspecialized(const Params& params) : params(params) {} - template CUTLASS_DEVICE auto store( diff --git a/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp index 76cb3e11..1eaea0ce 100644 --- a/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp @@ -637,7 +637,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { } tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv); - + if (i == size(tTMEM_LOADrS) - kReleasePipeCount) { order_s.arrive(); } @@ -691,7 +691,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3); cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2); float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y; - + row_sum = local_row_sum; if (final_call) { @@ -787,14 +787,14 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { // good values would be either 32 or 64 const int kCorrectionTileSize = 32 / sizeof(ElementOut); - using TMEM_LOAD = std::conditional_t; // 4x32 threads with 64 cols of 32b elem + using TMEM_LOAD = std::conditional_t; // 4x32 threads with 64 cols of 32b elem typename CollectiveMmaPV::TiledMma mma; Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); Tensor tOcO = mma.get_slice(0).partition_C(cO); Tensor tOsO = mma.get_slice(0).partition_C(sO); - + Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int{}))); Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int{}))); Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int{}))); @@ -809,7 +809,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{})); auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); - + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _)); Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _)); Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _)); @@ -824,9 +824,9 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i); Tensor tTMrO = make_tensor(shape(tTMEM_LOADcO(_, _0{}, _0{}, i))); - + copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO); - + #ifndef ONLY_SOFTMAX CUTLASS_PRAGMA_UNROLL for (int j = 0; j < size(tTMrO); j += 2) { @@ -872,24 +872,24 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { // good values would be either 32 or 64 const int kCorrectionTileSize = 16; - using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem typename CollectiveMmaPV::TiledMma mma; Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); Tensor tOcO = mma.get_slice(0).partition_C(cO); - + Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); tOtO_i.data() = tOtO_i.data().get() + tmem_O; - + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i); auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i); auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); - + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i); Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i); Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i); @@ -899,7 +899,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { float2 scale_f32x2 = make_float2(scale, scale); Tensor tTMrO = make_tensor(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{})); - + auto copy_in = [&](int i) { Tensor tTMEM_LOADtO_i = tTMEM_LOADtO; tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize); @@ -942,7 +942,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { } } - template + template CUTLASS_DEVICE auto correction( BlkCoord const& blk_coord, @@ -951,8 +951,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state, PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state, PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state, - PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, - CollectiveEpilogue& epilogue) { + PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state) { int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); @@ -962,7 +961,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{})); Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); - + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); @@ -1061,25 +1060,13 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { // F2FP // store to smem Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{}); - Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), repeat_like(typename CollectiveEpilogue::StrideLSE{}, _1{}), epilogue.params.dLSE); - correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO); - if (epilogue.params.ptr_LSE != nullptr) { - int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord); - - ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax); - - if (row_idx < get<0>(problem_shape)) { - gLSE(row_idx, get<2>(blk_coord)) = lse; - } - } - cutlass::arch::fence_view_async_tmem_load(); pipeline_o.consumer_release(pipeline_o_consumer_state); ++pipeline_o_consumer_state; - + pipeline_epi.producer_commit(pipeline_epi_producer_state); ++pipeline_epi_producer_state; @@ -1096,16 +1083,6 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO); - if (epilogue.params.ptr_LSE != nullptr) { - int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{}); - - ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax); - - if (row_idx < get<0>(problem_shape)) { - gLSE(row_idx, get<2>(blk_coord)) = lse; - } - } - cutlass::arch::fence_view_async_tmem_load(); pipeline_o.consumer_release(pipeline_o_consumer_state); diff --git a/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp b/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp index 3c8f7195..80fcdf9f 100644 --- a/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp +++ b/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp @@ -50,19 +50,13 @@ namespace cutlass::fmha::device { ////////////////////////////// CUTLASS 3.x API ///////////////////////////////// //////////////////////////////////////////////////////////////////////////////// -template< - class ProblemShape, - class Element, - class ElementAccumulator, - class TileShape, - class Mask -> +template class Sm100FmhaBwd { public: /// Argument structure: User API struct Arguments { // Q K D HB - ProblemShape problem_shape; + cute::tuple> problem_size; const Element* ptr_Q; cute::tuple> stride_Q; @@ -92,16 +86,14 @@ public: }; using OperationSumOdO = cutlass::fmha::device::FMHA< - cutlass::fmha::kernel::FmhaKernelBwdSumOdO + cutlass::fmha::kernel::FmhaKernelBwdSumOdO >; using OperationConvert = cutlass::fmha::device::FMHA< - cutlass::fmha::kernel::FmhaKernelBwdConvert + cutlass::fmha::kernel::FmhaKernelBwdConvert >; using Operation = cutlass::fmha::device::FMHA< - cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized< - ProblemShape, Element, ElementAccumulator, TileShape, Mask - > + cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized >; using Kernel = typename Operation::Kernel; @@ -121,15 +113,15 @@ private: ElementAccumulator* sum_odo = nullptr, ElementAccumulator* scaled_lse = nullptr) { using namespace cute; - auto [Q_, K, D, HB] = args.problem_shape; + auto [Q, K, D, HB] = args.problem_size; auto [H, B] = HB; D = cutlass::round_up(D, 8); // Alignment - int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment auto stride_sum_OdO = make_stride(_1{}, make_stride(Q, Q*H)); auto stride_scaled_lse = make_stride(_1{}, make_stride(Q, Q*H)); auto log2_e = log2f(expf(1.0f)); return typename OperationSumOdO::Arguments { - args.problem_shape, + args.problem_size, args.ptr_O, args.stride_O, args.ptr_dO, args.stride_dO, sum_odo, stride_sum_OdO, @@ -141,13 +133,13 @@ private: static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) { using namespace cute; - auto [Q_, K, D, HB] = args.problem_shape; + auto [Q, K, D, HB] = args.problem_size; auto [H, B] = HB; D = cutlass::round_up(D, 8); // Alignment - int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment auto stride_src_dQ = make_stride(D, _1{}, make_stride(D*Q, D*Q*H)); return typename OperationConvert::Arguments { - args.problem_shape, + args.problem_size, src, stride_src_dQ, nullptr, stride_src_dQ, nullptr, stride_src_dQ, @@ -164,7 +156,7 @@ private: ElementAccumulator* scaled_lse = nullptr, cute::tuple> const& stride_scaled_lse = {}, ElementAccumulator* dQ_acc = nullptr, cute::tuple> const& stride_dQ = {}) { return typename Operation::Arguments{ - args.problem_shape, + args.problem_size, { args.ptr_Q, args.stride_Q, args.ptr_K, args.stride_K, args.ptr_V, args.stride_V, @@ -207,10 +199,10 @@ public: /// Gets the workspace size static size_t get_workspace_size(Arguments const& args) { - auto [Q_, K, D, HB] = args.problem_shape; + auto [Q, K, D, HB] = args.problem_size; auto [H, B] = HB; D = cutlass::round_up(D, 8); // Alignment - int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment size_t workspace_bytes = 0; // OdO vector workspace_bytes += B*H*Q * sizeof(ElementAccumulator); @@ -227,10 +219,10 @@ public: CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ=" << workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null")); - auto [Q_, K, D, HB] = args.problem_shape; + auto [Q, K, D, HB] = args.problem_size; auto [H, B] = HB; D = cutlass::round_up(D, 8); // Alignment - int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment ElementAccumulator* sum_OdO = reinterpret_cast(workspace_sum_OdO); ElementAccumulator* scaled_lse = reinterpret_cast(workspace_scaled_lse); ElementAccumulator* dQ_acc = reinterpret_cast(workspace_dQ); @@ -256,10 +248,10 @@ public: CUTLASS_TRACE_HOST("Universal::initialize() - workspace " << workspace << ", stream: " << (stream ? "non-null" : "null")); - auto [Q_, K, D, HB] = args.problem_shape; + auto [Q, K, D, HB] = args.problem_size; auto [H, B] = HB; D = cutlass::round_up(D, 8); // Alignment - int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment char* workspace_chr = reinterpret_cast(workspace); ElementAccumulator* sum_OdO = reinterpret_cast(workspace_chr); workspace_chr += B*H*Q * sizeof(ElementAccumulator); diff --git a/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_convert.hpp b/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_convert.hpp index c7f869f9..c2618bcb 100644 --- a/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_convert.hpp +++ b/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_convert.hpp @@ -39,11 +39,11 @@ namespace cutlass::fmha::kernel { using namespace cute; -template +template struct FmhaKernelBwdConvert { struct Arguments { - ProblemShape problem_shape; + tuple> problem_size; const ElementAcc* ptr_src_dQ; tuple> stride_src_dQ; @@ -85,11 +85,11 @@ struct FmhaKernelBwdConvert { static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq; static bool can_implement(Arguments const& args) { - return get<2>(args.problem_shape) % kElementsPerLoad == 0; + return get<2>(args.problem_size) % kElementsPerLoad == 0; } static dim3 get_grid_shape(Params const& params) { - dim3 grid(size<3,0>(params.problem_shape), size<3,1>(params.problem_shape), ceil_div(std::max(size<0>(params.problem_shape), size<1>(params.problem_shape)), kBlockSeq)); + dim3 grid(size<3,0>(params.problem_size), size<3,1>(params.problem_size), ceil_div(std::max(size<0>(params.problem_size), size<1>(params.problem_size)), kBlockSeq)); return grid; } @@ -102,25 +102,18 @@ struct FmhaKernelBwdConvert { return args; } - template - CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, Count const& count) { + template + CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, int count) { auto ptr_src_bh = ptr_src + get<2,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y; auto ptr_dest_bh = ptr_dest + get<2,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y; - int seqlen = count; - if constexpr (is_variable_length_v) { - int offset = count.cumulative_length[blockIdx.y]; - ptr_dest_bh += offset * get<0>(stride_dest); - seqlen = count.cumulative_length[blockIdx.y + 1] - offset; - } - for (int idx_s_t = threadIdx.y; idx_s_t < kBlockSeq; idx_s_t += kNumThreadsSeq) { int idx_s = idx_s_t + kBlockSeq * blockIdx.z; - if (idx_s >= seqlen) continue; + if (idx_s >= count) continue; auto ptr_src_bhs = ptr_src_bh + idx_s * get<0>(stride_src); auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<0>(stride_dest); - for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_shape); idx_d += kElementsPerLoad * kNumThreadsD) { + for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) { ElementAcc value_src[kElementsPerLoad]; Element value_dest[kElementsPerLoad]; @@ -139,13 +132,13 @@ struct FmhaKernelBwdConvert { CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { if (params.ptr_src_dQ != nullptr) { - copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_shape)); + copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_size)); } if (params.ptr_src_dK != nullptr) { - copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_shape)); + copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_size)); } if (params.ptr_src_dV != nullptr) { - copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_shape)); + copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_size)); } } }; diff --git a/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp b/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp index 98c127da..44080e2d 100644 --- a/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp +++ b/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp @@ -39,11 +39,11 @@ namespace cutlass::fmha::kernel { using namespace cute; -template +template struct FmhaKernelBwdSumOdO { struct Arguments { - ProblemShape problem_shape; + cute::tuple> problem_size; const Element* ptr_O; cute::tuple> stride_O; @@ -86,11 +86,11 @@ struct FmhaKernelBwdSumOdO { static const int kIterationsQ = kBlockQ / kNumThreadsQ; static bool can_implement(Arguments const& args) { - return get<2>(args.problem_shape) % kElementsPerLoad == 0; + return get<2>(args.problem_size) % kElementsPerLoad == 0; } static dim3 get_grid_shape(Params const& params) { - dim3 grid(ceil_div(size<0>(params.problem_shape), kBlockQ), size<3,0>(params.problem_shape), size<3,1>(params.problem_shape)); + dim3 grid(ceil_div(size<0>(params.problem_size), kBlockQ), size<3,0>(params.problem_size), size<3,1>(params.problem_size)); return grid; } @@ -110,20 +110,10 @@ struct FmhaKernelBwdSumOdO { auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse); auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse); - auto problem_q = get<0>(params.problem_shape); - int seqlen_q = problem_q; - if constexpr (is_variable_length_v) { - int offset = problem_q.cumulative_length[blockIdx.z]; - ptr_O_bh += offset * get<0>(params.stride_O); - ptr_dO_bh += offset * get<0>(params.stride_dO); - ptr_lse_bh += offset * get<0>(params.stride_lse); - seqlen_q = problem_q.cumulative_length[blockIdx.z + 1] - offset; - } - CUTLASS_PRAGMA_UNROLL for (int idx_q_t = threadIdx.y; idx_q_t < kBlockQ; idx_q_t += kNumThreadsQ) { int idx_q = idx_q_t + kBlockQ * blockIdx.x; - if (idx_q >= seqlen_q) continue; + if (idx_q >= get<0>(params.problem_size)) continue; ElementAcc acc = 0; auto ptr_O_bhq = ptr_O_bh + idx_q * get<0>(params.stride_O); auto ptr_dO_bhq = ptr_dO_bh + idx_q * get<0>(params.stride_dO); @@ -131,7 +121,7 @@ struct FmhaKernelBwdSumOdO { auto ptr_lse_bhq = ptr_lse_bh + idx_q * get<0>(params.stride_lse); auto ptr_scaled_lse_bhq = ptr_scaled_lse_bh + idx_q * get<0>(params.stride_scaled_lse); - for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_shape); idx_d += kElementsPerLoad * kNumThreadsD) { + for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) { Element value_O[kElementsPerLoad]; Element value_dO[kElementsPerLoad]; diff --git a/examples/77_blackwell_fmha/kernel/fmha_options.hpp b/examples/77_blackwell_fmha/kernel/fmha_options.hpp index 387c9dac..d4faa8d2 100644 --- a/examples/77_blackwell_fmha/kernel/fmha_options.hpp +++ b/examples/77_blackwell_fmha/kernel/fmha_options.hpp @@ -82,4 +82,4 @@ struct Option { using option_value = Value; }; -} // namespace cutlass::fmha::kernel \ No newline at end of file +} // namespace cutlass::fmha::kernel diff --git a/examples/77_blackwell_fmha/kernel/fmha_tile_scheduler.hpp b/examples/77_blackwell_fmha/kernel/fmha_tile_scheduler.hpp index 35964cb6..119f069c 100644 --- a/examples/77_blackwell_fmha/kernel/fmha_tile_scheduler.hpp +++ b/examples/77_blackwell_fmha/kernel/fmha_tile_scheduler.hpp @@ -90,8 +90,8 @@ struct PersistentTileScheduler { struct Params { int num_blocks; FastDivmod divmod_m_block; - FastDivmod divmod_b; FastDivmod divmod_h; + FastDivmod divmod_b; KernelHardwareInfo hw_info; }; @@ -146,7 +146,7 @@ struct PersistentTileScheduler { params.divmod_m_block(block_decode, m_block, block_decode); params.divmod_b(block_decode, bidb, block_decode); params.divmod_h(block_decode, bidh, block_decode); - return make_coord(m_block, _0{}, make_coord(bidb, bidh)); + return make_coord(m_block, _0{}, make_coord(bidh, bidb)); } CUTLASS_DEVICE diff --git a/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp index efc3d6a4..e1bd43d5 100644 --- a/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp @@ -43,8 +43,6 @@ #include "collective/fmha_common.hpp" -#include - namespace cutlass::fmha::kernel { using namespace cutlass::fmha::collective; @@ -52,7 +50,6 @@ using namespace cutlass::fmha::collective; using namespace cute; template< - class ProblemShape, class Element, class ElementAcc, class TileShape, @@ -121,7 +118,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { using TensorStrideContiguousK = Stride>; using TensorStrideContiguousMN = Stride<_1, int, Stride>; - + // compute S using CollectiveMmaKQ = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, @@ -277,6 +274,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { static constexpr int SharedStorageSize = offsetof(SharedStorage, tmem_base_ptr) + sizeof(uint32_t); static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem"); + using ProblemShape = Shape>; // Q K D (H B), eventuall D = (D_QK, D_VO) using TensorStride = TensorStrideContiguousK; // S D (H B) using RowTensorStride = Stride<_1, Stride>; // S (H B) @@ -362,16 +360,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { static Params to_underlying_arguments(Arguments const& args, void*) { - auto [Q_, K_, D, HB] = args.problem_shape; - int Q = Q_; - int K = K_; - - if constexpr (is_variable_length_v) { - Q = Q_.total_length; - } - if constexpr (is_variable_length_v) { - K = K_.total_length; - } + auto [Q, K, D, HB] = args.problem_shape; auto params_kq = CollectiveMmaKQ::to_underlying_arguments( make_shape(K, Q, D, HB), @@ -389,10 +378,10 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { TMA_DQ tma_red_dq = make_tma_copy( SM90_TMA_REDUCE_ADD{}, - make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q_, D, HB), args.mainloop.stride_dq_acc), + make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q, D, HB), args.mainloop.stride_dq_acc), SmemLayoutDQ{}(_, _, _0{}) ); - + return Params{ args.problem_shape, args.mainloop, @@ -427,11 +416,10 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { } - template + template CUTLASS_DEVICE void load( BlkCoord const& blk_coord, - BlkOffset const& blk_offset, - ProblemShape_ const& problem_shape, + ProblemShape const& problem_shape, int iter_index, int iter_count, MainloopArguments const& mainloop_args, @@ -452,15 +440,10 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { uint16_t mcast_mask = 0; - auto mK_in = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB)); - auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D, HB)); - auto mQ_in = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB)); - auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D, HB)); - - auto mK = domain_offset(select<1,2,3>(blk_offset), mK_in); - auto mV = domain_offset(select<1,2,3>(blk_offset), mV_in); - auto mQ = domain_offset(select<0,2,3>(blk_offset), mQ_in); - auto mDO = domain_offset(select<0,2,3>(blk_offset), mDO_in); + auto mK = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB)); + auto mQ = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB)); + auto mV = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D, HB)); + auto mDO = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D, HB)); auto gK = local_tile(mK, TileShapeKQ{}, make_coord(_,_,_), Step<_1, X, _1>{}); auto gQ = local_tile(mQ, TileShapeKQ{}, make_coord(_,_,_), Step{}); @@ -469,7 +452,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { ThrMMA cta_mma_kq = TiledMmaKQ{}.get_slice(_0{}); ThrMMA cta_mma_vdo = TiledMmaVDO{}.get_slice(_0{}); - + auto tSTgK = cta_mma_kq.partition_A(gK); auto tSTgQ = cta_mma_kq.partition_B(gQ); auto tDPTgV = cta_mma_vdo.partition_A(gV); @@ -494,8 +477,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO)); // set up lse and sum_odo - - auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord; + + auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord; pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); @@ -512,7 +495,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { } // load Q - if (cute::elect_one_sync()) { + if (cute::elect_one_sync()) { cute::copy( mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), @@ -532,14 +515,12 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { int smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4; int gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; auto mLSE = make_tensor(mainloop_args.ptr_lse, make_shape(Q, HB), mainloop_args.stride_lse); - for (int i = 0; i < 4; i++) { - cutlass::arch::cp_async_zfill<4>( - shared_tensors.smem_lse.begin() + smem_idx + i, - &mLSE(gmem_idx + i, blk_coord_batch), - gmem_idx + i < Q - ); - } - + cutlass::arch::cp_async_zfill<16>( + shared_tensors.smem_lse.begin() + smem_idx, + &mLSE(gmem_idx, blk_coord_batch), + gmem_idx < Q + ); + pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); ++pipeline_load_compute_lse_producer_state; @@ -548,7 +529,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV); - + // load V if (cute::elect_one_sync()) { cute::copy( @@ -559,7 +540,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { } // load dO - if (cute::elect_one_sync()) { + if (cute::elect_one_sync()) { cute::copy( mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), @@ -575,13 +556,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4; gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; auto mSumOdO = make_tensor(mainloop_args.ptr_sum_odo, make_shape(Q, HB), mainloop_args.stride_sum_odo); - for (int i = 0; i < 4; i++) { - cutlass::arch::cp_async_zfill<4>( - shared_tensors.smem_sum_odo.begin() + smem_idx + i, - &mSumOdO(gmem_idx + i, blk_coord_batch), - gmem_idx + i < Q - ); - } + cutlass::arch::cp_async<16>( + shared_tensors.smem_sum_odo.begin() + smem_idx, + &mSumOdO(gmem_idx, blk_coord_batch), + gmem_idx < Q + ); pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); ++pipeline_load_compute_sum_odo_producer_state; @@ -594,7 +573,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); // load Q - if (cute::elect_one_sync()) { + if (cute::elect_one_sync()) { cute::copy( mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), @@ -605,26 +584,24 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { ++pipeline_load_mma_q_producer_state; pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); - + // load LSE smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4; gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; - for (int i = 0; i < 4; i++) { - cutlass::arch::cp_async_zfill<4>( - shared_tensors.smem_lse.begin() + smem_idx + i, - &mLSE(gmem_idx + i, blk_coord_batch), - gmem_idx + i < Q - ); - } - + cutlass::arch::cp_async<16>( + shared_tensors.smem_lse.begin() + smem_idx, + &mLSE(gmem_idx, blk_coord_batch), + gmem_idx < Q + ); + pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); ++pipeline_load_compute_lse_producer_state; pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); - // load dO - if (cute::elect_one_sync()) { + // load dO + if (cute::elect_one_sync()) { cute::copy( mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), @@ -635,18 +612,16 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { ++pipeline_load_mma_do_producer_state; pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); - + // load sum_OdO smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4; gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; - for (int i = 0; i < 4; i++) { - cutlass::arch::cp_async_zfill<4>( - shared_tensors.smem_sum_odo.begin() + smem_idx + i, - &mSumOdO(gmem_idx + i, blk_coord_batch), - gmem_idx + i < Q - ); - } - + cutlass::arch::cp_async_zfill<16>( + shared_tensors.smem_sum_odo.begin() + smem_idx, + &mSumOdO(gmem_idx, blk_coord_batch), + gmem_idx < Q + ); + pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); ++pipeline_load_compute_sum_odo_producer_state; @@ -656,31 +631,31 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { } - template + template CUTLASS_DEVICE void mma( BlkCoord const& blk_coord, - ProblemShape_ const& problem_shape, + ProblemShape const& problem_shape, int iter_index, int iter_count, MainloopArguments const& mainloop_args, TensorStorage& shared_tensors, - PipelineLoadMmaQ& pipeline_load_mma_q, - typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state, - PipelineLoadMmaDO& pipeline_load_mma_do, + PipelineLoadMmaQ& pipeline_load_mma_q, + typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state, + PipelineLoadMmaDO& pipeline_load_mma_do, typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state, - PipelineMmaComputeS& pipeline_mma_compute_s, + PipelineMmaComputeS& pipeline_mma_compute_s, typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state, - PipelineMmaComputeDP& pipeline_mma_compute_dp, + PipelineMmaComputeDP& pipeline_mma_compute_dp, typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state, - PipelineMmaReduceDQ& pipeline_mma_reduce_dq, + PipelineMmaReduceDQ& pipeline_mma_reduce_dq, typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state, - PipelineComputeMmaP& pipeline_compute_mma_p, + PipelineComputeMmaP& pipeline_compute_mma_p, typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state, - PipelineComputeMmaDS& pipeline_compute_mma_ds, + PipelineComputeMmaDS& pipeline_compute_mma_ds, typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state, PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) { - + auto [Q, K, D, HB] = problem_shape; auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); @@ -710,7 +685,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{}); tDVrP.data() = TmemAllocation::kP; Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT); - + TiledMmaKQ tiled_mma_kq; TiledMmaVDO tiled_mma_vdo; TiledMmaDSK tiled_mma_dsk; @@ -948,8 +923,6 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { TensorC const& coord, TensorShape const& tensor_shape) { - Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); }); - auto copy_op = make_cotiled_copy( Copy_Atom, Element>{}, make_layout(make_shape(_1{}, Int{})), @@ -957,91 +930,42 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { ); auto thr_copy = copy_op.get_slice(_0{}); - Tensor tCg = thr_copy.partition_D(gmem); - Tensor tCr = thr_copy.partition_S(quantize(regs)); - Tensor tPc = thr_copy.partition_D(preds); + auto tCg = thr_copy.partition_D(gmem); + auto tCr = thr_copy.partition_S(quantize(regs)); + auto tCc = thr_copy.partition_D(coord); - copy_if(copy_op, tPc, tCr, tCg); - } + constexpr int R = decltype(tCr.layout())::rank; + auto tCg_v = group_modes<1, R>(tCg); + auto tCr_v = group_modes<1, R>(tCr); + auto tCc_v = group_modes<1, R>(tCc); + auto tCp_v = make_tensor(shape<1>(tCc_v)); - - template - CUTLASS_DEVICE void epilogue_clear( - BlkCoord const& blk_coord, - BlkOffset const& blk_offset, - ProblemShape_ const& problem_shape, - MainloopArguments const& mainloop_args, - EpilogueArguments const& epilogue_args) { - - auto [Q, K, D, HB] = problem_shape; - auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord; - - auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); - auto mDK = domain_offset(select<1,2,3>(blk_offset), mDK_in); - auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) - (_, _, blk_coord_k, _0{}, blk_coord_batch); - - Tensor cDK = domain_offset( - make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), - make_identity_tensor(take<0,2>(TileShapeDSQ{})) - ); - - auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); - auto mDV = domain_offset(select<1,2,3>(blk_offset), mDV_in); - auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) - (_, _, blk_coord_k, _0{}, blk_coord_batch); - - Tensor cDV = domain_offset( - make_coord(blk_coord_k * TileShapeK{}, _0{}), - make_identity_tensor(take<0,2>(TileShapePDO{})) - ); - - if (threadIdx.x >= 256) { - return; + for (int i = 0; i < size(tCp_v); ++i) { + tCp_v(i) = elem_less(tCc_v(_0{},i), tensor_shape); } - auto tiled_copy = make_cotiled_copy( - Copy_Atom, Element>{}, - make_ordered_layout(make_shape(_256{}, Int{}), Step<_1, _0>{}), - make_ordered_layout(make_shape(TileShapeK{}, TileShapeDQK{}), Step<_1, _0>{})); - - auto thr_copy = tiled_copy.get_slice(threadIdx.x); - auto tCgDK = thr_copy.partition_D(gDK); - auto tCcDK = thr_copy.partition_S(cDK); - auto tCrDK = make_tensor(shape(tCcDK)); - - clear(tCrDK); - store(tCgDK, tCrDK, tCcDK, select<1,2>(problem_shape)); - - auto tCgDV = thr_copy.partition_D(gDV); - auto tCcDV = thr_copy.partition_S(cDV); - auto tCrDV = make_tensor(shape(tCcDV)); - - clear(tCrDV); - store(tCgDV, tCrDV, tCcDV, select<1,2>(problem_shape)); + copy_if(copy_op, tCp_v, tCr_v, tCg_v); } - template + template CUTLASS_DEVICE void epilogue( BlkCoord const& blk_coord, - BlkOffset const& blk_offset, - ProblemShape_ const& problem_shape, + ProblemShape const& problem_shape, MainloopArguments const& mainloop_args, EpilogueArguments const& epilogue_args, PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { auto [Q, K, D, HB] = problem_shape; - auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord; + auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord; auto load_op = SM100_TMEM_LOAD_32dp32b16x{}; auto tDKtDK = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{}); tDKtDK.data() = TmemAllocation::kDK; - auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); - auto mDK = domain_offset(select<1,2,3>(blk_offset), mDK_in); + auto mDK = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) (_, _, blk_coord_k, _0{}, blk_coord_batch); @@ -1076,13 +1000,12 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { auto tDVtDV = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{}); tDVtDV.data() = TmemAllocation::kDV; - auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); - auto mDV = domain_offset(select<1,2,3>(blk_offset), mDV_in); + auto mDV = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) (_, _, blk_coord_k, _0{}, blk_coord_batch); Tensor cDV = domain_offset( - make_coord(blk_coord_k * TileShapeK{}, _0{}), + make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), make_identity_tensor(take<0,2>(TileShapePDO{})) ); @@ -1126,11 +1049,10 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { } - template + template CUTLASS_DEVICE void compute( BlkCoord const& blk_coord, - BlkOffset const& blk_offset, - ProblemShape_ const& problem_shape, + ProblemShape const& problem_shape, int iter_index, int iter_count, MainloopArguments const& mainloop_args, @@ -1151,7 +1073,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { - + auto [Q, K, D, HB] = problem_shape; // in tmem, S & P overlap @@ -1192,7 +1114,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { Tensor tTR_cST = split_wg(thread_t2r.partition_D(cST)); Tensor tTR_rST = make_tensor(shape(tTR_cST)); Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST)); - + Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT); Tensor tTR_cDPT = split_wg(tTR_cDPT_p); Tensor tTR_rDPT = make_tensor(shape(tTR_cDPT)); @@ -1214,9 +1136,6 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { auto tRT_tP = split_wg(thread_r2t.partition_D(tDVrP)); auto tRT_cST = split_wg(thread_r2t.partition_S(tDVcST)); - bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} >= get<1>(problem_shape); - int last_iter = iter_count - 1 + iter_index; - CUTLASS_PRAGMA_NO_UNROLL while (iter_count > 0) { // wait for S and P @@ -1233,28 +1152,20 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { fn(cute::false_type{}); } }; - - bool leading_causal_masking = false; - if constexpr (std::is_base_of_v) { - leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord)); - } - bool trailing_residual_masking = false; - if constexpr (std::is_base_of_v) { - trailing_residual_masking = warp_uniform((iter_index == last_iter) || is_residual_k); - } - - dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) { + + dispatch_bool(std::is_base_of_v && + warp_uniform(iter_index == get<1>(blk_coord)), [&](auto is_causal_masked_tile) { // compute P = softmax(S, LSE) cute::copy(tiled_t2r, tTR_tST, tTR_rST); - - if constexpr (decltype(is_masked_tile)::value) { + + if constexpr (std::is_base_of_v && decltype(is_causal_masked_tile)::value) { Mask{}.apply_mask(tTR_rST, [&](int i) { auto c_transpose = tTR_cST(i); return make_coord(get<1>(c_transpose) + iter_index * TileShapeQ{}, get<0>(c_transpose) + get<1>(blk_coord) * TileShapeK{}); }, problem_shape); } - + ElementAcc log2_e = static_cast(M_LOG2E); float2 softmax_scale_log2_e; softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e; @@ -1273,16 +1184,16 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { tTR_rST(i) = ::exp2f(out.x); tTR_rST(i+1) = ::exp2f(out.y); } - + auto tRT_rST = quantize(tTR_rST); auto tRT_rST_reshaped = make_tensor(tRT_rST.data(), shape(tRT_cST)); - + cutlass::arch::fence_view_async_tmem_load(); cutlass::arch::NamedBarrier( kNumComputeWarps * NumThreadsPerWarp, cutlass::arch::ReservedNamedBarriers::TransformBarrier ).arrive_and_wait(); - + cute::copy(tiled_r2t, tRT_rST_reshaped, tRT_tP); }); @@ -1364,15 +1275,15 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { } epilogue( - blk_coord, blk_offset, problem_shape, mainloop_args, epilogue_args, + blk_coord, problem_shape, mainloop_args, epilogue_args, pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state ); } - template + template CUTLASS_DEVICE void reduce( BlkCoord const& blk_coord, - ProblemShape_ const& problem_shape, + ProblemShape const& problem_shape, int iter_index, int iter_count, MainloopArguments const& mainloop_args, @@ -1382,12 +1293,12 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state, PipelineReduceTmaStore& pipeline_reduce_tma_store, typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) { - + using X = Underscore; - + auto [Q, K, D, HB] = problem_shape; - auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord; + auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord; // must match TileShapeDQ auto load_op = SM100_TMEM_LOAD_32dp32b32x{}; @@ -1396,7 +1307,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { tDQtDQ.data() = TmemAllocation::kDQ; Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB)); - auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step{}) + auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) (_, _, _, _0{}, blk_coord_batch); Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{})); @@ -1465,7 +1376,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { iter_index += 1; } } - + CUTLASS_DEVICE void operator()(Params const& params, char* smem) { int warp_idx = cutlass::canonical_warp_idx_sync(); @@ -1650,7 +1561,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state; typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state; typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state; - + auto pipeline_load_mma_q_producer_state = make_producer_start_state(); auto pipeline_load_mma_do_producer_state = make_producer_start_state(); auto pipeline_load_compute_lse_producer_state = make_producer_start_state(); @@ -1665,45 +1576,27 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { pipeline_init_wait(size(ClusterShape{})); - auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z)); - auto [problem_shape, blk_offset] = apply_variable_length_offset( - params.problem_shape, - blk_coord - ); + auto blk_coord = make_coord(_0{}, blockIdx.x, make_coord(blockIdx.y, blockIdx.z)); + auto problem_shape = params.problem_shape; int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{}); int iter_start = 0; if constexpr (std::is_base_of_v) { iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{}; } - if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) { - return; - } iter_count -= iter_start; - if (iter_count <= 0) { - epilogue_clear( - blk_coord, - blk_offset, - problem_shape, - params.mainloop, - params.epilogue - ); - return; - } - if (role == WarpRole::Load) { warpgroup_reg_set(); - + load( blk_coord, - blk_offset, problem_shape, iter_start, iter_count, params.mainloop, params.mainloop_params, shared_storage.tensors, - pipeline_load_mma_q, pipeline_load_mma_q_producer_state, + pipeline_load_mma_q, pipeline_load_mma_q_producer_state, pipeline_load_mma_do, pipeline_load_mma_do_producer_state, pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state, pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state @@ -1715,7 +1608,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); __syncwarp(); - + mma( blk_coord, problem_shape, @@ -1723,7 +1616,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { iter_count, params.mainloop, shared_storage.tensors, - pipeline_load_mma_q, pipeline_load_mma_q_consumer_state, + pipeline_load_mma_q, pipeline_load_mma_q_consumer_state, pipeline_load_mma_do, pipeline_load_mma_do_consumer_state, pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state, pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state, @@ -1736,10 +1629,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { } else if (role == WarpRole::Compute) { warpgroup_reg_set(); - + compute( blk_coord, - blk_offset, problem_shape, iter_start, iter_count, @@ -1768,7 +1660,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { } else if (role == WarpRole::Reduce) { warpgroup_reg_set(); - + reduce( blk_coord, problem_shape, @@ -1785,9 +1677,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { } else { warpgroup_reg_set(); - + /* no-op */ - + } } diff --git a/examples/77_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp index e297e731..fbb8d362 100644 --- a/examples/77_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp @@ -356,7 +356,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state(); CollectiveMainloop mainloop; - CollectiveEpilogue epilogue{params.epilogue}; + CollectiveEpilogue epilogue; if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) { warpgroup_reg_set(); @@ -407,8 +407,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { pipeline_s0_corr, pipeline_s0_corr_consumer_state, pipeline_s1_corr, pipeline_s1_corr_consumer_state, pipeline_mma_corr, pipeline_mma_corr_consumer_state, - pipeline_corr_epi, pipeline_corr_epi_producer_state, - epilogue + pipeline_corr_epi, pipeline_corr_epi_producer_state ); diff --git a/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_reduction.hpp b/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_reduction.hpp index 98f40ce8..c6a05750 100644 --- a/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_reduction.hpp +++ b/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_reduction.hpp @@ -146,7 +146,7 @@ struct Sm100FmhaMlaReductionKernel { ElementAcc sum_lse = 0; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kNLsePerThread; ++i) { - sum_lse = sum_lse + expf(local_lse[i] - lse_max); + sum_lse = sum_lse + expf(local_lse[i] - params.scale * lse_max); } CUTLASS_PRAGMA_UNROLL @@ -156,7 +156,7 @@ struct Sm100FmhaMlaReductionKernel { sum_lse = __shfl_sync(0xffffffff, sum_lse, 0); - ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits::infinity() : logf(sum_lse) + lse_max; + ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits::infinity() : logf(sum_lse) + params.scale * lse_max; if (threadIdx.x == 0 and params.ptr_lse != nullptr) { gLSE(0) = global_lse; } diff --git a/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp b/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp index 96a4965b..bb8cfb34 100644 --- a/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp +++ b/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp @@ -44,10 +44,10 @@ template< class Fusion > void __global__ fmha_bwd_reference_dQ_kernel( - ProblemShape problem_shape_in, - TensorQ mQ_in, TensorK mK_in, TensorV mV_in, - TensorO mO_in, TensorLSE mLSE_in, TensorDO mDO_in, - TensorDQ mDQ_in, /* TensorDK mDK, TensorDV mDV, */ + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + TensorDQ mDQ, /* TensorDK mDK, TensorDV mDV, */ Fusion fusion) { using namespace cute; @@ -58,28 +58,15 @@ void __global__ fmha_bwd_reference_dQ_kernel( extern __shared__ char mS_mem[]; Element* mS = reinterpret_cast(mS_mem); - Element softmax_scale = static_cast(1.0 / sqrt(1.0 * size<2>(problem_shape_in))); + Element softmax_scale = static_cast(1.0 / sqrt(1.0 * size<1>(mO))); - for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) { - auto [problem_shape, offset] = apply_variable_length_offset( - problem_shape_in, - make_coord(_0{}, _0{}, _0{}, idx2crd(idx_L, get<3>(problem_shape_in))) - ); - // problem_shape = problem_shape_in; - // offset = repeat_like(problem_shape_in, _0{}); - auto mQ = domain_offset(select<0,2,3>(offset), mQ_in); - auto mK = domain_offset(select<1,2,3>(offset), mK_in); - auto mV = domain_offset(select<1,2,3>(offset), mV_in); - auto mO = domain_offset(select<0,2,3>(offset), mO_in); - auto mLSE = domain_offset(select<0,3>(offset), mLSE_in); - auto mDO = domain_offset(select<0,2,3>(offset), mDO_in); - auto mDQ = domain_offset(select<0,2,3>(offset), mDQ_in); - for (int idx_Q = blockIdx.x; idx_Q < size<0>(problem_shape); idx_Q += gridDim.x) { - for (int idx_K = threadIdx.x; idx_K < size<1>(problem_shape); idx_K += blockDim.x) { + for (int idx_L = blockIdx.y; idx_L < size<2>(mDQ); idx_L += gridDim.y) { + for (int idx_Q = blockIdx.x; idx_Q < size<0>(mDQ); idx_Q += gridDim.x) { + for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) { ElementAccumulator acc_qk = 0; ElementAccumulator acc_dov = 0; ElementAccumulator acc_doo = 0; - for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) { + for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) { acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L); acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L); acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L); @@ -96,9 +83,9 @@ void __global__ fmha_bwd_reference_dQ_kernel( __syncthreads(); - for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) { + for (int idx_D = threadIdx.x; idx_D < size<1>(mDQ); idx_D += blockDim.x) { ElementAccumulator acc = 0; - for (int idx_K = 0; idx_K < size<1>(problem_shape); idx_K++) { + for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) { acc += mS[idx_K] * mK(idx_K, idx_D, idx_L); } mDQ(idx_Q, idx_D, idx_L) = static_cast(acc); @@ -117,10 +104,10 @@ template< class Fusion > void __global__ fmha_bwd_reference_dK_kernel( - ProblemShape problem_shape_in, - TensorQ mQ_in, TensorK mK_in, TensorV mV_in, - TensorO mO_in, TensorLSE mLSE_in, TensorDO mDO_in, - /* TensorDQ mDQ_in, */ TensorDK mDK_in, /* TensorDV mDV_in, */ + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /* TensorDQ mDQ, */ TensorDK mDK, /* TensorDV mDV, */ Fusion fusion) { using namespace cute; @@ -131,28 +118,15 @@ void __global__ fmha_bwd_reference_dK_kernel( extern __shared__ char mS_mem[]; Element* mS = reinterpret_cast(mS_mem); - Element softmax_scale = static_cast(1.0 / sqrt(1.0 * size<2>(problem_shape_in))); + Element softmax_scale = static_cast(1.0 / sqrt(1.0 * size<1>(mO))); - for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) { - auto [problem_shape, offset] = apply_variable_length_offset( - problem_shape_in, - make_coord(_0{}, _0{}, _0{}, idx2crd(idx_L, get<3>(problem_shape_in))) - ); - // problem_shape = problem_shape_in; - // offset = repeat_like(problem_shape_in, _0{}); - auto mQ = domain_offset(select<0,2,3>(offset), mQ_in); - auto mK = domain_offset(select<1,2,3>(offset), mK_in); - auto mV = domain_offset(select<1,2,3>(offset), mV_in); - auto mO = domain_offset(select<0,2,3>(offset), mO_in); - auto mLSE = domain_offset(select<0,3>(offset), mLSE_in); - auto mDO = domain_offset(select<0,2,3>(offset), mDO_in); - auto mDK = domain_offset(select<1,2,3>(offset), mDK_in); - for (int idx_K = blockIdx.x; idx_K < size<1>(problem_shape); idx_K += gridDim.x) { - for (int idx_Q = threadIdx.x; idx_Q < size<0>(problem_shape); idx_Q += blockDim.x) { + for (int idx_L = blockIdx.y; idx_L < size<2>(mDK); idx_L += gridDim.y) { + for (int idx_K = blockIdx.x; idx_K < size<0>(mDK); idx_K += gridDim.x) { + for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) { ElementAccumulator acc_qk = 0; ElementAccumulator acc_dov = 0; ElementAccumulator acc_doo = 0; - for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) { + for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) { acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L); acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L); acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L); @@ -169,9 +143,9 @@ void __global__ fmha_bwd_reference_dK_kernel( __syncthreads(); - for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) { + for (int idx_D = threadIdx.x; idx_D < size<1>(mDK); idx_D += blockDim.x) { ElementAccumulator acc = 0; - for (int idx_Q = 0; idx_Q < size<0>(problem_shape); idx_Q++) { + for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) { acc += mS[idx_Q] * mQ(idx_Q, idx_D, idx_L); } mDK(idx_K, idx_D, idx_L) = static_cast(acc); @@ -190,10 +164,10 @@ template< class Fusion > void __global__ fmha_bwd_reference_dV_kernel( - ProblemShape problem_shape_in, - TensorQ mQ_in, TensorK mK_in, TensorV mV_in, - TensorO mO_in, TensorLSE mLSE_in, TensorDO mDO_in, - /* TensorDQ mDQ_in, TensorDK mDK_in, */ TensorDV mDV_in, + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /* TensorDQ mDQ, TensorDK mDK, */ TensorDV mDV, Fusion fusion) { using namespace cute; @@ -204,27 +178,14 @@ void __global__ fmha_bwd_reference_dV_kernel( extern __shared__ char mS_mem[]; Element* mS = reinterpret_cast(mS_mem); - Element softmax_scale = static_cast(1.0 / sqrt(1.0 * size<2>(problem_shape_in))); + ElementAcc softmax_scale = static_cast(1.0 / sqrt(1.0 * size<1>(mO))); - for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) { - auto [problem_shape, offset] = apply_variable_length_offset( - problem_shape_in, - make_coord(_0{}, _0{}, _0{}, idx2crd(idx_L, get<3>(problem_shape_in))) - ); - // problem_shape = problem_shape_in; - // offset = repeat_like(problem_shape_in, _0{}); - auto mQ = domain_offset(select<0,2,3>(offset), mQ_in); - auto mK = domain_offset(select<1,2,3>(offset), mK_in); - auto mV = domain_offset(select<1,2,3>(offset), mV_in); - auto mO = domain_offset(select<0,2,3>(offset), mO_in); - auto mLSE = domain_offset(select<0,3>(offset), mLSE_in); - auto mDO = domain_offset(select<0,2,3>(offset), mDO_in); - auto mDV = domain_offset(select<1,2,3>(offset), mDV_in); - for (int idx_K = blockIdx.x; idx_K < size<1>(problem_shape); idx_K += gridDim.x) { - for (int idx_Q = threadIdx.x; idx_Q < size<0>(problem_shape); idx_Q += blockDim.x) { + for (int idx_L = blockIdx.y; idx_L < size<2>(mDV); idx_L += gridDim.y) { + for (int idx_K = blockIdx.x; idx_K < size<0>(mDV); idx_K += gridDim.x) { + for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) { ElementAcc acc_qk = 0; - for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) { + for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) { ElementAcc rQ = mQ(idx_Q, idx_D0, idx_L); ElementAcc rK = mK(idx_K, idx_D0, idx_L); acc_qk += rQ * rK; @@ -241,9 +202,9 @@ void __global__ fmha_bwd_reference_dV_kernel( __syncthreads(); - for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) { + for (int idx_D = threadIdx.x; idx_D < size<1>(mDV); idx_D += blockDim.x) { ElementAcc acc = 0; - for (int idx_Q = 0; idx_Q < size<0>(problem_shape); idx_Q++) { + for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) { ElementAcc rS = mS[idx_Q]; ElementAcc rDO = mDO(idx_Q, idx_D, idx_L); acc += rS * rDO; diff --git a/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp b/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp index 68718c6b..b7c6b412 100644 --- a/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp +++ b/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp @@ -127,7 +127,7 @@ void __global__ fmha_reference_kernel( mO(idx_Q + offset_Q, idx_D, idx_L) = static_cast(acc * scale); } - if (threadIdx.x == 0 && mLSE.data() != nullptr) { + if (threadIdx.x == 0) { mLSE(idx_Q + offset_Q, idx_L) = log(sum) + softmax_scale * maxS; } diff --git a/examples/77_blackwell_fmha/reference/reference_abs_error.hpp b/examples/77_blackwell_fmha/reference/reference_abs_error.hpp index a4d4b262..6d833ad1 100644 --- a/examples/77_blackwell_fmha/reference/reference_abs_error.hpp +++ b/examples/77_blackwell_fmha/reference/reference_abs_error.hpp @@ -75,8 +75,6 @@ struct DeviceAllocation { size_t size() const { return size_; } - size_t get_storage_size() const { return (size_ + offset_) * sizeof(T); } - void copy_from_host(const T* ptr, size_t sz) { auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault); assert(ret == cudaSuccess);