# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # 1. Redistributions of source code must retain the above copyright notice, this # list of conditions and the following disclaimer. # 2. Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # 3. Neither the name of the copyright holder nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import argparse import math from typing import Type, Tuple, Optional, Callable from types import SimpleNamespace from functools import partial import torch import torch.nn.functional as F import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute import cutlass.cute.testing as testing import cutlass.cute.nvgpu.tcgen05 as tcgen05 import cutlass.cute.nvgpu.cpasync as cpasync import cutlass.utils as utils import cutlass.pipeline as pipeline import cutlass.torch as cutlass_torch import cutlass.utils.blackwell_helpers as sm100_utils from cutlass.cute.runtime import from_dlpack """ A Multi-Head Latent Attention (MLA) example for the NVIDIA Blackwell SM100 architecture using CUTE DSL This example demonstrates an implementation of inference of multi-head latent attention using a TMA + Blackwell SM100 TensorCore warp-specialized persistent kernel. The implementation integrates the (Qc + Qr)*(Kc + Kr)^T matrix multiplication, softmax normalization, and softmax((Qc + Qr)*(Kc + Kr)^T)*Vc into a single kernel. The kernel provides support for page table storage and variable-length KV cache sequences. It implements KV splitting functionality to minimize latency when processing long KV sequences. The kernel implements key optimizations including: - Warp specialization for different computation phases (load, MMA, softmax, correction, epilogue) - Pipeline stages between different warps for overlapping computation and memory access - Support for different precision data types - Two sub-kernels (split KV kernel and reduction kernel) that enable split KV processing To run this example: .. code-block:: bash python examples/blackwell/mla.py \ --batch_size 4 --latent_dim 512 --rope_dim 64 \ --num_heads 128 --seq_len 1024 \ --in_dtype Float8E4M3FN --out_dtype Float16 \ --acc_dtype Float32 --lse_dtype Float32 \ --use_page_table --is_var_seq --is_var_split_kv \ --is_persistent The above example runs Multi-Head Latent Attention (MLA) with the following configuration: - Batch size: 4 - Sequence length: 1024 - Latent dimension: 512 - RoPE dimension: 64 - Number of heads: 128 - Data types: Float8E4M3FN (input), Float16 (output), Float32 (accumulation and LSE) It utilizes page table storage for the KV cache and enables both variable-length KV cache sequences and variable split KV processing with persistent scheduling. To collect performance with NCU profiler: .. code-block:: bash ncu python examples/blackwell/mla.py \ --batch_size 4 --latent_dim 512 --rope_dim 64 \ --num_heads 128 --seq_len 1024 \ --in_dtype Float8E4M3FN --out_dtype Float16 \ --acc_dtype Float32 --lse_dtype Float32 \ --use_page_table --is_var_seq --is_var_split_kv \ --is_persistent --warmup_iterations 3 \ --iterations 10 --skip_ref_check Constraints for this example: * Data type requirements: - Input/output: Float8E4M3FN or Float16 - Accumulation and LSE: Float32 * Fixed architecture parameters: - Number of attention heads: 128 - Latent dimension: 512 - RoPE dimension: 64 * Input query modes should be (NumHeads, LatentDim/RopeDim, BatchSize) * Input kv latent/rope modes should be (SeqLen, LatentDim/RopeDim, BatchSize) * Query sequence length must be 1 * Only supports 2-CTA instructions * Variable sequence length requires page table storage enabled """ class MLAStaticTileSchedulerParams: def __init__( self, is_persistent: bool, problem_shape_b: cute.Int32, cluster_shape_mnk: cute.Shape, split_kv: cutlass.Int32, *, loc=None, ip=None, ): """The static tile scheduler parameters prepared for MLA static tile scheduler. :param is_persistent: Whether to use persistent kernel mode :type is_persistent: bool :param problem_shape_b: The shape of the problem :type problem_shape_b: cute.Int32 :param cluster_shape_mnk: The shape of the cluster :type cluster_shape_mnk: cute.Shape :param split_kv: The scalar factor for split KV """ self.is_persistent = is_persistent self.problem_shape_b = problem_shape_b self.cluster_shape_mnk = cluster_shape_mnk self.split_kv = split_kv self.loc = loc self.ip = ip def __extract_mlir_values__(self): values = cutlass.extract_mlir_values(self.problem_shape_b) values += cutlass.extract_mlir_values(self.split_kv) return values def __new_from_mlir_values__(self, values): problem_shape_b = cutlass.new_from_mlir_values( self.problem_shape_b, (values[0],) ) split_kv = cutlass.new_from_mlir_values(self.split_kv, (values[1],)) return MLAStaticTileSchedulerParams( self.is_persistent, problem_shape_b, self.cluster_shape_mnk, split_kv, loc=self.loc, ) def create_mla_static_tile_scheduler_params( is_persistent: bool, problem_shape_b: cute.Int32, cluster_shape_mnk: cute.Shape, split_kv: cutlass.Int32, ) -> MLAStaticTileSchedulerParams: return MLAStaticTileSchedulerParams( is_persistent, problem_shape_b, cluster_shape_mnk, split_kv ) class MLAStaticTileScheduler: def __init__( self, params: MLAStaticTileSchedulerParams, current_work_linear_idx: cutlass.Int32, blk_coord: cute.Coord, grid_shape: cute.Shape, *, is_valid: bool = True, loc=None, ip=None, ): """The static tile scheduler for MLA split kv kernel. Based on `is_persistent`, it provides 2 modes for use: - Persistent mode: Launch fixed blocks and reschedule the data blocks. - Non-persistent mode: Launch dynamic blocks and exit when the current work is done. :param params: The static tile scheduler parameters :type params: MLAStaticTileSchedulerParams :param current_work_linear_idx: The linear index of the current work :type current_work_linear_idx: cutlass.Int32 :param blk_coord: The coordinate of the current work :type blk_coord: cute.Coord :param grid_shape: The shape of the grid :type grid_shape: cute.Shape :param is_valid: Whether the current work is valid :type is_valid: bool """ self.params = params self.blk_coord = blk_coord self.grid_shape = grid_shape self.current_work_linear_idx = current_work_linear_idx if params.is_persistent: self.persistent_blk_layout = cute.make_layout( ( params.cluster_shape_mnk[0], 1, params.problem_shape_b, params.split_kv, ), loc=loc, ip=ip, ) self.num_blocks = cute.size(self.persistent_blk_layout, loc=loc, ip=ip) # Used for persistent scheduling self.num_persistent_sm = cute.size(grid_shape, loc=loc, ip=ip) else: self.is_valid = is_valid self.loc = loc self.ip = ip @staticmethod def get_grid_shape( params: MLAStaticTileSchedulerParams, max_active_clusters: int, *, loc=None, ip=None, ) -> cute.Shape: # called by host grid_shape = ( params.cluster_shape_mnk[0], params.problem_shape_b, params.split_kv, ) if params.is_persistent: return ( cutlass.min( max_active_clusters * cute.size(params.cluster_shape_mnk), cute.size(grid_shape, loc=loc, ip=ip), ), 1, 1, ) else: return grid_shape def get_current_work(self, *, loc=None, ip=None) -> utils.WorkTileInfo: is_valid = ( self.current_work_linear_idx < self.num_blocks if self.params.is_persistent else self.is_valid ) if self.params.is_persistent: blk_coord = self.persistent_blk_layout.get_hier_coord( self.current_work_linear_idx, loc=loc, ip=ip ) else: blk_coord = (self.blk_coord[0], 0, self.blk_coord[1], self.blk_coord[2]) return utils.WorkTileInfo(blk_coord, is_valid) def initial_work_tile_info(self, *, loc=None, ip=None): return self.get_current_work(loc=loc, ip=ip) def advance_to_next_work(self, *, advance_count=1, loc=None, ip=None): if self.params.is_persistent: self.current_work_linear_idx += advance_count * self.num_persistent_sm else: self.is_valid = False def __extract_mlir_values__(self): values = cutlass.extract_mlir_values(self.params) values.extend(cutlass.extract_mlir_values(self.current_work_linear_idx)) values.extend(cutlass.extract_mlir_values(self.blk_coord)) values.extend(cutlass.extract_mlir_values(self.grid_shape)) return values def __new_from_mlir_values__(self, values): assert len(values) == 9 new_params = cutlass.new_from_mlir_values(self.params, values[0:2]) new_current_work_linear_idx = cutlass.new_from_mlir_values( self.current_work_linear_idx, [values[2]] ) new_blk_coord = cutlass.new_from_mlir_values(self.blk_coord, values[3:6]) new_grid_shape = cutlass.new_from_mlir_values(self.grid_shape, values[6:]) return MLAStaticTileScheduler( new_params, new_current_work_linear_idx, new_blk_coord, new_grid_shape ) def create_mla_static_tile_scheduler( params: MLAStaticTileSchedulerParams, blk_coord: cute.Coord, grid_shape: cute.Shape, ) -> MLAStaticTileScheduler: return MLAStaticTileScheduler(params, blk_coord[0], blk_coord, grid_shape) LOG2_E = 1.4426950408889634074 # avoid register indexing on array. MAX_SPLITS = 256 class BlackwellMultiHeadLatentAttentionForward: def __init__( self, acc_dtype: Type[cutlass.Numeric], lse_dtype: Type[cutlass.Numeric], mma_qk_tiler_mn: Tuple[int, int], mma_pv_tiler_mn: Tuple[int, int], max_active_clusters: int, is_persistent: bool, is_cpasync: bool, use_page_table: bool, is_var_seq: bool, is_var_split_kv: bool, ): """Initializes the configuration for a Blackwell Multi-Head Latent Attention (MLA) kernel. :param acc_dtype: Data type for accumulation S and O :type acc_dtype: Type[cutlass.Numeric] :param lse_dtype: Data type for output LSE :type lse_dtype: Type[cutlass.Numeric] :param mma_s_tiler: The (H, K) tile shape of the MMA instruction for S :type mma_s_tiler: Tuple[int, int] :param mma_p_tiler: The (H, D) tile shape of the MMA instruction for P :type mma_p_tiler: Tuple[int, int] :param max_active_clusters: Maximum number of active clusters :type max_active_clusters: int :param is_persistent: Whether to use persistent kernel mode :type is_persistent: bool :param is_cpasync: Whether to use CP async mode :type is_cpasync: bool :param use_page_table: Whether to use page table :type use_page_table: bool :param is_var_seq: Whether to use variable sequence length :type is_var_seq: bool :param is_var_split_kv: Whether to use variable split KV :type is_var_split_kv: bool """ self.latent_dim = 512 self.rope_dim = 64 self.acc_dtype = acc_dtype self.lse_dtype = lse_dtype self.mma_qk_tiler_mn = mma_qk_tiler_mn self.mma_pv_tiler_mn = mma_pv_tiler_mn self.max_active_clusters = max_active_clusters self.is_persistent = is_persistent self.is_cpasync = is_cpasync self.use_page_table = use_page_table self.is_var_seq = is_var_seq self.is_var_split_kv = is_var_split_kv self.cluster_shape_mnk = (2, 1, 1) self.use_2cta_instrs = True # When using 2 CTAs with m=128: warps 0-1 handle accumulation for first half [0, n/2), # while warps 2-3 handle accumulation for second half [n/2, n) self.warps_in_n = 2 self.num_compute_warps = 4 self.threads_per_warp = 32 self.num_load_warps = 2 if self.is_cpasync else 1 mma_qk_tiler_k = self.rope_dim if self.is_cpasync else self.rope_dim * 2 self.mma_qk_tiler = ( self.mma_qk_tiler_mn[0], self.mma_qk_tiler_mn[1], mma_qk_tiler_k, ) self.mma_pv_tiler = ( self.mma_pv_tiler_mn[0], self.mma_pv_tiler_mn[1], self.mma_qk_tiler[1] * self.mma_qk_tiler[2] // self.mma_pv_tiler_mn[1], ) self.iterations_qk_latent = self.latent_dim // self.mma_qk_tiler[2] self.iterations_qk_rope = mma_qk_tiler_k // self.mma_qk_tiler[2] self.iterations_qk = self.iterations_qk_latent + self.iterations_qk_rope self.iterations_pv_k = self.mma_qk_tiler[1] // self.mma_pv_tiler[2] self.iterations_pv_n = self.latent_dim // self.mma_pv_tiler[1] # Set specialized warp ids self.compute_warp_ids = (0, 1, 2, 3) self.correction_warp_ids = (4, 5, 6, 7) self.mma_warp_id = 8 if self.is_cpasync: self.load_cp_async_warp_ids = (9, 10) self.load_pt_warp_id = 11 self.threads_per_cta = self.threads_per_warp * len( ( self.mma_warp_id, *self.load_cp_async_warp_ids, self.load_pt_warp_id, *self.compute_warp_ids, *self.correction_warp_ids, ) ) else: self.load_tma_warp_id = 9 self.empty_warp_ids = (10, 11) self.threads_per_cta = self.threads_per_warp * len( ( self.mma_warp_id, self.load_tma_warp_id, *self.compute_warp_ids, *self.correction_warp_ids, *self.empty_warp_ids, ) ) # register settings self.softmax_reg_num = 192 self.correction_reg_num = 192 self.other_reg_num = 112 # Named barriers self.tmem_ptr_sync_bar = pipeline.NamedBarrier( barrier_id=1, num_threads=( self.threads_per_warp + self.threads_per_warp * self.num_compute_warps * 2 ), ) self.softmax_exchange_sync_bar = pipeline.NamedBarrier( barrier_id=2, num_threads=(self.threads_per_warp * self.num_compute_warps) ) self.epilogue_exchange_sync_bar = pipeline.NamedBarrier( barrier_id=3, num_threads=(self.threads_per_warp * self.num_compute_warps) ) def _setup_attributes(self): """Set up configurations and parameters for the MLA kernel operation. This method initializes and configures various attributes required for the execution of the multi-head latent attention kernel, mainly about the pipeline stages: - Sets up staging parameters for Q, K, V inputs and accumulator data - Configures pipeline stages for softmax, correction, and epilogue operations """ self.load_q_stage = self.iterations_qk self.load_kv_stage = (24 if self.is_cpasync else 12) // ( self.k_dtype.width // 8 ) self.mma_s_stage = 2 self.p_mma_stage = 2 self.p_cor_stage = 2 self.mma_o_stage = 1 self.load_pt_stage = self.load_kv_stage if self.is_cpasync else 1 self.tmem_o_offset = self.mma_s_stage * self.mma_qk_tiler[1] // self.warps_in_n self.correction_factor_offset = ( self.tmem_o_offset + self.latent_dim // self.warps_in_n ) @cute.jit def __call__( self, q_latent: cute.Tensor, q_rope: cute.Tensor, c_latent: cute.Tensor, c_rope: cute.Tensor, page_table: cute.Tensor, o: cute.Tensor, lse: cute.Tensor, workspace: cute.Tensor, split_kv: cutlass.Int32, cache_seqs: Optional[cute.Tensor], block_split_kvs: Optional[cute.Tensor], softmax_scale: cutlass.Float32, output_scale: cutlass.Float32, stream: cuda.CUstream, ): """Execute the Multi-Head Latent Attention operation on the provided tensors. The method handles: 1. Initialization of workspace for temporary split KV buffers 2. Validation of tensor data types 3. Initialization of hardware-specific parameters and memory layouts 4. Configuration of TMA (Tensor Memory Access) operations 5. Grid and work scheduling computation 6. Kernel launch(split KV kernel and reduction kernel) with appropriate parameters :param q_latent: The query tensor with shape [num_head, latent_dim, batch_size] :type q_latent: cute.Tensor :param q_rope: The query RoPE tensor with shape [num_head, rope_dim, batch_size] :type q_rope: cute.Tensor :param c_latent: The key tensor with shape [seq_len, latent_dim, batch_size] :type c_latent: cute.Tensor :param c_rope: The key RoPE tensor with shape [seq_len, rope_dim, batch_size] :type c_rope: cute.Tensor :param page_table: The page table tensor with shape [page_count, batch_size] :type page_table: cute.Tensor :param o: The output tensor with shape [num_head, latent_dim, batch_size] :type o: cute.Tensor :param lse: The LSE tensor with shape [num_head, batch_size] :type lse: cute.Tensor :param workspace: The workspace tensor with 1-d shape prepared for acc_o and acc_lse :type workspace: cute.Tensor :param split_kv: The scalar factor for split KV :type split_kv: cutlass.Int32 :param cache_seqs: The cache sequences tensor with shape [batch_size] :type cache_seqs: cute.Tensor :param block_split_kvs: The block split KV tensor with shape [batch_size] :type block_split_kvs: cute.Tensor :param softmax_scale: The scale factor for softmax :type softmax_scale: cutlass.Float32 :param output_scale: The scale factor for the output :type output_scale: cutlass.Float32 :param stream: The CUDA stream to execute the kernel on :type stream: cuda.CUstream :raises TypeError: If tensor data types don't match or aren't supported """ # setup static attributes before smem/grid/tma computation self.q_dtype = q_latent.element_type self.k_dtype = c_latent.element_type self.v_dtype = c_latent.element_type self.o_dtype = o.element_type # check type consistency if cutlass.const_expr( self.q_dtype != self.k_dtype or self.q_dtype != self.v_dtype ): raise TypeError( f"Type mismatch: {self.q_dtype} != {self.k_dtype} or {self.q_dtype} != {self.v_dtype}" ) # check leading dimensions of input/output if cutlass.const_expr(q_latent.stride[1] != 1 or q_rope.stride[1] != 1): raise ValueError("q_latent and q_rope must have leading dimension 1") if cutlass.const_expr(c_latent.stride[1] != 1 or c_rope.stride[1] != 1): raise ValueError("c_latent and c_rope must have leading dimension 1") if cutlass.const_expr(o.stride[1] != 1): raise ValueError("o must have leading dimension 1") if cutlass.const_expr(lse.stride[0] != 1): raise ValueError("lse must have leading dimension 0") acc_o, acc_lse = self.initialize_workspace( q_latent.shape[0], q_latent.shape[1], q_latent.shape[2], split_kv, self.acc_dtype, workspace, ) c_latent_tranpose_layout = cute.select(c_latent.layout, mode=[1, 0, 2]) c_latent_transpose = cute.make_tensor( c_latent.iterator, c_latent_tranpose_layout ) self.q_major_mode = tcgen05.OperandMajorMode.K self.k_major_mode = tcgen05.OperandMajorMode.K self.v_major_mode = tcgen05.OperandMajorMode.MN self._setup_attributes() cta_group = tcgen05.CtaGroup.TWO # the intermediate tensor p is from smem & k-major p_major_mode = tcgen05.OperandMajorMode.K qk_tiled_mma = sm100_utils.make_trivial_tiled_mma( self.q_dtype, self.q_major_mode, self.k_major_mode, self.acc_dtype, cta_group, self.mma_qk_tiler[:2], ) pv_tiled_mma = sm100_utils.make_trivial_tiled_mma( self.v_dtype, p_major_mode, self.v_major_mode, self.acc_dtype, cta_group, self.mma_pv_tiler[:2], ) cta_layout_vmnk = cute.tiled_divide( cute.make_layout(self.cluster_shape_mnk), (qk_tiled_mma.thr_id.shape,), ) self.epi_tile = self.mma_pv_tiler[:2] q_smem_layout_staged = sm100_utils.make_smem_layout_a( qk_tiled_mma, self.mma_qk_tiler, self.q_dtype, self.load_q_stage, ) kc_smem_layout_staged = sm100_utils.make_smem_layout_b( qk_tiled_mma, self.mma_qk_tiler, self.k_dtype, self.load_kv_stage, ) p_smem_layout_staged = sm100_utils.make_smem_layout_a( pv_tiled_mma, self.mma_pv_tiler, self.q_dtype, (self.iterations_pv_k * self.p_mma_stage), ) p_smem_layout_staged = cute.logical_divide( p_smem_layout_staged, (None, None, None, self.iterations_pv_k) ) vc_smem_layout_staged = sm100_utils.make_smem_layout_b( pv_tiled_mma, self.mma_pv_tiler, self.v_dtype, self.load_kv_stage, ) if cutlass.const_expr(not self.is_cpasync): # TMA load for Q latent and rope tma_load_op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp(cta_group) q_smem_layout = cute.select(q_smem_layout_staged, mode=[0, 1, 2]) tma_atom_q_latent, tma_tensor_q_latent = cute.nvgpu.make_tiled_tma_atom_A( tma_load_op, q_latent, q_smem_layout, self.mma_qk_tiler, qk_tiled_mma, cta_layout_vmnk.shape, ) tma_atom_q_rope, tma_tensor_q_rope = cute.nvgpu.make_tiled_tma_atom_A( tma_load_op, q_rope, q_smem_layout, self.mma_qk_tiler, qk_tiled_mma, cta_layout_vmnk.shape, ) # TMA load for c latent and k rope kc_smem_layout = cute.select(kc_smem_layout_staged, mode=[0, 1, 2]) tma_atom_c_latent, tma_tensor_c_latent = cute.nvgpu.make_tiled_tma_atom_B( tma_load_op, c_latent, kc_smem_layout, self.mma_qk_tiler, qk_tiled_mma, cta_layout_vmnk.shape, ) tma_atom_c_rope, tma_tensor_c_rope = cute.nvgpu.make_tiled_tma_atom_B( tma_load_op, c_rope, kc_smem_layout, self.mma_qk_tiler, qk_tiled_mma, cta_layout_vmnk.shape, ) # TMA load for c latent transpose vc_smem_layout = cute.select(vc_smem_layout_staged, mode=[0, 1, 2]) tma_atom_c_latent_transpose, tma_tensor_c_latent_transpose = ( cute.nvgpu.make_tiled_tma_atom_B( tma_load_op, c_latent_transpose, vc_smem_layout, self.mma_pv_tiler, pv_tiled_mma, cta_layout_vmnk.shape, ) ) q_copy_size = cute.size_in_bytes(self.q_dtype, q_smem_layout) * cute.size( qk_tiled_mma.thr_id.shape ) kc_copy_size = cute.size_in_bytes(self.k_dtype, kc_smem_layout) * cute.size( qk_tiled_mma.thr_id.shape ) vc_copy_size = cute.size_in_bytes(self.v_dtype, vc_smem_layout) * cute.size( pv_tiled_mma.thr_id.shape ) assert kc_copy_size == vc_copy_size, ( "kc_copy_size and vc_copy_size must be the same" ) self.tma_copy_q_bytes = q_copy_size self.tma_copy_kc_bytes = kc_copy_size else: self.tma_copy_q_bytes = 0 self.tma_copy_kc_bytes = 0 tile_sched_params, grid = self._compute_grid( o, split_kv, self.cluster_shape_mnk, self.max_active_clusters, self.is_persistent, ) @cute.struct class SplitKVKernelSharedStorage: # Pipeline barriers load_q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_q_stage * 2] load_kv_mbar_ptr: cute.struct.MemRange[ cutlass.Int64, self.load_kv_stage * 2 ] mma_s_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.mma_s_stage * 2] p_mma_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.p_mma_stage * 2] p_cor_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.p_cor_stage * 2] mma_o_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.mma_o_stage * 2] load_pt_mbar_ptr: cute.struct.MemRange[ cutlass.Int64, self.load_pt_stage * 2 ] # Smem tensors softmax_smem_exchange: cute.struct.MemRange[ self.acc_dtype, self.num_compute_warps * self.threads_per_warp ] epilogue_smem_exchange: cute.struct.MemRange[ self.acc_dtype, self.num_compute_warps * self.threads_per_warp ] smem_page_table: cute.struct.MemRange[ cutlass.Int32, self.load_pt_stage * self.mma_qk_tiler[1] ] smem_q: cute.struct.Align[ cute.struct.MemRange[self.q_dtype, cute.cosize(q_smem_layout_staged)], 1024, ] smem_kc: cute.struct.Align[ cute.struct.MemRange[self.k_dtype, cute.cosize(kc_smem_layout_staged)], 1024, ] smem_p: cute.struct.Align[ cute.struct.MemRange[self.q_dtype, cute.cosize(p_smem_layout_staged)], 1024, ] # Tmem dealloc cluster barrier tmem_dealloc_mbar_ptr: cutlass.Int64 # Tmem holding buffer tmem_holding_buf: cutlass.Int32 softmax_scale_log2 = softmax_scale * LOG2_E # Launch the kernel synchronously if cutlass.const_expr(self.is_cpasync): self.split_kv_kernel( qk_tiled_mma, pv_tiled_mma, None, q_latent, None, q_rope, None, c_latent, None, c_rope, None, c_latent_transpose, page_table, o, lse, acc_o, acc_lse, split_kv, cache_seqs, block_split_kvs, softmax_scale_log2, output_scale, q_smem_layout_staged, kc_smem_layout_staged, p_smem_layout_staged, vc_smem_layout_staged, cta_layout_vmnk, tile_sched_params, SplitKVKernelSharedStorage, ).launch( grid=grid, block=[self.threads_per_cta, 1, 1], cluster=self.cluster_shape_mnk, smem=SplitKVKernelSharedStorage.size_in_bytes(), stream=stream, min_blocks_per_mp=1, ) else: self.split_kv_kernel( qk_tiled_mma, pv_tiled_mma, tma_atom_q_latent, tma_tensor_q_latent, tma_atom_q_rope, tma_tensor_q_rope, tma_atom_c_latent, tma_tensor_c_latent, tma_atom_c_rope, tma_tensor_c_rope, tma_atom_c_latent_transpose, tma_tensor_c_latent_transpose, page_table, o, lse, acc_o, acc_lse, split_kv, cache_seqs, block_split_kvs, softmax_scale_log2, output_scale, q_smem_layout_staged, kc_smem_layout_staged, p_smem_layout_staged, vc_smem_layout_staged, cta_layout_vmnk, tile_sched_params, SplitKVKernelSharedStorage, ).launch( grid=grid, block=[self.threads_per_cta, 1, 1], cluster=self.cluster_shape_mnk, smem=SplitKVKernelSharedStorage.size_in_bytes(), stream=stream, min_blocks_per_mp=1, ) if cutlass.const_expr(acc_o is not None): self.reduction_kernel( o, lse, acc_o, acc_lse, split_kv, cache_seqs, block_split_kvs, ).launch( grid=(q_latent.shape[0], 1, q_latent.shape[2]), block=[self.threads_per_warp * self.num_compute_warps, 1, 1], smem=MAX_SPLITS * self.acc_dtype.width // 8, stream=stream, min_blocks_per_mp=1, ) @cute.kernel def split_kv_kernel( self, tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, tma_atom_q_latent: Optional[cute.CopyAtom], mQL: cute.Tensor, tma_atom_q_rope: Optional[cute.CopyAtom], mQR: cute.Tensor, tma_atom_c_latent: Optional[cute.CopyAtom], mCL: cute.Tensor, tma_atom_c_rope: Optional[cute.CopyAtom], mKR: cute.Tensor, tma_atom_c_latent_transpose: Optional[cute.CopyAtom], mCLT: cute.Tensor, mPT: cute.Tensor, mO: Optional[cute.Tensor], mLSE: Optional[cute.Tensor], mAccO: Optional[cute.Tensor], mAccLSE: Optional[cute.Tensor], split_kv: cutlass.Int32, cache_seqs: cute.Tensor, block_split_kvs: cute.Tensor, softmax_scale_log2: cutlass.Float32, output_scale: cutlass.Float32, q_smem_layout_staged: cute.ComposedLayout, kc_smem_layout_staged: cute.ComposedLayout, p_smem_layout_staged: cute.ComposedLayout, vc_smem_layout_staged: cute.ComposedLayout, cta_layout_vmnk: cute.Layout, tile_sched_params: MLAStaticTileSchedulerParams, SharedStorage: cutlass.Constexpr, ): """The device split_kv kernel implementation of the Multi-Head Latent Attention. This kernel coordinates multiple specialized warps to perform different phases of the MLA computation: 1. Load warp: Loads Q/C latent/rope data from global memory to shared memory using TMA 2. MMA warp: Performs matrix multiplications (Q*K^T and P*V) 3. Compute warps: Compute softmax and do rescaling on accumulators, and store the intermediate/final results to global memory The kernel produces either intermediate or final results of the MLA computation based on the split_kv parameter. When split_kv is 1, the kernel generates the final results directly. Otherwise, it produces intermediate results that will later be combined by a reduction kernel. The kernel implements a complex pipeline with overlapping computation and memory operations, using tensor memory access (TMA) for efficient data loading, warp specialization for different computation phases. :param tiled_mma_qk: Tiled MMA for Q*K^T :type tiled_mma_qk: cute.TiledMma :param tiled_mma_pv: Tiled MMA for P*V :type tiled_mma_pv: cute.TiledMma :param tma_atom_q_latent: TMA copy atom for query latent tensor :type tma_atom_q_latent: cute.CopyAtom :param mQL: query latent tensor :type mQL: cute.Tensor :param tma_atom_q_rope: TMA copy atom for query rope tensor :type tma_atom_q_rope: cute.CopyAtom :param mKR: Compressed rope tensor :type mKR: cute.Tensor :param tma_atom_c_latent: TMA copy atom for c latent tensor :type tma_atom_c_latent: cute.CopyAtom :param mCL: Compressed latent tensor :type mCL: cute.Tensor :param tma_atom_c_rope: TMA copy atom for c rope tensor :type tma_atom_c_rope: cute.CopyAtom :param mCLT: Compressed latent transpose tensor :type mCLT: cute.Tensor :param mPT: Page table tensor :type mPT: cute.Tensor :param mO: Output tensor :type mO: cute.Tensor :param mLSE: Log-sum-exp tensor :type mLSE: cute.Tensor :param mAccO: Intermediate accumulator output tensor :type mAccO: cute.Tensor :param mAccLSE: Intermediate accumulator log-sum-exp tensor :type mAccLSE: cute.Tensor :param split_kv: The split_kv parameter :type split_kv: cutlass.Int32 :param cache_seqs: The variable sequence length tensor :type cache_seqs: cute.Tensor :param block_split_kvs: The per-block split_kv values tensor :type block_split_kvs: cute.Tensor :param softmax_scale_log2: The log2 scale factor for softmax :type softmax_scale_log2: cutlass.Float32 :param output_scale: The scale factor for the output :type output_scale: cutlass.Float32 :param q_smem_layout_staged: Shared memory layout for query tensor :type q_smem_layout_staged: cute.ComposedLayout :param kc_smem_layout_staged: Shared memory layout for key tensor :type kc_smem_layout_staged: cute.ComposedLayout :param p_smem_layout_staged: Shared memory layout for probability matrix :type p_smem_layout_staged: cute.ComposedLayout :param vc_smem_layout_staged: Shared memory layout for value tensor :type vc_smem_layout_staged: cute.ComposedLayout :param cta_layout_vmnk: Layout for compute threads :type cta_layout_vmnk: cute.Layout :param tile_sched_params: Scheduling parameters for work distribution :type tile_sched_params: MLAStaticTileSchedulerParams :param SharedStorage: Shared storage for the kernel :type SharedStorage: cutlass.Constexpr """ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) tidx, _, _ = cute.arch.thread_idx() bidx, _, _ = cute.arch.block_idx() mma_tile_coord_v = bidx % cute.size(tiled_mma_qk.thr_id.shape) is_leader_cta = mma_tile_coord_v == 0 # Coords inside cluster cta_rank_in_cluster = cute.arch.make_warp_uniform( cute.arch.block_idx_in_cluster() ) # Prefetch tma descriptor if cutlass.const_expr(not self.is_cpasync): if warp_idx == self.mma_warp_id: cpasync.prefetch_descriptor(tma_atom_q_latent) cpasync.prefetch_descriptor(tma_atom_q_rope) cpasync.prefetch_descriptor(tma_atom_c_latent) cpasync.prefetch_descriptor(tma_atom_c_rope) cpasync.prefetch_descriptor(tma_atom_c_latent_transpose) # Alloc smem = utils.SmemAllocator() storage = smem.allocate(SharedStorage) tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr tmem_holding_buf = storage.tmem_holding_buf # Tensor memory dealloc barrier init if warp_idx == self.mma_warp_id: num_tmem_dealloc_threads = self.threads_per_warp * self.num_compute_warps with cute.arch.elect_one(): cute.arch.mbarrier_init(tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads) cute.arch.mbarrier_init_fence() load_q_pipeline = self.make_and_init_load_qkv_pipeline( storage.load_q_mbar_ptr.data_ptr(), cta_layout_vmnk, self.load_q_stage, self.tma_copy_q_bytes, self.is_cpasync, ) load_kv_pipeline = self.make_and_init_load_qkv_pipeline( storage.load_kv_mbar_ptr.data_ptr(), cta_layout_vmnk, self.load_kv_stage, self.tma_copy_kc_bytes, self.is_cpasync, ) mma_s_pipeline = self.make_and_init_mma_s_pipeline( storage.mma_s_mbar_ptr.data_ptr(), cta_layout_vmnk ) p_mma_pipeline = self.make_and_init_p_mma_pipeline( storage.p_mma_mbar_ptr.data_ptr(), cta_layout_vmnk ) p_cor_pipeline = self.make_and_init_p_cor_pipeline( storage.p_cor_mbar_ptr.data_ptr() ) mma_o_pipeline = self.make_and_init_mma_o_pipeline( storage.mma_o_mbar_ptr.data_ptr(), cta_layout_vmnk ) if cutlass.const_expr(self.is_cpasync): load_pt_pipeline = self.make_and_init_load_pt_pipeline( storage.load_pt_mbar_ptr.data_ptr() ) # Cluster arrive after barrier init if cutlass.const_expr(cute.size(self.cluster_shape_mnk) > 1): cute.arch.cluster_arrive_relaxed() # Generate smem tensor Q/KC/VC/exchange # (MMA, MMA_H, MMA_R, PIPE) sQ = storage.smem_q.get_tensor( q_smem_layout_staged.outer, swizzle=q_smem_layout_staged.inner ) # (MMA, MMA_K, MMA_R, PIPE) sKC = storage.smem_kc.get_tensor( kc_smem_layout_staged.outer, swizzle=kc_smem_layout_staged.inner ) # (MMA, MMA_D, MMA_K, PIPE) # reuse smem sVC_ptr = cute.recast_ptr(sKC.iterator, vc_smem_layout_staged.inner) sVC = cute.make_tensor(sVC_ptr, vc_smem_layout_staged.outer) # (MMA, MMA_H, MMA_K) sP = storage.smem_p.get_tensor( p_smem_layout_staged.outer, swizzle=p_smem_layout_staged.inner ) # (compute_threads,) softmax_smem_exchange = storage.softmax_smem_exchange.get_tensor( cute.make_layout(self.num_compute_warps * self.threads_per_warp) ) epilogue_smem_exchange = storage.epilogue_smem_exchange.get_tensor( cute.make_layout(self.num_compute_warps * self.threads_per_warp) ) # # Cluster wait before tensor memory alloc # if cutlass.const_expr(cute.size(self.cluster_shape_mnk) > 1): cute.arch.cluster_wait() else: pipeline.sync(barrier_id=4) # /////////////////////////////////////////////////////////////////////////////// # Load warps, including page table and data tensors # /////////////////////////////////////////////////////////////////////////////// if cutlass.const_expr(self.is_cpasync): sPT = storage.smem_page_table.get_tensor( cute.make_layout((self.mma_qk_tiler[1], self.load_pt_stage)) ) # Load page table when isasync is true if warp_idx == self.load_pt_warp_id: cute.arch.warpgroup_reg_dealloc(self.other_reg_num) load_pt_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.load_pt_stage ) tile_sched = create_mla_static_tile_scheduler( tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() ) work_tile = tile_sched.initial_work_tile_info() while work_tile.is_valid_tile: blk_coord = work_tile.tile_idx k_index, k_tile_count, local_split_kv = self.get_k_tile_count( split_kv, cache_seqs, block_split_kvs, blk_coord, ) if k_tile_count > 0: load_pt_common_params = SimpleNamespace( blk_coord=blk_coord, load_pt_pipeline=load_pt_pipeline, mPT=mPT, sPT=sPT, tidx=tidx, page_size=mCL.shape[0], ) load_pt_producer_state = self.load_page_table( load_pt_common_params, k_index, k_tile_count, load_pt_producer_state, ) tile_sched.advance_to_next_work() work_tile = tile_sched.get_current_work() load_pt_pipeline.producer_tail(load_pt_producer_state) if ( warp_idx == self.load_cp_async_warp_ids[0] or warp_idx == self.load_cp_async_warp_ids[1] ): cute.arch.warpgroup_reg_dealloc(self.other_reg_num) load_pt_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.load_pt_stage ) load_pt_release_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.load_pt_stage ) load_q_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.load_q_stage ) load_kv_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.load_kv_stage ) load_kv_commit_state = load_kv_producer_state.clone() tile_sched = create_mla_static_tile_scheduler( tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() ) work_tile = tile_sched.initial_work_tile_info() while work_tile.is_valid_tile: blk_coord = work_tile.tile_idx k_index, k_tile_count, local_split_kv = self.get_k_tile_count( split_kv, cache_seqs, block_split_kvs, blk_coord, ) if k_tile_count > 0: load_cpasync_common_params = SimpleNamespace( blk_coord=blk_coord, load_pt_pipeline=load_pt_pipeline, load_q_pipeline=load_q_pipeline, load_kv_pipeline=load_kv_pipeline, sPT=sPT, tidx=tidx, page_size=mCL.shape[0], ) load_cpasync_qk_params = SimpleNamespace( tiled_mma_qk=tiled_mma_qk, mQL=mQL, mQR=mQR, mCL=mCL, mKR=mKR, sQ=sQ, sKC=sKC, ) load_cpasync_v_params = SimpleNamespace( tiled_mma_pv=tiled_mma_pv, mCLT=mCLT, sVC=sVC, ) ( load_pt_consumer_state, load_pt_release_state, load_q_producer_state, load_kv_producer_state, load_kv_commit_state, ) = self.load_cpasync( load_cpasync_common_params, load_cpasync_qk_params, load_cpasync_v_params, k_index, k_tile_count, load_pt_consumer_state, load_pt_release_state, load_q_producer_state, load_kv_producer_state, load_kv_commit_state, ) tile_sched.advance_to_next_work() work_tile = tile_sched.get_current_work() load_q_pipeline.producer_tail(load_q_producer_state) load_kv_pipeline.producer_tail(load_kv_producer_state) else: if ( warp_idx >= self.empty_warp_ids[0] and warp_idx <= self.empty_warp_ids[-1] ): cute.arch.warpgroup_reg_dealloc(self.other_reg_num) if warp_idx == self.load_tma_warp_id: cute.arch.warpgroup_reg_dealloc(self.other_reg_num) load_q_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.load_q_stage ) load_kv_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.load_kv_stage ) tile_sched = create_mla_static_tile_scheduler( tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() ) work_tile = tile_sched.initial_work_tile_info() while work_tile.is_valid_tile: blk_coord = work_tile.tile_idx k_index, k_tile_count, local_split_kv = self.get_k_tile_count( split_kv, cache_seqs, block_split_kvs, blk_coord, ) if k_tile_count > 0: # Construct fixed common/tma_qk/tma_pv params for load_tma tma_common_params = SimpleNamespace( blk_coord=blk_coord, local_split_kv=local_split_kv, load_q_pipeline=load_q_pipeline, load_kv_pipeline=load_kv_pipeline, mPT=mPT, ) tma_qk_params = SimpleNamespace( tiled_mma_qk=tiled_mma_qk, tma_atom_q_latent=tma_atom_q_latent, tma_atom_q_rope=tma_atom_q_rope, tma_atom_c_latent=tma_atom_c_latent, tma_atom_c_rope=tma_atom_c_rope, mQL=mQL, mQR=mQR, mCL=mCL, mKR=mKR, sQ=sQ, sKC=sKC, ) tma_pv_params = SimpleNamespace( tiled_mma_pv=tiled_mma_pv, tma_atom_c_latent_transpose=tma_atom_c_latent_transpose, mCL=mCL, mKR=mKR, mCLT=mCLT, sVC=sVC, ) # Load tma load_q_producer_state, load_kv_producer_state = self.load_tma( tma_common_params, tma_qk_params, tma_pv_params, k_index, k_tile_count, load_q_producer_state, load_kv_producer_state, ) tile_sched.advance_to_next_work() work_tile = tile_sched.get_current_work() load_q_pipeline.producer_tail(load_q_producer_state) load_kv_pipeline.producer_tail(load_kv_producer_state) # /////////////////////////////////////////////////////////////////////////////// # MMA warp # /////////////////////////////////////////////////////////////////////////////// if warp_idx == self.mma_warp_id: cute.arch.warpgroup_reg_dealloc(self.other_reg_num) # Alloc tensor memory buffer cute.arch.alloc_tmem( cute.arch.SM100_TMEM_CAPACITY_COLUMNS, tmem_holding_buf, is_two_cta=self.use_2cta_instrs, ) # sync with compute warp before tmem ptr is retrieved self.tmem_ptr_sync_bar.arrive() # Retrieving tensor memory ptr and make accumulator tensor tmem_ptr = cute.arch.retrieve_tmem_ptr( self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf, ) load_q_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.load_q_stage ) load_kv_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.load_kv_stage ) mma_s_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.mma_s_stage ) p_mma_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.p_mma_stage ) mma_o_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.mma_o_stage ) tile_sched = create_mla_static_tile_scheduler( tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() ) work_tile = tile_sched.initial_work_tile_info() while work_tile.is_valid_tile: blk_coord = work_tile.tile_idx k_index, k_tile_count, local_split_kv = self.get_k_tile_count( split_kv, cache_seqs, block_split_kvs, blk_coord ) if k_tile_count > 0: mma_common_params = SimpleNamespace( blk_coord=blk_coord, local_split_kv=local_split_kv, load_q_pipeline=load_q_pipeline, load_kv_pipeline=load_kv_pipeline, tmem_ptr=tmem_ptr, is_leader_cta=is_leader_cta, L=mCL.shape[1], ) mma_qk_params = SimpleNamespace( mma_s_pipeline=mma_s_pipeline, sQ=sQ, sKC=sKC, ) mma_pv_params = SimpleNamespace( p_mma_pipeline=p_mma_pipeline, mma_o_pipeline=mma_o_pipeline, sP=sP, sVC=sVC, ) ( tiled_mma_qk, tiled_mma_pv, load_q_consumer_state, load_kv_consumer_state, mma_s_producer_state, p_mma_consumer_state, mma_o_producer_state, ) = self.mma( mma_common_params, mma_qk_params, mma_pv_params, k_tile_count, tiled_mma_qk, tiled_mma_pv, load_q_consumer_state, load_kv_consumer_state, mma_s_producer_state, p_mma_consumer_state, mma_o_producer_state, ) tile_sched.advance_to_next_work() work_tile = tile_sched.get_current_work() mma_s_pipeline.producer_tail(mma_s_producer_state) mma_o_pipeline.producer_tail(mma_o_producer_state) cute.arch.relinquish_tmem_alloc_permit(is_two_cta=self.use_2cta_instrs) # Dealloc the tensor memory buffer cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) cute.arch.dealloc_tmem( tmem_ptr, cute.arch.SM100_TMEM_CAPACITY_COLUMNS, is_two_cta=self.use_2cta_instrs, ) # /////////////////////////////////////////////////////////////////////////////// # Compute warp # /////////////////////////////////////////////////////////////////////////////// if ( warp_idx >= self.compute_warp_ids[0] and warp_idx <= self.compute_warp_ids[-1] ): cute.arch.warpgroup_reg_alloc(self.softmax_reg_num) mma_s_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.mma_s_stage ) p_mma_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.p_mma_stage ) p_cor_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.p_cor_stage ) mma_o_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.mma_o_stage ) # sync with mma warp before retrieving tmem ptr self.tmem_ptr_sync_bar.wait() tmem_ptr = cute.arch.retrieve_tmem_ptr( self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf, ) tile_sched = create_mla_static_tile_scheduler( tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() ) work_tile = tile_sched.initial_work_tile_info() while work_tile.is_valid_tile: blk_coord = work_tile.tile_idx k_index, k_tile_count, local_split_kv = self.get_k_tile_count( split_kv, cache_seqs, block_split_kvs, blk_coord ) if k_tile_count > 0: compute_common_params = SimpleNamespace( blk_coord=blk_coord, split_kv=split_kv, local_split_kv=local_split_kv, smem_exchange=softmax_smem_exchange, mAccO=mAccO, mO=mO, K=cache_seqs[blk_coord[2]], L=mCL.shape[1], tmem_ptr=tmem_ptr, tidx=tidx, p_cor_pipeline=p_cor_pipeline, ) compute_softmax_params = SimpleNamespace( tiled_mma_qk=tiled_mma_qk, sP=sP, mma_s_pipeline=mma_s_pipeline, p_mma_pipeline=p_mma_pipeline, softmax_scale_log2=softmax_scale_log2, ) mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state = ( self.compute( compute_common_params, compute_softmax_params, k_index=k_index, k_tile_count=k_tile_count, mma_s_consumer_state=mma_s_consumer_state, p_mma_producer_state=p_mma_producer_state, p_cor_producer_state=p_cor_producer_state, ) ) tile_sched.advance_to_next_work() work_tile = tile_sched.get_current_work() p_cor_pipeline.producer_tail(p_cor_producer_state) # /////////////////////////////////////////////////////////////////////////////// # Correction warp # /////////////////////////////////////////////////////////////////////////////// if ( warp_idx >= self.correction_warp_ids[0] and warp_idx <= self.correction_warp_ids[-1] ): cute.arch.warpgroup_reg_alloc(self.correction_reg_num) p_cor_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.p_cor_stage ) mma_o_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.mma_o_stage ) # sync with mma warp before retrieving tmem ptr self.tmem_ptr_sync_bar.wait() tmem_ptr = cute.arch.retrieve_tmem_ptr( self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf, ) tile_sched = create_mla_static_tile_scheduler( tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() ) work_tile = tile_sched.initial_work_tile_info() while work_tile.is_valid_tile: blk_coord = work_tile.tile_idx k_index, k_tile_count, local_split_kv = self.get_k_tile_count( split_kv, cache_seqs, block_split_kvs, blk_coord ) if k_tile_count > 0: compute_common_params = SimpleNamespace( blk_coord=blk_coord, split_kv=split_kv, local_split_kv=local_split_kv, smem_exchange=epilogue_smem_exchange, mAccO=mAccO, mO=mO, K=cache_seqs[blk_coord[2]], L=mCL.shape[1], H=mQL.shape[0], tmem_ptr=tmem_ptr, tidx=tidx, tiled_mma_pv=tiled_mma_pv, p_cor_pipeline=p_cor_pipeline, mma_o_pipeline=mma_o_pipeline, ) compute_epilogue_params = SimpleNamespace( output_scale=output_scale, softmax_scale_log2=softmax_scale_log2, mAccLSE=mAccLSE, mLSE=mLSE, ) p_cor_consumer_state, mma_o_consumer_state = self.correction( compute_common_params, compute_epilogue_params, k_tile_count=k_tile_count, p_cor_consumer_state=p_cor_consumer_state, mma_o_consumer_state=mma_o_consumer_state, ) tile_sched.advance_to_next_work() work_tile = tile_sched.get_current_work() # Arrive for the tensor memory deallocation barrier cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1) return @cute.kernel def reduction_kernel( self, mO: cute.Tensor, mLSE: cute.Tensor, mAccO: cute.Tensor, mAccLSE: cute.Tensor, split_kv: cutlass.Int32, cache_seqs: cute.Tensor, block_split_kvs: cute.Tensor, ): """The reduction kernel for Multi-Head Latent Attention (MLA) that combines intermediate results from multiple split_kv blocks into final outputs. :param mO: Output tensor for storing final results :type mO: cute.Tensor :param mLSE: Log-sum-exp tensor for storing final LSE values :type mLSE: cute.Tensor :param mAccO: Accumulated output tensor from split_kv blocks :type mAccO: cute.Tensor :param mAccLSE: Accumulated LSE tensor from split_kv blocks :type mAccLSE: cute.Tensor :param split_kv: Number of split_kv blocks :type split_kv: cutlass.Int32 :param cache_seqs: Cache sequence lengths tensor :type cache_seqs: cute.Tensor :param block_split_kvs: Per-block split_kv values tensor (for variable split_kv) :type block_split_kvs: cute.Tensor """ bidx, _, bidz = cute.arch.block_idx() tidx, _, _ = cute.arch.thread_idx() blk_coord = (bidx, 0, bidz) local_split_kv = ( block_split_kvs[blk_coord[2]] if self.is_var_split_kv else split_kv ) k_tile_total = cute.ceil_div(cache_seqs[blk_coord[2]], self.mma_qk_tiler[1]) k_tile_per_cta = cute.ceil_div(k_tile_total, local_split_kv) local_split_kv = cute.ceil_div(k_tile_total, k_tile_per_cta) # Alloc shared memory smem = utils.SmemAllocator() storage = smem.allocate(MAX_SPLITS * self.acc_dtype.width // 8, 16) lse_scale_ptr = cute.recast_ptr(storage, dtype=self.acc_dtype) smem_lse_scale = cute.make_tensor(lse_scale_ptr, cute.make_layout(MAX_SPLITS)) gLSE = mAccLSE[blk_coord[0], None, blk_coord[2]] warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) if warp_idx == 0: # calculate the global lse and exp ^ (local_lse - global_lse) lse_per_thread = cute.ceil_div(MAX_SPLITS, self.threads_per_warp) local_lse = cute.make_rmem_tensor( cute.make_layout(lse_per_thread), self.lse_dtype ) lse_max = -self.lse_dtype.inf # find the max lse for i in cutlass.range_constexpr(lse_per_thread): split_kv_idx = tidx + i * self.threads_per_warp local_lse[i] = ( gLSE[split_kv_idx] if cute.elem_less(split_kv_idx, local_split_kv) else -self.lse_dtype.inf ) # reduce the local lse lse_max = cute.arch.fmax(lse_max, local_lse[i]) lse_max = cute.arch.warp_reduction_max(lse_max) lse_max = lse_max if lse_max != -self.lse_dtype.inf else 0.0 # calculate sum_lse sum_lse = 0.0 for i in cutlass.range_constexpr(lse_per_thread): sum_lse += cute.math.exp2(local_lse[i] - lse_max, fastmath=True) sum_lse = cute.arch.warp_reduction_sum(sum_lse) # calculate the global_lse global_lse = ( lse_max + cute.math.log2(sum_lse, fastmath=True) if not sum_lse == self.lse_dtype(0.0) or sum_lse != sum_lse else self.lse_dtype.inf ) if tidx == 0: mLSE[blk_coord[0], blk_coord[2]] = global_lse # store the scale to shared memory for i in cutlass.range_constexpr(lse_per_thread): split_kv_idx = tidx + i * self.threads_per_warp if cute.elem_less(split_kv_idx, local_split_kv): smem_lse_scale[split_kv_idx] = cute.math.exp2( local_lse[i] - global_lse, fastmath=True ) pipeline.sync(barrier_id=4) elements_per_thread = cute.ceil_div( self.latent_dim, self.threads_per_warp * self.num_compute_warps ) gAccO = mAccO[blk_coord[0], None, None, blk_coord[2]] rAccO = cute.make_rmem_tensor( cute.make_layout(elements_per_thread), self.acc_dtype ) rO = cute.make_rmem_tensor(cute.make_layout(elements_per_thread), self.o_dtype) rAccO.fill(0.0) for i in range(local_split_kv): for j in cutlass.range_constexpr(elements_per_thread): element_idx = tidx + j * self.threads_per_warp * self.num_compute_warps rAccO[j] += gAccO[i, element_idx] * smem_lse_scale[i] rO.store(rAccO.load().to(self.o_dtype)) for j in cutlass.range_constexpr(elements_per_thread): element_idx = tidx + j * self.threads_per_warp * self.num_compute_warps mO[blk_coord[0], element_idx, blk_coord[2]] = rO[j] return @staticmethod def get_split_kv( B: int, K: int, mma_qk_tiler_mn: tuple, max_active_blocks: int ) -> int: """Get the proper split_kv value for the MLA kernel based on parameters. :param B: Batch size :type B: int :param K: Sequence length :type K: int :param mma_qk_tiler_mn: MLA tiling parameters :type mma_qk_tiler_mn: tuple :param max_active_blocks: Maximum number of active blocks :type max_active_blocks: int :return: Split_kv value :rtype: int """ max_splits = ceil_div(K, mma_qk_tiler_mn[1]) blocks_per_batch = max(1, max_active_blocks // B) split_heur = min(max_splits, blocks_per_batch) k_waves = ceil_div(max_splits, split_heur) split_wave_aware = ceil_div(max_splits, k_waves) return split_wave_aware @cute.jit def get_k_tile_count( self, split_kv: cutlass.Int32, cache_seqs: cute.Tensor, block_split_kvs: cute.Tensor, blk_coord: cute.Coord, ) -> tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32]: """Get the current k_index, k_tile_count, and local split_kv value for the MLA kernel. :param split_kv: Split_kv value :type split_kv: cutlass.Int32 :param cache_seqs: Cache sequence lengths tensor :type cache_seqs: cute.Tensor :param block_split_kvs: Per-block split_kv values tensor :type block_split_kvs: cute.Tensor :param blk_coord: Block coordinate :type blk_coord: cute.Coord :return: k_index, k_tile_count, split_kv :rtype: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32] """ K = cache_seqs[blk_coord[2]] if cutlass.const_expr(self.is_var_split_kv): split_kv = block_split_kvs[blk_coord[2]] k_tile_total = cute.ceil_div(K, self.mma_qk_tiler[1]) k_tile_per_cta = cute.ceil_div(k_tile_total, split_kv) k_index = blk_coord[3] * k_tile_per_cta k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index) return k_index, k_tile_count, split_kv @cute.jit def load_page_table( self, common_params: SimpleNamespace, k_index: cutlass.Int32, k_tile_count: cutlass.Int32, load_pt_producer_state: pipeline.PipelineState, ) -> pipeline.PipelineState: """Load warp to load page table. Updates the load pt producer state. :param common_params: The common parameters :type common_params: SimpleNamespace :param k_index: The k index :type k_index: cutlass.Int32 :param k_tile_count: The k tile count :type k_tile_count: cutlass.Int32 :param load_pt_producer_state: The load pt producer state :type load_pt_producer_state: pipeline.PipelineState :return: The load pt producer state :rtype: pipeline.PipelineState """ mPT = common_params.mPT[None, common_params.blk_coord[2]] page_per_tile = self.mma_qk_tiler[1] >> cute.arch.log2_of_pow2_int( common_params.page_size ) tidx = common_params.tidx % self.threads_per_warp load_pt_pipeline = common_params.load_pt_pipeline while k_tile_count > 0: load_pt_pipeline.producer_acquire(load_pt_producer_state) elem_per_thread = cute.ceil_div(page_per_tile, self.threads_per_warp) # atom_async_copy: async copy atom for page table load atom_async_copy = cute.make_copy_atom( cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS), cutlass.Int32, num_bits_per_copy=cutlass.Int32.width, ) mPT_for_copy = cute.flat_divide(mPT, (1,)) sPT_for_copy = cute.flat_divide(common_params.sPT, (1,)) # elem_per_thread is a dynamic value depends on the page_size setting. for i in range(elem_per_thread): idx = i * self.threads_per_warp + tidx if cute.elem_less( k_index * page_per_tile + idx, mPT.shape[0] ) and cute.elem_less(idx, page_per_tile): cute.copy( atom_async_copy, mPT_for_copy[None, k_index * page_per_tile + idx], sPT_for_copy[None, idx, load_pt_producer_state.index], ) else: sPT_for_copy[None, idx, load_pt_producer_state.index].fill(0) mbar_ptr = load_pt_pipeline.producer_get_barrier(load_pt_producer_state) load_pt_pipeline.producer_commit(load_pt_producer_state) load_pt_producer_state.advance() k_index += 1 k_tile_count -= 1 return load_pt_producer_state @cute.jit def load_cpasync( self, common_params: SimpleNamespace, qk_params: SimpleNamespace, v_params: SimpleNamespace, k_index: cutlass.Int32, k_tile_count: cutlass.Int32, load_pt_consumer_state: pipeline.PipelineState, load_pt_release_state: pipeline.PipelineState, load_q_producer_state: pipeline.PipelineState, load_kv_producer_state: pipeline.PipelineState, load_kv_commit_state: pipeline.PipelineState, ) -> tuple[ pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, ]: """Load warp to load cpasync. Updates the load cpasync producer state. :param common_params: The common parameters :type common_params: SimpleNamespace :param load_pt_consumer_state: The load pt consumer state :type load_pt_consumer_state: pipeline.PipelineState :param load_pt_release_state: The load pt release state :type load_pt_release_state: pipeline.PipelineState :param load_q_producer_state: The load q producer state :type load_q_producer_state: pipeline.PipelineState :param load_kv_producer_state: The load kv producer state :type load_kv_producer_state: pipeline.PipelineState :param load_kv_commit_state: The load kv commit state :type load_kv_commit_state: pipeline.PipelineState :return: The load pt consumer state, the load pt release state, the load q producer state, the load kv producer state, the load kv commit state :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] """ tidx = ( common_params.tidx - self.threads_per_warp * self.load_cp_async_warp_ids[0] ) # slice view the the global tensors for cpasync, their coords are from counting tensor coord. mCL_for_slice = cute.make_tensor( qk_params.mCL.iterator, cute.make_layout( ( (qk_params.mCL.shape[0], qk_params.mCL.shape[2]), qk_params.mCL.shape[1], ), stride=( (qk_params.mCL.stride[0], qk_params.mCL.stride[2]), qk_params.mCL.stride[1], ), ), ) mKR_for_slice = cute.make_tensor( qk_params.mKR.iterator, cute.make_layout( ( (qk_params.mKR.shape[0], qk_params.mKR.shape[2]), qk_params.mKR.shape[1], ), stride=( (qk_params.mKR.stride[0], qk_params.mKR.stride[2]), qk_params.mKR.stride[1], ), ), ) mCLT_for_slice = cute.make_tensor( v_params.mCLT.iterator, cute.make_layout( ( v_params.mCLT.shape[0], (v_params.mCLT.shape[1], v_params.mCLT.shape[2]), ), stride=( v_params.mCLT.stride[0], (v_params.mCLT.stride[1], v_params.mCLT.stride[2]), ), ), ) # make identity tensor for partition mCL_for_partition = cute.make_identity_tensor( (qk_params.mCL.shape[0] * qk_params.mCL.shape[2], qk_params.mCL.shape[1]) ) mKR_for_partition = cute.make_identity_tensor( (qk_params.mKR.shape[0] * qk_params.mKR.shape[2], qk_params.mKR.shape[1]) ) mCLT_for_partition = cute.make_identity_tensor( (v_params.mCLT.shape[0], v_params.mCLT.shape[1] * v_params.mCLT.shape[2]) ) # Flatten divide and partition global tensors for QK TMA load # (bM, bK, rM, rK, rL) mma_qk_tiler_mk = cute.select(self.mma_qk_tiler, mode=[0, 2]) gQL = cute.flat_divide(qk_params.mQL, mma_qk_tiler_mk) gQR = cute.flat_divide(qk_params.mQR, mma_qk_tiler_mk) mma_qk_tiler_nk = cute.select(self.mma_qk_tiler, mode=[1, 2]) gCL = cute.flat_divide(mCL_for_partition, mma_qk_tiler_nk) gKR = cute.flat_divide(mKR_for_partition, mma_qk_tiler_nk) thr_mma_qk = qk_params.tiled_mma_qk.get_slice( common_params.blk_coord[0] % cute.size(qk_params.tiled_mma_qk.thr_id) ) tSgQL = thr_mma_qk.partition_A(gQL) tSgQR = thr_mma_qk.partition_A(gQR) tSgCL = thr_mma_qk.partition_B(gCL) tSgKR = thr_mma_qk.partition_B(gKR) # create cpasync tiled copy qk cpasync_bits = 128 # thread for copy thread = self.threads_per_warp * self.num_load_warps # Value for copy value = cpasync_bits // self.q_dtype.width cpasync_q_tiled_copy = cute.make_cotiled_copy( cute.make_copy_atom( cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), self.q_dtype, num_bits_per_copy=cpasync_bits, ), cute.make_ordered_layout((thread, value), (1, 0)), qk_params.sQ[None, None, None, 0].layout, ) cpasync_kc_tiled_copy = cute.make_cotiled_copy( cute.make_copy_atom( cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), self.q_dtype, num_bits_per_copy=cpasync_bits, ), cute.make_ordered_layout((thread, value), (1, 0)), qk_params.sKC[None, None, None, 0].layout, ) cpasync_q_thr_copy = cpasync_q_tiled_copy.get_slice(tidx) cpasync_kc_thr_copy = cpasync_kc_tiled_copy.get_slice(tidx) # copy async partition tQgQL = cpasync_q_thr_copy.partition_S(tSgQL) tQgQR = cpasync_q_thr_copy.partition_S(tSgQR) tQsQ = cpasync_q_thr_copy.partition_D(qk_params.sQ) tKCgCL = cpasync_kc_thr_copy.partition_S(tSgCL) tKCgKR = cpasync_kc_thr_copy.partition_S(tSgKR) tKCsKC = cpasync_kc_thr_copy.partition_D(qk_params.sKC) gCLT = cute.flat_divide( mCLT_for_partition, cute.select(self.mma_pv_tiler, mode=[1, 2]) ) thr_mma_pv = v_params.tiled_mma_pv.get_slice( common_params.blk_coord[0] % cute.size(v_params.tiled_mma_pv.thr_id) ) tOgCLT = thr_mma_pv.partition_B(gCLT) cpasync_v_tiled_copy = cute.make_cotiled_copy( cute.make_copy_atom( cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), self.q_dtype, num_bits_per_copy=cpasync_bits, ), cute.make_ordered_layout((thread, value), (1, 0)), v_params.sVC[None, None, None, 0].layout, ) cpasync_v_thr_copy = cpasync_v_tiled_copy.get_slice(tidx) tVCgCLT = cpasync_v_thr_copy.partition_S(tOgCLT) tVCsVC = cpasync_v_thr_copy.partition_D(v_params.sVC) # Use to record the in-flight cpasync stage count, wait and producer commit until `load_kv_stage - 1` cpasync arrive copy_in_flight_count = cutlass.Int32(0) qk_params.tiled_copy_q = cpasync_q_tiled_copy qk_params.tiled_copy_kc = cpasync_kc_tiled_copy qk_params.mCL_for_slice = mCL_for_slice qk_params.mKR_for_slice = mKR_for_slice qk_params.tQgQL = tQgQL qk_params.tQgQR = tQgQR qk_params.tQsQ = tQsQ qk_params.tKCgCL = tKCgCL qk_params.tKCgKR = tKCgKR qk_params.tKCsKC = tKCsKC v_params.tiled_copy_vc = cpasync_v_tiled_copy v_params.tVCgCLT = tVCgCLT v_params.tVCsVC = tVCsVC v_params.mCLT_for_slice = mCLT_for_slice # first load qk latent/rope ( load_pt_consumer_state, load_q_producer_state, load_kv_producer_state, load_kv_commit_state, copy_in_flight_count, ) = self.load_cpasync_qk_one_k_tile( common_params, qk_params, k_index, load_pt_consumer_state, load_q_producer_state, load_kv_producer_state, load_kv_commit_state, copy_in_flight_count, load_q=True, ) k_index += 1 k_tile_count -= 1 # mainloop, load qk and v while k_tile_count > 0: ( load_pt_consumer_state, load_q_producer_state, load_kv_producer_state, load_kv_commit_state, copy_in_flight_count, ) = self.load_cpasync_qk_one_k_tile( common_params, qk_params, k_index, load_pt_consumer_state, load_q_producer_state, load_kv_producer_state, load_kv_commit_state, copy_in_flight_count, load_q=False, ) ( load_pt_release_state, load_kv_producer_state, load_kv_commit_state, copy_in_flight_count, ) = self.load_cpasync_v_one_k_tile( common_params, v_params, k_index - 1, load_pt_release_state, load_kv_producer_state, load_kv_commit_state, copy_in_flight_count, ) k_index += 1 k_tile_count -= 1 # load last tile of v ( load_pt_release_state, load_kv_producer_state, load_kv_commit_state, copy_in_flight_count, ) = self.load_cpasync_v_one_k_tile( common_params, v_params, k_index - 1, load_pt_release_state, load_kv_producer_state, load_kv_commit_state, copy_in_flight_count, ) padding_in_flight = 0 while copy_in_flight_count + padding_in_flight < self.load_kv_stage - 1: padding_in_flight += 1 cute.arch.cp_async_commit_group() # wait for previous cpasync arrive load_kv_pipeline = common_params.load_kv_pipeline while copy_in_flight_count > 0: cute.arch.cp_async_commit_group() cute.arch.cp_async_wait_group(self.load_kv_stage - 1) load_kv_pipeline.producer_commit(load_kv_commit_state) load_kv_commit_state.advance() copy_in_flight_count -= 1 # wait all cpasync arrive cute.arch.cp_async_wait_group(0) return ( load_pt_consumer_state, load_pt_release_state, load_q_producer_state, load_kv_producer_state, load_kv_commit_state, ) @cute.jit def load_cpasync_one_smem_stage( self, common_params: SimpleNamespace, load_q_producer_state: pipeline.PipelineState, load_kv_producer_state: pipeline.PipelineState, load_kv_commit_state: pipeline.PipelineState, copy_func: Callable, copy_in_flight_count: cutlass.Int32, load_q: bool, ) -> tuple[ pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, cutlass.Int32, ]: """Load one smem stage of cpasync. Reused for qkv load stages. Updates the load cpasync producer state. :param common_params: The common parameters :type common_params: SimpleNamespace :param load_pt_consumer_state: The load pt consumer state :type load_pt_consumer_state: pipeline.PipelineState :param load_q_producer_state: The load q producer state :type load_q_producer_state: pipeline.PipelineState :param load_kv_producer_state: The load kv producer state :type load_kv_producer_state: pipeline.PipelineState :param load_kv_commit_state: The load kv commit state :type load_kv_commit_state: pipeline.PipelineState :param copy_func: The copy function :type copy_func: Callable :param copy_in_flight_count: The copy in-flight count :type copy_in_flight_count: cutlass.Int32 :param load_q: Whether to load q :type load_q: bool :return: The load q producer state, the load kv producer state, the load kv commit state, the copy in-flight count :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, cutlass.Int32] """ if cutlass.const_expr(load_q): common_params.load_q_pipeline.producer_acquire(load_q_producer_state) common_params.load_kv_pipeline.producer_acquire(load_kv_producer_state) producer_index = load_kv_producer_state.index copy_func(producer_index) cute.arch.cp_async_commit_group() if cutlass.const_expr(load_q): # directly commit the q producer state here, mma will wait for kv. common_params.load_q_pipeline.producer_commit(load_q_producer_state) load_q_producer_state.advance() load_kv_producer_state.advance() copy_in_flight_count += 1 # wait cpasync arrive until the last stage load_kv_pipeline = common_params.load_kv_pipeline if copy_in_flight_count == self.load_kv_stage: cute.arch.cp_async_wait_group(self.load_kv_stage - 1) load_kv_pipeline.producer_commit(load_kv_commit_state) load_kv_commit_state.advance() copy_in_flight_count -= 1 return ( load_q_producer_state, load_kv_producer_state, load_kv_commit_state, copy_in_flight_count, ) @cute.jit def load_cpasync_page_table_lookup_copy( self, tiled_copy: cute.TiledCopy, gKV: cute.Tensor, sKV: cute.Tensor, sPT: cute.Tensor, gKV_for_slice: cute.Tensor, k_index: cutlass.Int32, latent_idx: cutlass.Int32, qkv_stage_idx: cutlass.Int32, page_table_stage: cutlass.Int32, page_size: cutlass.Int32, transpose: bool = False, ): """Make page table lookup for KV cache latent/rope, then do atom copy of cpasync. :param tiled_copy: The tiled copy :type tiled_copy: cute.TiledCopy :param gKV: The global KV tensor :type gKV: cute.Tensor :param sKV: The sliced KV tensor :type sKV: cute.Tensor :param sPT: The sliced page table tensor :type sPT: cute.Tensor :param gKV_for_slice: The global KV for slice tensor :type gKV_for_slice: cute.Tensor :param k_index: The k index :type k_index: cutlass.Int32 :param latent_idx: The latent index :type latent_idx: cutlass.Int32 :param qkv_stage_idx: The qkv stage index :type qkv_stage_idx: cutlass.Int32 :param page_table_stage: The page table stage :type page_table_stage: cutlass.Int32 :param transpose: Whether to transpose the gKV_for_slice :type transpose: bool """ rest_modes_start = 1 rest_modes_end = 4 if cutlass.const_expr(transpose): gKV_grouped = cute.group_modes( gKV[None, None, None, None, latent_idx, k_index], rest_modes_start, rest_modes_end, ) else: gKV_grouped = cute.group_modes( gKV[None, None, None, None, k_index, latent_idx], rest_modes_start, rest_modes_end, ) sKV_grouped = cute.group_modes( sKV[None, None, None, None, qkv_stage_idx], rest_modes_start, rest_modes_end ) page_size_log2 = cute.arch.log2_of_pow2_int(page_size) page_per_tile = self.mma_qk_tiler[1] >> page_size_log2 gKV_for_copy_offsets = cute.make_rmem_tensor( cute.size(gKV_grouped.shape[1]), cute.cosize(gKV_for_slice.layout).dtype ) # unroll the rest of the loop to apply page table lookup. for i in cutlass.range_constexpr(cute.size(gKV_grouped.shape[1])): # get the coordinate of the gKV_for_slice coord = gKV_grouped[None, i].iterator if cutlass.const_expr(transpose): # fast path of mod & div here to avoid the division because of the page size is power of 2. page_coord = ((coord[1] & (page_size - 1)), coord[1] >> page_size_log2) new_coord = (coord[0], page_coord) new_coord_pt = new_coord[1][1] & (page_per_tile - 1) gKV_for_copy_offset = cute.crd2idx( ( new_coord[0], (new_coord[1][0], sPT[new_coord_pt, page_table_stage]), ), gKV_for_slice.layout, ) else: # fast path of mod & div here to avoid the division because of the page size is power of 2. page_coord = (coord[0] & (page_size - 1), coord[0] >> page_size_log2) new_coord = (page_coord, coord[1]) new_coord_pt = new_coord[0][1] & (page_per_tile - 1) gKV_for_copy_offset = cute.crd2idx( ( (new_coord[0][0], sPT[new_coord_pt, page_table_stage]), new_coord[1], ), gKV_for_slice.layout, ) gKV_for_copy_offsets[i] = gKV_for_copy_offset cpasync_bits = 128 for i in cutlass.range_constexpr(cute.size(gKV_grouped.shape[1])): # calculate the actual offset and apply. sKV_for_copy = sKV_grouped[None, i] gKV_for_copy_offset = cute.assume( gKV_for_copy_offsets[i], cpasync_bits // self.q_dtype.width ) gKV_for_copy_iter = gKV_for_slice.iterator + gKV_for_copy_offset gKV_for_copy = cute.make_tensor(gKV_for_copy_iter, sKV_for_copy.layout) cute.copy(tiled_copy, gKV_for_copy, sKV_for_copy) return @cute.jit def load_cpasync_qk_one_k_tile( self, common_params: SimpleNamespace, qk_params: SimpleNamespace, k_index: cutlass.Int32, load_pt_consumer_state: pipeline.PipelineState, load_q_producer_state: pipeline.PipelineState, load_kv_producer_state: pipeline.PipelineState, load_kv_commit_state: pipeline.PipelineState, copy_in_flight_count: cutlass.Int32, load_q: bool, ) -> tuple[ pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, cutlass.Int32, ]: """Load one k tile of Q/K. Updates the load cpasync producer state. :param qk_params: The qk parameters :type qk_params: SimpleNamespace :param k_index: The k index :type k_index: cutlass.Int32 :param load_pt_consumer_state: The load pt consumer state :type load_pt_consumer_state: pipeline.PipelineState :param load_q_producer_state: The load q producer state :type load_q_producer_state: pipeline.PipelineState :param load_kv_producer_state: The load kv producer state :type load_kv_producer_state: pipeline.PipelineState :param load_kv_commit_state: The load kv commit state :type load_kv_commit_state: pipeline.PipelineState :param copy_in_flight_count: The copy stage count :type copy_in_flight_count: int :param load_q: Whether to load q :type load_q: bool :return: The load pt consumer state, the load q producer state, the load kv producer state, the load kv commit state, the copy stage count :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, int] """ common_params.load_pt_pipeline.consumer_wait(load_pt_consumer_state) page_table_stage = load_pt_consumer_state.index load_pt_consumer_state.advance() def copy_qk_latent(latent_idx, qkv_stage_idx): if load_q: cute.copy( qk_params.tiled_copy_q, qk_params.tQgQL[ None, None, None, None, 0, latent_idx, common_params.blk_coord[2], ], qk_params.tQsQ[None, None, None, None, latent_idx], ) # make sure the page table lookups first. self.load_cpasync_page_table_lookup_copy( qk_params.tiled_copy_kc, qk_params.tKCgCL, qk_params.tKCsKC, common_params.sPT, qk_params.mCL_for_slice, k_index, latent_idx, qkv_stage_idx, page_table_stage, common_params.page_size, ) def copy_qk_rope(latent_idx, qkv_stage_idx): if load_q: cute.copy( qk_params.tiled_copy_q, qk_params.tQgQR[ None, None, None, None, 0, latent_idx, common_params.blk_coord[2], ], qk_params.tQsQ[ None, None, None, None, self.iterations_qk_latent + latent_idx ], ) # make sure the page table lookups first. self.load_cpasync_page_table_lookup_copy( qk_params.tiled_copy_kc, qk_params.tKCgKR, qk_params.tKCsKC, common_params.sPT, qk_params.mKR_for_slice, k_index, latent_idx, qkv_stage_idx, page_table_stage, common_params.page_size, ) # use dynamic loop here to avoid instruction cache miss. for i in range(self.iterations_qk_latent): ( load_q_producer_state, load_kv_producer_state, load_kv_commit_state, copy_in_flight_count, ) = self.load_cpasync_one_smem_stage( common_params, load_q_producer_state, load_kv_producer_state, load_kv_commit_state, partial(copy_qk_latent, i), copy_in_flight_count, load_q=load_q, ) for i in range(self.iterations_qk_rope): ( load_q_producer_state, load_kv_producer_state, load_kv_commit_state, copy_in_flight_count, ) = self.load_cpasync_one_smem_stage( common_params, load_q_producer_state, load_kv_producer_state, load_kv_commit_state, partial(copy_qk_rope, i), copy_in_flight_count, load_q=load_q, ) return ( load_pt_consumer_state, load_q_producer_state, load_kv_producer_state, load_kv_commit_state, copy_in_flight_count, ) @cute.jit def load_cpasync_v_one_k_tile( self, common_params: SimpleNamespace, v_params: SimpleNamespace, k_index: cutlass.Int32, load_pt_release_state: pipeline.PipelineState, load_kv_producer_state: pipeline.PipelineState, load_kv_commit_state: pipeline.PipelineState, copy_in_flight_count: cutlass.Int32, ) -> tuple[ pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, cutlass.Int32, ]: """Load one k tile of V. Updates the load cpasync producer state. :param common_params: The common parameters :type common_params: SimpleNamespace :param v_params: The v parameters :type v_params: SimpleNamespace :param k_index: The k index :type k_index: cutlass.Int32 :param load_pt_release_state: The load pt release state :type load_pt_release_state: pipeline.PipelineState :param load_kv_producer_state: The load kv producer state :type load_kv_producer_state: pipeline.PipelineState :param load_kv_commit_state: The load kv commit state :type load_kv_commit_state: pipeline.PipelineState :param copy_in_flight_count: The copy in-flight count :type copy_in_flight_count: cutlass.Int32 :return: The load pt release state, the load kv producer state, the load kv commit state, the copy in-flight count :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, cutlass.Int32] """ page_table_stage = load_pt_release_state.index def copy_v_latent(iter_k_idx, latent_idx, qkv_stage_idx): # make sure the page table lookups first. self.load_cpasync_page_table_lookup_copy( v_params.tiled_copy_vc, v_params.tVCgCLT, v_params.tVCsVC, common_params.sPT, v_params.mCLT_for_slice, k_index * self.iterations_pv_k + iter_k_idx, latent_idx, qkv_stage_idx, page_table_stage, common_params.page_size, transpose=True, ) # use dynamic loop here to avoid instruction cache miss. for i in range(self.iterations_pv_k): for j in range(self.iterations_pv_n): ( _, load_kv_producer_state, load_kv_commit_state, copy_in_flight_count, ) = self.load_cpasync_one_smem_stage( common_params, None, load_kv_producer_state, load_kv_commit_state, partial(copy_v_latent, i, j), copy_in_flight_count, load_q=False, ) common_params.load_pt_pipeline.consumer_release(load_pt_release_state) load_pt_release_state.advance() return ( load_pt_release_state, load_kv_producer_state, load_kv_commit_state, copy_in_flight_count, ) @cute.jit def load_tma( self, common_params: SimpleNamespace, qk_params: SimpleNamespace, v_params: SimpleNamespace, k_index: cutlass.Int32, k_tile_count: cutlass.Int32, load_q_producer_state: pipeline.PipelineState, load_kv_producer_state: pipeline.PipelineState, ) -> tuple[pipeline.PipelineState, pipeline.PipelineState]: """Load wrap to load Q/C latent/rope tensors. Updates the load qkv producer state. :param common_params: The common parameters :type common_params: SimpleNamespace :param qk_params: The qk parameters :type qk_params: SimpleNamespace :param v_params: The v parameters :type v_params: SimpleNamespace :param k_index: The k index :type k_index: cutlass.Int32 :param k_tile_count: The k tile count :type k_tile_count: cutlass.Int32 :param load_q_producer_state: The load q producer state :type load_q_producer_state: pipeline.PipelineState :param load_kv_producer_state: The load kv producer state :type load_kv_producer_state: pipeline.PipelineState :return: The load q producer state and load kv producer state :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState] """ # page table mPT = None if cutlass.const_expr(self.use_page_table): mPT = common_params.mPT[None, common_params.blk_coord[2]] # Flatten divide and partition global tensors for QK TMA load # (bM, bK, rM, rK, rL) mma_qk_tiler_mk = cute.select(self.mma_qk_tiler, mode=[0, 2]) gQL = cute.flat_divide(qk_params.mQL, mma_qk_tiler_mk) gQR = cute.flat_divide(qk_params.mQR, mma_qk_tiler_mk) mma_qk_tiler_nk = cute.select(self.mma_qk_tiler, mode=[1, 2]) gCL = cute.flat_divide(qk_params.mCL, mma_qk_tiler_nk) gKR = cute.flat_divide(qk_params.mKR, mma_qk_tiler_nk) thr_mma_qk = qk_params.tiled_mma_qk.get_slice( common_params.blk_coord[0] % cute.size(qk_params.tiled_mma_qk.thr_id) ) tSgQL = thr_mma_qk.partition_A(gQL) tSgQR = thr_mma_qk.partition_A(gQR) tSgCL = thr_mma_qk.partition_B(gCL) tSgKR = thr_mma_qk.partition_B(gKR) # tma partition for q, k latent/rope # smem: ((atom_v, rest_v), STAGE) # gmem: ((atom_v, rest_v), RestM, RestK, RestL) tQsQ, tQLgQL_mkl = cpasync.tma_partition( qk_params.tma_atom_q_latent, 0, cute.make_layout(1), cute.group_modes(qk_params.sQ, 0, 3), cute.group_modes(tSgQL, 0, 3), ) _, tQRgQR_mkl = cpasync.tma_partition( qk_params.tma_atom_q_rope, 0, cute.make_layout(1), cute.group_modes(qk_params.sQ, 0, 3), cute.group_modes(tSgQR, 0, 3), ) tKCsKC, tCLgCL = cpasync.tma_partition( qk_params.tma_atom_c_latent, 0, cute.make_layout(1), cute.group_modes(qk_params.sKC, 0, 3), cute.group_modes(tSgCL, 0, 3), ) _, tKRgKR = cpasync.tma_partition( qk_params.tma_atom_c_rope, 0, cute.make_layout(1), cute.group_modes(qk_params.sKC, 0, 3), cute.group_modes(tSgKR, 0, 3), ) tQLgQL = tQLgQL_mkl[None, None, None, common_params.blk_coord[2]] tQRgQR = tQRgQR_mkl[None, None, None, common_params.blk_coord[2]] # Flatten divide and partition global tensors for V TMA load mma_pv_tiler_nk = cute.select(self.mma_pv_tiler, mode=[1, 2]) gCLT = cute.flat_divide(v_params.mCLT, mma_pv_tiler_nk) thr_mma_pv = v_params.tiled_mma_pv.get_slice( common_params.blk_coord[0] % cute.size(v_params.tiled_mma_pv.thr_id) ) tOgCLT = thr_mma_pv.partition_B(gCLT) # tma partition for vc # smem: ((atom_v, rest_v), STAGE) # gmem: ((atom_v, rest_v), RestM, RestK, RestL) tVCsVC, tCLTgCLT = cpasync.tma_partition( v_params.tma_atom_c_latent_transpose, 0, cute.make_layout(1), cute.group_modes(v_params.sVC, 0, 3), cute.group_modes(tOgCLT, 0, 3), ) # set extra params common_params.mPT = mPT qk_params.tQLgQL = tQLgQL qk_params.tQRgQR = tQRgQR qk_params.tCLgCL = tCLgCL qk_params.tKRgKR = tKRgKR qk_params.tQsQ = tQsQ qk_params.tKCsKC = tKCsKC v_params.tCLTgCLT = tCLTgCLT v_params.tVCsVC = tVCsVC load_q_producer_state, load_kv_producer_state = self.load_tma_qk_one_k_tile( common_params, qk_params, k_index, k_tile_count, load_q_producer_state, load_kv_producer_state, load_q=True, ) k_index += 1 k_tile_count -= 1 while k_tile_count > 0: load_q_producer_state, load_kv_producer_state = self.load_tma_qk_one_k_tile( common_params, qk_params, k_index, k_tile_count, load_q_producer_state, load_kv_producer_state, load_q=False, ) load_kv_producer_state = self.load_tma_v_one_k_tile( common_params, v_params, k_index - 1, load_kv_producer_state, ) k_index += 1 k_tile_count -= 1 # load last v tile load_kv_producer_state = self.load_tma_v_one_k_tile( common_params, v_params, k_index - 1, load_kv_producer_state, ) return load_q_producer_state, load_kv_producer_state @cute.jit def load_tma_qk_one_k_tile( self, common_params: SimpleNamespace, qk_params: SimpleNamespace, k_index: cutlass.Int32, k_tile_count: cutlass.Int32, load_q_producer_state: pipeline.PipelineState, load_kv_producer_state: pipeline.PipelineState, load_q: bool, ) -> tuple[pipeline.PipelineState, pipeline.PipelineState]: """Load one k-tile of Q/C latent/rope tensors. Updates the load qkv producer state. :param common_params: The common parameters :type common_params: SimpleNamespace :param qk_params: The qk parameters :type qk_params: SimpleNamespace :param k_index: The k index :type k_index: cutlass.Int32 :param k_tile_count: The k tile count :type k_tile_count: cutlass.Int32 :param load_q_producer_state: The load q producer state :type load_q_producer_state: pipeline.PipelineState :param load_kv_producer_state: The load kv producer state :type load_kv_producer_state: pipeline.PipelineState :param load_q: Whether to load q :type load_q: bool :return: The load q producer state and load kv producer state :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState] """ k_idx = cute.make_rmem_tensor(cute.make_layout(2), cutlass.Int32) # prefetch next K load to keep busy while we transpose-load from cache kPrefetchDistance = 1 if cutlass.const_expr(self.use_page_table): k_idx[0] = common_params.mPT[k_index] k_idx[1] = common_params.mPT[k_index + kPrefetchDistance] else: k_idx[0] = common_params.blk_coord[2] k_idx[1] = common_params.blk_coord[2] for i in cutlass.range_constexpr(self.iterations_qk_latent): # load q once at first iteration if cutlass.const_expr(load_q): # get the mbar ptr from pipeline. tma_bar_ptr = common_params.load_q_pipeline.producer_get_barrier( load_q_producer_state ) # expect the extra bytes for q. common_params.load_q_pipeline.producer_acquire(load_q_producer_state) # load q latent cute.copy( qk_params.tma_atom_q_latent, qk_params.tQLgQL[None, 0, load_q_producer_state.index], qk_params.tQsQ[None, load_q_producer_state.index], tma_bar_ptr=tma_bar_ptr, ) load_q_producer_state.advance() # get the mbar ptr from pipeline. tma_bar_ptr = common_params.load_kv_pipeline.producer_get_barrier( load_kv_producer_state ) # expect the extra bytes for q. common_params.load_kv_pipeline.producer_acquire(load_kv_producer_state) # load k latent if cutlass.const_expr(self.use_page_table): cute.copy( qk_params.tma_atom_c_latent, qk_params.tCLgCL[None, 0, i, k_idx[0]], qk_params.tKCsKC[None, load_kv_producer_state.index], tma_bar_ptr=tma_bar_ptr, ) else: cute.copy( qk_params.tma_atom_c_latent, qk_params.tCLgCL[None, k_index, i, k_idx[0]], qk_params.tKCsKC[None, load_kv_producer_state.index], tma_bar_ptr=tma_bar_ptr, ) load_kv_producer_state.advance() for i in cutlass.range_constexpr(self.iterations_qk_rope): # load q rope once at first iteration if cutlass.const_expr(load_q): # get the mbar ptr from pipeline. tma_bar_ptr = common_params.load_q_pipeline.producer_get_barrier( load_q_producer_state ) # expect the extra bytes for q. common_params.load_q_pipeline.producer_acquire(load_q_producer_state) # load q rope cute.copy( qk_params.tma_atom_q_rope, qk_params.tQRgQR[None, 0, i], qk_params.tQsQ[None, i + self.iterations_qk_latent], tma_bar_ptr=tma_bar_ptr, ) load_q_producer_state.advance() # get the mbar ptr from pipeline. tma_bar_ptr = common_params.load_kv_pipeline.producer_get_barrier( load_kv_producer_state ) # expect the extra bytes for q. common_params.load_kv_pipeline.producer_acquire(load_kv_producer_state) # load k rope if cutlass.const_expr(self.use_page_table): cute.copy( qk_params.tma_atom_c_rope, qk_params.tKRgKR[None, 0, i, k_idx[0]], qk_params.tKCsKC[None, load_kv_producer_state.index], tma_bar_ptr=tma_bar_ptr, ) else: cute.copy( qk_params.tma_atom_c_rope, qk_params.tKRgKR[None, k_index, i, k_idx[0]], qk_params.tKCsKC[None, load_kv_producer_state.index], tma_bar_ptr=tma_bar_ptr, ) load_kv_producer_state.advance() for i in cutlass.range_constexpr(self.iterations_qk_latent): if cutlass.const_expr(self.use_page_table): if k_tile_count > kPrefetchDistance: cute.prefetch( qk_params.tma_atom_c_latent, qk_params.tCLgCL[ None, k_index, i, k_idx[1], ], ) else: cute.prefetch( qk_params.tma_atom_c_latent, qk_params.tCLgCL[None, k_index + kPrefetchDistance, i, k_idx[1]], ) for i in cutlass.range_constexpr(self.iterations_qk_rope): if cutlass.const_expr(self.use_page_table): if k_tile_count > kPrefetchDistance: cute.prefetch( qk_params.tma_atom_c_rope, qk_params.tKRgKR[ None, k_index, i, k_idx[1], ], ) else: cute.prefetch( qk_params.tma_atom_c_rope, qk_params.tKRgKR[None, k_index + kPrefetchDistance, i, k_idx[1]], ) return load_q_producer_state, load_kv_producer_state @cute.jit def load_tma_v_one_k_tile( self, common_params: SimpleNamespace, v_params: SimpleNamespace, k_index: cutlass.Int32, load_kv_producer_state: pipeline.PipelineState, ) -> pipeline.PipelineState: """Load one k-tile of compressed latent transpose tensor(v). Updates the load qkv producer state. :param common_params: The common parameters :type common_params: SimpleNamespace :param v_params: The load tma v parameters :type v_params: SimpleNamespace :param k_index: The k index :type k_index: cutlass.Int32 :param load_kv_producer_state: The load qkv producer state :type load_kv_producer_state: pipeline.PipelineState :return: The load qkv producer state :rtype: pipeline.PipelineState """ k_idx = cute.make_rmem_tensor(cute.make_layout(1), cutlass.Int32) if cutlass.const_expr(self.use_page_table): k_idx[0] = common_params.mPT[k_index] else: k_idx[0] = common_params.blk_coord[2] for i in cutlass.range_constexpr(self.iterations_pv_k): for j in cutlass.range_constexpr(self.iterations_pv_n): # get the mbar ptr from pipeline. tma_bar_ptr = common_params.load_kv_pipeline.producer_get_barrier( load_kv_producer_state ) common_params.load_kv_pipeline.producer_acquire(load_kv_producer_state) if cutlass.const_expr(self.use_page_table): cute.copy( v_params.tma_atom_c_latent_transpose, v_params.tCLTgCLT[None, j, i, k_idx[0]], v_params.tVCsVC[None, load_kv_producer_state.index], tma_bar_ptr=tma_bar_ptr, ) else: cute.copy( v_params.tma_atom_c_latent_transpose, v_params.tCLTgCLT[ None, j, k_index * self.iterations_pv_k + i, k_idx[0], ], v_params.tVCsVC[None, load_kv_producer_state.index], tma_bar_ptr=tma_bar_ptr, ) load_kv_producer_state.advance() return load_kv_producer_state @cute.jit def mma( self, common_params: SimpleNamespace, qk_params: SimpleNamespace, pv_params: SimpleNamespace, k_tile_count: cutlass.Int32, tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, load_q_consumer_state: pipeline.PipelineState, load_kv_consumer_state: pipeline.PipelineState, mma_s_producer_state: pipeline.PipelineState, p_mma_consumer_state: pipeline.PipelineState, mma_o_producer_state: pipeline.PipelineState, ) -> tuple[ cute.TiledMma, cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, ]: """MMA warp to compute the result of Q*K^T and P*V. Updates the tiled mma and pipeline states. :param common_params: The common parameters for mma qk and pv :type common_params: SimpleNamespace :param qk_params: The mma qk parameters :type qk_params: SimpleNamespace :param pv_params: The mma pv parameters :type pv_params: SimpleNamespace :param k_tile_count: The k tile count :type k_tile_count: cutlass.Int32 :param tiled_mma_qk: The tiled mma qk :type tiled_mma_qk: cute.TiledMma :param tiled_mma_pv: The tiled mma pv :type tiled_mma_pv: cute.TiledMma :param load_q_consumer_state: The load q consumer state :type load_q_consumer_state: pipeline.PipelineState :param load_kv_consumer_state: The load kv consumer state :type load_kv_consumer_state: pipeline.PipelineState :param mma_s_producer_state: The mma s producer state :type mma_s_producer_state: pipeline.PipelineState :param p_mma_consumer_state: The p mma consumer state :type p_mma_consumer_state: pipeline.PipelineState :param mma_o_producer_state: The mma o producer state :type mma_o_producer_state: pipeline.PipelineState :return: The tiled mma qk, the tiled mma pv, the load q consumer state, the load kv consumer state, the mma s producer state, the p mma consumer state, and the mma o producer state :rtype: tuple[cute.TiledMma, cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] """ tSrQ = tiled_mma_qk.make_fragment_A(qk_params.sQ) tSrKC = tiled_mma_qk.make_fragment_B(qk_params.sKC) tOrP = tiled_mma_pv.make_fragment_A(pv_params.sP) tOrVC = tiled_mma_pv.make_fragment_B(pv_params.sVC) tStS_shape = tiled_mma_qk.partition_shape_C( cute.select(self.mma_qk_tiler, mode=[0, 1]) ) tStS_staged_fake = tiled_mma_qk.make_fragment_C( cute.append(tStS_shape, self.mma_s_stage) ) # use real tmem ptr for tStS tStS_staged = cute.make_tensor(common_params.tmem_ptr, tStS_staged_fake.layout) tOtO_shape = tiled_mma_pv.partition_shape_C( cute.select(self.mma_pv_tiler, mode=[0, 1]) ) # mma O has 1 stage. assert self.mma_o_stage == 1, ( "mma O has 1 stage, otherwise the tmem usage exceeds the limit." ) tOtO = tiled_mma_pv.make_fragment_C(tOtO_shape) tOtO_layout = cute.append( tOtO.layout, cute.make_layout( common_params.L // self.mma_pv_tiler[1], stride=self.mma_pv_tiler[1] // self.warps_in_n, ), ) tOtO_staged = cute.make_tensor( tStS_staged.iterator + self.tmem_o_offset, tOtO_layout ) # set more parameters qk_params.tSrQ = tSrQ qk_params.tSrKC = tSrKC qk_params.tStS_staged = tStS_staged pv_params.tOrP = tOrP pv_params.tOrVC = tOrVC pv_params.tOtO_staged = tOtO_staged # mma O accumulates on K, so the accumlate flag is set to False once before all K blocks. tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, False) load_q_pipeline = common_params.load_q_pipeline if common_params.is_leader_cta: load_q_release_state = load_q_consumer_state.clone() ( tiled_mma_qk, load_q_consumer_state, load_kv_consumer_state, mma_s_producer_state, ) = self.mma_qk( common_params, qk_params, tiled_mma_qk, load_q_consumer_state, load_kv_consumer_state, mma_s_producer_state, wait_q=True, ) k_tile_count -= 1 while k_tile_count > 0: ( tiled_mma_qk, load_q_consumer_state, load_kv_consumer_state, mma_s_producer_state, ) = self.mma_qk( common_params, qk_params, tiled_mma_qk, load_q_consumer_state, load_kv_consumer_state, mma_s_producer_state, wait_q=False, ) ( tiled_mma_pv, load_kv_consumer_state, p_mma_consumer_state, mma_o_producer_state, ) = self.mma_pv( common_params, pv_params, tiled_mma_pv, load_kv_consumer_state, p_mma_consumer_state, mma_o_producer_state, ) k_tile_count -= 1 # release q consumer states for i in cutlass.range_constexpr(self.iterations_qk): load_q_pipeline.consumer_release(load_q_release_state) load_q_release_state.advance() ( tiled_mma_pv, load_kv_consumer_state, p_mma_consumer_state, mma_o_producer_state, ) = self.mma_pv( common_params, pv_params, tiled_mma_pv, load_kv_consumer_state, p_mma_consumer_state, mma_o_producer_state, ) return ( tiled_mma_qk, tiled_mma_pv, load_q_consumer_state, load_kv_consumer_state, mma_s_producer_state, p_mma_consumer_state, mma_o_producer_state, ) @cute.jit def mma_qk( self, common_params: SimpleNamespace, qk_params: SimpleNamespace, tiled_mma_qk: cute.TiledMma, load_q_consumer_state: pipeline.PipelineState, load_kv_consumer_state: pipeline.PipelineState, mma_s_producer_state: pipeline.PipelineState, wait_q: bool, ) -> tuple[ cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, ]: """Compute one k-tile of mma for Q*K^T. Updates the tiled MMA QK and pipeline states. :param qk_params: The qk parameters :type qk_params: SimpleNamespace :param tiled_mma_qk: The tiled mma qk :type tiled_mma_qk: cute.TiledMma :param load_q_consumer_state: The load q consumer state :type load_q_consumer_state: pipeline.PipelineState :param load_kv_consumer_state: The load kv consumer state :type load_kv_consumer_state: pipeline.PipelineState :param mma_s_producer_state: The mma s producer state :type mma_s_producer_state: pipeline.PipelineState :return: The tiled mma qk, the load q consumer state, the load kv consumer state, and the mma s producer state :rtype: tuple[cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] """ tStS = qk_params.tStS_staged[None, None, None, mma_s_producer_state.index] qk_params.mma_s_pipeline.producer_acquire(mma_s_producer_state) tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, False) load_q_pipeline = common_params.load_q_pipeline load_kv_pipeline = common_params.load_kv_pipeline for q_stage in range(self.iterations_qk_latent): if cutlass.const_expr(wait_q): load_q_pipeline.consumer_wait(load_q_consumer_state) load_kv_pipeline.consumer_wait(load_kv_consumer_state) kc_stage = load_kv_consumer_state.index for k_block in cutlass.range_constexpr(cute.size(qk_params.tSrQ.shape[2])): cute.gemm( tiled_mma_qk, tStS, qk_params.tSrQ[None, None, k_block, q_stage], qk_params.tSrKC[None, None, k_block, kc_stage], tStS, ) tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, True) load_kv_pipeline.consumer_release(load_kv_consumer_state) load_kv_consumer_state.advance() if cutlass.const_expr(wait_q): load_q_consumer_state.advance() for q_stage in range(self.iterations_qk_rope): if cutlass.const_expr(wait_q): load_q_pipeline.consumer_wait(load_q_consumer_state) load_kv_pipeline.consumer_wait(load_kv_consumer_state) kc_stage = load_kv_consumer_state.index for k_block in cutlass.range_constexpr( self.rope_dim // tiled_mma_qk.shape_mnk[2] ): cute.gemm( tiled_mma_qk, tStS, qk_params.tSrQ[ None, None, k_block, q_stage + self.iterations_qk_latent ], qk_params.tSrKC[None, None, k_block, kc_stage], tStS, ) tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, True) load_kv_pipeline.consumer_release(load_kv_consumer_state) load_kv_consumer_state.advance() if cutlass.const_expr(wait_q): load_q_consumer_state.advance() qk_params.mma_s_pipeline.producer_commit(mma_s_producer_state) mma_s_producer_state.advance() return ( tiled_mma_qk, load_q_consumer_state, load_kv_consumer_state, mma_s_producer_state, ) @cute.jit def mma_pv( self, common_params: SimpleNamespace, pv_params: SimpleNamespace, tiled_mma_pv: cute.TiledMma, load_kv_consumer_state: pipeline.PipelineState, p_mma_consumer_state: pipeline.PipelineState, mma_o_producer_state: pipeline.PipelineState, ) -> tuple[ cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, ]: """Compute one k-tile of mma for P*V. Updates the tiled mma pv and pipeline states. :param common_params: The common parameters :type common_params: SimpleNamespace :param pv_params: The pv parameters :type pv_params: SimpleNamespace :param tiled_mma_pv: The tiled mma pv :type tiled_mma_pv: cute.TiledMma :param load_kv_consumer_state: The load kv consumer state :type load_kv_consumer_state: pipeline.PipelineState :param p_mma_consumer_state: The P MMA consumer state :type p_mma_consumer_state: pipeline.PipelineState :param mma_o_producer_state: The MMA o producer state :type mma_o_producer_state: pipeline.PipelineState :return: The tiled mma pv, the load qkv consumer state, the P MMA consumer state, and the MMA o producer state :rtype: tuple[cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] """ pv_params.mma_o_pipeline.producer_acquire(mma_o_producer_state) pv_params.p_mma_pipeline.consumer_wait(p_mma_consumer_state) load_kv_pipeline = common_params.load_kv_pipeline for p_stage in range(self.iterations_pv_k): accumulate_flag = tiled_mma_pv.get(tcgen05.Field.ACCUMULATE) for acc_stage in range(self.iterations_pv_n): load_kv_pipeline.consumer_wait(load_kv_consumer_state) tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, accumulate_flag) vc_stage = load_kv_consumer_state.index tOtO = pv_params.tOtO_staged[None, None, None, acc_stage] for k_block in cutlass.range_constexpr(pv_params.tOrP.shape[2]): cute.gemm( tiled_mma_pv, tOtO, pv_params.tOrP[ None, None, k_block, (p_stage, p_mma_consumer_state.index), ], pv_params.tOrVC[None, None, k_block, vc_stage], tOtO, ) tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, True) load_kv_pipeline.consumer_release(load_kv_consumer_state) load_kv_consumer_state.advance() pv_params.p_mma_pipeline.consumer_release(p_mma_consumer_state) p_mma_consumer_state.advance() pv_params.mma_o_pipeline.producer_commit(mma_o_producer_state) mma_o_producer_state.advance() return ( tiled_mma_pv, load_kv_consumer_state, p_mma_consumer_state, mma_o_producer_state, ) @cute.jit def compute( self, common_params: SimpleNamespace, softmax_params: SimpleNamespace, k_index: cutlass.Int32, k_tile_count: cutlass.Int32, mma_s_consumer_state: pipeline.PipelineState, p_mma_producer_state: pipeline.PipelineState, p_cor_producer_state: pipeline.PipelineState, ) -> tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState]: """Compute warp to compute the result of softmax, rescale, and epilogue. Updates the related pipeline states. :param common_params: The common parameters :type common_params: SimpleNamespace :param softmax_params: The softmax parameters :type softmax_params: SimpleNamespace :param k_index: The index of the k-tile :type k_index: cutlass.Int32 :param k_tile_count: The number of k-tiles :type k_tile_count: cutlass.Int32 :param mma_s_consumer_state: The MMA s consumer state :type mma_s_consumer_state: pipeline.PipelineState :param p_mma_producer_state: The P MMA producer state :type p_mma_producer_state: pipeline.PipelineState :param p_cor_producer_state: The P correction producer state :type p_cor_producer_state: pipeline.PipelineState :return: The MMA s consumer state, the P MMA producer state, and the P correction producer state :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] """ k_tile_total = cute.ceil_div(common_params.K, self.mma_qk_tiler[1]) row_max = -self.acc_dtype.inf row_sum = self.acc_dtype(0) correction_factor = self.acc_dtype(1) while k_tile_count > 0: ( mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state, row_max, row_sum, correction_factor, ) = self.softmax_dispatch_apply_mask( common_params, softmax_params, k_index, k_tile_total, mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state, row_max, row_sum, correction_factor, ) k_index = k_index + 1 k_tile_count = k_tile_count - 1 return mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state @cute.jit def correction( self, common_params: SimpleNamespace, epilogue_params: SimpleNamespace, k_tile_count: cutlass.Int32, p_cor_consumer_state: pipeline.PipelineState, mma_o_consumer_state: pipeline.PipelineState, ) -> tuple[pipeline.PipelineState, pipeline.PipelineState]: """Compute warp to compute the result of softmax, rescale, and epilogue. Updates the related pipeline states. :param common_params: The common parameters :type common_params: SimpleNamespace :param epilogue_params: The epilogue parameters :type epilogue_params: SimpleNamespace :param k_index: The index of the k-tile :type k_index: cutlass.Int32 :param k_tile_count: The number of k-tiles :type k_tile_count: cutlass.Int32 :param p_cor_consumer_state: The P correction consumer state :type p_cor_consumer_state: pipeline.PipelineState :param mma_o_consumer_state: The MMA o consumer state :type mma_o_consumer_state: pipeline.PipelineState :return: The P correction consumer state, and the MMA o consumer state :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState] """ p_cor_consumer_state, row_sum, row_max, correction_factor, no_correction = ( self.get_correction_factor(common_params, p_cor_consumer_state) ) k_tile_count = k_tile_count - 1 while k_tile_count > 0: p_cor_consumer_state, row_sum, row_max, correction_factor, no_correction = ( self.get_correction_factor(common_params, p_cor_consumer_state) ) mma_o_consumer_state = self.rescale( common_params, mma_o_consumer_state, correction_factor, no_correction ) k_tile_count = k_tile_count - 1 mma_o_consumer_state = self.epilogue( common_params, epilogue_params, mma_o_consumer_state, row_sum, row_max ) return p_cor_consumer_state, mma_o_consumer_state @cute.jit def softmax_dispatch_apply_mask( self, common_params: SimpleNamespace, softmax_params: SimpleNamespace, k_index: cutlass.Int32, k_tile_total: cutlass.Int32, mma_s_consumer_state: pipeline.PipelineState, p_mma_producer_state: pipeline.PipelineState, p_cor_producer_state: pipeline.PipelineState, row_max: cutlass.Float32, row_sum: cutlass.Float32, correction_factor: cutlass.Float32, ) -> tuple[ pipeline.PipelineState, pipeline.PipelineState, cutlass.Float32, cutlass.Float32, cutlass.Float32, ]: """Dispatch whether to apply mask for softmax for last k tile. :param common_params: The common parameters :type common_params: SimpleNamespace :param softmax_params: The softmax parameters :type softmax_params: SimpleNamespace :param k_index: The index of the k-tile :type k_index: cutlass.Int32 :param k_tile_total: The total number of k-tiles :type k_tile_total: cutlass.Int32 :param mma_s_consumer_state: The MMA s consumer state :type mma_s_consumer_state: pipeline.PipelineState :param p_mma_producer_state: The P MMA producer state :type p_mma_producer_state: pipeline.PipelineState :param p_cor_producer_state: The P correction producer state :type p_cor_producer_state: pipeline.PipelineState :param row_max: The row max :type row_max: cutlass.Float32 :param row_sum: The row sum :type row_sum: cutlass.Float32 :param correction_factor: The correction factor :type correction_factor: cutlass.Float32 :return: The MMA s consumer state, the P MMA producer state, the row max, the row sum, and the correction factor :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, cutlass.Float32, cutlass.Float32, cutlass.Float32] """ if k_index == k_tile_total - 1: ( mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state, row_max, row_sum, correction_factor, ) = self.softmax( common_params, softmax_params, k_index, mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state, row_max, row_sum, correction_factor, True, ) else: ( mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state, row_max, row_sum, correction_factor, ) = self.softmax( common_params, softmax_params, k_index, mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state, row_max, row_sum, correction_factor, False, ) return ( mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state, row_max, row_sum, correction_factor, ) @cute.jit def softmax( self, common_params: SimpleNamespace, softmax_params: SimpleNamespace, k_index: cutlass.Int32, mma_s_consumer_state: pipeline.PipelineState, p_mma_producer_state: pipeline.PipelineState, p_cor_producer_state: pipeline.PipelineState, row_max: cutlass.Float32, row_sum: cutlass.Float32, correction_factor: cutlass.Float32, is_last_tile: bool, ) -> tuple[ pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, cutlass.Float32, cutlass.Float32, cutlass.Float32, ]: """Softmax for one k-tile. Updates the related pipeline states and returns the computed results. :param common_params: The common parameters :type common_params: SimpleNamespace :param softmax_params: The softmax parameters :type softmax_params: SimpleNamespace :param k_index: The index of the k-tile :type k_index: cutlass.Int32 :param mma_s_consumer_state: The MMA s consumer state :type mma_s_consumer_state: pipeline.PipelineState :param p_mma_producer_state: The P MMA producer state :type p_mma_producer_state: pipeline.PipelineState :param p_cor_producer_state: The P correction producer state :type p_cor_producer_state: pipeline.PipelineState :param row_max: The row max :type row_max: cutlass.Float32 :param row_sum: The row sum :type row_sum: cutlass.Float32 :param correction_factor: The correction factor :type correction_factor: cutlass.Float32 :param is_last_tile: Whether the last tile :type is_last_tile: bool :return: The MMA s consumer state, the P MMA producer state, the P correction producer state, the row max, the row sum, and the correction factor :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, cutlass.Float32, cutlass.Float32, cutlass.Float32] """ softmax_params.p_mma_pipeline.producer_acquire(p_mma_producer_state) softmax_params.mma_s_pipeline.consumer_wait(mma_s_consumer_state) # load S from tmem tStS_shape = softmax_params.tiled_mma_qk.partition_shape_C( cute.select(self.mma_qk_tiler, mode=[0, 1]) ) tStS_staged_fake = softmax_params.tiled_mma_qk.make_fragment_C( cute.append(tStS_shape, self.mma_s_stage) ) tStS_staged = cute.make_tensor(common_params.tmem_ptr, tStS_staged_fake.layout) tStS = tStS_staged[None, None, None, mma_s_consumer_state.index] tAcc = tStS[(None, None), 0, 0] cta_qk_tiler = ( self.mma_qk_tiler[0] // self.cluster_shape_mnk[0], self.mma_qk_tiler[1], self.mma_qk_tiler[2], ) cS = cute.make_identity_tensor(cute.select(cta_qk_tiler, mode=[0, 1])) tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype ) tmem_tiled_copy = tcgen05.make_tmem_copy(tmem_load_atom, tAcc) tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) tmem_thr_copy = tmem_tiled_copy.get_slice(tidx) tTR_tAcc = tmem_thr_copy.partition_S(tAcc) tTR_tS = tmem_thr_copy.partition_D(cS) tTR_rAcc = cute.make_fragment_like(tTR_tS, self.acc_dtype) cute.copy(tmem_tiled_copy, tTR_tAcc, tTR_rAcc) row_max_new = row_max for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): if cutlass.const_expr(is_last_tile): tTR_rAcc[i] = ( tTR_rAcc[i] if cute.elem_less( tTR_tS[i][1] + self.mma_qk_tiler[1] * k_index, common_params.K, ) else -self.acc_dtype.inf ) # update row_max row_max_new = cute.arch.fmax(row_max_new, tTR_rAcc[i]) # if warps in N is 2, reduce row_max across warps (0, 1) and (2, 3) if cutlass.const_expr(self.warps_in_n == 2): common_params.smem_exchange[tidx] = row_max_new self.softmax_exchange_sync_bar.wait() row_max_new = cute.arch.fmax( row_max_new, common_params.smem_exchange[ (tidx + 64) % (self.num_compute_warps * self.threads_per_warp) ], ) # find correction factor correction_factor = cute.math.exp2( (row_max - row_max_new) * softmax_params.softmax_scale_log2, fastmath=True ) no_correction = cutlass.Int32(row_max == row_max_new) # softmax fma_b = (softmax_params.softmax_scale_log2, softmax_params.softmax_scale_log2) fma_c = ( (0.0 - row_max_new) * softmax_params.softmax_scale_log2, (0.0 - row_max_new) * softmax_params.softmax_scale_log2, ) for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2): tTR_rAcc[i], tTR_rAcc[i + 1] = cute.arch.fma_packed_f32x2( (tTR_rAcc[i], tTR_rAcc[i + 1]), fma_b, fma_c ) tTR_rAcc[i] = cute.math.exp2(tTR_rAcc[i], fastmath=True) tTR_rAcc[i + 1] = cute.math.exp2(tTR_rAcc[i + 1], fastmath=True) tTR_rS = cute.make_fragment_like(tTR_tS, self.q_dtype) # quantize tTR_rS.store(tTR_rAcc.load().to(self.q_dtype)) # create sP sP = softmax_params.sP[None, None, None, (None, p_mma_producer_state.index)] sP_mk_view = cute.make_tensor( sP.iterator, cute.make_layout( ( (sP.shape[0][0], sP.shape[1]), (sP.shape[0][1], sP.shape[2], sP.shape[3]), ), stride=( (sP.stride[0][0], sP.stride[1]), (sP.stride[0][1], sP.stride[2], sP.stride[3]), ), ), ) # change to PISL sP_wo_swizzle_iter = cute.recast_ptr(sP.iterator, swizzle_=None) swizzle_bits = ( int(math.log2(self.mma_pv_tiler[2] * self.q_dtype.width // 8 // 32)) + 1 ) swizzle_base = 3 if self.q_dtype.width == 16 else 4 sP_swizzle = cute.make_swizzle(swizzle_bits, swizzle_base, 3) sP_mk_view = cute.make_tensor( sP_wo_swizzle_iter, cute.make_composed_layout(sP_swizzle, 0, sP_mk_view.layout), ) universal_copy_bits = 128 smem_copy_atom = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), self.q_dtype, num_bits_per_copy=universal_copy_bits, ) smem_tiled_copy = cute.make_tiled_copy_D(smem_copy_atom, tmem_tiled_copy) smem_thr_copy = smem_tiled_copy.get_slice(tidx) rP_copy_view = smem_thr_copy.retile(tTR_rS) sP_copy_view = smem_thr_copy.partition_D(sP_mk_view) cute.copy(smem_tiled_copy, rP_copy_view, sP_copy_view) # row_sum, using `add_packed_f32x2` to reduce the number of instructions row_sum = row_sum * correction_factor row_sum_vec = (0.0, 0.0) for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2): row_sum_vec = cute.arch.add_packed_f32x2( row_sum_vec, (tTR_rAcc[i], tTR_rAcc[i + 1]) ) row_sum = row_sum_vec[0] + row_sum_vec[1] + row_sum # fence between tmem load and mma s cute.arch.fence_view_async_tmem_load() # fence between smem store and mma o cute.arch.fence_view_async_shared() softmax_params.mma_s_pipeline.consumer_release(mma_s_consumer_state) softmax_params.p_mma_pipeline.producer_commit(p_mma_producer_state) mma_s_consumer_state.advance() p_mma_producer_state.advance() # store correction factor/row_sum/row_max to tmem for correction warp common_params.p_cor_pipeline.producer_acquire(p_cor_producer_state) # pad for 4x32b corr_layout = cute.make_layout( (tAcc.shape[0], (4, tAcc.shape[1][1]), self.mma_s_stage), stride=(tAcc.stride[0], (1, tAcc.stride[1][1]), 4), ) tCor = cute.make_tensor( common_params.tmem_ptr + self.correction_factor_offset, corr_layout, ) cCor = cute.make_identity_tensor(tCor.shape) corr_tmem_store_atom = cute.make_copy_atom( tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(4)), self.acc_dtype ) corr_tmem_store_tiled_copy = tcgen05.make_tmem_copy(corr_tmem_store_atom, tCor) corr_tmem_store_thr_copy = corr_tmem_store_tiled_copy.get_slice(tidx) cCor_for_copy = corr_tmem_store_thr_copy.partition_S(cCor) tCor_for_copy = corr_tmem_store_thr_copy.partition_D(tCor) rCor = cute.make_fragment_like( cCor_for_copy[None, None, None, 0], self.acc_dtype ) rCor_int = cute.make_tensor( cute.recast_ptr(rCor.iterator, dtype=cutlass.Int32), rCor.layout ) rCor[0] = row_sum rCor[1] = row_max_new rCor[2] = correction_factor rCor_int[3] = no_correction cute.copy( corr_tmem_store_tiled_copy, rCor, tCor_for_copy[None, None, None, p_cor_producer_state.index], ) # fence between tmem store and correction warp cute.arch.fence_view_async_tmem_store() common_params.p_cor_pipeline.producer_commit(p_cor_producer_state) p_cor_producer_state.advance() return ( mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state, row_max_new, row_sum, correction_factor, ) @cute.jit def _tmem_load_partition( self, common_params: SimpleNamespace, tiled_mma_pv: cute.TiledMma, iter_n: int ) -> tuple[ cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma ]: """Tensor memory load partition for rescale and epilogue. :param common_params: The common parameters :type common_params: SimpleNamespace :param tiled_mma_pv: The tiled mma pv :type tiled_mma_pv: cute.TiledMma :param iter_n: The iteration number :type iter_n: int :return: The tiled mma pv, the tiled mma pv, the tiled mma pv, the tiled mma pv, the tiled mma pv :rtype: tuple[cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma] """ tOtO_shape = tiled_mma_pv.partition_shape_C( cute.select(self.mma_pv_tiler, mode=[0, 1]) ) tOtO = tiled_mma_pv.make_fragment_C(tOtO_shape) tOtO_layout = cute.append( tOtO.layout, cute.make_layout( common_params.L // self.mma_pv_tiler[1], stride=self.mma_pv_tiler[1] // self.warps_in_n, ), ) tOtO = cute.make_tensor( common_params.tmem_ptr + self.tmem_o_offset, tOtO_layout ) tOtO = tOtO[None, None, None, iter_n] tAcc = tOtO[(None, None), 0, 0] tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype ) tmem_load_tiled_copy = tcgen05.make_tmem_copy(tmem_load_atom, tAcc) tmem_load_thr_copy = tmem_load_tiled_copy.get_slice( common_params.tidx % (self.num_compute_warps * self.threads_per_warp) ) cta_pv_tiler = ( self.mma_pv_tiler[0] // self.cluster_shape_mnk[0], self.mma_pv_tiler[1], self.mma_pv_tiler[2], ) # Flatten divide and partition global tensors for O cta_pv_tiler_mn = cute.select(cta_pv_tiler, mode=[0, 1]) gO = None if cutlass.const_expr(common_params.mAccO is not None): gO = cute.local_tile( common_params.mAccO[None, common_params.blk_coord[3], None, None], cta_pv_tiler_mn, (common_params.blk_coord[0], iter_n, common_params.blk_coord[2]), ) cO = cute.local_tile( cute.make_identity_tensor( common_params.mAccO[ None, common_params.blk_coord[3], None, None ].shape ), cta_pv_tiler_mn, (common_params.blk_coord[0], iter_n, common_params.blk_coord[2]), ) else: gO = cute.local_tile( common_params.mO, cta_pv_tiler_mn, (common_params.blk_coord[0], iter_n, common_params.blk_coord[2]), ) cO = cute.local_tile( cute.make_identity_tensor(common_params.mO.shape), cta_pv_tiler_mn, (common_params.blk_coord[0], iter_n, common_params.blk_coord[2]), ) tTR_tAcc = tmem_load_thr_copy.partition_S(tAcc) tTR_gO = tmem_load_thr_copy.partition_D(gO) tTR_cO = tmem_load_thr_copy.partition_D(cO) tTR_rAcc = cute.make_fragment_like(tTR_gO, self.acc_dtype) return tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc def get_correction_factor( self, common_params: SimpleNamespace, p_cor_consumer_state: pipeline.PipelineState, ) -> tuple[ pipeline.PipelineState, cutlass.Float32, cutlass.Float32, cutlass.Float32, cutlass.Int32, ]: """Get the correction factor from the P correction consumer state. :param common_params: The common parameters :type common_params: SimpleNamespace :param p_cor_consumer_state: The P correction consumer state :type p_cor_consumer_state: pipeline.PipelineState :return: The P correction consumer state, the row_sum, the row_max, and the correction factor :rtype: tuple[pipeline.PipelineState, cutlass.Float32, cutlass.Float32, cutlass.Float32, cutlass.Int32] """ common_params.p_cor_pipeline.consumer_wait(p_cor_consumer_state) tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) # load correction factor _, tAcc, _, _, _, _ = self._tmem_load_partition( common_params, common_params.tiled_mma_pv, 0 ) corr_layout = cute.make_layout( (tAcc.shape[0], (4, tAcc.shape[1][1]), self.p_cor_stage), stride=(tAcc.stride[0], (1, tAcc.stride[1][1]), 4), ) tCor = cute.make_tensor( common_params.tmem_ptr + self.correction_factor_offset, corr_layout ) cCor = cute.make_identity_tensor(tCor.shape) corr_tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(4)), self.acc_dtype ) corr_tmem_load_tiled_copy = tcgen05.make_tmem_copy(corr_tmem_load_atom, tCor) corr_tmem_load_thr_copy = corr_tmem_load_tiled_copy.get_slice(tidx) tCor_for_copy = corr_tmem_load_thr_copy.partition_S(tCor) cCor_for_copy = corr_tmem_load_thr_copy.partition_D(cCor) rCor = cute.make_fragment_like( cCor_for_copy[None, None, None, 0], self.acc_dtype ) rCor_int = cute.make_tensor( cute.recast_ptr(rCor.iterator, dtype=cutlass.Int32), rCor.layout ) cute.copy( corr_tmem_load_tiled_copy, tCor_for_copy[None, None, None, p_cor_consumer_state.index], rCor, ) row_sum = rCor[0] row_max = rCor[1] correction_factor = rCor[2] no_correction = rCor_int[3] common_params.p_cor_pipeline.consumer_release(p_cor_consumer_state) p_cor_consumer_state.advance() return p_cor_consumer_state, row_sum, row_max, correction_factor, no_correction @cute.jit def rescale( self, common_params: SimpleNamespace, mma_o_consumer_state: pipeline.PipelineState, correction_factor: cutlass.Float32, no_correction: cutlass.Int32, ) -> pipeline.PipelineState: """Rescale for one k-tile. Updates the related pipeline state. :param common_params: The common parameters :type common_params: SimpleNamespace :param mma_o_consumer_state: The mma o consumer state :type mma_o_consumer_state: pipeline.PipelineState :param correction_factor: The correction factor :type correction_factor: cutlass.Float32 :param no_correction: Whether to apply correction factor :type no_correction: cutlass.Int32 :return: The MMA o consumer state :rtype: pipeline.PipelineState """ common_params.mma_o_pipeline.consumer_wait(mma_o_consumer_state) skip_correction = cute.arch.vote_all_sync(no_correction == 1) if not skip_correction: for iter_n in cutlass.range_constexpr(self.iterations_pv_n): # tmem load tiled copy and partition results. tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc = ( self._tmem_load_partition( common_params, common_params.tiled_mma_pv, iter_n ) ) # tmem store tiled copy tmem_store_atom = cute.make_copy_atom( tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype ) tmem_store_tiled_copy = tcgen05.make_tmem_copy(tmem_store_atom, tAcc) # load o cute.copy(tmem_load_tiled_copy, tTR_tAcc, tTR_rAcc) # rescale, using `mul_packed_f32x2` to reduce the number of instructions for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2): tTR_rAcc[i], tTR_rAcc[i + 1] = cute.arch.mul_packed_f32x2( ( tTR_rAcc[i], tTR_rAcc[i + 1], ), (correction_factor, correction_factor), ) # store o to tensor memory for next k tile cute.copy(tmem_store_tiled_copy, tTR_rAcc, tTR_tAcc) cute.arch.fence_view_async_tmem_store() common_params.mma_o_pipeline.consumer_release(mma_o_consumer_state) mma_o_consumer_state.advance() return mma_o_consumer_state @cute.jit def epilogue( self, common_params: SimpleNamespace, epilogue_params: SimpleNamespace, mma_o_consumer_state: pipeline.PipelineState, row_sum: cutlass.Float32, row_max: cutlass.Float32, ) -> pipeline.PipelineState: """Epilogue for one k-tile. Updates the related pipeline state. :param common_params: The common parameters :type common_params: SimpleNamespace :param epilogue_params: The epilogue parameters :type epilogue_params: SimpleNamespace :param mma_o_consumer_state: The mma o consumer state :type mma_o_consumer_state: pipeline.PipelineState :param row_sum: The row sum :type row_sum: cutlass.Float32 :param row_max: The row max :type row_max: cutlass.Float32 :return: The MMA o consumer state :rtype: pipeline.PipelineState """ # mma_o pipeline consumer wait common_params.mma_o_pipeline.consumer_wait(mma_o_consumer_state) tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) # exchange row_sum between warps (0, 1) and (2, 3) if cutlass.const_expr(self.warps_in_n == 2): common_params.smem_exchange[tidx] = row_sum self.epilogue_exchange_sync_bar.wait() # (64, 2) row_sum = ( row_sum + common_params.smem_exchange[ (tidx + 64) % (self.num_compute_warps * self.threads_per_warp) ] ) for iter_n in cutlass.range_constexpr(self.iterations_pv_n): # tmem load tiled copy and partition results. tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc = ( self._tmem_load_partition( common_params, common_params.tiled_mma_pv, iter_n ) ) # load o cute.copy(tmem_load_tiled_copy, tTR_tAcc, tTR_rAcc) # apply output scale and normalize by row_sum for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2): tTR_rAcc[i], tTR_rAcc[i + 1] = cute.arch.mul_packed_f32x2( (tTR_rAcc[i], tTR_rAcc[i + 1]), ( epilogue_params.output_scale * cute.arch.rcp_approx(row_sum), epilogue_params.output_scale * cute.arch.rcp_approx(row_sum), ), ) # store o to global memory tR2G_rO_src = None tR2G_rO_dst = tTR_gO if cutlass.const_expr(common_params.mAccO is None): tR2G_rO_src = cute.make_fragment_like(tTR_gO, self.o_dtype) # using final output dtype for o tR2G_rO_src.store(tTR_rAcc.load().to(self.o_dtype)) else: # using accumulate dtype for o tR2G_rO_src = tTR_rAcc if cute.elem_less(tTR_cO[0][0], common_params.H): cute.autovec_copy(tR2G_rO_src, tR2G_rO_dst) # store the lse to global memory cta_pv_tiler = ( self.mma_pv_tiler[0] // self.cluster_shape_mnk[0], self.mma_pv_tiler[1], self.mma_pv_tiler[2], ) gLSE = None cLSE = None if cutlass.const_expr(epilogue_params.mAccLSE is None): gLSE = cute.local_tile( epilogue_params.mLSE, (cta_pv_tiler[0], 1, 1), ( common_params.blk_coord[0], common_params.blk_coord[1], common_params.blk_coord[2], ), (1, None, 1), ) cLSE = cute.local_tile( cute.make_identity_tensor(epilogue_params.mLSE.shape), (cta_pv_tiler[0], 1, 1), ( common_params.blk_coord[0], common_params.blk_coord[1], common_params.blk_coord[2], ), (1, None, 1), ) else: gLSE = cute.local_tile( epilogue_params.mAccLSE[None, common_params.blk_coord[3], None], (cta_pv_tiler[0], 1, 1), ( common_params.blk_coord[0], common_params.blk_coord[1], common_params.blk_coord[2], ), (1, None, 1), ) cLSE = cute.local_tile( cute.make_identity_tensor( epilogue_params.mAccLSE[ None, common_params.blk_coord[3], None ].shape ), (cta_pv_tiler[0], 1, 1), ( common_params.blk_coord[0], common_params.blk_coord[1], common_params.blk_coord[2], ), (1, None, 1), ) lse = ( cute.math.log2(row_sum, fastmath=True) + epilogue_params.softmax_scale_log2 * row_max ) if cutlass.const_expr(self.warps_in_n == 2): if cute.elem_less(cLSE[tidx][0], common_params.H): gLSE[tidx] = lse cute.arch.fence_view_async_tmem_load() common_params.mma_o_pipeline.consumer_release(mma_o_consumer_state) mma_o_consumer_state.advance() return mma_o_consumer_state def make_and_init_load_qkv_pipeline( self, load_qkv_mbar_ptr, cta_layout_vmnk, load_stages, tx_count, is_cpasync ) -> pipeline.PipelineTmaUmma: """Create and initialize the tma load qkv pipeline. :param load_qkv_mbar_ptr: The load qkv mbar pointer :type load_qkv_mbar_ptr: cute.Tensor :param cta_layout_vmnk: The cta layout vmnk :type cta_layout_vmnk: tuple[int, int, int] :param load_stages: The load stages :type load_stages: list[int] :param tx_count: The tx count :type tx_count: int :param is_cpasync: Whether to use cpasync :type is_cpasync: bool :return: The tma load qkv pipeline :rtype: pipeline.PipelineTmaUmma """ if is_cpasync: load_qkv_producer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, len(self.load_cp_async_warp_ids) * self.threads_per_warp * self.cluster_shape_mnk[0], ) load_qkv_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, len([self.mma_warp_id]) ) return pipeline.PipelineAsyncUmma.create( barrier_storage=load_qkv_mbar_ptr, num_stages=load_stages, producer_group=load_qkv_producer_group, consumer_group=load_qkv_consumer_group, cta_layout_vmnk=cta_layout_vmnk, ) else: load_qkv_producer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, len([self.load_tma_warp_id]) ) load_qkv_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, len([self.mma_warp_id]) ) return pipeline.PipelineTmaUmma.create( barrier_storage=load_qkv_mbar_ptr, num_stages=load_stages, producer_group=load_qkv_producer_group, consumer_group=load_qkv_consumer_group, tx_count=tx_count, cta_layout_vmnk=cta_layout_vmnk, ) def make_and_init_mma_s_pipeline( self, mma_s_mbar_ptr, cta_layout_vmnk ) -> pipeline.PipelineUmmaAsync: """Create and initialize the mma s pipeline. :param mma_s_mbar_ptr: The mma s mbar pointer :type mma_s_mbar_ptr: cute.Tensor :param cta_layout_vmnk: The cta layout vmnk :type cta_layout_vmnk: tuple[int, int, int] :return: The mma s pipeline :rtype: pipeline.PipelineUmmaAsync """ mma_s_producer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, len([self.mma_warp_id]) ) consumer_thread_size = ( self.threads_per_warp * len(self.compute_warp_ids) * self.cluster_shape_mnk[0] ) mma_s_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, consumer_thread_size, ) return pipeline.PipelineUmmaAsync.create( barrier_storage=mma_s_mbar_ptr, num_stages=self.mma_s_stage, producer_group=mma_s_producer_group, consumer_group=mma_s_consumer_group, cta_layout_vmnk=cta_layout_vmnk, ) def make_and_init_p_mma_pipeline( self, p_mma_mbar_ptr, cta_layout_vmnk ) -> pipeline.PipelineAsyncUmma: """Create and initialize the p mma pipeline. :param p_mma_mbar_ptr: The p mma mbar pointer :type p_mma_mbar_ptr: cute.Tensor :param cta_layout_vmnk: The cta layout vmnk :type cta_layout_vmnk: tuple[int, int, int] :return: The p mma pipeline :rtype: pipeline.PipelineAsyncUmma """ producer_thread_size = ( self.threads_per_warp * len(self.compute_warp_ids) * self.cluster_shape_mnk[0] ) p_mma_producer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, producer_thread_size, ) p_mma_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, len([self.mma_warp_id]) ) return pipeline.PipelineAsyncUmma.create( barrier_storage=p_mma_mbar_ptr, num_stages=self.p_mma_stage, producer_group=p_mma_producer_group, consumer_group=p_mma_consumer_group, cta_layout_vmnk=cta_layout_vmnk, ) def make_and_init_p_cor_pipeline( self, p_cor_mbar_ptr ) -> pipeline.PipelineAsyncUmma: """Create and initialize the p correction pipeline. :param p_cor_mbar_ptr: The p correction mbar pointer :type p_cor_mbar_ptr: cute.Tensor :return: The p correction pipeline :rtype: pipeline.PipelineAsyncUmma """ producer_thread_size = self.threads_per_warp * len(self.compute_warp_ids) p_cor_producer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, producer_thread_size, ) p_cor_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, producer_thread_size, ) return pipeline.PipelineAsync.create( barrier_storage=p_cor_mbar_ptr, num_stages=self.p_cor_stage, producer_group=p_cor_producer_group, consumer_group=p_cor_consumer_group, ) def make_and_init_mma_o_pipeline( self, mma_o_mbar_ptr, cta_layout_vmnk ) -> pipeline.PipelineUmmaAsync: """Create and initialize the mma o pipeline. :param mma_o_mbar_ptr: The mma o mbar pointer :type mma_o_mbar_ptr: cute.Tensor :param cta_layout_vmnk: The cta layout vmnk :type cta_layout_vmnk: tuple[int, int, int] :return: The mma o pipeline :rtype: pipeline.PipelineUmmaAsync """ mma_o_producer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, len([self.mma_warp_id]) ) consumer_thread_size = ( self.threads_per_warp * len(self.compute_warp_ids) * self.cluster_shape_mnk[0] ) mma_o_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, consumer_thread_size, ) return pipeline.PipelineUmmaAsync.create( barrier_storage=mma_o_mbar_ptr, num_stages=self.mma_o_stage, producer_group=mma_o_producer_group, consumer_group=mma_o_consumer_group, cta_layout_vmnk=cta_layout_vmnk, ) def make_and_init_load_pt_pipeline(self, load_pt_mbar_ptr): """Create and initialize the load page table pipeline. :param load_pt_mbar_ptr: The load page table mbar pointer :type load_pt_mbar_ptr: cute.Tensor :return: The load page table pipeline :rtype: pipeline.PipelineAsync """ load_pt_producer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, self.threads_per_warp * len([self.load_pt_warp_id]), ) load_pt_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, self.threads_per_warp * len(self.load_cp_async_warp_ids), ) return pipeline.PipelineCpAsync.create( barrier_storage=load_pt_mbar_ptr, num_stages=self.load_pt_stage, producer_group=load_pt_producer_group, consumer_group=load_pt_consumer_group, ) @staticmethod def _compute_grid( o: cute.Tensor, split_kv: cutlass.Int32, cluster_shape_mnk: Tuple[int, int, int], max_active_clusters: int, is_persistent: bool, ) -> Tuple[MLAStaticTileSchedulerParams, Tuple[int, int, int]]: """Compute grid shape for the output tensor C. :param c: The output tensor C :type c: cute.Tensor :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. :type cta_tile_shape_mnk: tuple[int, int, int] :param cluster_shape_mn: Shape of each cluster in M, N dimensions. :type cluster_shape_mn: tuple[int, int] :return: Tile scheduler parameters and grid shape. :rtype: tuple[MLAStaticTileSchedulerParams, tuple[int, int, int]] """ o_shape = o.shape tile_sched_params = create_mla_static_tile_scheduler_params( is_persistent, cute.size(o_shape[2]), cluster_shape_mnk, split_kv, ) grid = MLAStaticTileScheduler.get_grid_shape( tile_sched_params, max_active_clusters ) return tile_sched_params, grid @staticmethod def get_workspace_size( H: int, D: int, B: int, split_kv: int, acc_dtype: Type[cutlass.Numeric], ) -> int: """Get the extra workspace(device memory) size for the MLA kernel when split_kv is not 1. :param H: The height of the output tensor C :type H: int :param D: The depth of the output tensor C :type D: int :param B: The batch size of the output tensor C :type B: int :param split_kv: The split key-value of the output tensor C :type split_kv: int :param acc_dtype: The data type of the output tensor C :type acc_dtype: Type[cutlass.Numeric] :return: The workspace size for the MLA kernel :rtype: int """ if split_kv == 1: return 0 return B * H * split_kv * (D + 1) * acc_dtype.width // 8 @cute.jit def initialize_workspace( self, H: cutlass.Int32, D: cutlass.Int32, B: cutlass.Int32, split_kv: cutlass.Int32, acc_dtype: Type[cutlass.Numeric], workspace: cute.Tensor, ) -> tuple[cute.Tensor, cute.Tensor]: """Initialize the workspace for the MLA kernel. Construct the intermediate tensors acc_o and acc_lse. :param H: The height of the output tensor C :type H: cutlass.Int32 :param D: The depth of the output tensor C :type D: cutlass.Int32 :param B: The batch size of the output tensor C :type B: cutlass.Int32 :param split_kv: The split key-value of the output tensor C :type split_kv: cutlass.Int32 :param acc_dtype: The data type of the output tensor C :type acc_dtype: Type[cutlass.Numeric] :param workspace: The workspace tensor :type workspace: cute.Tensor :return: The output tensor C and the workspace tensor :rtype: tuple[cute.Tensor, cute.Tensor] """ acc_o, acc_lse = None, None if cutlass.const_expr(workspace is not None): align = 128 // self.q_dtype.width acc_o_layout = cute.make_layout( (H, split_kv, D, B), stride=( cute.assume(split_kv * D, align), cute.assume(D, align), 1, cute.assume(H * split_kv * D, align), ), ) acc_o_iter = cute.recast_ptr(workspace.iterator, dtype=acc_dtype) acc_o = cute.make_tensor(acc_o_iter, acc_o_layout) acc_lse_layout = cute.make_layout( (H, split_kv, B), stride=(split_kv, 1, H * split_kv) ) acc_lse_iter = cute.recast_ptr( workspace.iterator + cute.cosize(acc_o_layout) * acc_dtype.width // 8, dtype=acc_dtype, ) acc_lse = cute.make_tensor(acc_lse_iter, acc_lse_layout) return acc_o, acc_lse @staticmethod def can_implement( B: int, K: int, H: int, L: int, R: int, in_dtype: Type[cutlass.Numeric], out_dtype: Type[cutlass.Numeric], acc_dtype: Type[cutlass.Numeric], lse_dtype: Type[cutlass.Numeric], mma_qk_tiler_mn: Tuple[int, int], mma_pv_tiler_mn: Tuple[int, int], split_kv: int, is_persistent: bool, is_cpasync: bool, is_var_seq: bool, is_var_split_kv: bool, use_page_table: bool, page_size: int, ) -> bool: """Check if the MLA kernel can be implemented. :param H: The height of the output tensor C :type H: int :param K: The width of the output tensor C :type K: int :param L: The length of the output tensor C :type L: int :param R: The row of the output tensor C :type R: int :param B: The batch size of the output tensor C :type B: int :param in_dtype: The data type of the input tensor :type in_dtype: Type[cutlass.Numeric] :param out_dtype: The data type of the output tensor :type out_dtype: Type[cutlass.Numeric] :param acc_dtype: The data type of the accumulator :type acc_dtype: Type[cutlass.Numeric] :param lse_dtype: The data type of the log-sum-exp :type lse_dtype: Type[cutlass.Numeric] :param mma_qk_tiler_mn: The tile shape of the query-key matrix multiplication :type mma_qk_tiler_mn: Tuple[int, int] :param mma_pv_tiler_mn: The tile shape of the probability-value matrix multiplication :type mma_pv_tiler_mn: Tuple[int, int] :param split_kv: The split key-value of the output tensor C :type split_kv: int :param is_persistent: Whether to use persistent kernel optimization :type is_persistent: bool :param is_cpasync: Whether to use cpasync :type is_cpasync: bool :param is_var_seq: Whether to use variable sequence length :type is_var_seq: bool :param is_var_split_kv: Whether to use variable split_kv :type is_var_split_kv: bool :param use_page_table: Whether to use page table :type use_page_table: bool :param page_size: The page size of the page table :type page_size: int :return: Whether the MLA kernel can be implemented :rtype: bool """ if L != 512 or R != 64: return False if in_dtype not in [cutlass.Float8E4M3FN, cutlass.Float16]: return False if out_dtype not in [cutlass.Float8E4M3FN, cutlass.Float16]: return False if acc_dtype != cutlass.Float32 or lse_dtype != cutlass.Float32: return False if is_cpasync: if not use_page_table: return False if page_size & (page_size - 1) != 0: return False if page_size > mma_qk_tiler_mn[1]: return False else: if use_page_table and page_size != mma_qk_tiler_mn[1]: return False if mma_qk_tiler_mn[0] != mma_pv_tiler_mn[0] or mma_qk_tiler_mn[0] != 128: return False if is_var_split_kv and (not use_page_table or not is_var_seq): return False if is_var_seq and not use_page_table: return False if not is_cpasync and (H > 128 or (H < 128 and split_kv != 1)): return False if is_cpasync and H != 128: return False if K <= 0: return False return True def ceil_div(a: int, b: int) -> int: return (a + b - 1) // b def run( batch_size: int, seq_len: int, num_heads: int, latent_dim: int, rope_dim: int, in_dtype: Type[cutlass.Numeric], out_dtype: Type[cutlass.Numeric], acc_dtype: Type[cutlass.Numeric], lse_dtype: Type[cutlass.Numeric], mma_qk_tiler_mn: Tuple[int, int], mma_pv_tiler_mn: Tuple[int, int], split_kv: int, is_persistent: bool, is_cpasync: bool, is_var_seq: bool, is_var_split_kv: bool, use_page_table: bool, page_size: int, softmax_scale: float, output_scale: float, tolerance: float, warmup_iterations: int, iterations: int, skip_ref_check: bool, use_cold_l2: bool, **kwargs, ): """Execute Multi-Head Latent Attention (MLA) on Blackwell architecture and validate results. This function creates random input tensors for query latent/rope, compressed latent/rope, and value, then performs the complete MLA computation pipeline. It supports configurable data types, tiling parameters, page table, variable sequence length, and variable split_kv. Results can be validated against a PyTorch reference implementation or run multiple times for performance measurement. :param batch_size: Batch size :type batch_size: int :param seq_len: Sequence length :type seq_len: int :param num_heads: Number of heads :type num_heads: int :param latent_dim: dimension of query/compressed latent :type latent_dim: int :param rope_dim: dimension of query/compressed rope :type rope_dim: int :param in_dtype: Input data type for query/compressed latent/rope tensors :type in_dtype: Type[cutlass.Numeric] :param out_dtype: Output data type for attention output :type out_dtype: Type[cutlass.Numeric] :param acc_dtype: Accumulator data type for query-key matrix multiplication :type acc_dtype: Type[cutlass.Numeric] :param lse_dtype: Accumulator data type for log-sum-exp :type lse_dtype: Type[cutlass.Numeric] :param mma_qk_tiler_mn: Matrix multiply accumulate tile shape (M, N) for query-key matrix multiplication :type mma_qk_tiler_mn: Tuple[int, int] :param mma_pv_tiler_mn: Matrix multiply accumulate tile shape (M, N) for probability-value matrix multiplication :type mma_pv_tiler_mn: Tuple[int, int] :param split_kv: Split key-value :type split_kv: int :param is_persistent: Whether to use persistent kernel optimization :type is_persistent: bool :param is_cpasync: Whether to use cpasync :type is_cpasync: bool :param is_var_seq: Whether to use variable sequence length :type is_var_seq: bool :param is_var_split_kv: Whether to use variable split_kv :type is_var_split_kv: bool :param use_page_table: Whether to use page table :type use_page_table: bool :param page_size: Page size of the page table :type page_size: int :param softmax_scale: Attention score scaling factor :type softmax_scale: float :param output_scale: Output scaling factor :type output_scale: float :param tolerance: Maximum acceptable error for validation :type tolerance: float :param warmup_iterations: Number of warmup iterations :type warmup_iterations: int :param iterations: Number of iterations to run for performance testing :type iterations: int :param skip_ref_check: Skip validation against reference implementation :type skip_ref_check: bool :param use_cold_l2: Whether to use cold L2 cache :type use_cold_l2: bool :raises ValueError: If input shapes are incompatible or head dimension is unsupported :raises RuntimeError: If GPU is unavailable for computation """ print("Running Blackwell MLA test with:") print(f" batch_size: {batch_size}") print(f" seq_len: {seq_len}") print(f" num_heads: {num_heads}") print(f" latent_dim: {latent_dim}") print(f" rope_dim: {rope_dim}") print(f" in_dtype: {in_dtype}") print(f" out_dtype: {out_dtype}") print(f" acc_dtype: {acc_dtype}") print(f" mma_qk_tiler_mn: {mma_qk_tiler_mn}") print(f" mma_pv_tiler_mn: {mma_pv_tiler_mn}") print(f" split_kv: {split_kv}") print(f" is_persistent: {is_persistent}") print(f" is_cpasync: {is_cpasync}") print(f" is_var_seq: {is_var_seq}") print(f" is_var_split_kv: {is_var_split_kv}") print(f" use_page_table: {use_page_table}") print(f" page_size: {page_size}") print(f" softmax_scale: {softmax_scale}") print(f" output_scale: {output_scale}") print(f" tolerance: {tolerance}") print(f" warmup_iterations: {warmup_iterations}") print(f" iterations: {iterations}") print(f" skip_ref_check: {skip_ref_check}") print(f" use_cold_l2: {use_cold_l2}") # Prepare pytorch tensors: Q, K, V (random from 0 to 2) and O (all zero) if not torch.cuda.is_available(): raise RuntimeError("GPU is required to run this example!") if not BlackwellMultiHeadLatentAttentionForward.can_implement( batch_size, seq_len, num_heads, latent_dim, rope_dim, in_dtype, out_dtype, acc_dtype, lse_dtype, mma_qk_tiler_mn, mma_pv_tiler_mn, split_kv, is_persistent, is_cpasync, is_var_seq, is_var_split_kv, use_page_table, page_size, ): raise TypeError( f"Unsupported testcase {in_dtype}, {out_dtype}, {acc_dtype}, {lse_dtype}, {mma_qk_tiler_mn}, {mma_pv_tiler_mn}, {split_kv}, {is_persistent}, {is_cpasync}, {is_var_seq}, {is_var_split_kv}, {use_page_table}, {page_size}" ) torch.manual_seed(1111) def create_data_tensor( B, HK, D, dtype, is_dynamic_layout=True, page_table=None, cache_seqs=None, is_lse=False, ): shape = (B, HK, D) if page_table is not None: if cache_seqs is not None: max_seq_len = torch.max(cache_seqs) shape = (B * ceil_div(max_seq_len, page_size), page_size, D) else: shape = (B * ceil_div(HK, page_size), page_size, D) permute_order = (1, 2, 0) stride_order = (2, 0, 1) leading_dim = 1 if is_lse: shape = (B, HK) permute_order = (1, 0) stride_order = (1, 0) leading_dim = 0 init_config = cutlass.torch.RandomInitConfig(min_val=-2, max_val=2) torch_dtype = ( cutlass_torch.dtype(dtype) if dtype != cutlass.Float8E4M3FN else torch.int8 ) # Create dtype torch tensor (cpu) torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( shape, torch_dtype, permute_order=permute_order, init_type=cutlass.torch.TensorInitType.RANDOM, init_config=init_config, ) # Create dtype torch tensor (gpu) torch_tensor_gpu = torch_tensor_cpu.cuda() # Create f32 torch tensor (cpu) f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32) # Create dtype cute tensor (gpu) # Set use_32bit_stride to True for small problem size(cosize(layout) <= Int32_max) for better performance. cute_tensor = from_dlpack( torch_tensor_gpu, assumed_align=16, use_32bit_stride=True ) cute_tensor.element_type = dtype if is_dynamic_layout: cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=leading_dim) if not is_lse: cute_tensor = cute_tensor.mark_compact_shape_dynamic( mode=leading_dim, stride_order=stride_order, divisibility=(128 // dtype.width), ) cute_tensor = cutlass_torch.convert_cute_tensor( f32_torch_tensor, cute_tensor, dtype, is_dynamic_layout=is_dynamic_layout, ) return f32_torch_tensor, cute_tensor, torch_tensor_gpu def create_cache_seqs(batch_size, seq_len, is_var_seq): cache_seqs_ref = torch.ones(batch_size, dtype=torch.int32) * seq_len cache_seqs_gpu = cache_seqs_ref.cuda() # Set use_32bit_stride to True for small problem size(cosize(layout) <= Int32_max) for better performance. cache_seqs = from_dlpack( cache_seqs_gpu, assumed_align=16, use_32bit_stride=True ).mark_layout_dynamic() if is_var_seq: max_seq_len = seq_len min_seq_len = int(seq_len * 0.8) cache_seqs_ref = cutlass_torch.create_and_permute_torch_tensor( (batch_size,), torch.int32, init_type=cutlass.torch.TensorInitType.RANDOM, init_config=cutlass.torch.RandomInitConfig( min_val=min_seq_len, max_val=max_seq_len + 1 ), ) cache_seqs_gpu = cache_seqs_ref.cuda() cache_seqs = from_dlpack( # Set use_32bit_stride to True for small problem size(cosize(layout) <= Int32_max) for better performance. cache_seqs_gpu, assumed_align=16, use_32bit_stride=True, ).mark_layout_dynamic() return cache_seqs_ref, cache_seqs, cache_seqs_gpu def create_page_table(batch_size, seq_len, is_var_seq, use_page_table, page_size): page_table_ref, page_table, page_table_gpu = None, None, None if use_page_table: max_seq_len = seq_len if not is_var_seq else torch.max(cache_seqs_ref) page_count = ceil_div(max_seq_len, page_size) page_table_ref = torch.empty([batch_size, page_count], dtype=torch.int32) # use transposed index for page table to make sure the value is in bound of `batch_size * seq_len_block`. In practice, the value could be any positive values. This setting is only for testing purpose. for b in range(batch_size): for j in range(page_count): page_table_ref[b, j] = b + j * batch_size page_table_gpu = page_table_ref.permute(1, 0).cuda() # Set use_32bit_stride to True for small problem size(cosize(layout) <= Int32_max) for better performance. page_table = from_dlpack( page_table_gpu, assumed_align=16, use_32bit_stride=True ).mark_layout_dynamic(leading_dim=0) return page_table_ref, page_table, page_table_gpu def create_block_split_kvs( batch_size, split_kv, cache_seqs_ref, is_var_split_kv, mma_qk_tiler_mn, cluster_shape_mnk, max_active_clusters, ): block_split_kvs_ref, block_split_kvs, block_split_kvs_gpu = None, None, None # check if split_kv is valid otherwise do auto setting of split_kv if is_var_split_kv: block_split_kvs_ref = torch.zeros([batch_size], dtype=torch.int32) for b in range(batch_size): block_split_kvs_ref[b] = ( BlackwellMultiHeadLatentAttentionForward.get_split_kv( batch_size, cache_seqs_ref[b].item(), mma_qk_tiler_mn, max_active_clusters * cluster_shape_mnk[0], ) ) split_kv = torch.max(block_split_kvs_ref).item() block_split_kvs_gpu = block_split_kvs_ref.cuda() # Set use_32bit_stride to True for small problem size(cosize(layout) <= Int32_max) for better performance. block_split_kvs = from_dlpack( block_split_kvs_gpu, assumed_align=16, use_32bit_stride=True ).mark_layout_dynamic() elif split_kv <= 0: split_kv = BlackwellMultiHeadLatentAttentionForward.get_split_kv( batch_size, cache_seqs_ref[0].item(), mma_qk_tiler_mn, max_active_clusters * cluster_shape_mnk[0], ) return split_kv, block_split_kvs_ref, block_split_kvs, block_split_kvs_gpu def create_workspace(num_heads, latent_dim, batch_size, split_kv, acc_dtype): workspace_size = BlackwellMultiHeadLatentAttentionForward.get_workspace_size( num_heads, latent_dim, batch_size, split_kv, acc_dtype, ) workspace, workspace_torch = None, None if workspace_size > 0: workspace_torch = torch.empty([workspace_size], dtype=torch.int8).cuda() # Set use_32bit_stride to True for small problem size(cosize(layout) <= Int32_max) for better performance. workspace = from_dlpack( workspace_torch, assumed_align=16, use_32bit_stride=True ) return workspace, workspace_torch cache_seqs_ref, cache_seqs, cache_seqs_torch = create_cache_seqs( batch_size, seq_len, is_var_seq ) page_table_ref, page_table, page_table_torch = create_page_table( batch_size, seq_len, is_var_seq, use_page_table, page_size ) cluster_shape_mnk = (2, 1, 1) hardware_info = utils.HardwareInfo() max_active_clusters = hardware_info.get_max_active_clusters( cluster_shape_mnk[0] * cluster_shape_mnk[1] ) split_kv, block_split_kvs_ref, block_split_kvs, block_split_kvs_torch = ( create_block_split_kvs( batch_size, split_kv, cache_seqs_ref, is_var_split_kv, mma_qk_tiler_mn, cluster_shape_mnk, max_active_clusters, ) ) q_latent_ref, q_latent, q_latent_torch = create_data_tensor( batch_size, num_heads, latent_dim, in_dtype, is_dynamic_layout=True, ) q_rope_ref, q_rope, q_rope_torch = create_data_tensor( batch_size, num_heads, rope_dim, in_dtype, is_dynamic_layout=True, ) c_latent_ref, c_latent, c_latent_torch = create_data_tensor( batch_size, seq_len, latent_dim, in_dtype, is_dynamic_layout=True, page_table=page_table, cache_seqs=cache_seqs_ref, ) c_rope_ref, c_rope, c_rope_torch = create_data_tensor( batch_size, seq_len, rope_dim, in_dtype, is_dynamic_layout=True, page_table=page_table, cache_seqs=cache_seqs_ref, ) o_ref, o, o_torch = create_data_tensor( batch_size, num_heads, latent_dim, out_dtype, is_dynamic_layout=True ) lse_ref, lse, lse_torch = create_data_tensor( batch_size, num_heads, 1, lse_dtype, is_dynamic_layout=True, is_lse=True ) workspace, workspace_torch = create_workspace( num_heads, latent_dim, batch_size, split_kv, acc_dtype ) mla = BlackwellMultiHeadLatentAttentionForward( acc_dtype, lse_dtype, mma_qk_tiler_mn, mma_pv_tiler_mn, max_active_clusters, is_persistent, is_cpasync, use_page_table, is_var_seq, is_var_split_kv, ) # Get current CUDA stream from PyTorch torch_stream = torch.cuda.current_stream() # Get the raw stream pointer as a CUstream stream = cuda.CUstream(torch_stream.cuda_stream) # compile mla kernel compiled_mla = cute.compile( mla, q_latent, q_rope, c_latent, c_rope, page_table, o, lse, workspace, split_kv, cache_seqs, block_split_kvs, softmax_scale, output_scale, stream, ) def torch_reference_mla( q_latent, q_rope, c_latent, c_rope, page_table, cache_seqs, softmax_scale=1.0, output_scale=1.0, ): # expand and concat q_latent and q_rope to have the dimension of sequence length for q q_ref = torch.cat([q_latent, q_rope], dim=1).permute(2, 0, 1).unsqueeze(2) # expand and concat c_latent and c_rope to have the dimension of num_heads for k and v if use_page_table: page_count = page_table_ref.shape[1] k_ref_paged = ( torch.cat([c_latent, c_rope], dim=1) .permute(2, 0, 1) .reshape(batch_size * page_count, page_size, latent_dim + rope_dim) ) v_ref_paged = c_latent.permute(2, 0, 1).reshape( batch_size * page_count, page_size, latent_dim ) if is_var_seq: max_seq_len = torch.max(cache_seqs_ref) else: max_seq_len = seq_len k_ref = torch.zeros([batch_size, 1, max_seq_len, latent_dim + rope_dim]) v_ref = torch.zeros([batch_size, 1, max_seq_len, latent_dim]) k_ref = torch.index_select( k_ref_paged, 0, torch.flatten(page_table_ref) ).reshape(batch_size, 1, -1, latent_dim + rope_dim)[:, :, :max_seq_len, :] v_ref = torch.index_select( v_ref_paged, 0, torch.flatten(page_table_ref) ).reshape(batch_size, 1, -1, latent_dim)[:, :, :max_seq_len, :] for b in range(batch_size): k_ref[b, :, cache_seqs_ref[b] :, :] = 0 v_ref[b, :, cache_seqs_ref[b] :, :] = 0 else: k_ref = torch.cat([c_latent, c_rope], dim=1).permute(2, 0, 1).unsqueeze(1) v_ref = c_latent.permute(2, 0, 1).unsqueeze(1) o_ref = F.scaled_dot_product_attention( q_ref, k_ref, v_ref, attn_mask=None, dropout_p=0.0, scale=softmax_scale, is_causal=False, ) s_ref = torch.einsum("bhld,bhsd->bhls", q_ref, k_ref) s_ref_max = torch.max(s_ref, dim=-1, keepdim=True).values softmax_scale_log2 = LOG2_E * softmax_scale s_ref_sum = torch.sum( torch.exp2((s_ref - s_ref_max) * softmax_scale_log2), dim=-1, keepdim=True ) lse_ref = s_ref_max * softmax_scale_log2 + torch.log2(s_ref_sum) lse_ref = lse_ref.squeeze(3).squeeze(2).permute(1, 0) o_ref = o_ref * output_scale o_ref = o_ref.squeeze(2).permute(1, 2, 0) return o_ref, lse_ref if not skip_ref_check: # Execute kernel once for reference checking compiled_mla( q_latent, q_rope, c_latent, c_rope, page_table, o, lse, workspace, split_kv, cache_seqs, block_split_kvs, softmax_scale, output_scale, stream, ) torch.cuda.synchronize() print("Verifying results...") if in_dtype == cutlass.Float8E4M3FN: tolerance = 0.13 o_ref, lse_ref = torch_reference_mla( q_latent_ref, q_rope_ref, c_latent_ref, c_rope_ref, page_table, cache_seqs, softmax_scale, output_scale, ) if out_dtype in [cutlass.Float8E5M2, cutlass.Float8E4M3FN]: # convert o back to f32 for comparison o_fp32, o_fp32_torch = cutlass_torch.cute_tensor_like( torch.empty(*o_torch.shape, dtype=torch.float32), cutlass.Float32, is_dynamic_layout=True, assumed_align=16, ) cute.testing.convert(o, o_fp32) o = o_fp32_torch.cpu() ref_fp8, _ = cutlass_torch.cute_tensor_like( torch.empty(*o_ref.permute(2, 0, 1).shape, dtype=torch.uint8).permute( 1, 2, 0 ), out_dtype, is_dynamic_layout=True, assumed_align=16, ) o_ref_gpu = o_ref.cuda() # Set use_32bit_stride to True for small problem size(cosize(layout) <= Int32_max) for better performance. o_ref_f32 = from_dlpack( o_ref_gpu, use_32bit_stride=True ).mark_layout_dynamic(leading_dim=1) # convert ref : f32 -> fp8 -> f32 cute.testing.convert(o_ref_f32, ref_fp8) cute.testing.convert(ref_fp8, o_ref_f32) o_ref = o_ref_gpu.cpu() else: o = o_torch.cpu().to(torch.float32) lse = lse_torch.cpu() lse_ref = lse_ref.to(cutlass.torch.dtype(lse_dtype)) # Assert close results torch.testing.assert_close(o, o_ref, atol=tolerance, rtol=1e-05) torch.testing.assert_close(lse, lse_ref, atol=tolerance, rtol=1e-05) print("Results verified successfully!") def generate_tensors(): _, cache_seqs, _ = create_cache_seqs(batch_size, seq_len, is_var_seq) _, page_table, _ = create_page_table( batch_size, seq_len, is_var_seq, use_page_table, page_size ) _split_kv, _, block_split_kvs, _ = create_block_split_kvs( batch_size, split_kv, cache_seqs_ref, is_var_split_kv, mma_qk_tiler_mn, cluster_shape_mnk, max_active_clusters, ) _, q_latent, _ = create_data_tensor( batch_size, num_heads, latent_dim, in_dtype, is_dynamic_layout=True, ) _, q_rope, _ = create_data_tensor( batch_size, num_heads, rope_dim, in_dtype, is_dynamic_layout=True, ) _, c_latent, _ = create_data_tensor( batch_size, seq_len, latent_dim, in_dtype, is_dynamic_layout=True, page_table=page_table, cache_seqs=cache_seqs_ref, ) _, c_rope, _ = create_data_tensor( batch_size, seq_len, rope_dim, in_dtype, is_dynamic_layout=True, page_table=page_table, cache_seqs=cache_seqs_ref, ) _, o, _ = create_data_tensor( batch_size, num_heads, latent_dim, out_dtype, is_dynamic_layout=True ) _, lse, _ = create_data_tensor( batch_size, num_heads, 1, lse_dtype, is_dynamic_layout=True, is_lse=True ) workspace, workspace_torch = create_workspace( num_heads, latent_dim, batch_size, _split_kv, acc_dtype ) return testing.JitArguments( q_latent, q_rope, c_latent, c_rope, page_table, o, lse, workspace, _split_kv, cache_seqs, block_split_kvs, softmax_scale, output_scale, stream, ) workspace_count = 1 if use_cold_l2: one_workspace_bytes = ( q_latent_torch.numel() * q_latent_torch.element_size() + q_rope_torch.numel() * q_rope_torch.element_size() + c_latent_torch.numel() * c_latent_torch.element_size() + c_rope_torch.numel() * c_rope_torch.element_size() + o_torch.numel() * o_torch.element_size() + lse_torch.numel() * lse_torch.element_size() + cache_seqs_torch.numel() * cache_seqs_torch.element_size() ) if use_page_table: one_workspace_bytes += ( page_table_torch.numel() * page_table_torch.element_size() ) if is_var_split_kv: one_workspace_bytes += ( block_split_kvs_torch.numel() * block_split_kvs_torch.element_size() ) if workspace_torch is not None: one_workspace_bytes += ( workspace_torch.numel() * workspace_torch.element_size() ) workspace_count = testing.get_workspace_count( one_workspace_bytes, warmup_iterations, iterations ) avg_time_us = testing.benchmark( compiled_mla, workspace_generator=generate_tensors, workspace_count=workspace_count, stream=stream, warmup_iterations=warmup_iterations, iterations=iterations, ) return avg_time_us # Return execution time in microseconds if __name__ == "__main__": def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: try: return tuple(int(x.strip()) for x in s.split(",")) except ValueError: raise argparse.ArgumentTypeError( "Invalid format. Expected comma-separated integers." ) def parse_mma_tiler(s: str) -> Tuple[int, int, Tuple[int, int]]: ret = parse_comma_separated_ints(s) if len(ret) != 2: raise argparse.ArgumentTypeError( "Invalid format. Expected 2 comma-separated integers." ) return (ret[0], ret[1]) parser = argparse.ArgumentParser(description="Example of MLA on Blackwell.") parser.add_argument( "--in_dtype", type=cutlass.dtype, default=cutlass.Float8E4M3FN, help="Input data type", ) parser.add_argument( "--out_dtype", type=cutlass.dtype, default=cutlass.Float16, help="Output data type", ) parser.add_argument( "--acc_dtype", type=cutlass.dtype, default=cutlass.Float32, help="Accumulator data type", ) parser.add_argument( "--lse_dtype", type=cutlass.dtype, default=cutlass.Float32, help="LSE data type", ) parser.add_argument( "--mma_qk_tiler_mn", type=parse_mma_tiler, default=(128, 128), help="MMA tile shape (H, K)", ) parser.add_argument( "--mma_pv_tiler_mn", type=parse_mma_tiler, default=(128, 256), help="MMA tile shape (H, D)", ) parser.add_argument( "--is_persistent", action="store_true", help="Is persistent", ) parser.add_argument( "--batch_size", type=int, default=1, help="Batch size", ) parser.add_argument( "--seq_len", type=int, default=128, help="Sequence length of K/V", ) parser.add_argument( "--num_heads", type=int, default=128, help="Number of heads of Q", ) parser.add_argument( "--latent_dim", type=int, default=512, help="Latent dimension of Q/C", ) parser.add_argument( "--rope_dim", type=int, default=64, help="Rope dimension of Q/C", ) parser.add_argument( "--is_cpasync", action="store_true", help="Use cpasync for load or not", ) parser.add_argument( "--is_var_seq", action="store_true", help="Use variable length of sequence length or not", ) parser.add_argument( "--is_var_split_kv", action="store_true", help="Use variable length of split kv or not", ) parser.add_argument( "--use_page_table", action="store_true", help="Use page table or not, must be True when is_cpasync is True", ) parser.add_argument( "--page_size", type=int, default=128, help="Page size of page table", ) parser.add_argument( "--split_kv", type=int, default=-1, help="Split KV setting", ) parser.add_argument( "--softmax_scale", type=float, default=1.0, help="Scaling factor to scale softmax", ) parser.add_argument( "--output_scale", type=float, default=1.0, help="Scaling factor to scale output", ) parser.add_argument( "--tolerance", type=float, default=1e-02, help="Tolerance for validation" ) parser.add_argument( "--warmup_iterations", type=int, default=0, help="Number of iterations for warmup", ) parser.add_argument( "--iterations", type=int, default=1, help="Number of iterations after warmup", ) parser.add_argument( "--skip_ref_check", action="store_true", help="Skip reference check", ) parser.add_argument( "--use_cold_l2", action="store_true", help="Use cold L2 cache", ) args = parser.parse_args() run( args.batch_size, args.seq_len, args.num_heads, args.latent_dim, args.rope_dim, args.in_dtype, args.out_dtype, args.acc_dtype, args.lse_dtype, args.mma_qk_tiler_mn, args.mma_pv_tiler_mn, args.split_kv, args.is_persistent, args.is_cpasync, args.is_var_seq, args.is_var_split_kv, args.use_page_table, args.page_size, args.softmax_scale, args.output_scale, args.tolerance, args.warmup_iterations, args.iterations, args.skip_ref_check, args.use_cold_l2, ) print("PASS")