* v4.3 update. * Update the cute_dsl_api changelog's doc link * Update version to 4.3.0 * Update the example link * Update doc to encourage user to install DSL from requirements.txt --------- Co-authored-by: Larry Wu <larwu@nvidia.com>
5194 lines
200 KiB
Python
5194 lines
200 KiB
Python
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
|
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
# list of conditions and the following disclaimer.
|
|
|
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
|
|
# 3. Neither the name of the copyright holder nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
import argparse
|
|
import math
|
|
from typing import Type, Tuple, Optional, Callable
|
|
from types import SimpleNamespace
|
|
from functools import partial
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import cuda.bindings.driver as cuda
|
|
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
import cutlass.cute.testing as testing
|
|
import cutlass.cute.nvgpu.tcgen05 as tcgen05
|
|
import cutlass.cute.nvgpu.cpasync as cpasync
|
|
import cutlass.utils as utils
|
|
import cutlass.pipeline as pipeline
|
|
import cutlass.torch as cutlass_torch
|
|
import cutlass.utils.blackwell_helpers as sm100_utils
|
|
from cutlass.cute.runtime import from_dlpack
|
|
|
|
"""
|
|
A Multi-Head Latent Attention (MLA) example for the NVIDIA Blackwell SM100 architecture using CUTE DSL
|
|
|
|
This example demonstrates an implementation of inference of multi-head latent attention using a TMA + Blackwell
|
|
SM100 TensorCore warp-specialized persistent kernel. The implementation integrates the (Qc + Qr)*(Kc + Kr)^T
|
|
matrix multiplication, softmax normalization, and softmax((Qc + Qr)*(Kc + Kr)^T)*Vc into a single kernel.
|
|
The kernel provides support for page table storage and variable-length KV cache sequences. It implements KV splitting
|
|
functionality to minimize latency when processing long KV sequences.
|
|
|
|
The kernel implements key optimizations including:
|
|
- Warp specialization for different computation phases (load, MMA, softmax, correction, epilogue)
|
|
- Pipeline stages between different warps for overlapping computation and memory access
|
|
- Support for different precision data types
|
|
- Two sub-kernels (split KV kernel and reduction kernel) that enable split KV processing
|
|
|
|
To run this example:
|
|
|
|
.. code-block:: bash
|
|
|
|
python examples/blackwell/mla.py \
|
|
--batch_size 4 --latent_dim 512 --rope_dim 64 \
|
|
--num_heads 128 --seq_len 1024 \
|
|
--in_dtype Float8E4M3FN --out_dtype Float16 \
|
|
--acc_dtype Float32 --lse_dtype Float32 \
|
|
--use_page_table --is_var_seq --is_var_split_kv \
|
|
--is_persistent
|
|
|
|
The above example runs Multi-Head Latent Attention (MLA) with the following configuration:
|
|
- Batch size: 4
|
|
- Sequence length: 1024
|
|
- Latent dimension: 512
|
|
- RoPE dimension: 64
|
|
- Number of heads: 128
|
|
- Data types: Float8E4M3FN (input), Float16 (output), Float32 (accumulation and LSE)
|
|
|
|
It utilizes page table storage for the KV cache and enables both variable-length KV cache sequences
|
|
and variable split KV processing with persistent scheduling.
|
|
|
|
To collect performance with NCU profiler:
|
|
|
|
.. code-block:: bash
|
|
|
|
ncu python examples/blackwell/mla.py \
|
|
--batch_size 4 --latent_dim 512 --rope_dim 64 \
|
|
--num_heads 128 --seq_len 1024 \
|
|
--in_dtype Float8E4M3FN --out_dtype Float16 \
|
|
--acc_dtype Float32 --lse_dtype Float32 \
|
|
--use_page_table --is_var_seq --is_var_split_kv \
|
|
--is_persistent --warmup_iterations 3 \
|
|
--iterations 10 --skip_ref_check
|
|
|
|
Constraints for this example:
|
|
* Data type requirements:
|
|
- Input/output: Float8E4M3FN or Float16
|
|
- Accumulation and LSE: Float32
|
|
* Fixed architecture parameters:
|
|
- Number of attention heads: 128
|
|
- Latent dimension: 512
|
|
- RoPE dimension: 64
|
|
* Input query modes should be (NumHeads, LatentDim/RopeDim, BatchSize)
|
|
* Input kv latent/rope modes should be (SeqLen, LatentDim/RopeDim, BatchSize)
|
|
* Query sequence length must be 1
|
|
* Only supports 2-CTA instructions
|
|
* Variable sequence length requires page table storage enabled
|
|
"""
|
|
|
|
|
|
class MLAStaticTileSchedulerParams:
|
|
def __init__(
|
|
self,
|
|
is_persistent: bool,
|
|
problem_shape_b: cute.Int32,
|
|
cluster_shape_mnk: cute.Shape,
|
|
split_kv: cutlass.Int32,
|
|
*,
|
|
loc=None,
|
|
ip=None,
|
|
):
|
|
"""The static tile scheduler parameters prepared for MLA static tile scheduler.
|
|
|
|
:param is_persistent: Whether to use persistent kernel mode
|
|
:type is_persistent: bool
|
|
:param problem_shape_b: The shape of the problem
|
|
:type problem_shape_b: cute.Int32
|
|
:param cluster_shape_mnk: The shape of the cluster
|
|
:type cluster_shape_mnk: cute.Shape
|
|
:param split_kv: The scalar factor for split KV
|
|
"""
|
|
self.is_persistent = is_persistent
|
|
self.problem_shape_b = problem_shape_b
|
|
self.cluster_shape_mnk = cluster_shape_mnk
|
|
self.split_kv = split_kv
|
|
self.loc = loc
|
|
self.ip = ip
|
|
|
|
def __extract_mlir_values__(self):
|
|
values = cutlass.extract_mlir_values(self.problem_shape_b)
|
|
values += cutlass.extract_mlir_values(self.split_kv)
|
|
return values
|
|
|
|
def __new_from_mlir_values__(self, values):
|
|
problem_shape_b = cutlass.new_from_mlir_values(
|
|
self.problem_shape_b, (values[0],)
|
|
)
|
|
split_kv = cutlass.new_from_mlir_values(self.split_kv, (values[1],))
|
|
return MLAStaticTileSchedulerParams(
|
|
self.is_persistent,
|
|
problem_shape_b,
|
|
self.cluster_shape_mnk,
|
|
split_kv,
|
|
loc=self.loc,
|
|
)
|
|
|
|
|
|
def create_mla_static_tile_scheduler_params(
|
|
is_persistent: bool,
|
|
problem_shape_b: cute.Int32,
|
|
cluster_shape_mnk: cute.Shape,
|
|
split_kv: cutlass.Int32,
|
|
) -> MLAStaticTileSchedulerParams:
|
|
return MLAStaticTileSchedulerParams(
|
|
is_persistent, problem_shape_b, cluster_shape_mnk, split_kv
|
|
)
|
|
|
|
|
|
class MLAStaticTileScheduler:
|
|
def __init__(
|
|
self,
|
|
params: MLAStaticTileSchedulerParams,
|
|
current_work_linear_idx: cutlass.Int32,
|
|
blk_coord: cute.Coord,
|
|
grid_shape: cute.Shape,
|
|
*,
|
|
is_valid: bool = True,
|
|
loc=None,
|
|
ip=None,
|
|
):
|
|
"""The static tile scheduler for MLA split kv kernel.
|
|
Based on `is_persistent`, it provides 2 modes for use:
|
|
- Persistent mode: Launch fixed blocks and reschedule the data blocks.
|
|
- Non-persistent mode: Launch dynamic blocks and exit when the current work is done.
|
|
|
|
:param params: The static tile scheduler parameters
|
|
:type params: MLAStaticTileSchedulerParams
|
|
:param current_work_linear_idx: The linear index of the current work
|
|
:type current_work_linear_idx: cutlass.Int32
|
|
:param blk_coord: The coordinate of the current work
|
|
:type blk_coord: cute.Coord
|
|
:param grid_shape: The shape of the grid
|
|
:type grid_shape: cute.Shape
|
|
:param is_valid: Whether the current work is valid
|
|
:type is_valid: bool
|
|
"""
|
|
self.params = params
|
|
self.blk_coord = blk_coord
|
|
self.grid_shape = grid_shape
|
|
self.current_work_linear_idx = current_work_linear_idx
|
|
if params.is_persistent:
|
|
self.persistent_blk_layout = cute.make_layout(
|
|
(
|
|
params.cluster_shape_mnk[0],
|
|
1,
|
|
params.problem_shape_b,
|
|
params.split_kv,
|
|
),
|
|
loc=loc,
|
|
ip=ip,
|
|
)
|
|
self.num_blocks = cute.size(self.persistent_blk_layout, loc=loc, ip=ip)
|
|
# Used for persistent scheduling
|
|
self.num_persistent_sm = cute.size(grid_shape, loc=loc, ip=ip)
|
|
else:
|
|
self.is_valid = is_valid
|
|
self.loc = loc
|
|
self.ip = ip
|
|
|
|
@staticmethod
|
|
def get_grid_shape(
|
|
params: MLAStaticTileSchedulerParams,
|
|
max_active_clusters: int,
|
|
*,
|
|
loc=None,
|
|
ip=None,
|
|
) -> cute.Shape:
|
|
# called by host
|
|
grid_shape = (
|
|
params.cluster_shape_mnk[0],
|
|
params.problem_shape_b,
|
|
params.split_kv,
|
|
)
|
|
if params.is_persistent:
|
|
return (
|
|
cutlass.min(
|
|
max_active_clusters * cute.size(params.cluster_shape_mnk),
|
|
cute.size(grid_shape, loc=loc, ip=ip),
|
|
),
|
|
1,
|
|
1,
|
|
)
|
|
else:
|
|
return grid_shape
|
|
|
|
def get_current_work(self, *, loc=None, ip=None) -> utils.WorkTileInfo:
|
|
is_valid = (
|
|
self.current_work_linear_idx < self.num_blocks
|
|
if self.params.is_persistent
|
|
else self.is_valid
|
|
)
|
|
|
|
if self.params.is_persistent:
|
|
blk_coord = self.persistent_blk_layout.get_hier_coord(
|
|
self.current_work_linear_idx, loc=loc, ip=ip
|
|
)
|
|
else:
|
|
blk_coord = (self.blk_coord[0], 0, self.blk_coord[1], self.blk_coord[2])
|
|
|
|
return utils.WorkTileInfo(blk_coord, is_valid)
|
|
|
|
def initial_work_tile_info(self, *, loc=None, ip=None):
|
|
return self.get_current_work(loc=loc, ip=ip)
|
|
|
|
def advance_to_next_work(self, *, advance_count=1, loc=None, ip=None):
|
|
if self.params.is_persistent:
|
|
self.current_work_linear_idx += advance_count * self.num_persistent_sm
|
|
else:
|
|
self.is_valid = False
|
|
|
|
def __extract_mlir_values__(self):
|
|
values = cutlass.extract_mlir_values(self.params)
|
|
values.extend(cutlass.extract_mlir_values(self.current_work_linear_idx))
|
|
values.extend(cutlass.extract_mlir_values(self.blk_coord))
|
|
values.extend(cutlass.extract_mlir_values(self.grid_shape))
|
|
return values
|
|
|
|
def __new_from_mlir_values__(self, values):
|
|
assert len(values) == 9
|
|
new_params = cutlass.new_from_mlir_values(self.params, values[0:2])
|
|
new_current_work_linear_idx = cutlass.new_from_mlir_values(
|
|
self.current_work_linear_idx, [values[2]]
|
|
)
|
|
new_blk_coord = cutlass.new_from_mlir_values(self.blk_coord, values[3:6])
|
|
new_grid_shape = cutlass.new_from_mlir_values(self.grid_shape, values[6:])
|
|
return MLAStaticTileScheduler(
|
|
new_params, new_current_work_linear_idx, new_blk_coord, new_grid_shape
|
|
)
|
|
|
|
|
|
def create_mla_static_tile_scheduler(
|
|
params: MLAStaticTileSchedulerParams,
|
|
blk_coord: cute.Coord,
|
|
grid_shape: cute.Shape,
|
|
) -> MLAStaticTileScheduler:
|
|
return MLAStaticTileScheduler(params, blk_coord[0], blk_coord, grid_shape)
|
|
|
|
|
|
LOG2_E = 1.4426950408889634074
|
|
# avoid register indexing on array.
|
|
MAX_SPLITS = 256
|
|
|
|
|
|
class BlackwellMultiHeadLatentAttentionForward:
|
|
def __init__(
|
|
self,
|
|
acc_dtype: Type[cutlass.Numeric],
|
|
lse_dtype: Type[cutlass.Numeric],
|
|
mma_qk_tiler_mn: Tuple[int, int],
|
|
mma_pv_tiler_mn: Tuple[int, int],
|
|
max_active_clusters: int,
|
|
is_persistent: bool,
|
|
is_cpasync: bool,
|
|
use_page_table: bool,
|
|
is_var_seq: bool,
|
|
is_var_split_kv: bool,
|
|
):
|
|
"""Initializes the configuration for a Blackwell Multi-Head Latent Attention (MLA) kernel.
|
|
|
|
:param acc_dtype: Data type for accumulation S and O
|
|
:type acc_dtype: Type[cutlass.Numeric]
|
|
:param lse_dtype: Data type for output LSE
|
|
:type lse_dtype: Type[cutlass.Numeric]
|
|
:param mma_s_tiler: The (H, K) tile shape of the MMA instruction for S
|
|
:type mma_s_tiler: Tuple[int, int]
|
|
:param mma_p_tiler: The (H, D) tile shape of the MMA instruction for P
|
|
:type mma_p_tiler: Tuple[int, int]
|
|
:param max_active_clusters: Maximum number of active clusters
|
|
:type max_active_clusters: int
|
|
:param is_persistent: Whether to use persistent kernel mode
|
|
:type is_persistent: bool
|
|
:param is_cpasync: Whether to use CP async mode
|
|
:type is_cpasync: bool
|
|
:param use_page_table: Whether to use page table
|
|
:type use_page_table: bool
|
|
:param is_var_seq: Whether to use variable sequence length
|
|
:type is_var_seq: bool
|
|
:param is_var_split_kv: Whether to use variable split KV
|
|
:type is_var_split_kv: bool
|
|
"""
|
|
|
|
self.latent_dim = 512
|
|
self.rope_dim = 64
|
|
self.acc_dtype = acc_dtype
|
|
self.lse_dtype = lse_dtype
|
|
self.mma_qk_tiler_mn = mma_qk_tiler_mn
|
|
self.mma_pv_tiler_mn = mma_pv_tiler_mn
|
|
self.max_active_clusters = max_active_clusters
|
|
self.is_persistent = is_persistent
|
|
self.is_cpasync = is_cpasync
|
|
self.use_page_table = use_page_table
|
|
self.is_var_seq = is_var_seq
|
|
self.is_var_split_kv = is_var_split_kv
|
|
self.cluster_shape_mnk = (2, 1, 1)
|
|
self.use_2cta_instrs = True
|
|
# When using 2 CTAs with m=128: warps 0-1 handle accumulation for first half [0, n/2),
|
|
# while warps 2-3 handle accumulation for second half [n/2, n)
|
|
self.warps_in_n = 2
|
|
self.num_compute_warps = 4
|
|
self.threads_per_warp = 32
|
|
self.num_load_warps = 2 if self.is_cpasync else 1
|
|
mma_qk_tiler_k = self.rope_dim if self.is_cpasync else self.rope_dim * 2
|
|
self.mma_qk_tiler = (
|
|
self.mma_qk_tiler_mn[0],
|
|
self.mma_qk_tiler_mn[1],
|
|
mma_qk_tiler_k,
|
|
)
|
|
self.mma_pv_tiler = (
|
|
self.mma_pv_tiler_mn[0],
|
|
self.mma_pv_tiler_mn[1],
|
|
self.mma_qk_tiler[1] * self.mma_qk_tiler[2] // self.mma_pv_tiler_mn[1],
|
|
)
|
|
self.iterations_qk_latent = self.latent_dim // self.mma_qk_tiler[2]
|
|
self.iterations_qk_rope = mma_qk_tiler_k // self.mma_qk_tiler[2]
|
|
self.iterations_qk = self.iterations_qk_latent + self.iterations_qk_rope
|
|
self.iterations_pv_k = self.mma_qk_tiler[1] // self.mma_pv_tiler[2]
|
|
self.iterations_pv_n = self.latent_dim // self.mma_pv_tiler[1]
|
|
|
|
# Set specialized warp ids
|
|
self.compute_warp_ids = (0, 1, 2, 3)
|
|
self.correction_warp_ids = (4, 5, 6, 7)
|
|
self.mma_warp_id = 8
|
|
if self.is_cpasync:
|
|
self.load_cp_async_warp_ids = (9, 10)
|
|
self.load_pt_warp_id = 11
|
|
self.threads_per_cta = self.threads_per_warp * len(
|
|
(
|
|
self.mma_warp_id,
|
|
*self.load_cp_async_warp_ids,
|
|
self.load_pt_warp_id,
|
|
*self.compute_warp_ids,
|
|
*self.correction_warp_ids,
|
|
)
|
|
)
|
|
else:
|
|
self.load_tma_warp_id = 9
|
|
self.empty_warp_ids = (10, 11)
|
|
self.threads_per_cta = self.threads_per_warp * len(
|
|
(
|
|
self.mma_warp_id,
|
|
self.load_tma_warp_id,
|
|
*self.compute_warp_ids,
|
|
*self.correction_warp_ids,
|
|
*self.empty_warp_ids,
|
|
)
|
|
)
|
|
|
|
# register settings
|
|
self.softmax_reg_num = 192
|
|
self.correction_reg_num = 192
|
|
self.other_reg_num = 112
|
|
# Named barriers
|
|
self.tmem_ptr_sync_bar = pipeline.NamedBarrier(
|
|
barrier_id=1,
|
|
num_threads=(
|
|
self.threads_per_warp
|
|
+ self.threads_per_warp * self.num_compute_warps * 2
|
|
),
|
|
)
|
|
self.softmax_exchange_sync_bar = pipeline.NamedBarrier(
|
|
barrier_id=2, num_threads=(self.threads_per_warp * self.num_compute_warps)
|
|
)
|
|
self.epilogue_exchange_sync_bar = pipeline.NamedBarrier(
|
|
barrier_id=3, num_threads=(self.threads_per_warp * self.num_compute_warps)
|
|
)
|
|
|
|
def _setup_attributes(self):
|
|
"""Set up configurations and parameters for the MLA kernel operation.
|
|
|
|
This method initializes and configures various attributes required for the
|
|
execution of the multi-head latent attention kernel, mainly about the pipeline stages:
|
|
|
|
- Sets up staging parameters for Q, K, V inputs and accumulator data
|
|
- Configures pipeline stages for softmax, correction, and epilogue operations
|
|
"""
|
|
|
|
self.load_q_stage = self.iterations_qk
|
|
self.load_kv_stage = (24 if self.is_cpasync else 12) // (
|
|
self.k_dtype.width // 8
|
|
)
|
|
self.mma_s_stage = 2
|
|
self.p_mma_stage = 2
|
|
self.p_cor_stage = 2
|
|
self.mma_o_stage = 1
|
|
self.load_pt_stage = self.load_kv_stage if self.is_cpasync else 1
|
|
|
|
self.tmem_o_offset = self.mma_s_stage * self.mma_qk_tiler[1] // self.warps_in_n
|
|
self.correction_factor_offset = (
|
|
self.tmem_o_offset + self.latent_dim // self.warps_in_n
|
|
)
|
|
|
|
@cute.jit
|
|
def __call__(
|
|
self,
|
|
q_latent: cute.Tensor,
|
|
q_rope: cute.Tensor,
|
|
c_latent: cute.Tensor,
|
|
c_rope: cute.Tensor,
|
|
page_table: cute.Tensor,
|
|
o: cute.Tensor,
|
|
lse: cute.Tensor,
|
|
workspace: cute.Tensor,
|
|
split_kv: cutlass.Int32,
|
|
cache_seqs: Optional[cute.Tensor],
|
|
block_split_kvs: Optional[cute.Tensor],
|
|
softmax_scale: cutlass.Float32,
|
|
output_scale: cutlass.Float32,
|
|
stream: cuda.CUstream,
|
|
):
|
|
"""Execute the Multi-Head Latent Attention operation on the provided tensors.
|
|
|
|
The method handles:
|
|
1. Initialization of workspace for temporary split KV buffers
|
|
2. Validation of tensor data types
|
|
3. Initialization of hardware-specific parameters and memory layouts
|
|
4. Configuration of TMA (Tensor Memory Access) operations
|
|
5. Grid and work scheduling computation
|
|
6. Kernel launch(split KV kernel and reduction kernel) with appropriate parameters
|
|
|
|
:param q_latent: The query tensor with shape [num_head, latent_dim, batch_size]
|
|
:type q_latent: cute.Tensor
|
|
:param q_rope: The query RoPE tensor with shape [num_head, rope_dim, batch_size]
|
|
:type q_rope: cute.Tensor
|
|
:param c_latent: The key tensor with shape [seq_len, latent_dim, batch_size]
|
|
:type c_latent: cute.Tensor
|
|
:param c_rope: The key RoPE tensor with shape [seq_len, rope_dim, batch_size]
|
|
:type c_rope: cute.Tensor
|
|
:param page_table: The page table tensor with shape [page_count, batch_size]
|
|
:type page_table: cute.Tensor
|
|
:param o: The output tensor with shape [num_head, latent_dim, batch_size]
|
|
:type o: cute.Tensor
|
|
:param lse: The LSE tensor with shape [num_head, batch_size]
|
|
:type lse: cute.Tensor
|
|
:param workspace: The workspace tensor with 1-d shape prepared for acc_o and acc_lse
|
|
:type workspace: cute.Tensor
|
|
:param split_kv: The scalar factor for split KV
|
|
:type split_kv: cutlass.Int32
|
|
:param cache_seqs: The cache sequences tensor with shape [batch_size]
|
|
:type cache_seqs: cute.Tensor
|
|
:param block_split_kvs: The block split KV tensor with shape [batch_size]
|
|
:type block_split_kvs: cute.Tensor
|
|
:param softmax_scale: The scale factor for softmax
|
|
:type softmax_scale: cutlass.Float32
|
|
:param output_scale: The scale factor for the output
|
|
:type output_scale: cutlass.Float32
|
|
:param stream: The CUDA stream to execute the kernel on
|
|
:type stream: cuda.CUstream
|
|
|
|
:raises TypeError: If tensor data types don't match or aren't supported
|
|
"""
|
|
|
|
# setup static attributes before smem/grid/tma computation
|
|
self.q_dtype = q_latent.element_type
|
|
self.k_dtype = c_latent.element_type
|
|
self.v_dtype = c_latent.element_type
|
|
self.o_dtype = o.element_type
|
|
|
|
# check type consistency
|
|
if cutlass.const_expr(
|
|
self.q_dtype != self.k_dtype or self.q_dtype != self.v_dtype
|
|
):
|
|
raise TypeError(
|
|
f"Type mismatch: {self.q_dtype} != {self.k_dtype} or {self.q_dtype} != {self.v_dtype}"
|
|
)
|
|
# check leading dimensions of input/output
|
|
if cutlass.const_expr(q_latent.stride[1] != 1 or q_rope.stride[1] != 1):
|
|
raise ValueError("q_latent and q_rope must have leading dimension 1")
|
|
if cutlass.const_expr(c_latent.stride[1] != 1 or c_rope.stride[1] != 1):
|
|
raise ValueError("c_latent and c_rope must have leading dimension 1")
|
|
if cutlass.const_expr(o.stride[1] != 1):
|
|
raise ValueError("o must have leading dimension 1")
|
|
if cutlass.const_expr(lse.stride[0] != 1):
|
|
raise ValueError("lse must have leading dimension 0")
|
|
|
|
acc_o, acc_lse = self.initialize_workspace(
|
|
q_latent.shape[0],
|
|
q_latent.shape[1],
|
|
q_latent.shape[2],
|
|
split_kv,
|
|
self.acc_dtype,
|
|
workspace,
|
|
)
|
|
|
|
c_latent_tranpose_layout = cute.select(c_latent.layout, mode=[1, 0, 2])
|
|
c_latent_transpose = cute.make_tensor(
|
|
c_latent.iterator, c_latent_tranpose_layout
|
|
)
|
|
|
|
self.q_major_mode = tcgen05.OperandMajorMode.K
|
|
self.k_major_mode = tcgen05.OperandMajorMode.K
|
|
self.v_major_mode = tcgen05.OperandMajorMode.MN
|
|
|
|
self._setup_attributes()
|
|
|
|
cta_group = tcgen05.CtaGroup.TWO
|
|
# the intermediate tensor p is from smem & k-major
|
|
p_major_mode = tcgen05.OperandMajorMode.K
|
|
qk_tiled_mma = sm100_utils.make_trivial_tiled_mma(
|
|
self.q_dtype,
|
|
self.q_major_mode,
|
|
self.k_major_mode,
|
|
self.acc_dtype,
|
|
cta_group,
|
|
self.mma_qk_tiler[:2],
|
|
)
|
|
pv_tiled_mma = sm100_utils.make_trivial_tiled_mma(
|
|
self.v_dtype,
|
|
p_major_mode,
|
|
self.v_major_mode,
|
|
self.acc_dtype,
|
|
cta_group,
|
|
self.mma_pv_tiler[:2],
|
|
)
|
|
|
|
cta_layout_vmnk = cute.tiled_divide(
|
|
cute.make_layout(self.cluster_shape_mnk),
|
|
(qk_tiled_mma.thr_id.shape,),
|
|
)
|
|
|
|
self.epi_tile = self.mma_pv_tiler[:2]
|
|
|
|
q_smem_layout_staged = sm100_utils.make_smem_layout_a(
|
|
qk_tiled_mma,
|
|
self.mma_qk_tiler,
|
|
self.q_dtype,
|
|
self.load_q_stage,
|
|
)
|
|
kc_smem_layout_staged = sm100_utils.make_smem_layout_b(
|
|
qk_tiled_mma,
|
|
self.mma_qk_tiler,
|
|
self.k_dtype,
|
|
self.load_kv_stage,
|
|
)
|
|
p_smem_layout_staged = sm100_utils.make_smem_layout_a(
|
|
pv_tiled_mma,
|
|
self.mma_pv_tiler,
|
|
self.q_dtype,
|
|
(self.iterations_pv_k * self.p_mma_stage),
|
|
)
|
|
p_smem_layout_staged = cute.logical_divide(
|
|
p_smem_layout_staged, (None, None, None, self.iterations_pv_k)
|
|
)
|
|
vc_smem_layout_staged = sm100_utils.make_smem_layout_b(
|
|
pv_tiled_mma,
|
|
self.mma_pv_tiler,
|
|
self.v_dtype,
|
|
self.load_kv_stage,
|
|
)
|
|
if cutlass.const_expr(not self.is_cpasync):
|
|
# TMA load for Q latent and rope
|
|
tma_load_op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp(cta_group)
|
|
|
|
q_smem_layout = cute.select(q_smem_layout_staged, mode=[0, 1, 2])
|
|
tma_atom_q_latent, tma_tensor_q_latent = cute.nvgpu.make_tiled_tma_atom_A(
|
|
tma_load_op,
|
|
q_latent,
|
|
q_smem_layout,
|
|
self.mma_qk_tiler,
|
|
qk_tiled_mma,
|
|
cta_layout_vmnk.shape,
|
|
)
|
|
tma_atom_q_rope, tma_tensor_q_rope = cute.nvgpu.make_tiled_tma_atom_A(
|
|
tma_load_op,
|
|
q_rope,
|
|
q_smem_layout,
|
|
self.mma_qk_tiler,
|
|
qk_tiled_mma,
|
|
cta_layout_vmnk.shape,
|
|
)
|
|
# TMA load for c latent and k rope
|
|
kc_smem_layout = cute.select(kc_smem_layout_staged, mode=[0, 1, 2])
|
|
tma_atom_c_latent, tma_tensor_c_latent = cute.nvgpu.make_tiled_tma_atom_B(
|
|
tma_load_op,
|
|
c_latent,
|
|
kc_smem_layout,
|
|
self.mma_qk_tiler,
|
|
qk_tiled_mma,
|
|
cta_layout_vmnk.shape,
|
|
)
|
|
tma_atom_c_rope, tma_tensor_c_rope = cute.nvgpu.make_tiled_tma_atom_B(
|
|
tma_load_op,
|
|
c_rope,
|
|
kc_smem_layout,
|
|
self.mma_qk_tiler,
|
|
qk_tiled_mma,
|
|
cta_layout_vmnk.shape,
|
|
)
|
|
# TMA load for c latent transpose
|
|
vc_smem_layout = cute.select(vc_smem_layout_staged, mode=[0, 1, 2])
|
|
tma_atom_c_latent_transpose, tma_tensor_c_latent_transpose = (
|
|
cute.nvgpu.make_tiled_tma_atom_B(
|
|
tma_load_op,
|
|
c_latent_transpose,
|
|
vc_smem_layout,
|
|
self.mma_pv_tiler,
|
|
pv_tiled_mma,
|
|
cta_layout_vmnk.shape,
|
|
)
|
|
)
|
|
|
|
q_copy_size = cute.size_in_bytes(self.q_dtype, q_smem_layout) * cute.size(
|
|
qk_tiled_mma.thr_id.shape
|
|
)
|
|
kc_copy_size = cute.size_in_bytes(self.k_dtype, kc_smem_layout) * cute.size(
|
|
qk_tiled_mma.thr_id.shape
|
|
)
|
|
vc_copy_size = cute.size_in_bytes(self.v_dtype, vc_smem_layout) * cute.size(
|
|
pv_tiled_mma.thr_id.shape
|
|
)
|
|
assert kc_copy_size == vc_copy_size, (
|
|
"kc_copy_size and vc_copy_size must be the same"
|
|
)
|
|
|
|
self.tma_copy_q_bytes = q_copy_size
|
|
self.tma_copy_kc_bytes = kc_copy_size
|
|
else:
|
|
self.tma_copy_q_bytes = 0
|
|
self.tma_copy_kc_bytes = 0
|
|
|
|
tile_sched_params, grid = self._compute_grid(
|
|
o,
|
|
split_kv,
|
|
self.cluster_shape_mnk,
|
|
self.max_active_clusters,
|
|
self.is_persistent,
|
|
)
|
|
|
|
@cute.struct
|
|
class SplitKVKernelSharedStorage:
|
|
# Pipeline barriers
|
|
load_q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_q_stage * 2]
|
|
load_kv_mbar_ptr: cute.struct.MemRange[
|
|
cutlass.Int64, self.load_kv_stage * 2
|
|
]
|
|
mma_s_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.mma_s_stage * 2]
|
|
p_mma_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.p_mma_stage * 2]
|
|
p_cor_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.p_cor_stage * 2]
|
|
mma_o_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.mma_o_stage * 2]
|
|
load_pt_mbar_ptr: cute.struct.MemRange[
|
|
cutlass.Int64, self.load_pt_stage * 2
|
|
]
|
|
|
|
# Smem tensors
|
|
softmax_smem_exchange: cute.struct.MemRange[
|
|
self.acc_dtype, self.num_compute_warps * self.threads_per_warp
|
|
]
|
|
epilogue_smem_exchange: cute.struct.MemRange[
|
|
self.acc_dtype, self.num_compute_warps * self.threads_per_warp
|
|
]
|
|
|
|
smem_page_table: cute.struct.MemRange[
|
|
cutlass.Int32, self.load_pt_stage * self.mma_qk_tiler[1]
|
|
]
|
|
smem_q: cute.struct.Align[
|
|
cute.struct.MemRange[self.q_dtype, cute.cosize(q_smem_layout_staged)],
|
|
1024,
|
|
]
|
|
smem_kc: cute.struct.Align[
|
|
cute.struct.MemRange[self.k_dtype, cute.cosize(kc_smem_layout_staged)],
|
|
1024,
|
|
]
|
|
smem_p: cute.struct.Align[
|
|
cute.struct.MemRange[self.q_dtype, cute.cosize(p_smem_layout_staged)],
|
|
1024,
|
|
]
|
|
# Tmem dealloc cluster barrier
|
|
tmem_dealloc_mbar_ptr: cutlass.Int64
|
|
|
|
# Tmem holding buffer
|
|
tmem_holding_buf: cutlass.Int32
|
|
|
|
softmax_scale_log2 = softmax_scale * LOG2_E
|
|
# Launch the kernel synchronously
|
|
if cutlass.const_expr(self.is_cpasync):
|
|
self.split_kv_kernel(
|
|
qk_tiled_mma,
|
|
pv_tiled_mma,
|
|
None,
|
|
q_latent,
|
|
None,
|
|
q_rope,
|
|
None,
|
|
c_latent,
|
|
None,
|
|
c_rope,
|
|
None,
|
|
c_latent_transpose,
|
|
page_table,
|
|
o,
|
|
lse,
|
|
acc_o,
|
|
acc_lse,
|
|
split_kv,
|
|
cache_seqs,
|
|
block_split_kvs,
|
|
softmax_scale_log2,
|
|
output_scale,
|
|
q_smem_layout_staged,
|
|
kc_smem_layout_staged,
|
|
p_smem_layout_staged,
|
|
vc_smem_layout_staged,
|
|
cta_layout_vmnk,
|
|
tile_sched_params,
|
|
SplitKVKernelSharedStorage,
|
|
).launch(
|
|
grid=grid,
|
|
block=[self.threads_per_cta, 1, 1],
|
|
cluster=self.cluster_shape_mnk,
|
|
smem=SplitKVKernelSharedStorage.size_in_bytes(),
|
|
stream=stream,
|
|
min_blocks_per_mp=1,
|
|
)
|
|
else:
|
|
self.split_kv_kernel(
|
|
qk_tiled_mma,
|
|
pv_tiled_mma,
|
|
tma_atom_q_latent,
|
|
tma_tensor_q_latent,
|
|
tma_atom_q_rope,
|
|
tma_tensor_q_rope,
|
|
tma_atom_c_latent,
|
|
tma_tensor_c_latent,
|
|
tma_atom_c_rope,
|
|
tma_tensor_c_rope,
|
|
tma_atom_c_latent_transpose,
|
|
tma_tensor_c_latent_transpose,
|
|
page_table,
|
|
o,
|
|
lse,
|
|
acc_o,
|
|
acc_lse,
|
|
split_kv,
|
|
cache_seqs,
|
|
block_split_kvs,
|
|
softmax_scale_log2,
|
|
output_scale,
|
|
q_smem_layout_staged,
|
|
kc_smem_layout_staged,
|
|
p_smem_layout_staged,
|
|
vc_smem_layout_staged,
|
|
cta_layout_vmnk,
|
|
tile_sched_params,
|
|
SplitKVKernelSharedStorage,
|
|
).launch(
|
|
grid=grid,
|
|
block=[self.threads_per_cta, 1, 1],
|
|
cluster=self.cluster_shape_mnk,
|
|
smem=SplitKVKernelSharedStorage.size_in_bytes(),
|
|
stream=stream,
|
|
min_blocks_per_mp=1,
|
|
)
|
|
if cutlass.const_expr(acc_o is not None):
|
|
self.reduction_kernel(
|
|
o,
|
|
lse,
|
|
acc_o,
|
|
acc_lse,
|
|
split_kv,
|
|
cache_seqs,
|
|
block_split_kvs,
|
|
).launch(
|
|
grid=(q_latent.shape[0], 1, q_latent.shape[2]),
|
|
block=[self.threads_per_warp * self.num_compute_warps, 1, 1],
|
|
smem=MAX_SPLITS * self.acc_dtype.width // 8,
|
|
stream=stream,
|
|
min_blocks_per_mp=1,
|
|
)
|
|
|
|
@cute.kernel
|
|
def split_kv_kernel(
|
|
self,
|
|
tiled_mma_qk: cute.TiledMma,
|
|
tiled_mma_pv: cute.TiledMma,
|
|
tma_atom_q_latent: Optional[cute.CopyAtom],
|
|
mQL: cute.Tensor,
|
|
tma_atom_q_rope: Optional[cute.CopyAtom],
|
|
mQR: cute.Tensor,
|
|
tma_atom_c_latent: Optional[cute.CopyAtom],
|
|
mCL: cute.Tensor,
|
|
tma_atom_c_rope: Optional[cute.CopyAtom],
|
|
mKR: cute.Tensor,
|
|
tma_atom_c_latent_transpose: Optional[cute.CopyAtom],
|
|
mCLT: cute.Tensor,
|
|
mPT: cute.Tensor,
|
|
mO: Optional[cute.Tensor],
|
|
mLSE: Optional[cute.Tensor],
|
|
mAccO: Optional[cute.Tensor],
|
|
mAccLSE: Optional[cute.Tensor],
|
|
split_kv: cutlass.Int32,
|
|
cache_seqs: cute.Tensor,
|
|
block_split_kvs: cute.Tensor,
|
|
softmax_scale_log2: cutlass.Float32,
|
|
output_scale: cutlass.Float32,
|
|
q_smem_layout_staged: cute.ComposedLayout,
|
|
kc_smem_layout_staged: cute.ComposedLayout,
|
|
p_smem_layout_staged: cute.ComposedLayout,
|
|
vc_smem_layout_staged: cute.ComposedLayout,
|
|
cta_layout_vmnk: cute.Layout,
|
|
tile_sched_params: MLAStaticTileSchedulerParams,
|
|
SharedStorage: cutlass.Constexpr,
|
|
):
|
|
"""The device split_kv kernel implementation of the Multi-Head Latent Attention.
|
|
|
|
This kernel coordinates multiple specialized warps to perform different phases of the MLA computation:
|
|
1. Load warp: Loads Q/C latent/rope data from global memory to shared memory using TMA
|
|
2. MMA warp: Performs matrix multiplications (Q*K^T and P*V)
|
|
3. Compute warps: Compute softmax and do rescaling on accumulators, and store the intermediate/final results
|
|
to global memory
|
|
|
|
The kernel produces either intermediate or final results of the MLA computation based on the split_kv parameter.
|
|
When split_kv is 1, the kernel generates the final results directly. Otherwise, it produces intermediate results
|
|
that will later be combined by a reduction kernel.
|
|
|
|
The kernel implements a complex pipeline with overlapping computation and memory operations,
|
|
using tensor memory access (TMA) for efficient data loading, warp specialization for different
|
|
computation phases.
|
|
|
|
:param tiled_mma_qk: Tiled MMA for Q*K^T
|
|
:type tiled_mma_qk: cute.TiledMma
|
|
:param tiled_mma_pv: Tiled MMA for P*V
|
|
:type tiled_mma_pv: cute.TiledMma
|
|
:param tma_atom_q_latent: TMA copy atom for query latent tensor
|
|
:type tma_atom_q_latent: cute.CopyAtom
|
|
:param mQL: query latent tensor
|
|
:type mQL: cute.Tensor
|
|
:param tma_atom_q_rope: TMA copy atom for query rope tensor
|
|
:type tma_atom_q_rope: cute.CopyAtom
|
|
:param mKR: Compressed rope tensor
|
|
:type mKR: cute.Tensor
|
|
:param tma_atom_c_latent: TMA copy atom for c latent tensor
|
|
:type tma_atom_c_latent: cute.CopyAtom
|
|
:param mCL: Compressed latent tensor
|
|
:type mCL: cute.Tensor
|
|
:param tma_atom_c_rope: TMA copy atom for c rope tensor
|
|
:type tma_atom_c_rope: cute.CopyAtom
|
|
:param mCLT: Compressed latent transpose tensor
|
|
:type mCLT: cute.Tensor
|
|
:param mPT: Page table tensor
|
|
:type mPT: cute.Tensor
|
|
:param mO: Output tensor
|
|
:type mO: cute.Tensor
|
|
:param mLSE: Log-sum-exp tensor
|
|
:type mLSE: cute.Tensor
|
|
:param mAccO: Intermediate accumulator output tensor
|
|
:type mAccO: cute.Tensor
|
|
:param mAccLSE: Intermediate accumulator log-sum-exp tensor
|
|
:type mAccLSE: cute.Tensor
|
|
:param split_kv: The split_kv parameter
|
|
:type split_kv: cutlass.Int32
|
|
:param cache_seqs: The variable sequence length tensor
|
|
:type cache_seqs: cute.Tensor
|
|
:param block_split_kvs: The per-block split_kv values tensor
|
|
:type block_split_kvs: cute.Tensor
|
|
:param softmax_scale_log2: The log2 scale factor for softmax
|
|
:type softmax_scale_log2: cutlass.Float32
|
|
:param output_scale: The scale factor for the output
|
|
:type output_scale: cutlass.Float32
|
|
:param q_smem_layout_staged: Shared memory layout for query tensor
|
|
:type q_smem_layout_staged: cute.ComposedLayout
|
|
:param kc_smem_layout_staged: Shared memory layout for key tensor
|
|
:type kc_smem_layout_staged: cute.ComposedLayout
|
|
:param p_smem_layout_staged: Shared memory layout for probability matrix
|
|
:type p_smem_layout_staged: cute.ComposedLayout
|
|
:param vc_smem_layout_staged: Shared memory layout for value tensor
|
|
:type vc_smem_layout_staged: cute.ComposedLayout
|
|
:param cta_layout_vmnk: Layout for compute threads
|
|
:type cta_layout_vmnk: cute.Layout
|
|
:param tile_sched_params: Scheduling parameters for work distribution
|
|
:type tile_sched_params: MLAStaticTileSchedulerParams
|
|
:param SharedStorage: Shared storage for the kernel
|
|
:type SharedStorage: cutlass.Constexpr
|
|
"""
|
|
|
|
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
|
|
|
tidx, _, _ = cute.arch.thread_idx()
|
|
bidx, _, _ = cute.arch.block_idx()
|
|
mma_tile_coord_v = bidx % cute.size(tiled_mma_qk.thr_id.shape)
|
|
is_leader_cta = mma_tile_coord_v == 0
|
|
|
|
# Coords inside cluster
|
|
cta_rank_in_cluster = cute.arch.make_warp_uniform(
|
|
cute.arch.block_idx_in_cluster()
|
|
)
|
|
|
|
# Prefetch tma descriptor
|
|
if cutlass.const_expr(not self.is_cpasync):
|
|
if warp_idx == self.mma_warp_id:
|
|
cpasync.prefetch_descriptor(tma_atom_q_latent)
|
|
cpasync.prefetch_descriptor(tma_atom_q_rope)
|
|
cpasync.prefetch_descriptor(tma_atom_c_latent)
|
|
cpasync.prefetch_descriptor(tma_atom_c_rope)
|
|
cpasync.prefetch_descriptor(tma_atom_c_latent_transpose)
|
|
|
|
# Alloc
|
|
smem = utils.SmemAllocator()
|
|
storage = smem.allocate(SharedStorage)
|
|
|
|
tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr
|
|
tmem_holding_buf = storage.tmem_holding_buf
|
|
|
|
# Tensor memory dealloc barrier init
|
|
if warp_idx == self.mma_warp_id:
|
|
num_tmem_dealloc_threads = self.threads_per_warp * self.num_compute_warps
|
|
with cute.arch.elect_one():
|
|
cute.arch.mbarrier_init(tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads)
|
|
cute.arch.mbarrier_init_fence()
|
|
|
|
load_q_pipeline = self.make_and_init_load_qkv_pipeline(
|
|
storage.load_q_mbar_ptr.data_ptr(),
|
|
cta_layout_vmnk,
|
|
self.load_q_stage,
|
|
self.tma_copy_q_bytes,
|
|
self.is_cpasync,
|
|
)
|
|
load_kv_pipeline = self.make_and_init_load_qkv_pipeline(
|
|
storage.load_kv_mbar_ptr.data_ptr(),
|
|
cta_layout_vmnk,
|
|
self.load_kv_stage,
|
|
self.tma_copy_kc_bytes,
|
|
self.is_cpasync,
|
|
)
|
|
mma_s_pipeline = self.make_and_init_mma_s_pipeline(
|
|
storage.mma_s_mbar_ptr.data_ptr(), cta_layout_vmnk
|
|
)
|
|
p_mma_pipeline = self.make_and_init_p_mma_pipeline(
|
|
storage.p_mma_mbar_ptr.data_ptr(), cta_layout_vmnk
|
|
)
|
|
p_cor_pipeline = self.make_and_init_p_cor_pipeline(
|
|
storage.p_cor_mbar_ptr.data_ptr()
|
|
)
|
|
mma_o_pipeline = self.make_and_init_mma_o_pipeline(
|
|
storage.mma_o_mbar_ptr.data_ptr(), cta_layout_vmnk
|
|
)
|
|
if cutlass.const_expr(self.is_cpasync):
|
|
load_pt_pipeline = self.make_and_init_load_pt_pipeline(
|
|
storage.load_pt_mbar_ptr.data_ptr()
|
|
)
|
|
|
|
# Cluster arrive after barrier init
|
|
if cutlass.const_expr(cute.size(self.cluster_shape_mnk) > 1):
|
|
cute.arch.cluster_arrive_relaxed()
|
|
|
|
# Generate smem tensor Q/KC/VC/exchange
|
|
# (MMA, MMA_H, MMA_R, PIPE)
|
|
sQ = storage.smem_q.get_tensor(
|
|
q_smem_layout_staged.outer, swizzle=q_smem_layout_staged.inner
|
|
)
|
|
# (MMA, MMA_K, MMA_R, PIPE)
|
|
sKC = storage.smem_kc.get_tensor(
|
|
kc_smem_layout_staged.outer, swizzle=kc_smem_layout_staged.inner
|
|
)
|
|
# (MMA, MMA_D, MMA_K, PIPE)
|
|
# reuse smem
|
|
sVC_ptr = cute.recast_ptr(sKC.iterator, vc_smem_layout_staged.inner)
|
|
sVC = cute.make_tensor(sVC_ptr, vc_smem_layout_staged.outer)
|
|
# (MMA, MMA_H, MMA_K)
|
|
sP = storage.smem_p.get_tensor(
|
|
p_smem_layout_staged.outer, swizzle=p_smem_layout_staged.inner
|
|
)
|
|
# (compute_threads,)
|
|
softmax_smem_exchange = storage.softmax_smem_exchange.get_tensor(
|
|
cute.make_layout(self.num_compute_warps * self.threads_per_warp)
|
|
)
|
|
epilogue_smem_exchange = storage.epilogue_smem_exchange.get_tensor(
|
|
cute.make_layout(self.num_compute_warps * self.threads_per_warp)
|
|
)
|
|
|
|
#
|
|
# Cluster wait before tensor memory alloc
|
|
#
|
|
if cutlass.const_expr(cute.size(self.cluster_shape_mnk) > 1):
|
|
cute.arch.cluster_wait()
|
|
else:
|
|
pipeline.sync(barrier_id=4)
|
|
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Load warps, including page table and data tensors
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
if cutlass.const_expr(self.is_cpasync):
|
|
sPT = storage.smem_page_table.get_tensor(
|
|
cute.make_layout((self.mma_qk_tiler[1], self.load_pt_stage))
|
|
)
|
|
# Load page table when isasync is true
|
|
if warp_idx == self.load_pt_warp_id:
|
|
cute.arch.warpgroup_reg_dealloc(self.other_reg_num)
|
|
load_pt_producer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Producer, self.load_pt_stage
|
|
)
|
|
tile_sched = create_mla_static_tile_scheduler(
|
|
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
|
|
)
|
|
work_tile = tile_sched.initial_work_tile_info()
|
|
while work_tile.is_valid_tile:
|
|
blk_coord = work_tile.tile_idx
|
|
k_index, k_tile_count, local_split_kv = self.get_k_tile_count(
|
|
split_kv,
|
|
cache_seqs,
|
|
block_split_kvs,
|
|
blk_coord,
|
|
)
|
|
if k_tile_count > 0:
|
|
load_pt_common_params = SimpleNamespace(
|
|
blk_coord=blk_coord,
|
|
load_pt_pipeline=load_pt_pipeline,
|
|
mPT=mPT,
|
|
sPT=sPT,
|
|
tidx=tidx,
|
|
page_size=mCL.shape[0],
|
|
)
|
|
load_pt_producer_state = self.load_page_table(
|
|
load_pt_common_params,
|
|
k_index,
|
|
k_tile_count,
|
|
load_pt_producer_state,
|
|
)
|
|
tile_sched.advance_to_next_work()
|
|
work_tile = tile_sched.get_current_work()
|
|
load_pt_pipeline.producer_tail(load_pt_producer_state)
|
|
|
|
if (
|
|
warp_idx == self.load_cp_async_warp_ids[0]
|
|
or warp_idx == self.load_cp_async_warp_ids[1]
|
|
):
|
|
cute.arch.warpgroup_reg_dealloc(self.other_reg_num)
|
|
load_pt_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.load_pt_stage
|
|
)
|
|
load_pt_release_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.load_pt_stage
|
|
)
|
|
load_q_producer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Producer, self.load_q_stage
|
|
)
|
|
load_kv_producer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Producer, self.load_kv_stage
|
|
)
|
|
load_kv_commit_state = load_kv_producer_state.clone()
|
|
tile_sched = create_mla_static_tile_scheduler(
|
|
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
|
|
)
|
|
work_tile = tile_sched.initial_work_tile_info()
|
|
while work_tile.is_valid_tile:
|
|
blk_coord = work_tile.tile_idx
|
|
k_index, k_tile_count, local_split_kv = self.get_k_tile_count(
|
|
split_kv,
|
|
cache_seqs,
|
|
block_split_kvs,
|
|
blk_coord,
|
|
)
|
|
if k_tile_count > 0:
|
|
load_cpasync_common_params = SimpleNamespace(
|
|
blk_coord=blk_coord,
|
|
load_pt_pipeline=load_pt_pipeline,
|
|
load_q_pipeline=load_q_pipeline,
|
|
load_kv_pipeline=load_kv_pipeline,
|
|
sPT=sPT,
|
|
tidx=tidx,
|
|
page_size=mCL.shape[0],
|
|
)
|
|
load_cpasync_qk_params = SimpleNamespace(
|
|
tiled_mma_qk=tiled_mma_qk,
|
|
mQL=mQL,
|
|
mQR=mQR,
|
|
mCL=mCL,
|
|
mKR=mKR,
|
|
sQ=sQ,
|
|
sKC=sKC,
|
|
)
|
|
load_cpasync_v_params = SimpleNamespace(
|
|
tiled_mma_pv=tiled_mma_pv,
|
|
mCLT=mCLT,
|
|
sVC=sVC,
|
|
)
|
|
(
|
|
load_pt_consumer_state,
|
|
load_pt_release_state,
|
|
load_q_producer_state,
|
|
load_kv_producer_state,
|
|
load_kv_commit_state,
|
|
) = self.load_cpasync(
|
|
load_cpasync_common_params,
|
|
load_cpasync_qk_params,
|
|
load_cpasync_v_params,
|
|
k_index,
|
|
k_tile_count,
|
|
load_pt_consumer_state,
|
|
load_pt_release_state,
|
|
load_q_producer_state,
|
|
load_kv_producer_state,
|
|
load_kv_commit_state,
|
|
)
|
|
tile_sched.advance_to_next_work()
|
|
work_tile = tile_sched.get_current_work()
|
|
load_q_pipeline.producer_tail(load_q_producer_state)
|
|
load_kv_pipeline.producer_tail(load_kv_producer_state)
|
|
else:
|
|
if (
|
|
warp_idx >= self.empty_warp_ids[0]
|
|
and warp_idx <= self.empty_warp_ids[-1]
|
|
):
|
|
cute.arch.warpgroup_reg_dealloc(self.other_reg_num)
|
|
if warp_idx == self.load_tma_warp_id:
|
|
cute.arch.warpgroup_reg_dealloc(self.other_reg_num)
|
|
load_q_producer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Producer, self.load_q_stage
|
|
)
|
|
load_kv_producer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Producer, self.load_kv_stage
|
|
)
|
|
tile_sched = create_mla_static_tile_scheduler(
|
|
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
|
|
)
|
|
work_tile = tile_sched.initial_work_tile_info()
|
|
while work_tile.is_valid_tile:
|
|
blk_coord = work_tile.tile_idx
|
|
k_index, k_tile_count, local_split_kv = self.get_k_tile_count(
|
|
split_kv,
|
|
cache_seqs,
|
|
block_split_kvs,
|
|
blk_coord,
|
|
)
|
|
if k_tile_count > 0:
|
|
# Construct fixed common/tma_qk/tma_pv params for load_tma
|
|
tma_common_params = SimpleNamespace(
|
|
blk_coord=blk_coord,
|
|
local_split_kv=local_split_kv,
|
|
load_q_pipeline=load_q_pipeline,
|
|
load_kv_pipeline=load_kv_pipeline,
|
|
mPT=mPT,
|
|
)
|
|
tma_qk_params = SimpleNamespace(
|
|
tiled_mma_qk=tiled_mma_qk,
|
|
tma_atom_q_latent=tma_atom_q_latent,
|
|
tma_atom_q_rope=tma_atom_q_rope,
|
|
tma_atom_c_latent=tma_atom_c_latent,
|
|
tma_atom_c_rope=tma_atom_c_rope,
|
|
mQL=mQL,
|
|
mQR=mQR,
|
|
mCL=mCL,
|
|
mKR=mKR,
|
|
sQ=sQ,
|
|
sKC=sKC,
|
|
)
|
|
tma_pv_params = SimpleNamespace(
|
|
tiled_mma_pv=tiled_mma_pv,
|
|
tma_atom_c_latent_transpose=tma_atom_c_latent_transpose,
|
|
mCL=mCL,
|
|
mKR=mKR,
|
|
mCLT=mCLT,
|
|
sVC=sVC,
|
|
)
|
|
# Load tma
|
|
load_q_producer_state, load_kv_producer_state = self.load_tma(
|
|
tma_common_params,
|
|
tma_qk_params,
|
|
tma_pv_params,
|
|
k_index,
|
|
k_tile_count,
|
|
load_q_producer_state,
|
|
load_kv_producer_state,
|
|
)
|
|
tile_sched.advance_to_next_work()
|
|
work_tile = tile_sched.get_current_work()
|
|
|
|
load_q_pipeline.producer_tail(load_q_producer_state)
|
|
load_kv_pipeline.producer_tail(load_kv_producer_state)
|
|
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# MMA warp
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
if warp_idx == self.mma_warp_id:
|
|
cute.arch.warpgroup_reg_dealloc(self.other_reg_num)
|
|
# Alloc tensor memory buffer
|
|
cute.arch.alloc_tmem(
|
|
cute.arch.SM100_TMEM_CAPACITY_COLUMNS,
|
|
tmem_holding_buf,
|
|
is_two_cta=self.use_2cta_instrs,
|
|
)
|
|
|
|
# sync with compute warp before tmem ptr is retrieved
|
|
self.tmem_ptr_sync_bar.arrive()
|
|
|
|
# Retrieving tensor memory ptr and make accumulator tensor
|
|
tmem_ptr = cute.arch.retrieve_tmem_ptr(
|
|
self.acc_dtype,
|
|
alignment=16,
|
|
ptr_to_buffer_holding_addr=tmem_holding_buf,
|
|
)
|
|
|
|
load_q_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.load_q_stage
|
|
)
|
|
load_kv_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.load_kv_stage
|
|
)
|
|
mma_s_producer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Producer, self.mma_s_stage
|
|
)
|
|
p_mma_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.p_mma_stage
|
|
)
|
|
mma_o_producer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Producer, self.mma_o_stage
|
|
)
|
|
tile_sched = create_mla_static_tile_scheduler(
|
|
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
|
|
)
|
|
work_tile = tile_sched.initial_work_tile_info()
|
|
while work_tile.is_valid_tile:
|
|
blk_coord = work_tile.tile_idx
|
|
k_index, k_tile_count, local_split_kv = self.get_k_tile_count(
|
|
split_kv, cache_seqs, block_split_kvs, blk_coord
|
|
)
|
|
if k_tile_count > 0:
|
|
mma_common_params = SimpleNamespace(
|
|
blk_coord=blk_coord,
|
|
local_split_kv=local_split_kv,
|
|
load_q_pipeline=load_q_pipeline,
|
|
load_kv_pipeline=load_kv_pipeline,
|
|
tmem_ptr=tmem_ptr,
|
|
is_leader_cta=is_leader_cta,
|
|
L=mCL.shape[1],
|
|
)
|
|
mma_qk_params = SimpleNamespace(
|
|
mma_s_pipeline=mma_s_pipeline,
|
|
sQ=sQ,
|
|
sKC=sKC,
|
|
)
|
|
mma_pv_params = SimpleNamespace(
|
|
p_mma_pipeline=p_mma_pipeline,
|
|
mma_o_pipeline=mma_o_pipeline,
|
|
sP=sP,
|
|
sVC=sVC,
|
|
)
|
|
(
|
|
tiled_mma_qk,
|
|
tiled_mma_pv,
|
|
load_q_consumer_state,
|
|
load_kv_consumer_state,
|
|
mma_s_producer_state,
|
|
p_mma_consumer_state,
|
|
mma_o_producer_state,
|
|
) = self.mma(
|
|
mma_common_params,
|
|
mma_qk_params,
|
|
mma_pv_params,
|
|
k_tile_count,
|
|
tiled_mma_qk,
|
|
tiled_mma_pv,
|
|
load_q_consumer_state,
|
|
load_kv_consumer_state,
|
|
mma_s_producer_state,
|
|
p_mma_consumer_state,
|
|
mma_o_producer_state,
|
|
)
|
|
tile_sched.advance_to_next_work()
|
|
work_tile = tile_sched.get_current_work()
|
|
|
|
mma_s_pipeline.producer_tail(mma_s_producer_state)
|
|
mma_o_pipeline.producer_tail(mma_o_producer_state)
|
|
|
|
cute.arch.relinquish_tmem_alloc_permit(is_two_cta=self.use_2cta_instrs)
|
|
# Dealloc the tensor memory buffer
|
|
cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0)
|
|
|
|
cute.arch.dealloc_tmem(
|
|
tmem_ptr,
|
|
cute.arch.SM100_TMEM_CAPACITY_COLUMNS,
|
|
is_two_cta=self.use_2cta_instrs,
|
|
)
|
|
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Compute warp
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
if (
|
|
warp_idx >= self.compute_warp_ids[0]
|
|
and warp_idx <= self.compute_warp_ids[-1]
|
|
):
|
|
cute.arch.warpgroup_reg_alloc(self.softmax_reg_num)
|
|
mma_s_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.mma_s_stage
|
|
)
|
|
p_mma_producer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Producer, self.p_mma_stage
|
|
)
|
|
p_cor_producer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Producer, self.p_cor_stage
|
|
)
|
|
mma_o_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.mma_o_stage
|
|
)
|
|
# sync with mma warp before retrieving tmem ptr
|
|
self.tmem_ptr_sync_bar.wait()
|
|
|
|
tmem_ptr = cute.arch.retrieve_tmem_ptr(
|
|
self.acc_dtype,
|
|
alignment=16,
|
|
ptr_to_buffer_holding_addr=tmem_holding_buf,
|
|
)
|
|
|
|
tile_sched = create_mla_static_tile_scheduler(
|
|
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
|
|
)
|
|
work_tile = tile_sched.initial_work_tile_info()
|
|
while work_tile.is_valid_tile:
|
|
blk_coord = work_tile.tile_idx
|
|
k_index, k_tile_count, local_split_kv = self.get_k_tile_count(
|
|
split_kv, cache_seqs, block_split_kvs, blk_coord
|
|
)
|
|
if k_tile_count > 0:
|
|
compute_common_params = SimpleNamespace(
|
|
blk_coord=blk_coord,
|
|
split_kv=split_kv,
|
|
local_split_kv=local_split_kv,
|
|
smem_exchange=softmax_smem_exchange,
|
|
mAccO=mAccO,
|
|
mO=mO,
|
|
K=cache_seqs[blk_coord[2]],
|
|
L=mCL.shape[1],
|
|
tmem_ptr=tmem_ptr,
|
|
tidx=tidx,
|
|
p_cor_pipeline=p_cor_pipeline,
|
|
)
|
|
compute_softmax_params = SimpleNamespace(
|
|
tiled_mma_qk=tiled_mma_qk,
|
|
sP=sP,
|
|
mma_s_pipeline=mma_s_pipeline,
|
|
p_mma_pipeline=p_mma_pipeline,
|
|
softmax_scale_log2=softmax_scale_log2,
|
|
)
|
|
mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state = (
|
|
self.compute(
|
|
compute_common_params,
|
|
compute_softmax_params,
|
|
k_index=k_index,
|
|
k_tile_count=k_tile_count,
|
|
mma_s_consumer_state=mma_s_consumer_state,
|
|
p_mma_producer_state=p_mma_producer_state,
|
|
p_cor_producer_state=p_cor_producer_state,
|
|
)
|
|
)
|
|
tile_sched.advance_to_next_work()
|
|
work_tile = tile_sched.get_current_work()
|
|
p_cor_pipeline.producer_tail(p_cor_producer_state)
|
|
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Correction warp
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
if (
|
|
warp_idx >= self.correction_warp_ids[0]
|
|
and warp_idx <= self.correction_warp_ids[-1]
|
|
):
|
|
cute.arch.warpgroup_reg_alloc(self.correction_reg_num)
|
|
p_cor_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.p_cor_stage
|
|
)
|
|
mma_o_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.mma_o_stage
|
|
)
|
|
# sync with mma warp before retrieving tmem ptr
|
|
self.tmem_ptr_sync_bar.wait()
|
|
|
|
tmem_ptr = cute.arch.retrieve_tmem_ptr(
|
|
self.acc_dtype,
|
|
alignment=16,
|
|
ptr_to_buffer_holding_addr=tmem_holding_buf,
|
|
)
|
|
|
|
tile_sched = create_mla_static_tile_scheduler(
|
|
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
|
|
)
|
|
work_tile = tile_sched.initial_work_tile_info()
|
|
while work_tile.is_valid_tile:
|
|
blk_coord = work_tile.tile_idx
|
|
k_index, k_tile_count, local_split_kv = self.get_k_tile_count(
|
|
split_kv, cache_seqs, block_split_kvs, blk_coord
|
|
)
|
|
if k_tile_count > 0:
|
|
compute_common_params = SimpleNamespace(
|
|
blk_coord=blk_coord,
|
|
split_kv=split_kv,
|
|
local_split_kv=local_split_kv,
|
|
smem_exchange=epilogue_smem_exchange,
|
|
mAccO=mAccO,
|
|
mO=mO,
|
|
K=cache_seqs[blk_coord[2]],
|
|
L=mCL.shape[1],
|
|
H=mQL.shape[0],
|
|
tmem_ptr=tmem_ptr,
|
|
tidx=tidx,
|
|
tiled_mma_pv=tiled_mma_pv,
|
|
p_cor_pipeline=p_cor_pipeline,
|
|
mma_o_pipeline=mma_o_pipeline,
|
|
)
|
|
compute_epilogue_params = SimpleNamespace(
|
|
output_scale=output_scale,
|
|
softmax_scale_log2=softmax_scale_log2,
|
|
mAccLSE=mAccLSE,
|
|
mLSE=mLSE,
|
|
)
|
|
p_cor_consumer_state, mma_o_consumer_state = self.correction(
|
|
compute_common_params,
|
|
compute_epilogue_params,
|
|
k_tile_count=k_tile_count,
|
|
p_cor_consumer_state=p_cor_consumer_state,
|
|
mma_o_consumer_state=mma_o_consumer_state,
|
|
)
|
|
tile_sched.advance_to_next_work()
|
|
work_tile = tile_sched.get_current_work()
|
|
# Arrive for the tensor memory deallocation barrier
|
|
cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1)
|
|
|
|
return
|
|
|
|
@cute.kernel
|
|
def reduction_kernel(
|
|
self,
|
|
mO: cute.Tensor,
|
|
mLSE: cute.Tensor,
|
|
mAccO: cute.Tensor,
|
|
mAccLSE: cute.Tensor,
|
|
split_kv: cutlass.Int32,
|
|
cache_seqs: cute.Tensor,
|
|
block_split_kvs: cute.Tensor,
|
|
):
|
|
"""The reduction kernel for Multi-Head Latent Attention (MLA) that combines intermediate results
|
|
from multiple split_kv blocks into final outputs.
|
|
|
|
:param mO: Output tensor for storing final results
|
|
:type mO: cute.Tensor
|
|
:param mLSE: Log-sum-exp tensor for storing final LSE values
|
|
:type mLSE: cute.Tensor
|
|
:param mAccO: Accumulated output tensor from split_kv blocks
|
|
:type mAccO: cute.Tensor
|
|
:param mAccLSE: Accumulated LSE tensor from split_kv blocks
|
|
:type mAccLSE: cute.Tensor
|
|
:param split_kv: Number of split_kv blocks
|
|
:type split_kv: cutlass.Int32
|
|
:param cache_seqs: Cache sequence lengths tensor
|
|
:type cache_seqs: cute.Tensor
|
|
:param block_split_kvs: Per-block split_kv values tensor (for variable split_kv)
|
|
:type block_split_kvs: cute.Tensor
|
|
"""
|
|
bidx, _, bidz = cute.arch.block_idx()
|
|
tidx, _, _ = cute.arch.thread_idx()
|
|
blk_coord = (bidx, 0, bidz)
|
|
local_split_kv = (
|
|
block_split_kvs[blk_coord[2]] if self.is_var_split_kv else split_kv
|
|
)
|
|
k_tile_total = cute.ceil_div(cache_seqs[blk_coord[2]], self.mma_qk_tiler[1])
|
|
k_tile_per_cta = cute.ceil_div(k_tile_total, local_split_kv)
|
|
local_split_kv = cute.ceil_div(k_tile_total, k_tile_per_cta)
|
|
|
|
# Alloc shared memory
|
|
smem = utils.SmemAllocator()
|
|
storage = smem.allocate(MAX_SPLITS * self.acc_dtype.width // 8, 16)
|
|
lse_scale_ptr = cute.recast_ptr(storage, dtype=self.acc_dtype)
|
|
smem_lse_scale = cute.make_tensor(lse_scale_ptr, cute.make_layout(MAX_SPLITS))
|
|
|
|
gLSE = mAccLSE[blk_coord[0], None, blk_coord[2]]
|
|
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
|
if warp_idx == 0:
|
|
# calculate the global lse and exp ^ (local_lse - global_lse)
|
|
lse_per_thread = cute.ceil_div(MAX_SPLITS, self.threads_per_warp)
|
|
|
|
local_lse = cute.make_rmem_tensor(
|
|
cute.make_layout(lse_per_thread), self.lse_dtype
|
|
)
|
|
lse_max = -self.lse_dtype.inf
|
|
# find the max lse
|
|
for i in cutlass.range_constexpr(lse_per_thread):
|
|
split_kv_idx = tidx + i * self.threads_per_warp
|
|
local_lse[i] = (
|
|
gLSE[split_kv_idx]
|
|
if cute.elem_less(split_kv_idx, local_split_kv)
|
|
else -self.lse_dtype.inf
|
|
)
|
|
# reduce the local lse
|
|
lse_max = cute.arch.fmax(lse_max, local_lse[i])
|
|
lse_max = cute.arch.warp_reduction_max(lse_max)
|
|
lse_max = lse_max if lse_max != -self.lse_dtype.inf else 0.0
|
|
# calculate sum_lse
|
|
sum_lse = 0.0
|
|
for i in cutlass.range_constexpr(lse_per_thread):
|
|
sum_lse += cute.math.exp2(local_lse[i] - lse_max, fastmath=True)
|
|
sum_lse = cute.arch.warp_reduction_sum(sum_lse)
|
|
# calculate the global_lse
|
|
global_lse = (
|
|
lse_max + cute.math.log2(sum_lse, fastmath=True)
|
|
if not sum_lse == self.lse_dtype(0.0) or sum_lse != sum_lse
|
|
else self.lse_dtype.inf
|
|
)
|
|
if tidx == 0:
|
|
mLSE[blk_coord[0], blk_coord[2]] = global_lse
|
|
# store the scale to shared memory
|
|
for i in cutlass.range_constexpr(lse_per_thread):
|
|
split_kv_idx = tidx + i * self.threads_per_warp
|
|
if cute.elem_less(split_kv_idx, local_split_kv):
|
|
smem_lse_scale[split_kv_idx] = cute.math.exp2(
|
|
local_lse[i] - global_lse, fastmath=True
|
|
)
|
|
|
|
pipeline.sync(barrier_id=4)
|
|
|
|
elements_per_thread = cute.ceil_div(
|
|
self.latent_dim, self.threads_per_warp * self.num_compute_warps
|
|
)
|
|
gAccO = mAccO[blk_coord[0], None, None, blk_coord[2]]
|
|
rAccO = cute.make_rmem_tensor(
|
|
cute.make_layout(elements_per_thread), self.acc_dtype
|
|
)
|
|
rO = cute.make_rmem_tensor(cute.make_layout(elements_per_thread), self.o_dtype)
|
|
rAccO.fill(0.0)
|
|
for i in range(local_split_kv):
|
|
for j in cutlass.range_constexpr(elements_per_thread):
|
|
element_idx = tidx + j * self.threads_per_warp * self.num_compute_warps
|
|
rAccO[j] += gAccO[i, element_idx] * smem_lse_scale[i]
|
|
rO.store(rAccO.load().to(self.o_dtype))
|
|
for j in cutlass.range_constexpr(elements_per_thread):
|
|
element_idx = tidx + j * self.threads_per_warp * self.num_compute_warps
|
|
mO[blk_coord[0], element_idx, blk_coord[2]] = rO[j]
|
|
return
|
|
|
|
@staticmethod
|
|
def get_split_kv(
|
|
B: int, K: int, mma_qk_tiler_mn: tuple, max_active_blocks: int
|
|
) -> int:
|
|
"""Get the proper split_kv value for the MLA kernel based on parameters.
|
|
|
|
:param B: Batch size
|
|
:type B: int
|
|
:param K: Sequence length
|
|
:type K: int
|
|
:param mma_qk_tiler_mn: MLA tiling parameters
|
|
:type mma_qk_tiler_mn: tuple
|
|
:param max_active_blocks: Maximum number of active blocks
|
|
:type max_active_blocks: int
|
|
:return: Split_kv value
|
|
:rtype: int
|
|
"""
|
|
max_splits = ceil_div(K, mma_qk_tiler_mn[1])
|
|
blocks_per_batch = max(1, max_active_blocks // B)
|
|
split_heur = min(max_splits, blocks_per_batch)
|
|
k_waves = ceil_div(max_splits, split_heur)
|
|
split_wave_aware = ceil_div(max_splits, k_waves)
|
|
return split_wave_aware
|
|
|
|
@cute.jit
|
|
def get_k_tile_count(
|
|
self,
|
|
split_kv: cutlass.Int32,
|
|
cache_seqs: cute.Tensor,
|
|
block_split_kvs: cute.Tensor,
|
|
blk_coord: cute.Coord,
|
|
) -> tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32]:
|
|
"""Get the current k_index, k_tile_count, and local split_kv value for the MLA kernel.
|
|
|
|
:param split_kv: Split_kv value
|
|
:type split_kv: cutlass.Int32
|
|
:param cache_seqs: Cache sequence lengths tensor
|
|
:type cache_seqs: cute.Tensor
|
|
:param block_split_kvs: Per-block split_kv values tensor
|
|
:type block_split_kvs: cute.Tensor
|
|
:param blk_coord: Block coordinate
|
|
:type blk_coord: cute.Coord
|
|
:return: k_index, k_tile_count, split_kv
|
|
:rtype: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32]
|
|
"""
|
|
K = cache_seqs[blk_coord[2]]
|
|
if cutlass.const_expr(self.is_var_split_kv):
|
|
split_kv = block_split_kvs[blk_coord[2]]
|
|
|
|
k_tile_total = cute.ceil_div(K, self.mma_qk_tiler[1])
|
|
k_tile_per_cta = cute.ceil_div(k_tile_total, split_kv)
|
|
k_index = blk_coord[3] * k_tile_per_cta
|
|
k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index)
|
|
return k_index, k_tile_count, split_kv
|
|
|
|
@cute.jit
|
|
def load_page_table(
|
|
self,
|
|
common_params: SimpleNamespace,
|
|
k_index: cutlass.Int32,
|
|
k_tile_count: cutlass.Int32,
|
|
load_pt_producer_state: pipeline.PipelineState,
|
|
) -> pipeline.PipelineState:
|
|
"""Load warp to load page table. Updates the load pt producer state.
|
|
|
|
:param common_params: The common parameters
|
|
:type common_params: SimpleNamespace
|
|
:param k_index: The k index
|
|
:type k_index: cutlass.Int32
|
|
:param k_tile_count: The k tile count
|
|
:type k_tile_count: cutlass.Int32
|
|
:param load_pt_producer_state: The load pt producer state
|
|
:type load_pt_producer_state: pipeline.PipelineState
|
|
|
|
:return: The load pt producer state
|
|
:rtype: pipeline.PipelineState
|
|
"""
|
|
mPT = common_params.mPT[None, common_params.blk_coord[2]]
|
|
page_per_tile = self.mma_qk_tiler[1] >> cute.arch.log2_of_pow2_int(
|
|
common_params.page_size
|
|
)
|
|
tidx = common_params.tidx % self.threads_per_warp
|
|
|
|
load_pt_pipeline = common_params.load_pt_pipeline
|
|
while k_tile_count > 0:
|
|
load_pt_pipeline.producer_acquire(load_pt_producer_state)
|
|
|
|
elem_per_thread = cute.ceil_div(page_per_tile, self.threads_per_warp)
|
|
|
|
# atom_async_copy: async copy atom for page table load
|
|
atom_async_copy = cute.make_copy_atom(
|
|
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS),
|
|
cutlass.Int32,
|
|
num_bits_per_copy=cutlass.Int32.width,
|
|
)
|
|
mPT_for_copy = cute.flat_divide(mPT, (1,))
|
|
sPT_for_copy = cute.flat_divide(common_params.sPT, (1,))
|
|
# elem_per_thread is a dynamic value depends on the page_size setting.
|
|
for i in range(elem_per_thread):
|
|
idx = i * self.threads_per_warp + tidx
|
|
if cute.elem_less(
|
|
k_index * page_per_tile + idx, mPT.shape[0]
|
|
) and cute.elem_less(idx, page_per_tile):
|
|
cute.copy(
|
|
atom_async_copy,
|
|
mPT_for_copy[None, k_index * page_per_tile + idx],
|
|
sPT_for_copy[None, idx, load_pt_producer_state.index],
|
|
)
|
|
else:
|
|
sPT_for_copy[None, idx, load_pt_producer_state.index].fill(0)
|
|
mbar_ptr = load_pt_pipeline.producer_get_barrier(load_pt_producer_state)
|
|
load_pt_pipeline.producer_commit(load_pt_producer_state)
|
|
load_pt_producer_state.advance()
|
|
k_index += 1
|
|
k_tile_count -= 1
|
|
|
|
return load_pt_producer_state
|
|
|
|
@cute.jit
|
|
def load_cpasync(
|
|
self,
|
|
common_params: SimpleNamespace,
|
|
qk_params: SimpleNamespace,
|
|
v_params: SimpleNamespace,
|
|
k_index: cutlass.Int32,
|
|
k_tile_count: cutlass.Int32,
|
|
load_pt_consumer_state: pipeline.PipelineState,
|
|
load_pt_release_state: pipeline.PipelineState,
|
|
load_q_producer_state: pipeline.PipelineState,
|
|
load_kv_producer_state: pipeline.PipelineState,
|
|
load_kv_commit_state: pipeline.PipelineState,
|
|
) -> tuple[
|
|
pipeline.PipelineState,
|
|
pipeline.PipelineState,
|
|
pipeline.PipelineState,
|
|
pipeline.PipelineState,
|
|
pipeline.PipelineState,
|
|
]:
|
|
"""Load warp to load cpasync. Updates the load cpasync producer state.
|
|
|
|
:param common_params: The common parameters
|
|
:type common_params: SimpleNamespace
|
|
:param load_pt_consumer_state: The load pt consumer state
|
|
:type load_pt_consumer_state: pipeline.PipelineState
|
|
:param load_pt_release_state: The load pt release state
|
|
:type load_pt_release_state: pipeline.PipelineState
|
|
:param load_q_producer_state: The load q producer state
|
|
:type load_q_producer_state: pipeline.PipelineState
|
|
:param load_kv_producer_state: The load kv producer state
|
|
:type load_kv_producer_state: pipeline.PipelineState
|
|
:param load_kv_commit_state: The load kv commit state
|
|
:type load_kv_commit_state: pipeline.PipelineState
|
|
|
|
:return: The load pt consumer state, the load pt release state, the load q producer state, the load kv producer state, the load kv commit state
|
|
:rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState]
|
|
"""
|
|
|
|
tidx = (
|
|
common_params.tidx - self.threads_per_warp * self.load_cp_async_warp_ids[0]
|
|
)
|
|
|
|
# slice view the the global tensors for cpasync, their coords are from counting tensor coord.
|
|
mCL_for_slice = cute.make_tensor(
|
|
qk_params.mCL.iterator,
|
|
cute.make_layout(
|
|
(
|
|
(qk_params.mCL.shape[0], qk_params.mCL.shape[2]),
|
|
qk_params.mCL.shape[1],
|
|
),
|
|
stride=(
|
|
(qk_params.mCL.stride[0], qk_params.mCL.stride[2]),
|
|
qk_params.mCL.stride[1],
|
|
),
|
|
),
|
|
)
|
|
mKR_for_slice = cute.make_tensor(
|
|
qk_params.mKR.iterator,
|
|
cute.make_layout(
|
|
(
|
|
(qk_params.mKR.shape[0], qk_params.mKR.shape[2]),
|
|
qk_params.mKR.shape[1],
|
|
),
|
|
stride=(
|
|
(qk_params.mKR.stride[0], qk_params.mKR.stride[2]),
|
|
qk_params.mKR.stride[1],
|
|
),
|
|
),
|
|
)
|
|
mCLT_for_slice = cute.make_tensor(
|
|
v_params.mCLT.iterator,
|
|
cute.make_layout(
|
|
(
|
|
v_params.mCLT.shape[0],
|
|
(v_params.mCLT.shape[1], v_params.mCLT.shape[2]),
|
|
),
|
|
stride=(
|
|
v_params.mCLT.stride[0],
|
|
(v_params.mCLT.stride[1], v_params.mCLT.stride[2]),
|
|
),
|
|
),
|
|
)
|
|
|
|
# make identity tensor for partition
|
|
mCL_for_partition = cute.make_identity_tensor(
|
|
(qk_params.mCL.shape[0] * qk_params.mCL.shape[2], qk_params.mCL.shape[1])
|
|
)
|
|
mKR_for_partition = cute.make_identity_tensor(
|
|
(qk_params.mKR.shape[0] * qk_params.mKR.shape[2], qk_params.mKR.shape[1])
|
|
)
|
|
mCLT_for_partition = cute.make_identity_tensor(
|
|
(v_params.mCLT.shape[0], v_params.mCLT.shape[1] * v_params.mCLT.shape[2])
|
|
)
|
|
|
|
# Flatten divide and partition global tensors for QK TMA load
|
|
# (bM, bK, rM, rK, rL)
|
|
mma_qk_tiler_mk = cute.select(self.mma_qk_tiler, mode=[0, 2])
|
|
gQL = cute.flat_divide(qk_params.mQL, mma_qk_tiler_mk)
|
|
gQR = cute.flat_divide(qk_params.mQR, mma_qk_tiler_mk)
|
|
|
|
mma_qk_tiler_nk = cute.select(self.mma_qk_tiler, mode=[1, 2])
|
|
gCL = cute.flat_divide(mCL_for_partition, mma_qk_tiler_nk)
|
|
gKR = cute.flat_divide(mKR_for_partition, mma_qk_tiler_nk)
|
|
|
|
thr_mma_qk = qk_params.tiled_mma_qk.get_slice(
|
|
common_params.blk_coord[0] % cute.size(qk_params.tiled_mma_qk.thr_id)
|
|
)
|
|
tSgQL = thr_mma_qk.partition_A(gQL)
|
|
tSgQR = thr_mma_qk.partition_A(gQR)
|
|
|
|
tSgCL = thr_mma_qk.partition_B(gCL)
|
|
tSgKR = thr_mma_qk.partition_B(gKR)
|
|
|
|
# create cpasync tiled copy qk
|
|
cpasync_bits = 128
|
|
# thread for copy
|
|
thread = self.threads_per_warp * self.num_load_warps
|
|
# Value for copy
|
|
value = cpasync_bits // self.q_dtype.width
|
|
cpasync_q_tiled_copy = cute.make_cotiled_copy(
|
|
cute.make_copy_atom(
|
|
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
|
|
self.q_dtype,
|
|
num_bits_per_copy=cpasync_bits,
|
|
),
|
|
cute.make_ordered_layout((thread, value), (1, 0)),
|
|
qk_params.sQ[None, None, None, 0].layout,
|
|
)
|
|
cpasync_kc_tiled_copy = cute.make_cotiled_copy(
|
|
cute.make_copy_atom(
|
|
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
|
|
self.q_dtype,
|
|
num_bits_per_copy=cpasync_bits,
|
|
),
|
|
cute.make_ordered_layout((thread, value), (1, 0)),
|
|
qk_params.sKC[None, None, None, 0].layout,
|
|
)
|
|
cpasync_q_thr_copy = cpasync_q_tiled_copy.get_slice(tidx)
|
|
cpasync_kc_thr_copy = cpasync_kc_tiled_copy.get_slice(tidx)
|
|
# copy async partition
|
|
tQgQL = cpasync_q_thr_copy.partition_S(tSgQL)
|
|
tQgQR = cpasync_q_thr_copy.partition_S(tSgQR)
|
|
tQsQ = cpasync_q_thr_copy.partition_D(qk_params.sQ)
|
|
|
|
tKCgCL = cpasync_kc_thr_copy.partition_S(tSgCL)
|
|
tKCgKR = cpasync_kc_thr_copy.partition_S(tSgKR)
|
|
tKCsKC = cpasync_kc_thr_copy.partition_D(qk_params.sKC)
|
|
|
|
gCLT = cute.flat_divide(
|
|
mCLT_for_partition, cute.select(self.mma_pv_tiler, mode=[1, 2])
|
|
)
|
|
thr_mma_pv = v_params.tiled_mma_pv.get_slice(
|
|
common_params.blk_coord[0] % cute.size(v_params.tiled_mma_pv.thr_id)
|
|
)
|
|
tOgCLT = thr_mma_pv.partition_B(gCLT)
|
|
|
|
cpasync_v_tiled_copy = cute.make_cotiled_copy(
|
|
cute.make_copy_atom(
|
|
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
|
|
self.q_dtype,
|
|
num_bits_per_copy=cpasync_bits,
|
|
),
|
|
cute.make_ordered_layout((thread, value), (1, 0)),
|
|
v_params.sVC[None, None, None, 0].layout,
|
|
)
|
|
cpasync_v_thr_copy = cpasync_v_tiled_copy.get_slice(tidx)
|
|
tVCgCLT = cpasync_v_thr_copy.partition_S(tOgCLT)
|
|
tVCsVC = cpasync_v_thr_copy.partition_D(v_params.sVC)
|
|
|
|
# Use to record the in-flight cpasync stage count, wait and producer commit until `load_kv_stage - 1` cpasync arrive
|
|
copy_in_flight_count = cutlass.Int32(0)
|
|
|
|
qk_params.tiled_copy_q = cpasync_q_tiled_copy
|
|
qk_params.tiled_copy_kc = cpasync_kc_tiled_copy
|
|
qk_params.mCL_for_slice = mCL_for_slice
|
|
qk_params.mKR_for_slice = mKR_for_slice
|
|
qk_params.tQgQL = tQgQL
|
|
qk_params.tQgQR = tQgQR
|
|
qk_params.tQsQ = tQsQ
|
|
qk_params.tKCgCL = tKCgCL
|
|
qk_params.tKCgKR = tKCgKR
|
|
qk_params.tKCsKC = tKCsKC
|
|
|
|
v_params.tiled_copy_vc = cpasync_v_tiled_copy
|
|
v_params.tVCgCLT = tVCgCLT
|
|
v_params.tVCsVC = tVCsVC
|
|
v_params.mCLT_for_slice = mCLT_for_slice
|
|
|
|
# first load qk latent/rope
|
|
(
|
|
load_pt_consumer_state,
|
|
load_q_producer_state,
|
|
load_kv_producer_state,
|
|
load_kv_commit_state,
|
|
copy_in_flight_count,
|
|
) = self.load_cpasync_qk_one_k_tile(
|
|
common_params,
|
|
qk_params,
|
|
k_index,
|
|
load_pt_consumer_state,
|
|
load_q_producer_state,
|
|
load_kv_producer_state,
|
|
load_kv_commit_state,
|
|
copy_in_flight_count,
|
|
load_q=True,
|
|
)
|
|
|
|
k_index += 1
|
|
k_tile_count -= 1
|
|
|
|
# mainloop, load qk and v
|
|
while k_tile_count > 0:
|
|
(
|
|
load_pt_consumer_state,
|
|
load_q_producer_state,
|
|
load_kv_producer_state,
|
|
load_kv_commit_state,
|
|
copy_in_flight_count,
|
|
) = self.load_cpasync_qk_one_k_tile(
|
|
common_params,
|
|
qk_params,
|
|
k_index,
|
|
load_pt_consumer_state,
|
|
load_q_producer_state,
|
|
load_kv_producer_state,
|
|
load_kv_commit_state,
|
|
copy_in_flight_count,
|
|
load_q=False,
|
|
)
|
|
(
|
|
load_pt_release_state,
|
|
load_kv_producer_state,
|
|
load_kv_commit_state,
|
|
copy_in_flight_count,
|
|
) = self.load_cpasync_v_one_k_tile(
|
|
common_params,
|
|
v_params,
|
|
k_index - 1,
|
|
load_pt_release_state,
|
|
load_kv_producer_state,
|
|
load_kv_commit_state,
|
|
copy_in_flight_count,
|
|
)
|
|
k_index += 1
|
|
k_tile_count -= 1
|
|
|
|
# load last tile of v
|
|
(
|
|
load_pt_release_state,
|
|
load_kv_producer_state,
|
|
load_kv_commit_state,
|
|
copy_in_flight_count,
|
|
) = self.load_cpasync_v_one_k_tile(
|
|
common_params,
|
|
v_params,
|
|
k_index - 1,
|
|
load_pt_release_state,
|
|
load_kv_producer_state,
|
|
load_kv_commit_state,
|
|
copy_in_flight_count,
|
|
)
|
|
|
|
padding_in_flight = 0
|
|
while copy_in_flight_count + padding_in_flight < self.load_kv_stage - 1:
|
|
padding_in_flight += 1
|
|
cute.arch.cp_async_commit_group()
|
|
# wait for previous cpasync arrive
|
|
load_kv_pipeline = common_params.load_kv_pipeline
|
|
while copy_in_flight_count > 0:
|
|
cute.arch.cp_async_commit_group()
|
|
cute.arch.cp_async_wait_group(self.load_kv_stage - 1)
|
|
load_kv_pipeline.producer_commit(load_kv_commit_state)
|
|
load_kv_commit_state.advance()
|
|
copy_in_flight_count -= 1
|
|
|
|
# wait all cpasync arrive
|
|
cute.arch.cp_async_wait_group(0)
|
|
return (
|
|
load_pt_consumer_state,
|
|
load_pt_release_state,
|
|
load_q_producer_state,
|
|
load_kv_producer_state,
|
|
load_kv_commit_state,
|
|
)
|
|
|
|
@cute.jit
|
|
def load_cpasync_one_smem_stage(
|
|
self,
|
|
common_params: SimpleNamespace,
|
|
load_q_producer_state: pipeline.PipelineState,
|
|
load_kv_producer_state: pipeline.PipelineState,
|
|
load_kv_commit_state: pipeline.PipelineState,
|
|
copy_func: Callable,
|
|
copy_in_flight_count: cutlass.Int32,
|
|
load_q: bool,
|
|
) -> tuple[
|
|
pipeline.PipelineState,
|
|
pipeline.PipelineState,
|
|
pipeline.PipelineState,
|
|
cutlass.Int32,
|
|
]:
|
|
"""Load one smem stage of cpasync. Reused for qkv load stages. Updates the load cpasync producer state.
|
|
|
|
:param common_params: The common parameters
|
|
:type common_params: SimpleNamespace
|
|
:param load_pt_consumer_state: The load pt consumer state
|
|
:type load_pt_consumer_state: pipeline.PipelineState
|
|
:param load_q_producer_state: The load q producer state
|
|
:type load_q_producer_state: pipeline.PipelineState
|
|
:param load_kv_producer_state: The load kv producer state
|
|
:type load_kv_producer_state: pipeline.PipelineState
|
|
:param load_kv_commit_state: The load kv commit state
|
|
:type load_kv_commit_state: pipeline.PipelineState
|
|
:param copy_func: The copy function
|
|
:type copy_func: Callable
|
|
:param copy_in_flight_count: The copy in-flight count
|
|
:type copy_in_flight_count: cutlass.Int32
|
|
:param load_q: Whether to load q
|
|
:type load_q: bool
|
|
|
|
:return: The load q producer state, the load kv producer state, the load kv commit state, the copy in-flight count
|
|
:rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, cutlass.Int32]
|
|
"""
|
|
if cutlass.const_expr(load_q):
|
|
common_params.load_q_pipeline.producer_acquire(load_q_producer_state)
|
|
common_params.load_kv_pipeline.producer_acquire(load_kv_producer_state)
|
|
producer_index = load_kv_producer_state.index
|
|
copy_func(producer_index)
|
|
cute.arch.cp_async_commit_group()
|
|
|
|
if cutlass.const_expr(load_q):
|
|
# directly commit the q producer state here, mma will wait for kv.
|
|
common_params.load_q_pipeline.producer_commit(load_q_producer_state)
|
|
load_q_producer_state.advance()
|
|
load_kv_producer_state.advance()
|
|
copy_in_flight_count += 1
|
|
|
|
# wait cpasync arrive until the last stage
|
|
load_kv_pipeline = common_params.load_kv_pipeline
|
|
if copy_in_flight_count == self.load_kv_stage:
|
|
cute.arch.cp_async_wait_group(self.load_kv_stage - 1)
|
|
load_kv_pipeline.producer_commit(load_kv_commit_state)
|
|
load_kv_commit_state.advance()
|
|
copy_in_flight_count -= 1
|
|
|
|
return (
|
|
load_q_producer_state,
|
|
load_kv_producer_state,
|
|
load_kv_commit_state,
|
|
copy_in_flight_count,
|
|
)
|
|
|
|
@cute.jit
|
|
def load_cpasync_page_table_lookup_copy(
|
|
self,
|
|
tiled_copy: cute.TiledCopy,
|
|
gKV: cute.Tensor,
|
|
sKV: cute.Tensor,
|
|
sPT: cute.Tensor,
|
|
gKV_for_slice: cute.Tensor,
|
|
k_index: cutlass.Int32,
|
|
latent_idx: cutlass.Int32,
|
|
qkv_stage_idx: cutlass.Int32,
|
|
page_table_stage: cutlass.Int32,
|
|
page_size: cutlass.Int32,
|
|
transpose: bool = False,
|
|
):
|
|
"""Make page table lookup for KV cache latent/rope, then do atom copy of cpasync.
|
|
|
|
:param tiled_copy: The tiled copy
|
|
:type tiled_copy: cute.TiledCopy
|
|
:param gKV: The global KV tensor
|
|
:type gKV: cute.Tensor
|
|
:param sKV: The sliced KV tensor
|
|
:type sKV: cute.Tensor
|
|
:param sPT: The sliced page table tensor
|
|
:type sPT: cute.Tensor
|
|
:param gKV_for_slice: The global KV for slice tensor
|
|
:type gKV_for_slice: cute.Tensor
|
|
:param k_index: The k index
|
|
:type k_index: cutlass.Int32
|
|
:param latent_idx: The latent index
|
|
:type latent_idx: cutlass.Int32
|
|
:param qkv_stage_idx: The qkv stage index
|
|
:type qkv_stage_idx: cutlass.Int32
|
|
:param page_table_stage: The page table stage
|
|
:type page_table_stage: cutlass.Int32
|
|
:param transpose: Whether to transpose the gKV_for_slice
|
|
:type transpose: bool
|
|
"""
|
|
rest_modes_start = 1
|
|
rest_modes_end = 4
|
|
if cutlass.const_expr(transpose):
|
|
gKV_grouped = cute.group_modes(
|
|
gKV[None, None, None, None, latent_idx, k_index],
|
|
rest_modes_start,
|
|
rest_modes_end,
|
|
)
|
|
else:
|
|
gKV_grouped = cute.group_modes(
|
|
gKV[None, None, None, None, k_index, latent_idx],
|
|
rest_modes_start,
|
|
rest_modes_end,
|
|
)
|
|
sKV_grouped = cute.group_modes(
|
|
sKV[None, None, None, None, qkv_stage_idx], rest_modes_start, rest_modes_end
|
|
)
|
|
page_size_log2 = cute.arch.log2_of_pow2_int(page_size)
|
|
page_per_tile = self.mma_qk_tiler[1] >> page_size_log2
|
|
gKV_for_copy_offsets = cute.make_rmem_tensor(
|
|
cute.size(gKV_grouped.shape[1]), cute.cosize(gKV_for_slice.layout).dtype
|
|
)
|
|
# unroll the rest of the loop to apply page table lookup.
|
|
for i in cutlass.range_constexpr(cute.size(gKV_grouped.shape[1])):
|
|
# get the coordinate of the gKV_for_slice
|
|
coord = gKV_grouped[None, i].iterator
|
|
if cutlass.const_expr(transpose):
|
|
# fast path of mod & div here to avoid the division because of the page size is power of 2.
|
|
page_coord = ((coord[1] & (page_size - 1)), coord[1] >> page_size_log2)
|
|
new_coord = (coord[0], page_coord)
|
|
new_coord_pt = new_coord[1][1] & (page_per_tile - 1)
|
|
gKV_for_copy_offset = cute.crd2idx(
|
|
(
|
|
new_coord[0],
|
|
(new_coord[1][0], sPT[new_coord_pt, page_table_stage]),
|
|
),
|
|
gKV_for_slice.layout,
|
|
)
|
|
else:
|
|
# fast path of mod & div here to avoid the division because of the page size is power of 2.
|
|
page_coord = (coord[0] & (page_size - 1), coord[0] >> page_size_log2)
|
|
new_coord = (page_coord, coord[1])
|
|
new_coord_pt = new_coord[0][1] & (page_per_tile - 1)
|
|
gKV_for_copy_offset = cute.crd2idx(
|
|
(
|
|
(new_coord[0][0], sPT[new_coord_pt, page_table_stage]),
|
|
new_coord[1],
|
|
),
|
|
gKV_for_slice.layout,
|
|
)
|
|
gKV_for_copy_offsets[i] = gKV_for_copy_offset
|
|
cpasync_bits = 128
|
|
for i in cutlass.range_constexpr(cute.size(gKV_grouped.shape[1])):
|
|
# calculate the actual offset and apply.
|
|
sKV_for_copy = sKV_grouped[None, i]
|
|
gKV_for_copy_offset = cute.assume(
|
|
gKV_for_copy_offsets[i], cpasync_bits // self.q_dtype.width
|
|
)
|
|
gKV_for_copy_iter = gKV_for_slice.iterator + gKV_for_copy_offset
|
|
gKV_for_copy = cute.make_tensor(gKV_for_copy_iter, sKV_for_copy.layout)
|
|
cute.copy(tiled_copy, gKV_for_copy, sKV_for_copy)
|
|
return
|
|
|
|
@cute.jit
|
|
def load_cpasync_qk_one_k_tile(
|
|
self,
|
|
common_params: SimpleNamespace,
|
|
qk_params: SimpleNamespace,
|
|
k_index: cutlass.Int32,
|
|
load_pt_consumer_state: pipeline.PipelineState,
|
|
load_q_producer_state: pipeline.PipelineState,
|
|
load_kv_producer_state: pipeline.PipelineState,
|
|
load_kv_commit_state: pipeline.PipelineState,
|
|
copy_in_flight_count: cutlass.Int32,
|
|
load_q: bool,
|
|
) -> tuple[
|
|
pipeline.PipelineState,
|
|
pipeline.PipelineState,
|
|
pipeline.PipelineState,
|
|
pipeline.PipelineState,
|
|
cutlass.Int32,
|
|
]:
|
|
"""Load one k tile of Q/K. Updates the load cpasync producer state.
|
|
|
|
:param qk_params: The qk parameters
|
|
:type qk_params: SimpleNamespace
|
|
:param k_index: The k index
|
|
:type k_index: cutlass.Int32
|
|
:param load_pt_consumer_state: The load pt consumer state
|
|
:type load_pt_consumer_state: pipeline.PipelineState
|
|
:param load_q_producer_state: The load q producer state
|
|
:type load_q_producer_state: pipeline.PipelineState
|
|
:param load_kv_producer_state: The load kv producer state
|
|
:type load_kv_producer_state: pipeline.PipelineState
|
|
:param load_kv_commit_state: The load kv commit state
|
|
:type load_kv_commit_state: pipeline.PipelineState
|
|
:param copy_in_flight_count: The copy stage count
|
|
:type copy_in_flight_count: int
|
|
:param load_q: Whether to load q
|
|
:type load_q: bool
|
|
|
|
:return: The load pt consumer state, the load q producer state, the load kv producer state, the load kv commit state, the copy stage count
|
|
:rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, int]
|
|
"""
|
|
common_params.load_pt_pipeline.consumer_wait(load_pt_consumer_state)
|
|
page_table_stage = load_pt_consumer_state.index
|
|
load_pt_consumer_state.advance()
|
|
|
|
def copy_qk_latent(latent_idx, qkv_stage_idx):
|
|
if load_q:
|
|
cute.copy(
|
|
qk_params.tiled_copy_q,
|
|
qk_params.tQgQL[
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
0,
|
|
latent_idx,
|
|
common_params.blk_coord[2],
|
|
],
|
|
qk_params.tQsQ[None, None, None, None, latent_idx],
|
|
)
|
|
# make sure the page table lookups first.
|
|
self.load_cpasync_page_table_lookup_copy(
|
|
qk_params.tiled_copy_kc,
|
|
qk_params.tKCgCL,
|
|
qk_params.tKCsKC,
|
|
common_params.sPT,
|
|
qk_params.mCL_for_slice,
|
|
k_index,
|
|
latent_idx,
|
|
qkv_stage_idx,
|
|
page_table_stage,
|
|
common_params.page_size,
|
|
)
|
|
|
|
def copy_qk_rope(latent_idx, qkv_stage_idx):
|
|
if load_q:
|
|
cute.copy(
|
|
qk_params.tiled_copy_q,
|
|
qk_params.tQgQR[
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
0,
|
|
latent_idx,
|
|
common_params.blk_coord[2],
|
|
],
|
|
qk_params.tQsQ[
|
|
None, None, None, None, self.iterations_qk_latent + latent_idx
|
|
],
|
|
)
|
|
# make sure the page table lookups first.
|
|
self.load_cpasync_page_table_lookup_copy(
|
|
qk_params.tiled_copy_kc,
|
|
qk_params.tKCgKR,
|
|
qk_params.tKCsKC,
|
|
common_params.sPT,
|
|
qk_params.mKR_for_slice,
|
|
k_index,
|
|
latent_idx,
|
|
qkv_stage_idx,
|
|
page_table_stage,
|
|
common_params.page_size,
|
|
)
|
|
|
|
# use dynamic loop here to avoid instruction cache miss.
|
|
for i in range(self.iterations_qk_latent):
|
|
(
|
|
load_q_producer_state,
|
|
load_kv_producer_state,
|
|
load_kv_commit_state,
|
|
copy_in_flight_count,
|
|
) = self.load_cpasync_one_smem_stage(
|
|
common_params,
|
|
load_q_producer_state,
|
|
load_kv_producer_state,
|
|
load_kv_commit_state,
|
|
partial(copy_qk_latent, i),
|
|
copy_in_flight_count,
|
|
load_q=load_q,
|
|
)
|
|
for i in range(self.iterations_qk_rope):
|
|
(
|
|
load_q_producer_state,
|
|
load_kv_producer_state,
|
|
load_kv_commit_state,
|
|
copy_in_flight_count,
|
|
) = self.load_cpasync_one_smem_stage(
|
|
common_params,
|
|
load_q_producer_state,
|
|
load_kv_producer_state,
|
|
load_kv_commit_state,
|
|
partial(copy_qk_rope, i),
|
|
copy_in_flight_count,
|
|
load_q=load_q,
|
|
)
|
|
|
|
return (
|
|
load_pt_consumer_state,
|
|
load_q_producer_state,
|
|
load_kv_producer_state,
|
|
load_kv_commit_state,
|
|
copy_in_flight_count,
|
|
)
|
|
|
|
@cute.jit
|
|
def load_cpasync_v_one_k_tile(
|
|
self,
|
|
common_params: SimpleNamespace,
|
|
v_params: SimpleNamespace,
|
|
k_index: cutlass.Int32,
|
|
load_pt_release_state: pipeline.PipelineState,
|
|
load_kv_producer_state: pipeline.PipelineState,
|
|
load_kv_commit_state: pipeline.PipelineState,
|
|
copy_in_flight_count: cutlass.Int32,
|
|
) -> tuple[
|
|
pipeline.PipelineState,
|
|
pipeline.PipelineState,
|
|
pipeline.PipelineState,
|
|
cutlass.Int32,
|
|
]:
|
|
"""Load one k tile of V. Updates the load cpasync producer state.
|
|
|
|
:param common_params: The common parameters
|
|
:type common_params: SimpleNamespace
|
|
:param v_params: The v parameters
|
|
:type v_params: SimpleNamespace
|
|
:param k_index: The k index
|
|
:type k_index: cutlass.Int32
|
|
:param load_pt_release_state: The load pt release state
|
|
:type load_pt_release_state: pipeline.PipelineState
|
|
:param load_kv_producer_state: The load kv producer state
|
|
:type load_kv_producer_state: pipeline.PipelineState
|
|
:param load_kv_commit_state: The load kv commit state
|
|
:type load_kv_commit_state: pipeline.PipelineState
|
|
:param copy_in_flight_count: The copy in-flight count
|
|
:type copy_in_flight_count: cutlass.Int32
|
|
|
|
:return: The load pt release state, the load kv producer state, the load kv commit state, the copy in-flight count
|
|
:rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, cutlass.Int32]
|
|
"""
|
|
page_table_stage = load_pt_release_state.index
|
|
|
|
def copy_v_latent(iter_k_idx, latent_idx, qkv_stage_idx):
|
|
# make sure the page table lookups first.
|
|
self.load_cpasync_page_table_lookup_copy(
|
|
v_params.tiled_copy_vc,
|
|
v_params.tVCgCLT,
|
|
v_params.tVCsVC,
|
|
common_params.sPT,
|
|
v_params.mCLT_for_slice,
|
|
k_index * self.iterations_pv_k + iter_k_idx,
|
|
latent_idx,
|
|
qkv_stage_idx,
|
|
page_table_stage,
|
|
common_params.page_size,
|
|
transpose=True,
|
|
)
|
|
|
|
# use dynamic loop here to avoid instruction cache miss.
|
|
for i in range(self.iterations_pv_k):
|
|
for j in range(self.iterations_pv_n):
|
|
(
|
|
_,
|
|
load_kv_producer_state,
|
|
load_kv_commit_state,
|
|
copy_in_flight_count,
|
|
) = self.load_cpasync_one_smem_stage(
|
|
common_params,
|
|
None,
|
|
load_kv_producer_state,
|
|
load_kv_commit_state,
|
|
partial(copy_v_latent, i, j),
|
|
copy_in_flight_count,
|
|
load_q=False,
|
|
)
|
|
common_params.load_pt_pipeline.consumer_release(load_pt_release_state)
|
|
load_pt_release_state.advance()
|
|
return (
|
|
load_pt_release_state,
|
|
load_kv_producer_state,
|
|
load_kv_commit_state,
|
|
copy_in_flight_count,
|
|
)
|
|
|
|
@cute.jit
|
|
def load_tma(
|
|
self,
|
|
common_params: SimpleNamespace,
|
|
qk_params: SimpleNamespace,
|
|
v_params: SimpleNamespace,
|
|
k_index: cutlass.Int32,
|
|
k_tile_count: cutlass.Int32,
|
|
load_q_producer_state: pipeline.PipelineState,
|
|
load_kv_producer_state: pipeline.PipelineState,
|
|
) -> tuple[pipeline.PipelineState, pipeline.PipelineState]:
|
|
"""Load wrap to load Q/C latent/rope tensors. Updates the load qkv producer state.
|
|
|
|
:param common_params: The common parameters
|
|
:type common_params: SimpleNamespace
|
|
:param qk_params: The qk parameters
|
|
:type qk_params: SimpleNamespace
|
|
:param v_params: The v parameters
|
|
:type v_params: SimpleNamespace
|
|
:param k_index: The k index
|
|
:type k_index: cutlass.Int32
|
|
:param k_tile_count: The k tile count
|
|
:type k_tile_count: cutlass.Int32
|
|
:param load_q_producer_state: The load q producer state
|
|
:type load_q_producer_state: pipeline.PipelineState
|
|
:param load_kv_producer_state: The load kv producer state
|
|
:type load_kv_producer_state: pipeline.PipelineState
|
|
|
|
:return: The load q producer state and load kv producer state
|
|
:rtype: tuple[pipeline.PipelineState, pipeline.PipelineState]
|
|
"""
|
|
# page table
|
|
mPT = None
|
|
if cutlass.const_expr(self.use_page_table):
|
|
mPT = common_params.mPT[None, common_params.blk_coord[2]]
|
|
|
|
# Flatten divide and partition global tensors for QK TMA load
|
|
# (bM, bK, rM, rK, rL)
|
|
mma_qk_tiler_mk = cute.select(self.mma_qk_tiler, mode=[0, 2])
|
|
gQL = cute.flat_divide(qk_params.mQL, mma_qk_tiler_mk)
|
|
gQR = cute.flat_divide(qk_params.mQR, mma_qk_tiler_mk)
|
|
|
|
mma_qk_tiler_nk = cute.select(self.mma_qk_tiler, mode=[1, 2])
|
|
gCL = cute.flat_divide(qk_params.mCL, mma_qk_tiler_nk)
|
|
gKR = cute.flat_divide(qk_params.mKR, mma_qk_tiler_nk)
|
|
|
|
thr_mma_qk = qk_params.tiled_mma_qk.get_slice(
|
|
common_params.blk_coord[0] % cute.size(qk_params.tiled_mma_qk.thr_id)
|
|
)
|
|
tSgQL = thr_mma_qk.partition_A(gQL)
|
|
tSgQR = thr_mma_qk.partition_A(gQR)
|
|
|
|
tSgCL = thr_mma_qk.partition_B(gCL)
|
|
tSgKR = thr_mma_qk.partition_B(gKR)
|
|
|
|
# tma partition for q, k latent/rope
|
|
|
|
# smem: ((atom_v, rest_v), STAGE)
|
|
# gmem: ((atom_v, rest_v), RestM, RestK, RestL)
|
|
tQsQ, tQLgQL_mkl = cpasync.tma_partition(
|
|
qk_params.tma_atom_q_latent,
|
|
0,
|
|
cute.make_layout(1),
|
|
cute.group_modes(qk_params.sQ, 0, 3),
|
|
cute.group_modes(tSgQL, 0, 3),
|
|
)
|
|
|
|
_, tQRgQR_mkl = cpasync.tma_partition(
|
|
qk_params.tma_atom_q_rope,
|
|
0,
|
|
cute.make_layout(1),
|
|
cute.group_modes(qk_params.sQ, 0, 3),
|
|
cute.group_modes(tSgQR, 0, 3),
|
|
)
|
|
|
|
tKCsKC, tCLgCL = cpasync.tma_partition(
|
|
qk_params.tma_atom_c_latent,
|
|
0,
|
|
cute.make_layout(1),
|
|
cute.group_modes(qk_params.sKC, 0, 3),
|
|
cute.group_modes(tSgCL, 0, 3),
|
|
)
|
|
|
|
_, tKRgKR = cpasync.tma_partition(
|
|
qk_params.tma_atom_c_rope,
|
|
0,
|
|
cute.make_layout(1),
|
|
cute.group_modes(qk_params.sKC, 0, 3),
|
|
cute.group_modes(tSgKR, 0, 3),
|
|
)
|
|
|
|
tQLgQL = tQLgQL_mkl[None, None, None, common_params.blk_coord[2]]
|
|
tQRgQR = tQRgQR_mkl[None, None, None, common_params.blk_coord[2]]
|
|
|
|
# Flatten divide and partition global tensors for V TMA load
|
|
mma_pv_tiler_nk = cute.select(self.mma_pv_tiler, mode=[1, 2])
|
|
gCLT = cute.flat_divide(v_params.mCLT, mma_pv_tiler_nk)
|
|
|
|
thr_mma_pv = v_params.tiled_mma_pv.get_slice(
|
|
common_params.blk_coord[0] % cute.size(v_params.tiled_mma_pv.thr_id)
|
|
)
|
|
tOgCLT = thr_mma_pv.partition_B(gCLT)
|
|
|
|
# tma partition for vc
|
|
# smem: ((atom_v, rest_v), STAGE)
|
|
# gmem: ((atom_v, rest_v), RestM, RestK, RestL)
|
|
tVCsVC, tCLTgCLT = cpasync.tma_partition(
|
|
v_params.tma_atom_c_latent_transpose,
|
|
0,
|
|
cute.make_layout(1),
|
|
cute.group_modes(v_params.sVC, 0, 3),
|
|
cute.group_modes(tOgCLT, 0, 3),
|
|
)
|
|
|
|
# set extra params
|
|
common_params.mPT = mPT
|
|
qk_params.tQLgQL = tQLgQL
|
|
qk_params.tQRgQR = tQRgQR
|
|
qk_params.tCLgCL = tCLgCL
|
|
qk_params.tKRgKR = tKRgKR
|
|
qk_params.tQsQ = tQsQ
|
|
qk_params.tKCsKC = tKCsKC
|
|
v_params.tCLTgCLT = tCLTgCLT
|
|
v_params.tVCsVC = tVCsVC
|
|
|
|
load_q_producer_state, load_kv_producer_state = self.load_tma_qk_one_k_tile(
|
|
common_params,
|
|
qk_params,
|
|
k_index,
|
|
k_tile_count,
|
|
load_q_producer_state,
|
|
load_kv_producer_state,
|
|
load_q=True,
|
|
)
|
|
k_index += 1
|
|
k_tile_count -= 1
|
|
while k_tile_count > 0:
|
|
load_q_producer_state, load_kv_producer_state = self.load_tma_qk_one_k_tile(
|
|
common_params,
|
|
qk_params,
|
|
k_index,
|
|
k_tile_count,
|
|
load_q_producer_state,
|
|
load_kv_producer_state,
|
|
load_q=False,
|
|
)
|
|
load_kv_producer_state = self.load_tma_v_one_k_tile(
|
|
common_params,
|
|
v_params,
|
|
k_index - 1,
|
|
load_kv_producer_state,
|
|
)
|
|
k_index += 1
|
|
k_tile_count -= 1
|
|
|
|
# load last v tile
|
|
load_kv_producer_state = self.load_tma_v_one_k_tile(
|
|
common_params,
|
|
v_params,
|
|
k_index - 1,
|
|
load_kv_producer_state,
|
|
)
|
|
return load_q_producer_state, load_kv_producer_state
|
|
|
|
@cute.jit
|
|
def load_tma_qk_one_k_tile(
|
|
self,
|
|
common_params: SimpleNamespace,
|
|
qk_params: SimpleNamespace,
|
|
k_index: cutlass.Int32,
|
|
k_tile_count: cutlass.Int32,
|
|
load_q_producer_state: pipeline.PipelineState,
|
|
load_kv_producer_state: pipeline.PipelineState,
|
|
load_q: bool,
|
|
) -> tuple[pipeline.PipelineState, pipeline.PipelineState]:
|
|
"""Load one k-tile of Q/C latent/rope tensors. Updates the load qkv producer state.
|
|
|
|
:param common_params: The common parameters
|
|
:type common_params: SimpleNamespace
|
|
:param qk_params: The qk parameters
|
|
:type qk_params: SimpleNamespace
|
|
:param k_index: The k index
|
|
:type k_index: cutlass.Int32
|
|
:param k_tile_count: The k tile count
|
|
:type k_tile_count: cutlass.Int32
|
|
:param load_q_producer_state: The load q producer state
|
|
:type load_q_producer_state: pipeline.PipelineState
|
|
:param load_kv_producer_state: The load kv producer state
|
|
:type load_kv_producer_state: pipeline.PipelineState
|
|
:param load_q: Whether to load q
|
|
:type load_q: bool
|
|
|
|
:return: The load q producer state and load kv producer state
|
|
:rtype: tuple[pipeline.PipelineState, pipeline.PipelineState]
|
|
"""
|
|
k_idx = cute.make_rmem_tensor(cute.make_layout(2), cutlass.Int32)
|
|
# prefetch next K load to keep busy while we transpose-load from cache
|
|
kPrefetchDistance = 1
|
|
if cutlass.const_expr(self.use_page_table):
|
|
k_idx[0] = common_params.mPT[k_index]
|
|
k_idx[1] = common_params.mPT[k_index + kPrefetchDistance]
|
|
else:
|
|
k_idx[0] = common_params.blk_coord[2]
|
|
k_idx[1] = common_params.blk_coord[2]
|
|
for i in cutlass.range_constexpr(self.iterations_qk_latent):
|
|
# load q once at first iteration
|
|
if cutlass.const_expr(load_q):
|
|
# get the mbar ptr from pipeline.
|
|
tma_bar_ptr = common_params.load_q_pipeline.producer_get_barrier(
|
|
load_q_producer_state
|
|
)
|
|
# expect the extra bytes for q.
|
|
common_params.load_q_pipeline.producer_acquire(load_q_producer_state)
|
|
# load q latent
|
|
cute.copy(
|
|
qk_params.tma_atom_q_latent,
|
|
qk_params.tQLgQL[None, 0, load_q_producer_state.index],
|
|
qk_params.tQsQ[None, load_q_producer_state.index],
|
|
tma_bar_ptr=tma_bar_ptr,
|
|
)
|
|
load_q_producer_state.advance()
|
|
# get the mbar ptr from pipeline.
|
|
tma_bar_ptr = common_params.load_kv_pipeline.producer_get_barrier(
|
|
load_kv_producer_state
|
|
)
|
|
# expect the extra bytes for q.
|
|
common_params.load_kv_pipeline.producer_acquire(load_kv_producer_state)
|
|
# load k latent
|
|
if cutlass.const_expr(self.use_page_table):
|
|
cute.copy(
|
|
qk_params.tma_atom_c_latent,
|
|
qk_params.tCLgCL[None, 0, i, k_idx[0]],
|
|
qk_params.tKCsKC[None, load_kv_producer_state.index],
|
|
tma_bar_ptr=tma_bar_ptr,
|
|
)
|
|
else:
|
|
cute.copy(
|
|
qk_params.tma_atom_c_latent,
|
|
qk_params.tCLgCL[None, k_index, i, k_idx[0]],
|
|
qk_params.tKCsKC[None, load_kv_producer_state.index],
|
|
tma_bar_ptr=tma_bar_ptr,
|
|
)
|
|
load_kv_producer_state.advance()
|
|
|
|
for i in cutlass.range_constexpr(self.iterations_qk_rope):
|
|
# load q rope once at first iteration
|
|
if cutlass.const_expr(load_q):
|
|
# get the mbar ptr from pipeline.
|
|
tma_bar_ptr = common_params.load_q_pipeline.producer_get_barrier(
|
|
load_q_producer_state
|
|
)
|
|
# expect the extra bytes for q.
|
|
common_params.load_q_pipeline.producer_acquire(load_q_producer_state)
|
|
# load q rope
|
|
cute.copy(
|
|
qk_params.tma_atom_q_rope,
|
|
qk_params.tQRgQR[None, 0, i],
|
|
qk_params.tQsQ[None, i + self.iterations_qk_latent],
|
|
tma_bar_ptr=tma_bar_ptr,
|
|
)
|
|
load_q_producer_state.advance()
|
|
# get the mbar ptr from pipeline.
|
|
tma_bar_ptr = common_params.load_kv_pipeline.producer_get_barrier(
|
|
load_kv_producer_state
|
|
)
|
|
# expect the extra bytes for q.
|
|
common_params.load_kv_pipeline.producer_acquire(load_kv_producer_state)
|
|
# load k rope
|
|
if cutlass.const_expr(self.use_page_table):
|
|
cute.copy(
|
|
qk_params.tma_atom_c_rope,
|
|
qk_params.tKRgKR[None, 0, i, k_idx[0]],
|
|
qk_params.tKCsKC[None, load_kv_producer_state.index],
|
|
tma_bar_ptr=tma_bar_ptr,
|
|
)
|
|
else:
|
|
cute.copy(
|
|
qk_params.tma_atom_c_rope,
|
|
qk_params.tKRgKR[None, k_index, i, k_idx[0]],
|
|
qk_params.tKCsKC[None, load_kv_producer_state.index],
|
|
tma_bar_ptr=tma_bar_ptr,
|
|
)
|
|
load_kv_producer_state.advance()
|
|
|
|
for i in cutlass.range_constexpr(self.iterations_qk_latent):
|
|
if cutlass.const_expr(self.use_page_table):
|
|
if k_tile_count > kPrefetchDistance:
|
|
cute.prefetch(
|
|
qk_params.tma_atom_c_latent,
|
|
qk_params.tCLgCL[
|
|
None,
|
|
k_index,
|
|
i,
|
|
k_idx[1],
|
|
],
|
|
)
|
|
else:
|
|
cute.prefetch(
|
|
qk_params.tma_atom_c_latent,
|
|
qk_params.tCLgCL[None, k_index + kPrefetchDistance, i, k_idx[1]],
|
|
)
|
|
|
|
for i in cutlass.range_constexpr(self.iterations_qk_rope):
|
|
if cutlass.const_expr(self.use_page_table):
|
|
if k_tile_count > kPrefetchDistance:
|
|
cute.prefetch(
|
|
qk_params.tma_atom_c_rope,
|
|
qk_params.tKRgKR[
|
|
None,
|
|
k_index,
|
|
i,
|
|
k_idx[1],
|
|
],
|
|
)
|
|
else:
|
|
cute.prefetch(
|
|
qk_params.tma_atom_c_rope,
|
|
qk_params.tKRgKR[None, k_index + kPrefetchDistance, i, k_idx[1]],
|
|
)
|
|
return load_q_producer_state, load_kv_producer_state
|
|
|
|
@cute.jit
|
|
def load_tma_v_one_k_tile(
|
|
self,
|
|
common_params: SimpleNamespace,
|
|
v_params: SimpleNamespace,
|
|
k_index: cutlass.Int32,
|
|
load_kv_producer_state: pipeline.PipelineState,
|
|
) -> pipeline.PipelineState:
|
|
"""Load one k-tile of compressed latent transpose tensor(v). Updates the load qkv producer state.
|
|
|
|
:param common_params: The common parameters
|
|
:type common_params: SimpleNamespace
|
|
:param v_params: The load tma v parameters
|
|
:type v_params: SimpleNamespace
|
|
:param k_index: The k index
|
|
:type k_index: cutlass.Int32
|
|
:param load_kv_producer_state: The load qkv producer state
|
|
:type load_kv_producer_state: pipeline.PipelineState
|
|
|
|
:return: The load qkv producer state
|
|
:rtype: pipeline.PipelineState
|
|
"""
|
|
k_idx = cute.make_rmem_tensor(cute.make_layout(1), cutlass.Int32)
|
|
if cutlass.const_expr(self.use_page_table):
|
|
k_idx[0] = common_params.mPT[k_index]
|
|
else:
|
|
k_idx[0] = common_params.blk_coord[2]
|
|
for i in cutlass.range_constexpr(self.iterations_pv_k):
|
|
for j in cutlass.range_constexpr(self.iterations_pv_n):
|
|
# get the mbar ptr from pipeline.
|
|
tma_bar_ptr = common_params.load_kv_pipeline.producer_get_barrier(
|
|
load_kv_producer_state
|
|
)
|
|
common_params.load_kv_pipeline.producer_acquire(load_kv_producer_state)
|
|
if cutlass.const_expr(self.use_page_table):
|
|
cute.copy(
|
|
v_params.tma_atom_c_latent_transpose,
|
|
v_params.tCLTgCLT[None, j, i, k_idx[0]],
|
|
v_params.tVCsVC[None, load_kv_producer_state.index],
|
|
tma_bar_ptr=tma_bar_ptr,
|
|
)
|
|
else:
|
|
cute.copy(
|
|
v_params.tma_atom_c_latent_transpose,
|
|
v_params.tCLTgCLT[
|
|
None,
|
|
j,
|
|
k_index * self.iterations_pv_k + i,
|
|
k_idx[0],
|
|
],
|
|
v_params.tVCsVC[None, load_kv_producer_state.index],
|
|
tma_bar_ptr=tma_bar_ptr,
|
|
)
|
|
load_kv_producer_state.advance()
|
|
return load_kv_producer_state
|
|
|
|
@cute.jit
|
|
def mma(
|
|
self,
|
|
common_params: SimpleNamespace,
|
|
qk_params: SimpleNamespace,
|
|
pv_params: SimpleNamespace,
|
|
k_tile_count: cutlass.Int32,
|
|
tiled_mma_qk: cute.TiledMma,
|
|
tiled_mma_pv: cute.TiledMma,
|
|
load_q_consumer_state: pipeline.PipelineState,
|
|
load_kv_consumer_state: pipeline.PipelineState,
|
|
mma_s_producer_state: pipeline.PipelineState,
|
|
p_mma_consumer_state: pipeline.PipelineState,
|
|
mma_o_producer_state: pipeline.PipelineState,
|
|
) -> tuple[
|
|
cute.TiledMma,
|
|
cute.TiledMma,
|
|
pipeline.PipelineState,
|
|
pipeline.PipelineState,
|
|
pipeline.PipelineState,
|
|
pipeline.PipelineState,
|
|
]:
|
|
"""MMA warp to compute the result of Q*K^T and P*V. Updates the tiled mma and pipeline states.
|
|
|
|
:param common_params: The common parameters for mma qk and pv
|
|
:type common_params: SimpleNamespace
|
|
:param qk_params: The mma qk parameters
|
|
:type qk_params: SimpleNamespace
|
|
:param pv_params: The mma pv parameters
|
|
:type pv_params: SimpleNamespace
|
|
:param k_tile_count: The k tile count
|
|
:type k_tile_count: cutlass.Int32
|
|
:param tiled_mma_qk: The tiled mma qk
|
|
:type tiled_mma_qk: cute.TiledMma
|
|
:param tiled_mma_pv: The tiled mma pv
|
|
:type tiled_mma_pv: cute.TiledMma
|
|
:param load_q_consumer_state: The load q consumer state
|
|
:type load_q_consumer_state: pipeline.PipelineState
|
|
:param load_kv_consumer_state: The load kv consumer state
|
|
:type load_kv_consumer_state: pipeline.PipelineState
|
|
:param mma_s_producer_state: The mma s producer state
|
|
:type mma_s_producer_state: pipeline.PipelineState
|
|
:param p_mma_consumer_state: The p mma consumer state
|
|
:type p_mma_consumer_state: pipeline.PipelineState
|
|
:param mma_o_producer_state: The mma o producer state
|
|
:type mma_o_producer_state: pipeline.PipelineState
|
|
|
|
:return: The tiled mma qk, the tiled mma pv, the load q consumer state, the load kv consumer state, the mma s producer state, the p mma consumer state, and the mma o producer state
|
|
:rtype: tuple[cute.TiledMma, cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState]
|
|
"""
|
|
|
|
tSrQ = tiled_mma_qk.make_fragment_A(qk_params.sQ)
|
|
tSrKC = tiled_mma_qk.make_fragment_B(qk_params.sKC)
|
|
tOrP = tiled_mma_pv.make_fragment_A(pv_params.sP)
|
|
tOrVC = tiled_mma_pv.make_fragment_B(pv_params.sVC)
|
|
|
|
tStS_shape = tiled_mma_qk.partition_shape_C(
|
|
cute.select(self.mma_qk_tiler, mode=[0, 1])
|
|
)
|
|
tStS_staged_fake = tiled_mma_qk.make_fragment_C(
|
|
cute.append(tStS_shape, self.mma_s_stage)
|
|
)
|
|
# use real tmem ptr for tStS
|
|
tStS_staged = cute.make_tensor(common_params.tmem_ptr, tStS_staged_fake.layout)
|
|
tOtO_shape = tiled_mma_pv.partition_shape_C(
|
|
cute.select(self.mma_pv_tiler, mode=[0, 1])
|
|
)
|
|
# mma O has 1 stage.
|
|
assert self.mma_o_stage == 1, (
|
|
"mma O has 1 stage, otherwise the tmem usage exceeds the limit."
|
|
)
|
|
tOtO = tiled_mma_pv.make_fragment_C(tOtO_shape)
|
|
tOtO_layout = cute.append(
|
|
tOtO.layout,
|
|
cute.make_layout(
|
|
common_params.L // self.mma_pv_tiler[1],
|
|
stride=self.mma_pv_tiler[1] // self.warps_in_n,
|
|
),
|
|
)
|
|
tOtO_staged = cute.make_tensor(
|
|
tStS_staged.iterator + self.tmem_o_offset, tOtO_layout
|
|
)
|
|
|
|
# set more parameters
|
|
qk_params.tSrQ = tSrQ
|
|
qk_params.tSrKC = tSrKC
|
|
qk_params.tStS_staged = tStS_staged
|
|
pv_params.tOrP = tOrP
|
|
pv_params.tOrVC = tOrVC
|
|
pv_params.tOtO_staged = tOtO_staged
|
|
|
|
# mma O accumulates on K, so the accumlate flag is set to False once before all K blocks.
|
|
tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, False)
|
|
load_q_pipeline = common_params.load_q_pipeline
|
|
if common_params.is_leader_cta:
|
|
load_q_release_state = load_q_consumer_state.clone()
|
|
(
|
|
tiled_mma_qk,
|
|
load_q_consumer_state,
|
|
load_kv_consumer_state,
|
|
mma_s_producer_state,
|
|
) = self.mma_qk(
|
|
common_params,
|
|
qk_params,
|
|
tiled_mma_qk,
|
|
load_q_consumer_state,
|
|
load_kv_consumer_state,
|
|
mma_s_producer_state,
|
|
wait_q=True,
|
|
)
|
|
k_tile_count -= 1
|
|
|
|
while k_tile_count > 0:
|
|
(
|
|
tiled_mma_qk,
|
|
load_q_consumer_state,
|
|
load_kv_consumer_state,
|
|
mma_s_producer_state,
|
|
) = self.mma_qk(
|
|
common_params,
|
|
qk_params,
|
|
tiled_mma_qk,
|
|
load_q_consumer_state,
|
|
load_kv_consumer_state,
|
|
mma_s_producer_state,
|
|
wait_q=False,
|
|
)
|
|
(
|
|
tiled_mma_pv,
|
|
load_kv_consumer_state,
|
|
p_mma_consumer_state,
|
|
mma_o_producer_state,
|
|
) = self.mma_pv(
|
|
common_params,
|
|
pv_params,
|
|
tiled_mma_pv,
|
|
load_kv_consumer_state,
|
|
p_mma_consumer_state,
|
|
mma_o_producer_state,
|
|
)
|
|
k_tile_count -= 1
|
|
# release q consumer states
|
|
for i in cutlass.range_constexpr(self.iterations_qk):
|
|
load_q_pipeline.consumer_release(load_q_release_state)
|
|
load_q_release_state.advance()
|
|
(
|
|
tiled_mma_pv,
|
|
load_kv_consumer_state,
|
|
p_mma_consumer_state,
|
|
mma_o_producer_state,
|
|
) = self.mma_pv(
|
|
common_params,
|
|
pv_params,
|
|
tiled_mma_pv,
|
|
load_kv_consumer_state,
|
|
p_mma_consumer_state,
|
|
mma_o_producer_state,
|
|
)
|
|
|
|
return (
|
|
tiled_mma_qk,
|
|
tiled_mma_pv,
|
|
load_q_consumer_state,
|
|
load_kv_consumer_state,
|
|
mma_s_producer_state,
|
|
p_mma_consumer_state,
|
|
mma_o_producer_state,
|
|
)
|
|
|
|
@cute.jit
|
|
def mma_qk(
|
|
self,
|
|
common_params: SimpleNamespace,
|
|
qk_params: SimpleNamespace,
|
|
tiled_mma_qk: cute.TiledMma,
|
|
load_q_consumer_state: pipeline.PipelineState,
|
|
load_kv_consumer_state: pipeline.PipelineState,
|
|
mma_s_producer_state: pipeline.PipelineState,
|
|
wait_q: bool,
|
|
) -> tuple[
|
|
cute.TiledMma,
|
|
pipeline.PipelineState,
|
|
pipeline.PipelineState,
|
|
pipeline.PipelineState,
|
|
]:
|
|
"""Compute one k-tile of mma for Q*K^T. Updates the tiled MMA QK and pipeline states.
|
|
|
|
:param qk_params: The qk parameters
|
|
:type qk_params: SimpleNamespace
|
|
:param tiled_mma_qk: The tiled mma qk
|
|
:type tiled_mma_qk: cute.TiledMma
|
|
:param load_q_consumer_state: The load q consumer state
|
|
:type load_q_consumer_state: pipeline.PipelineState
|
|
:param load_kv_consumer_state: The load kv consumer state
|
|
:type load_kv_consumer_state: pipeline.PipelineState
|
|
:param mma_s_producer_state: The mma s producer state
|
|
:type mma_s_producer_state: pipeline.PipelineState
|
|
|
|
:return: The tiled mma qk, the load q consumer state, the load kv consumer state, and the mma s producer state
|
|
:rtype: tuple[cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState]
|
|
"""
|
|
tStS = qk_params.tStS_staged[None, None, None, mma_s_producer_state.index]
|
|
|
|
qk_params.mma_s_pipeline.producer_acquire(mma_s_producer_state)
|
|
tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, False)
|
|
load_q_pipeline = common_params.load_q_pipeline
|
|
load_kv_pipeline = common_params.load_kv_pipeline
|
|
for q_stage in range(self.iterations_qk_latent):
|
|
if cutlass.const_expr(wait_q):
|
|
load_q_pipeline.consumer_wait(load_q_consumer_state)
|
|
load_kv_pipeline.consumer_wait(load_kv_consumer_state)
|
|
kc_stage = load_kv_consumer_state.index
|
|
for k_block in cutlass.range_constexpr(cute.size(qk_params.tSrQ.shape[2])):
|
|
cute.gemm(
|
|
tiled_mma_qk,
|
|
tStS,
|
|
qk_params.tSrQ[None, None, k_block, q_stage],
|
|
qk_params.tSrKC[None, None, k_block, kc_stage],
|
|
tStS,
|
|
)
|
|
tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, True)
|
|
load_kv_pipeline.consumer_release(load_kv_consumer_state)
|
|
load_kv_consumer_state.advance()
|
|
if cutlass.const_expr(wait_q):
|
|
load_q_consumer_state.advance()
|
|
for q_stage in range(self.iterations_qk_rope):
|
|
if cutlass.const_expr(wait_q):
|
|
load_q_pipeline.consumer_wait(load_q_consumer_state)
|
|
load_kv_pipeline.consumer_wait(load_kv_consumer_state)
|
|
kc_stage = load_kv_consumer_state.index
|
|
for k_block in cutlass.range_constexpr(
|
|
self.rope_dim // tiled_mma_qk.shape_mnk[2]
|
|
):
|
|
cute.gemm(
|
|
tiled_mma_qk,
|
|
tStS,
|
|
qk_params.tSrQ[
|
|
None, None, k_block, q_stage + self.iterations_qk_latent
|
|
],
|
|
qk_params.tSrKC[None, None, k_block, kc_stage],
|
|
tStS,
|
|
)
|
|
tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, True)
|
|
load_kv_pipeline.consumer_release(load_kv_consumer_state)
|
|
load_kv_consumer_state.advance()
|
|
if cutlass.const_expr(wait_q):
|
|
load_q_consumer_state.advance()
|
|
|
|
qk_params.mma_s_pipeline.producer_commit(mma_s_producer_state)
|
|
mma_s_producer_state.advance()
|
|
return (
|
|
tiled_mma_qk,
|
|
load_q_consumer_state,
|
|
load_kv_consumer_state,
|
|
mma_s_producer_state,
|
|
)
|
|
|
|
@cute.jit
|
|
def mma_pv(
|
|
self,
|
|
common_params: SimpleNamespace,
|
|
pv_params: SimpleNamespace,
|
|
tiled_mma_pv: cute.TiledMma,
|
|
load_kv_consumer_state: pipeline.PipelineState,
|
|
p_mma_consumer_state: pipeline.PipelineState,
|
|
mma_o_producer_state: pipeline.PipelineState,
|
|
) -> tuple[
|
|
cute.TiledMma,
|
|
pipeline.PipelineState,
|
|
pipeline.PipelineState,
|
|
pipeline.PipelineState,
|
|
]:
|
|
"""Compute one k-tile of mma for P*V. Updates the tiled mma pv and pipeline states.
|
|
|
|
:param common_params: The common parameters
|
|
:type common_params: SimpleNamespace
|
|
:param pv_params: The pv parameters
|
|
:type pv_params: SimpleNamespace
|
|
:param tiled_mma_pv: The tiled mma pv
|
|
:type tiled_mma_pv: cute.TiledMma
|
|
:param load_kv_consumer_state: The load kv consumer state
|
|
:type load_kv_consumer_state: pipeline.PipelineState
|
|
:param p_mma_consumer_state: The P MMA consumer state
|
|
:type p_mma_consumer_state: pipeline.PipelineState
|
|
:param mma_o_producer_state: The MMA o producer state
|
|
:type mma_o_producer_state: pipeline.PipelineState
|
|
|
|
:return: The tiled mma pv, the load qkv consumer state, the P MMA consumer state, and the MMA o producer state
|
|
:rtype: tuple[cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState]
|
|
"""
|
|
|
|
pv_params.mma_o_pipeline.producer_acquire(mma_o_producer_state)
|
|
pv_params.p_mma_pipeline.consumer_wait(p_mma_consumer_state)
|
|
load_kv_pipeline = common_params.load_kv_pipeline
|
|
for p_stage in range(self.iterations_pv_k):
|
|
accumulate_flag = tiled_mma_pv.get(tcgen05.Field.ACCUMULATE)
|
|
for acc_stage in range(self.iterations_pv_n):
|
|
load_kv_pipeline.consumer_wait(load_kv_consumer_state)
|
|
tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, accumulate_flag)
|
|
vc_stage = load_kv_consumer_state.index
|
|
tOtO = pv_params.tOtO_staged[None, None, None, acc_stage]
|
|
for k_block in cutlass.range_constexpr(pv_params.tOrP.shape[2]):
|
|
cute.gemm(
|
|
tiled_mma_pv,
|
|
tOtO,
|
|
pv_params.tOrP[
|
|
None,
|
|
None,
|
|
k_block,
|
|
(p_stage, p_mma_consumer_state.index),
|
|
],
|
|
pv_params.tOrVC[None, None, k_block, vc_stage],
|
|
tOtO,
|
|
)
|
|
tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, True)
|
|
load_kv_pipeline.consumer_release(load_kv_consumer_state)
|
|
load_kv_consumer_state.advance()
|
|
pv_params.p_mma_pipeline.consumer_release(p_mma_consumer_state)
|
|
p_mma_consumer_state.advance()
|
|
pv_params.mma_o_pipeline.producer_commit(mma_o_producer_state)
|
|
mma_o_producer_state.advance()
|
|
|
|
return (
|
|
tiled_mma_pv,
|
|
load_kv_consumer_state,
|
|
p_mma_consumer_state,
|
|
mma_o_producer_state,
|
|
)
|
|
|
|
@cute.jit
|
|
def compute(
|
|
self,
|
|
common_params: SimpleNamespace,
|
|
softmax_params: SimpleNamespace,
|
|
k_index: cutlass.Int32,
|
|
k_tile_count: cutlass.Int32,
|
|
mma_s_consumer_state: pipeline.PipelineState,
|
|
p_mma_producer_state: pipeline.PipelineState,
|
|
p_cor_producer_state: pipeline.PipelineState,
|
|
) -> tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState]:
|
|
"""Compute warp to compute the result of softmax, rescale, and epilogue. Updates the related pipeline states.
|
|
|
|
:param common_params: The common parameters
|
|
:type common_params: SimpleNamespace
|
|
:param softmax_params: The softmax parameters
|
|
:type softmax_params: SimpleNamespace
|
|
:param k_index: The index of the k-tile
|
|
:type k_index: cutlass.Int32
|
|
:param k_tile_count: The number of k-tiles
|
|
:type k_tile_count: cutlass.Int32
|
|
:param mma_s_consumer_state: The MMA s consumer state
|
|
:type mma_s_consumer_state: pipeline.PipelineState
|
|
:param p_mma_producer_state: The P MMA producer state
|
|
:type p_mma_producer_state: pipeline.PipelineState
|
|
:param p_cor_producer_state: The P correction producer state
|
|
:type p_cor_producer_state: pipeline.PipelineState
|
|
|
|
:return: The MMA s consumer state, the P MMA producer state, and the P correction producer state
|
|
:rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState]
|
|
"""
|
|
|
|
k_tile_total = cute.ceil_div(common_params.K, self.mma_qk_tiler[1])
|
|
|
|
row_max = -self.acc_dtype.inf
|
|
row_sum = self.acc_dtype(0)
|
|
correction_factor = self.acc_dtype(1)
|
|
while k_tile_count > 0:
|
|
(
|
|
mma_s_consumer_state,
|
|
p_mma_producer_state,
|
|
p_cor_producer_state,
|
|
row_max,
|
|
row_sum,
|
|
correction_factor,
|
|
) = self.softmax_dispatch_apply_mask(
|
|
common_params,
|
|
softmax_params,
|
|
k_index,
|
|
k_tile_total,
|
|
mma_s_consumer_state,
|
|
p_mma_producer_state,
|
|
p_cor_producer_state,
|
|
row_max,
|
|
row_sum,
|
|
correction_factor,
|
|
)
|
|
k_index = k_index + 1
|
|
k_tile_count = k_tile_count - 1
|
|
|
|
return mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state
|
|
|
|
@cute.jit
|
|
def correction(
|
|
self,
|
|
common_params: SimpleNamespace,
|
|
epilogue_params: SimpleNamespace,
|
|
k_tile_count: cutlass.Int32,
|
|
p_cor_consumer_state: pipeline.PipelineState,
|
|
mma_o_consumer_state: pipeline.PipelineState,
|
|
) -> tuple[pipeline.PipelineState, pipeline.PipelineState]:
|
|
"""Compute warp to compute the result of softmax, rescale, and epilogue. Updates the related pipeline states.
|
|
|
|
:param common_params: The common parameters
|
|
:type common_params: SimpleNamespace
|
|
:param epilogue_params: The epilogue parameters
|
|
:type epilogue_params: SimpleNamespace
|
|
:param k_index: The index of the k-tile
|
|
:type k_index: cutlass.Int32
|
|
:param k_tile_count: The number of k-tiles
|
|
:type k_tile_count: cutlass.Int32
|
|
:param p_cor_consumer_state: The P correction consumer state
|
|
:type p_cor_consumer_state: pipeline.PipelineState
|
|
:param mma_o_consumer_state: The MMA o consumer state
|
|
:type mma_o_consumer_state: pipeline.PipelineState
|
|
|
|
:return: The P correction consumer state, and the MMA o consumer state
|
|
:rtype: tuple[pipeline.PipelineState, pipeline.PipelineState]
|
|
"""
|
|
|
|
p_cor_consumer_state, row_sum, row_max, correction_factor, no_correction = (
|
|
self.get_correction_factor(common_params, p_cor_consumer_state)
|
|
)
|
|
k_tile_count = k_tile_count - 1
|
|
while k_tile_count > 0:
|
|
p_cor_consumer_state, row_sum, row_max, correction_factor, no_correction = (
|
|
self.get_correction_factor(common_params, p_cor_consumer_state)
|
|
)
|
|
mma_o_consumer_state = self.rescale(
|
|
common_params, mma_o_consumer_state, correction_factor, no_correction
|
|
)
|
|
k_tile_count = k_tile_count - 1
|
|
|
|
mma_o_consumer_state = self.epilogue(
|
|
common_params, epilogue_params, mma_o_consumer_state, row_sum, row_max
|
|
)
|
|
return p_cor_consumer_state, mma_o_consumer_state
|
|
|
|
@cute.jit
|
|
def softmax_dispatch_apply_mask(
|
|
self,
|
|
common_params: SimpleNamespace,
|
|
softmax_params: SimpleNamespace,
|
|
k_index: cutlass.Int32,
|
|
k_tile_total: cutlass.Int32,
|
|
mma_s_consumer_state: pipeline.PipelineState,
|
|
p_mma_producer_state: pipeline.PipelineState,
|
|
p_cor_producer_state: pipeline.PipelineState,
|
|
row_max: cutlass.Float32,
|
|
row_sum: cutlass.Float32,
|
|
correction_factor: cutlass.Float32,
|
|
) -> tuple[
|
|
pipeline.PipelineState,
|
|
pipeline.PipelineState,
|
|
cutlass.Float32,
|
|
cutlass.Float32,
|
|
cutlass.Float32,
|
|
]:
|
|
"""Dispatch whether to apply mask for softmax for last k tile.
|
|
|
|
:param common_params: The common parameters
|
|
:type common_params: SimpleNamespace
|
|
:param softmax_params: The softmax parameters
|
|
:type softmax_params: SimpleNamespace
|
|
:param k_index: The index of the k-tile
|
|
:type k_index: cutlass.Int32
|
|
:param k_tile_total: The total number of k-tiles
|
|
:type k_tile_total: cutlass.Int32
|
|
:param mma_s_consumer_state: The MMA s consumer state
|
|
:type mma_s_consumer_state: pipeline.PipelineState
|
|
:param p_mma_producer_state: The P MMA producer state
|
|
:type p_mma_producer_state: pipeline.PipelineState
|
|
:param p_cor_producer_state: The P correction producer state
|
|
:type p_cor_producer_state: pipeline.PipelineState
|
|
:param row_max: The row max
|
|
:type row_max: cutlass.Float32
|
|
:param row_sum: The row sum
|
|
:type row_sum: cutlass.Float32
|
|
:param correction_factor: The correction factor
|
|
:type correction_factor: cutlass.Float32
|
|
|
|
:return: The MMA s consumer state, the P MMA producer state, the row max, the row sum, and the correction factor
|
|
:rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, cutlass.Float32, cutlass.Float32, cutlass.Float32]
|
|
"""
|
|
if k_index == k_tile_total - 1:
|
|
(
|
|
mma_s_consumer_state,
|
|
p_mma_producer_state,
|
|
p_cor_producer_state,
|
|
row_max,
|
|
row_sum,
|
|
correction_factor,
|
|
) = self.softmax(
|
|
common_params,
|
|
softmax_params,
|
|
k_index,
|
|
mma_s_consumer_state,
|
|
p_mma_producer_state,
|
|
p_cor_producer_state,
|
|
row_max,
|
|
row_sum,
|
|
correction_factor,
|
|
True,
|
|
)
|
|
else:
|
|
(
|
|
mma_s_consumer_state,
|
|
p_mma_producer_state,
|
|
p_cor_producer_state,
|
|
row_max,
|
|
row_sum,
|
|
correction_factor,
|
|
) = self.softmax(
|
|
common_params,
|
|
softmax_params,
|
|
k_index,
|
|
mma_s_consumer_state,
|
|
p_mma_producer_state,
|
|
p_cor_producer_state,
|
|
row_max,
|
|
row_sum,
|
|
correction_factor,
|
|
False,
|
|
)
|
|
return (
|
|
mma_s_consumer_state,
|
|
p_mma_producer_state,
|
|
p_cor_producer_state,
|
|
row_max,
|
|
row_sum,
|
|
correction_factor,
|
|
)
|
|
|
|
@cute.jit
|
|
def softmax(
|
|
self,
|
|
common_params: SimpleNamespace,
|
|
softmax_params: SimpleNamespace,
|
|
k_index: cutlass.Int32,
|
|
mma_s_consumer_state: pipeline.PipelineState,
|
|
p_mma_producer_state: pipeline.PipelineState,
|
|
p_cor_producer_state: pipeline.PipelineState,
|
|
row_max: cutlass.Float32,
|
|
row_sum: cutlass.Float32,
|
|
correction_factor: cutlass.Float32,
|
|
is_last_tile: bool,
|
|
) -> tuple[
|
|
pipeline.PipelineState,
|
|
pipeline.PipelineState,
|
|
pipeline.PipelineState,
|
|
cutlass.Float32,
|
|
cutlass.Float32,
|
|
cutlass.Float32,
|
|
]:
|
|
"""Softmax for one k-tile. Updates the related pipeline states and returns the computed results.
|
|
|
|
:param common_params: The common parameters
|
|
:type common_params: SimpleNamespace
|
|
:param softmax_params: The softmax parameters
|
|
:type softmax_params: SimpleNamespace
|
|
:param k_index: The index of the k-tile
|
|
:type k_index: cutlass.Int32
|
|
:param mma_s_consumer_state: The MMA s consumer state
|
|
:type mma_s_consumer_state: pipeline.PipelineState
|
|
:param p_mma_producer_state: The P MMA producer state
|
|
:type p_mma_producer_state: pipeline.PipelineState
|
|
:param p_cor_producer_state: The P correction producer state
|
|
:type p_cor_producer_state: pipeline.PipelineState
|
|
:param row_max: The row max
|
|
:type row_max: cutlass.Float32
|
|
:param row_sum: The row sum
|
|
:type row_sum: cutlass.Float32
|
|
:param correction_factor: The correction factor
|
|
:type correction_factor: cutlass.Float32
|
|
:param is_last_tile: Whether the last tile
|
|
:type is_last_tile: bool
|
|
|
|
:return: The MMA s consumer state, the P MMA producer state, the P correction producer state, the row max, the row sum, and the correction factor
|
|
:rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, cutlass.Float32, cutlass.Float32, cutlass.Float32]
|
|
"""
|
|
|
|
softmax_params.p_mma_pipeline.producer_acquire(p_mma_producer_state)
|
|
softmax_params.mma_s_pipeline.consumer_wait(mma_s_consumer_state)
|
|
|
|
# load S from tmem
|
|
tStS_shape = softmax_params.tiled_mma_qk.partition_shape_C(
|
|
cute.select(self.mma_qk_tiler, mode=[0, 1])
|
|
)
|
|
tStS_staged_fake = softmax_params.tiled_mma_qk.make_fragment_C(
|
|
cute.append(tStS_shape, self.mma_s_stage)
|
|
)
|
|
tStS_staged = cute.make_tensor(common_params.tmem_ptr, tStS_staged_fake.layout)
|
|
tStS = tStS_staged[None, None, None, mma_s_consumer_state.index]
|
|
|
|
tAcc = tStS[(None, None), 0, 0]
|
|
cta_qk_tiler = (
|
|
self.mma_qk_tiler[0] // self.cluster_shape_mnk[0],
|
|
self.mma_qk_tiler[1],
|
|
self.mma_qk_tiler[2],
|
|
)
|
|
cS = cute.make_identity_tensor(cute.select(cta_qk_tiler, mode=[0, 1]))
|
|
|
|
tmem_load_atom = cute.make_copy_atom(
|
|
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype
|
|
)
|
|
tmem_tiled_copy = tcgen05.make_tmem_copy(tmem_load_atom, tAcc)
|
|
|
|
tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp)
|
|
|
|
tmem_thr_copy = tmem_tiled_copy.get_slice(tidx)
|
|
tTR_tAcc = tmem_thr_copy.partition_S(tAcc)
|
|
tTR_tS = tmem_thr_copy.partition_D(cS)
|
|
|
|
tTR_rAcc = cute.make_fragment_like(tTR_tS, self.acc_dtype)
|
|
|
|
cute.copy(tmem_tiled_copy, tTR_tAcc, tTR_rAcc)
|
|
|
|
row_max_new = row_max
|
|
for i in cutlass.range_constexpr(cute.size(tTR_rAcc)):
|
|
if cutlass.const_expr(is_last_tile):
|
|
tTR_rAcc[i] = (
|
|
tTR_rAcc[i]
|
|
if cute.elem_less(
|
|
tTR_tS[i][1] + self.mma_qk_tiler[1] * k_index,
|
|
common_params.K,
|
|
)
|
|
else -self.acc_dtype.inf
|
|
)
|
|
# update row_max
|
|
row_max_new = cute.arch.fmax(row_max_new, tTR_rAcc[i])
|
|
|
|
# if warps in N is 2, reduce row_max across warps (0, 1) and (2, 3)
|
|
if cutlass.const_expr(self.warps_in_n == 2):
|
|
common_params.smem_exchange[tidx] = row_max_new
|
|
self.softmax_exchange_sync_bar.wait()
|
|
row_max_new = cute.arch.fmax(
|
|
row_max_new,
|
|
common_params.smem_exchange[
|
|
(tidx + 64) % (self.num_compute_warps * self.threads_per_warp)
|
|
],
|
|
)
|
|
|
|
# find correction factor
|
|
correction_factor = cute.math.exp2(
|
|
(row_max - row_max_new) * softmax_params.softmax_scale_log2, fastmath=True
|
|
)
|
|
no_correction = cutlass.Int32(row_max == row_max_new)
|
|
# softmax
|
|
fma_b = (softmax_params.softmax_scale_log2, softmax_params.softmax_scale_log2)
|
|
fma_c = (
|
|
(0.0 - row_max_new) * softmax_params.softmax_scale_log2,
|
|
(0.0 - row_max_new) * softmax_params.softmax_scale_log2,
|
|
)
|
|
|
|
for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2):
|
|
tTR_rAcc[i], tTR_rAcc[i + 1] = cute.arch.fma_packed_f32x2(
|
|
(tTR_rAcc[i], tTR_rAcc[i + 1]), fma_b, fma_c
|
|
)
|
|
tTR_rAcc[i] = cute.math.exp2(tTR_rAcc[i], fastmath=True)
|
|
tTR_rAcc[i + 1] = cute.math.exp2(tTR_rAcc[i + 1], fastmath=True)
|
|
|
|
tTR_rS = cute.make_fragment_like(tTR_tS, self.q_dtype)
|
|
|
|
# quantize
|
|
tTR_rS.store(tTR_rAcc.load().to(self.q_dtype))
|
|
|
|
# create sP
|
|
sP = softmax_params.sP[None, None, None, (None, p_mma_producer_state.index)]
|
|
sP_mk_view = cute.make_tensor(
|
|
sP.iterator,
|
|
cute.make_layout(
|
|
(
|
|
(sP.shape[0][0], sP.shape[1]),
|
|
(sP.shape[0][1], sP.shape[2], sP.shape[3]),
|
|
),
|
|
stride=(
|
|
(sP.stride[0][0], sP.stride[1]),
|
|
(sP.stride[0][1], sP.stride[2], sP.stride[3]),
|
|
),
|
|
),
|
|
)
|
|
# change to PISL
|
|
sP_wo_swizzle_iter = cute.recast_ptr(sP.iterator, swizzle_=None)
|
|
swizzle_bits = (
|
|
int(math.log2(self.mma_pv_tiler[2] * self.q_dtype.width // 8 // 32)) + 1
|
|
)
|
|
swizzle_base = 3 if self.q_dtype.width == 16 else 4
|
|
sP_swizzle = cute.make_swizzle(swizzle_bits, swizzle_base, 3)
|
|
sP_mk_view = cute.make_tensor(
|
|
sP_wo_swizzle_iter,
|
|
cute.make_composed_layout(sP_swizzle, 0, sP_mk_view.layout),
|
|
)
|
|
universal_copy_bits = 128
|
|
smem_copy_atom = cute.make_copy_atom(
|
|
cute.nvgpu.CopyUniversalOp(),
|
|
self.q_dtype,
|
|
num_bits_per_copy=universal_copy_bits,
|
|
)
|
|
smem_tiled_copy = cute.make_tiled_copy_D(smem_copy_atom, tmem_tiled_copy)
|
|
smem_thr_copy = smem_tiled_copy.get_slice(tidx)
|
|
rP_copy_view = smem_thr_copy.retile(tTR_rS)
|
|
sP_copy_view = smem_thr_copy.partition_D(sP_mk_view)
|
|
cute.copy(smem_tiled_copy, rP_copy_view, sP_copy_view)
|
|
|
|
# row_sum, using `add_packed_f32x2` to reduce the number of instructions
|
|
row_sum = row_sum * correction_factor
|
|
row_sum_vec = (0.0, 0.0)
|
|
for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2):
|
|
row_sum_vec = cute.arch.add_packed_f32x2(
|
|
row_sum_vec, (tTR_rAcc[i], tTR_rAcc[i + 1])
|
|
)
|
|
row_sum = row_sum_vec[0] + row_sum_vec[1] + row_sum
|
|
|
|
# fence between tmem load and mma s
|
|
cute.arch.fence_view_async_tmem_load()
|
|
# fence between smem store and mma o
|
|
cute.arch.fence_view_async_shared()
|
|
|
|
softmax_params.mma_s_pipeline.consumer_release(mma_s_consumer_state)
|
|
softmax_params.p_mma_pipeline.producer_commit(p_mma_producer_state)
|
|
mma_s_consumer_state.advance()
|
|
p_mma_producer_state.advance()
|
|
|
|
# store correction factor/row_sum/row_max to tmem for correction warp
|
|
common_params.p_cor_pipeline.producer_acquire(p_cor_producer_state)
|
|
# pad for 4x32b
|
|
corr_layout = cute.make_layout(
|
|
(tAcc.shape[0], (4, tAcc.shape[1][1]), self.mma_s_stage),
|
|
stride=(tAcc.stride[0], (1, tAcc.stride[1][1]), 4),
|
|
)
|
|
tCor = cute.make_tensor(
|
|
common_params.tmem_ptr + self.correction_factor_offset,
|
|
corr_layout,
|
|
)
|
|
cCor = cute.make_identity_tensor(tCor.shape)
|
|
corr_tmem_store_atom = cute.make_copy_atom(
|
|
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(4)), self.acc_dtype
|
|
)
|
|
corr_tmem_store_tiled_copy = tcgen05.make_tmem_copy(corr_tmem_store_atom, tCor)
|
|
corr_tmem_store_thr_copy = corr_tmem_store_tiled_copy.get_slice(tidx)
|
|
cCor_for_copy = corr_tmem_store_thr_copy.partition_S(cCor)
|
|
tCor_for_copy = corr_tmem_store_thr_copy.partition_D(tCor)
|
|
rCor = cute.make_fragment_like(
|
|
cCor_for_copy[None, None, None, 0], self.acc_dtype
|
|
)
|
|
rCor_int = cute.make_tensor(
|
|
cute.recast_ptr(rCor.iterator, dtype=cutlass.Int32), rCor.layout
|
|
)
|
|
rCor[0] = row_sum
|
|
rCor[1] = row_max_new
|
|
rCor[2] = correction_factor
|
|
rCor_int[3] = no_correction
|
|
|
|
cute.copy(
|
|
corr_tmem_store_tiled_copy,
|
|
rCor,
|
|
tCor_for_copy[None, None, None, p_cor_producer_state.index],
|
|
)
|
|
# fence between tmem store and correction warp
|
|
cute.arch.fence_view_async_tmem_store()
|
|
common_params.p_cor_pipeline.producer_commit(p_cor_producer_state)
|
|
p_cor_producer_state.advance()
|
|
|
|
return (
|
|
mma_s_consumer_state,
|
|
p_mma_producer_state,
|
|
p_cor_producer_state,
|
|
row_max_new,
|
|
row_sum,
|
|
correction_factor,
|
|
)
|
|
|
|
@cute.jit
|
|
def _tmem_load_partition(
|
|
self, common_params: SimpleNamespace, tiled_mma_pv: cute.TiledMma, iter_n: int
|
|
) -> tuple[
|
|
cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma
|
|
]:
|
|
"""Tensor memory load partition for rescale and epilogue.
|
|
|
|
:param common_params: The common parameters
|
|
:type common_params: SimpleNamespace
|
|
:param tiled_mma_pv: The tiled mma pv
|
|
:type tiled_mma_pv: cute.TiledMma
|
|
:param iter_n: The iteration number
|
|
:type iter_n: int
|
|
|
|
:return: The tiled mma pv, the tiled mma pv, the tiled mma pv, the tiled mma pv, the tiled mma pv
|
|
:rtype: tuple[cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma]
|
|
"""
|
|
|
|
tOtO_shape = tiled_mma_pv.partition_shape_C(
|
|
cute.select(self.mma_pv_tiler, mode=[0, 1])
|
|
)
|
|
tOtO = tiled_mma_pv.make_fragment_C(tOtO_shape)
|
|
tOtO_layout = cute.append(
|
|
tOtO.layout,
|
|
cute.make_layout(
|
|
common_params.L // self.mma_pv_tiler[1],
|
|
stride=self.mma_pv_tiler[1] // self.warps_in_n,
|
|
),
|
|
)
|
|
tOtO = cute.make_tensor(
|
|
common_params.tmem_ptr + self.tmem_o_offset, tOtO_layout
|
|
)
|
|
tOtO = tOtO[None, None, None, iter_n]
|
|
|
|
tAcc = tOtO[(None, None), 0, 0]
|
|
|
|
tmem_load_atom = cute.make_copy_atom(
|
|
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype
|
|
)
|
|
tmem_load_tiled_copy = tcgen05.make_tmem_copy(tmem_load_atom, tAcc)
|
|
tmem_load_thr_copy = tmem_load_tiled_copy.get_slice(
|
|
common_params.tidx % (self.num_compute_warps * self.threads_per_warp)
|
|
)
|
|
|
|
cta_pv_tiler = (
|
|
self.mma_pv_tiler[0] // self.cluster_shape_mnk[0],
|
|
self.mma_pv_tiler[1],
|
|
self.mma_pv_tiler[2],
|
|
)
|
|
# Flatten divide and partition global tensors for O
|
|
cta_pv_tiler_mn = cute.select(cta_pv_tiler, mode=[0, 1])
|
|
|
|
gO = None
|
|
if cutlass.const_expr(common_params.mAccO is not None):
|
|
gO = cute.local_tile(
|
|
common_params.mAccO[None, common_params.blk_coord[3], None, None],
|
|
cta_pv_tiler_mn,
|
|
(common_params.blk_coord[0], iter_n, common_params.blk_coord[2]),
|
|
)
|
|
cO = cute.local_tile(
|
|
cute.make_identity_tensor(
|
|
common_params.mAccO[
|
|
None, common_params.blk_coord[3], None, None
|
|
].shape
|
|
),
|
|
cta_pv_tiler_mn,
|
|
(common_params.blk_coord[0], iter_n, common_params.blk_coord[2]),
|
|
)
|
|
else:
|
|
gO = cute.local_tile(
|
|
common_params.mO,
|
|
cta_pv_tiler_mn,
|
|
(common_params.blk_coord[0], iter_n, common_params.blk_coord[2]),
|
|
)
|
|
cO = cute.local_tile(
|
|
cute.make_identity_tensor(common_params.mO.shape),
|
|
cta_pv_tiler_mn,
|
|
(common_params.blk_coord[0], iter_n, common_params.blk_coord[2]),
|
|
)
|
|
tTR_tAcc = tmem_load_thr_copy.partition_S(tAcc)
|
|
tTR_gO = tmem_load_thr_copy.partition_D(gO)
|
|
tTR_cO = tmem_load_thr_copy.partition_D(cO)
|
|
tTR_rAcc = cute.make_fragment_like(tTR_gO, self.acc_dtype)
|
|
return tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc
|
|
|
|
def get_correction_factor(
|
|
self,
|
|
common_params: SimpleNamespace,
|
|
p_cor_consumer_state: pipeline.PipelineState,
|
|
) -> tuple[
|
|
pipeline.PipelineState,
|
|
cutlass.Float32,
|
|
cutlass.Float32,
|
|
cutlass.Float32,
|
|
cutlass.Int32,
|
|
]:
|
|
"""Get the correction factor from the P correction consumer state.
|
|
|
|
:param common_params: The common parameters
|
|
:type common_params: SimpleNamespace
|
|
:param p_cor_consumer_state: The P correction consumer state
|
|
:type p_cor_consumer_state: pipeline.PipelineState
|
|
|
|
:return: The P correction consumer state, the row_sum, the row_max, and the correction factor
|
|
:rtype: tuple[pipeline.PipelineState, cutlass.Float32, cutlass.Float32, cutlass.Float32, cutlass.Int32]
|
|
"""
|
|
common_params.p_cor_pipeline.consumer_wait(p_cor_consumer_state)
|
|
tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp)
|
|
# load correction factor
|
|
_, tAcc, _, _, _, _ = self._tmem_load_partition(
|
|
common_params, common_params.tiled_mma_pv, 0
|
|
)
|
|
corr_layout = cute.make_layout(
|
|
(tAcc.shape[0], (4, tAcc.shape[1][1]), self.p_cor_stage),
|
|
stride=(tAcc.stride[0], (1, tAcc.stride[1][1]), 4),
|
|
)
|
|
tCor = cute.make_tensor(
|
|
common_params.tmem_ptr + self.correction_factor_offset, corr_layout
|
|
)
|
|
cCor = cute.make_identity_tensor(tCor.shape)
|
|
corr_tmem_load_atom = cute.make_copy_atom(
|
|
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(4)), self.acc_dtype
|
|
)
|
|
corr_tmem_load_tiled_copy = tcgen05.make_tmem_copy(corr_tmem_load_atom, tCor)
|
|
corr_tmem_load_thr_copy = corr_tmem_load_tiled_copy.get_slice(tidx)
|
|
tCor_for_copy = corr_tmem_load_thr_copy.partition_S(tCor)
|
|
cCor_for_copy = corr_tmem_load_thr_copy.partition_D(cCor)
|
|
rCor = cute.make_fragment_like(
|
|
cCor_for_copy[None, None, None, 0], self.acc_dtype
|
|
)
|
|
rCor_int = cute.make_tensor(
|
|
cute.recast_ptr(rCor.iterator, dtype=cutlass.Int32), rCor.layout
|
|
)
|
|
cute.copy(
|
|
corr_tmem_load_tiled_copy,
|
|
tCor_for_copy[None, None, None, p_cor_consumer_state.index],
|
|
rCor,
|
|
)
|
|
row_sum = rCor[0]
|
|
row_max = rCor[1]
|
|
correction_factor = rCor[2]
|
|
no_correction = rCor_int[3]
|
|
|
|
common_params.p_cor_pipeline.consumer_release(p_cor_consumer_state)
|
|
p_cor_consumer_state.advance()
|
|
return p_cor_consumer_state, row_sum, row_max, correction_factor, no_correction
|
|
|
|
@cute.jit
|
|
def rescale(
|
|
self,
|
|
common_params: SimpleNamespace,
|
|
mma_o_consumer_state: pipeline.PipelineState,
|
|
correction_factor: cutlass.Float32,
|
|
no_correction: cutlass.Int32,
|
|
) -> pipeline.PipelineState:
|
|
"""Rescale for one k-tile. Updates the related pipeline state.
|
|
|
|
:param common_params: The common parameters
|
|
:type common_params: SimpleNamespace
|
|
:param mma_o_consumer_state: The mma o consumer state
|
|
:type mma_o_consumer_state: pipeline.PipelineState
|
|
:param correction_factor: The correction factor
|
|
:type correction_factor: cutlass.Float32
|
|
:param no_correction: Whether to apply correction factor
|
|
:type no_correction: cutlass.Int32
|
|
|
|
:return: The MMA o consumer state
|
|
:rtype: pipeline.PipelineState
|
|
"""
|
|
|
|
common_params.mma_o_pipeline.consumer_wait(mma_o_consumer_state)
|
|
skip_correction = cute.arch.vote_all_sync(no_correction == 1)
|
|
if not skip_correction:
|
|
for iter_n in cutlass.range_constexpr(self.iterations_pv_n):
|
|
# tmem load tiled copy and partition results.
|
|
tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc = (
|
|
self._tmem_load_partition(
|
|
common_params, common_params.tiled_mma_pv, iter_n
|
|
)
|
|
)
|
|
|
|
# tmem store tiled copy
|
|
tmem_store_atom = cute.make_copy_atom(
|
|
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype
|
|
)
|
|
tmem_store_tiled_copy = tcgen05.make_tmem_copy(tmem_store_atom, tAcc)
|
|
|
|
# load o
|
|
cute.copy(tmem_load_tiled_copy, tTR_tAcc, tTR_rAcc)
|
|
# rescale, using `mul_packed_f32x2` to reduce the number of instructions
|
|
for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2):
|
|
tTR_rAcc[i], tTR_rAcc[i + 1] = cute.arch.mul_packed_f32x2(
|
|
(
|
|
tTR_rAcc[i],
|
|
tTR_rAcc[i + 1],
|
|
),
|
|
(correction_factor, correction_factor),
|
|
)
|
|
|
|
# store o to tensor memory for next k tile
|
|
cute.copy(tmem_store_tiled_copy, tTR_rAcc, tTR_tAcc)
|
|
|
|
cute.arch.fence_view_async_tmem_store()
|
|
common_params.mma_o_pipeline.consumer_release(mma_o_consumer_state)
|
|
mma_o_consumer_state.advance()
|
|
|
|
return mma_o_consumer_state
|
|
|
|
@cute.jit
|
|
def epilogue(
|
|
self,
|
|
common_params: SimpleNamespace,
|
|
epilogue_params: SimpleNamespace,
|
|
mma_o_consumer_state: pipeline.PipelineState,
|
|
row_sum: cutlass.Float32,
|
|
row_max: cutlass.Float32,
|
|
) -> pipeline.PipelineState:
|
|
"""Epilogue for one k-tile. Updates the related pipeline state.
|
|
|
|
:param common_params: The common parameters
|
|
:type common_params: SimpleNamespace
|
|
:param epilogue_params: The epilogue parameters
|
|
:type epilogue_params: SimpleNamespace
|
|
:param mma_o_consumer_state: The mma o consumer state
|
|
:type mma_o_consumer_state: pipeline.PipelineState
|
|
:param row_sum: The row sum
|
|
:type row_sum: cutlass.Float32
|
|
:param row_max: The row max
|
|
:type row_max: cutlass.Float32
|
|
|
|
:return: The MMA o consumer state
|
|
:rtype: pipeline.PipelineState
|
|
"""
|
|
# mma_o pipeline consumer wait
|
|
common_params.mma_o_pipeline.consumer_wait(mma_o_consumer_state)
|
|
|
|
tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp)
|
|
|
|
# exchange row_sum between warps (0, 1) and (2, 3)
|
|
if cutlass.const_expr(self.warps_in_n == 2):
|
|
common_params.smem_exchange[tidx] = row_sum
|
|
self.epilogue_exchange_sync_bar.wait()
|
|
# (64, 2)
|
|
row_sum = (
|
|
row_sum
|
|
+ common_params.smem_exchange[
|
|
(tidx + 64) % (self.num_compute_warps * self.threads_per_warp)
|
|
]
|
|
)
|
|
for iter_n in cutlass.range_constexpr(self.iterations_pv_n):
|
|
# tmem load tiled copy and partition results.
|
|
tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc = (
|
|
self._tmem_load_partition(
|
|
common_params, common_params.tiled_mma_pv, iter_n
|
|
)
|
|
)
|
|
|
|
# load o
|
|
cute.copy(tmem_load_tiled_copy, tTR_tAcc, tTR_rAcc)
|
|
|
|
# apply output scale and normalize by row_sum
|
|
for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2):
|
|
tTR_rAcc[i], tTR_rAcc[i + 1] = cute.arch.mul_packed_f32x2(
|
|
(tTR_rAcc[i], tTR_rAcc[i + 1]),
|
|
(
|
|
epilogue_params.output_scale * cute.arch.rcp_approx(row_sum),
|
|
epilogue_params.output_scale * cute.arch.rcp_approx(row_sum),
|
|
),
|
|
)
|
|
|
|
# store o to global memory
|
|
tR2G_rO_src = None
|
|
tR2G_rO_dst = tTR_gO
|
|
if cutlass.const_expr(common_params.mAccO is None):
|
|
tR2G_rO_src = cute.make_fragment_like(tTR_gO, self.o_dtype)
|
|
# using final output dtype for o
|
|
tR2G_rO_src.store(tTR_rAcc.load().to(self.o_dtype))
|
|
else:
|
|
# using accumulate dtype for o
|
|
tR2G_rO_src = tTR_rAcc
|
|
|
|
if cute.elem_less(tTR_cO[0][0], common_params.H):
|
|
cute.autovec_copy(tR2G_rO_src, tR2G_rO_dst)
|
|
|
|
# store the lse to global memory
|
|
cta_pv_tiler = (
|
|
self.mma_pv_tiler[0] // self.cluster_shape_mnk[0],
|
|
self.mma_pv_tiler[1],
|
|
self.mma_pv_tiler[2],
|
|
)
|
|
gLSE = None
|
|
cLSE = None
|
|
if cutlass.const_expr(epilogue_params.mAccLSE is None):
|
|
gLSE = cute.local_tile(
|
|
epilogue_params.mLSE,
|
|
(cta_pv_tiler[0], 1, 1),
|
|
(
|
|
common_params.blk_coord[0],
|
|
common_params.blk_coord[1],
|
|
common_params.blk_coord[2],
|
|
),
|
|
(1, None, 1),
|
|
)
|
|
cLSE = cute.local_tile(
|
|
cute.make_identity_tensor(epilogue_params.mLSE.shape),
|
|
(cta_pv_tiler[0], 1, 1),
|
|
(
|
|
common_params.blk_coord[0],
|
|
common_params.blk_coord[1],
|
|
common_params.blk_coord[2],
|
|
),
|
|
(1, None, 1),
|
|
)
|
|
|
|
else:
|
|
gLSE = cute.local_tile(
|
|
epilogue_params.mAccLSE[None, common_params.blk_coord[3], None],
|
|
(cta_pv_tiler[0], 1, 1),
|
|
(
|
|
common_params.blk_coord[0],
|
|
common_params.blk_coord[1],
|
|
common_params.blk_coord[2],
|
|
),
|
|
(1, None, 1),
|
|
)
|
|
cLSE = cute.local_tile(
|
|
cute.make_identity_tensor(
|
|
epilogue_params.mAccLSE[
|
|
None, common_params.blk_coord[3], None
|
|
].shape
|
|
),
|
|
(cta_pv_tiler[0], 1, 1),
|
|
(
|
|
common_params.blk_coord[0],
|
|
common_params.blk_coord[1],
|
|
common_params.blk_coord[2],
|
|
),
|
|
(1, None, 1),
|
|
)
|
|
lse = (
|
|
cute.math.log2(row_sum, fastmath=True)
|
|
+ epilogue_params.softmax_scale_log2 * row_max
|
|
)
|
|
if cutlass.const_expr(self.warps_in_n == 2):
|
|
if cute.elem_less(cLSE[tidx][0], common_params.H):
|
|
gLSE[tidx] = lse
|
|
|
|
cute.arch.fence_view_async_tmem_load()
|
|
common_params.mma_o_pipeline.consumer_release(mma_o_consumer_state)
|
|
mma_o_consumer_state.advance()
|
|
|
|
return mma_o_consumer_state
|
|
|
|
def make_and_init_load_qkv_pipeline(
|
|
self, load_qkv_mbar_ptr, cta_layout_vmnk, load_stages, tx_count, is_cpasync
|
|
) -> pipeline.PipelineTmaUmma:
|
|
"""Create and initialize the tma load qkv pipeline.
|
|
|
|
:param load_qkv_mbar_ptr: The load qkv mbar pointer
|
|
:type load_qkv_mbar_ptr: cute.Tensor
|
|
:param cta_layout_vmnk: The cta layout vmnk
|
|
:type cta_layout_vmnk: tuple[int, int, int]
|
|
:param load_stages: The load stages
|
|
:type load_stages: list[int]
|
|
:param tx_count: The tx count
|
|
:type tx_count: int
|
|
:param is_cpasync: Whether to use cpasync
|
|
:type is_cpasync: bool
|
|
|
|
:return: The tma load qkv pipeline
|
|
:rtype: pipeline.PipelineTmaUmma
|
|
"""
|
|
if is_cpasync:
|
|
load_qkv_producer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread,
|
|
len(self.load_cp_async_warp_ids)
|
|
* self.threads_per_warp
|
|
* self.cluster_shape_mnk[0],
|
|
)
|
|
load_qkv_consumer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, len([self.mma_warp_id])
|
|
)
|
|
return pipeline.PipelineAsyncUmma.create(
|
|
barrier_storage=load_qkv_mbar_ptr,
|
|
num_stages=load_stages,
|
|
producer_group=load_qkv_producer_group,
|
|
consumer_group=load_qkv_consumer_group,
|
|
cta_layout_vmnk=cta_layout_vmnk,
|
|
)
|
|
else:
|
|
load_qkv_producer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, len([self.load_tma_warp_id])
|
|
)
|
|
load_qkv_consumer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, len([self.mma_warp_id])
|
|
)
|
|
return pipeline.PipelineTmaUmma.create(
|
|
barrier_storage=load_qkv_mbar_ptr,
|
|
num_stages=load_stages,
|
|
producer_group=load_qkv_producer_group,
|
|
consumer_group=load_qkv_consumer_group,
|
|
tx_count=tx_count,
|
|
cta_layout_vmnk=cta_layout_vmnk,
|
|
)
|
|
|
|
def make_and_init_mma_s_pipeline(
|
|
self, mma_s_mbar_ptr, cta_layout_vmnk
|
|
) -> pipeline.PipelineUmmaAsync:
|
|
"""Create and initialize the mma s pipeline.
|
|
|
|
:param mma_s_mbar_ptr: The mma s mbar pointer
|
|
:type mma_s_mbar_ptr: cute.Tensor
|
|
:param cta_layout_vmnk: The cta layout vmnk
|
|
:type cta_layout_vmnk: tuple[int, int, int]
|
|
|
|
:return: The mma s pipeline
|
|
:rtype: pipeline.PipelineUmmaAsync
|
|
"""
|
|
|
|
mma_s_producer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, len([self.mma_warp_id])
|
|
)
|
|
consumer_thread_size = (
|
|
self.threads_per_warp
|
|
* len(self.compute_warp_ids)
|
|
* self.cluster_shape_mnk[0]
|
|
)
|
|
mma_s_consumer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread,
|
|
consumer_thread_size,
|
|
)
|
|
return pipeline.PipelineUmmaAsync.create(
|
|
barrier_storage=mma_s_mbar_ptr,
|
|
num_stages=self.mma_s_stage,
|
|
producer_group=mma_s_producer_group,
|
|
consumer_group=mma_s_consumer_group,
|
|
cta_layout_vmnk=cta_layout_vmnk,
|
|
)
|
|
|
|
def make_and_init_p_mma_pipeline(
|
|
self, p_mma_mbar_ptr, cta_layout_vmnk
|
|
) -> pipeline.PipelineAsyncUmma:
|
|
"""Create and initialize the p mma pipeline.
|
|
|
|
:param p_mma_mbar_ptr: The p mma mbar pointer
|
|
:type p_mma_mbar_ptr: cute.Tensor
|
|
:param cta_layout_vmnk: The cta layout vmnk
|
|
:type cta_layout_vmnk: tuple[int, int, int]
|
|
|
|
:return: The p mma pipeline
|
|
:rtype: pipeline.PipelineAsyncUmma
|
|
"""
|
|
|
|
producer_thread_size = (
|
|
self.threads_per_warp
|
|
* len(self.compute_warp_ids)
|
|
* self.cluster_shape_mnk[0]
|
|
)
|
|
p_mma_producer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread,
|
|
producer_thread_size,
|
|
)
|
|
p_mma_consumer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, len([self.mma_warp_id])
|
|
)
|
|
return pipeline.PipelineAsyncUmma.create(
|
|
barrier_storage=p_mma_mbar_ptr,
|
|
num_stages=self.p_mma_stage,
|
|
producer_group=p_mma_producer_group,
|
|
consumer_group=p_mma_consumer_group,
|
|
cta_layout_vmnk=cta_layout_vmnk,
|
|
)
|
|
|
|
def make_and_init_p_cor_pipeline(
|
|
self, p_cor_mbar_ptr
|
|
) -> pipeline.PipelineAsyncUmma:
|
|
"""Create and initialize the p correction pipeline.
|
|
|
|
:param p_cor_mbar_ptr: The p correction mbar pointer
|
|
:type p_cor_mbar_ptr: cute.Tensor
|
|
|
|
:return: The p correction pipeline
|
|
:rtype: pipeline.PipelineAsyncUmma
|
|
"""
|
|
|
|
producer_thread_size = self.threads_per_warp * len(self.compute_warp_ids)
|
|
p_cor_producer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread,
|
|
producer_thread_size,
|
|
)
|
|
p_cor_consumer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread,
|
|
producer_thread_size,
|
|
)
|
|
return pipeline.PipelineAsync.create(
|
|
barrier_storage=p_cor_mbar_ptr,
|
|
num_stages=self.p_cor_stage,
|
|
producer_group=p_cor_producer_group,
|
|
consumer_group=p_cor_consumer_group,
|
|
)
|
|
|
|
def make_and_init_mma_o_pipeline(
|
|
self, mma_o_mbar_ptr, cta_layout_vmnk
|
|
) -> pipeline.PipelineUmmaAsync:
|
|
"""Create and initialize the mma o pipeline.
|
|
|
|
:param mma_o_mbar_ptr: The mma o mbar pointer
|
|
:type mma_o_mbar_ptr: cute.Tensor
|
|
:param cta_layout_vmnk: The cta layout vmnk
|
|
:type cta_layout_vmnk: tuple[int, int, int]
|
|
|
|
:return: The mma o pipeline
|
|
:rtype: pipeline.PipelineUmmaAsync
|
|
"""
|
|
|
|
mma_o_producer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, len([self.mma_warp_id])
|
|
)
|
|
consumer_thread_size = (
|
|
self.threads_per_warp
|
|
* len(self.compute_warp_ids)
|
|
* self.cluster_shape_mnk[0]
|
|
)
|
|
mma_o_consumer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread,
|
|
consumer_thread_size,
|
|
)
|
|
return pipeline.PipelineUmmaAsync.create(
|
|
barrier_storage=mma_o_mbar_ptr,
|
|
num_stages=self.mma_o_stage,
|
|
producer_group=mma_o_producer_group,
|
|
consumer_group=mma_o_consumer_group,
|
|
cta_layout_vmnk=cta_layout_vmnk,
|
|
)
|
|
|
|
def make_and_init_load_pt_pipeline(self, load_pt_mbar_ptr):
|
|
"""Create and initialize the load page table pipeline.
|
|
|
|
:param load_pt_mbar_ptr: The load page table mbar pointer
|
|
:type load_pt_mbar_ptr: cute.Tensor
|
|
|
|
:return: The load page table pipeline
|
|
:rtype: pipeline.PipelineAsync
|
|
"""
|
|
load_pt_producer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread,
|
|
self.threads_per_warp * len([self.load_pt_warp_id]),
|
|
)
|
|
load_pt_consumer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread,
|
|
self.threads_per_warp * len(self.load_cp_async_warp_ids),
|
|
)
|
|
return pipeline.PipelineCpAsync.create(
|
|
barrier_storage=load_pt_mbar_ptr,
|
|
num_stages=self.load_pt_stage,
|
|
producer_group=load_pt_producer_group,
|
|
consumer_group=load_pt_consumer_group,
|
|
)
|
|
|
|
@staticmethod
|
|
def _compute_grid(
|
|
o: cute.Tensor,
|
|
split_kv: cutlass.Int32,
|
|
cluster_shape_mnk: Tuple[int, int, int],
|
|
max_active_clusters: int,
|
|
is_persistent: bool,
|
|
) -> Tuple[MLAStaticTileSchedulerParams, Tuple[int, int, int]]:
|
|
"""Compute grid shape for the output tensor C.
|
|
|
|
:param c: The output tensor C
|
|
:type c: cute.Tensor
|
|
:param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile.
|
|
:type cta_tile_shape_mnk: tuple[int, int, int]
|
|
:param cluster_shape_mn: Shape of each cluster in M, N dimensions.
|
|
:type cluster_shape_mn: tuple[int, int]
|
|
|
|
:return: Tile scheduler parameters and grid shape.
|
|
:rtype: tuple[MLAStaticTileSchedulerParams, tuple[int, int, int]]
|
|
"""
|
|
o_shape = o.shape
|
|
tile_sched_params = create_mla_static_tile_scheduler_params(
|
|
is_persistent,
|
|
cute.size(o_shape[2]),
|
|
cluster_shape_mnk,
|
|
split_kv,
|
|
)
|
|
grid = MLAStaticTileScheduler.get_grid_shape(
|
|
tile_sched_params, max_active_clusters
|
|
)
|
|
|
|
return tile_sched_params, grid
|
|
|
|
@staticmethod
|
|
def get_workspace_size(
|
|
H: int,
|
|
D: int,
|
|
B: int,
|
|
split_kv: int,
|
|
acc_dtype: Type[cutlass.Numeric],
|
|
) -> int:
|
|
"""Get the extra workspace(device memory) size for the MLA kernel when split_kv is not 1.
|
|
|
|
:param H: The height of the output tensor C
|
|
:type H: int
|
|
:param D: The depth of the output tensor C
|
|
:type D: int
|
|
:param B: The batch size of the output tensor C
|
|
:type B: int
|
|
:param split_kv: The split key-value of the output tensor C
|
|
:type split_kv: int
|
|
:param acc_dtype: The data type of the output tensor C
|
|
:type acc_dtype: Type[cutlass.Numeric]
|
|
|
|
:return: The workspace size for the MLA kernel
|
|
:rtype: int
|
|
"""
|
|
if split_kv == 1:
|
|
return 0
|
|
return B * H * split_kv * (D + 1) * acc_dtype.width // 8
|
|
|
|
@cute.jit
|
|
def initialize_workspace(
|
|
self,
|
|
H: cutlass.Int32,
|
|
D: cutlass.Int32,
|
|
B: cutlass.Int32,
|
|
split_kv: cutlass.Int32,
|
|
acc_dtype: Type[cutlass.Numeric],
|
|
workspace: cute.Tensor,
|
|
) -> tuple[cute.Tensor, cute.Tensor]:
|
|
"""Initialize the workspace for the MLA kernel. Construct the intermediate tensors
|
|
acc_o and acc_lse.
|
|
|
|
:param H: The height of the output tensor C
|
|
:type H: cutlass.Int32
|
|
:param D: The depth of the output tensor C
|
|
:type D: cutlass.Int32
|
|
:param B: The batch size of the output tensor C
|
|
:type B: cutlass.Int32
|
|
:param split_kv: The split key-value of the output tensor C
|
|
:type split_kv: cutlass.Int32
|
|
:param acc_dtype: The data type of the output tensor C
|
|
:type acc_dtype: Type[cutlass.Numeric]
|
|
:param workspace: The workspace tensor
|
|
:type workspace: cute.Tensor
|
|
|
|
:return: The output tensor C and the workspace tensor
|
|
:rtype: tuple[cute.Tensor, cute.Tensor]
|
|
"""
|
|
acc_o, acc_lse = None, None
|
|
if cutlass.const_expr(workspace is not None):
|
|
align = 128 // self.q_dtype.width
|
|
acc_o_layout = cute.make_layout(
|
|
(H, split_kv, D, B),
|
|
stride=(
|
|
cute.assume(split_kv * D, align),
|
|
cute.assume(D, align),
|
|
1,
|
|
cute.assume(H * split_kv * D, align),
|
|
),
|
|
)
|
|
acc_o_iter = cute.recast_ptr(workspace.iterator, dtype=acc_dtype)
|
|
acc_o = cute.make_tensor(acc_o_iter, acc_o_layout)
|
|
acc_lse_layout = cute.make_layout(
|
|
(H, split_kv, B), stride=(split_kv, 1, H * split_kv)
|
|
)
|
|
acc_lse_iter = cute.recast_ptr(
|
|
workspace.iterator + cute.cosize(acc_o_layout) * acc_dtype.width // 8,
|
|
dtype=acc_dtype,
|
|
)
|
|
acc_lse = cute.make_tensor(acc_lse_iter, acc_lse_layout)
|
|
return acc_o, acc_lse
|
|
|
|
@staticmethod
|
|
def can_implement(
|
|
B: int,
|
|
K: int,
|
|
H: int,
|
|
L: int,
|
|
R: int,
|
|
in_dtype: Type[cutlass.Numeric],
|
|
out_dtype: Type[cutlass.Numeric],
|
|
acc_dtype: Type[cutlass.Numeric],
|
|
lse_dtype: Type[cutlass.Numeric],
|
|
mma_qk_tiler_mn: Tuple[int, int],
|
|
mma_pv_tiler_mn: Tuple[int, int],
|
|
split_kv: int,
|
|
is_persistent: bool,
|
|
is_cpasync: bool,
|
|
is_var_seq: bool,
|
|
is_var_split_kv: bool,
|
|
use_page_table: bool,
|
|
page_size: int,
|
|
) -> bool:
|
|
"""Check if the MLA kernel can be implemented.
|
|
|
|
:param H: The height of the output tensor C
|
|
:type H: int
|
|
:param K: The width of the output tensor C
|
|
:type K: int
|
|
:param L: The length of the output tensor C
|
|
:type L: int
|
|
:param R: The row of the output tensor C
|
|
:type R: int
|
|
:param B: The batch size of the output tensor C
|
|
:type B: int
|
|
:param in_dtype: The data type of the input tensor
|
|
:type in_dtype: Type[cutlass.Numeric]
|
|
:param out_dtype: The data type of the output tensor
|
|
:type out_dtype: Type[cutlass.Numeric]
|
|
:param acc_dtype: The data type of the accumulator
|
|
:type acc_dtype: Type[cutlass.Numeric]
|
|
:param lse_dtype: The data type of the log-sum-exp
|
|
:type lse_dtype: Type[cutlass.Numeric]
|
|
:param mma_qk_tiler_mn: The tile shape of the query-key matrix multiplication
|
|
:type mma_qk_tiler_mn: Tuple[int, int]
|
|
:param mma_pv_tiler_mn: The tile shape of the probability-value matrix multiplication
|
|
:type mma_pv_tiler_mn: Tuple[int, int]
|
|
:param split_kv: The split key-value of the output tensor C
|
|
:type split_kv: int
|
|
:param is_persistent: Whether to use persistent kernel optimization
|
|
:type is_persistent: bool
|
|
:param is_cpasync: Whether to use cpasync
|
|
:type is_cpasync: bool
|
|
:param is_var_seq: Whether to use variable sequence length
|
|
:type is_var_seq: bool
|
|
:param is_var_split_kv: Whether to use variable split_kv
|
|
:type is_var_split_kv: bool
|
|
:param use_page_table: Whether to use page table
|
|
:type use_page_table: bool
|
|
:param page_size: The page size of the page table
|
|
:type page_size: int
|
|
|
|
:return: Whether the MLA kernel can be implemented
|
|
:rtype: bool
|
|
"""
|
|
if L != 512 or R != 64:
|
|
return False
|
|
if in_dtype not in [cutlass.Float8E4M3FN, cutlass.Float16]:
|
|
return False
|
|
if out_dtype not in [cutlass.Float8E4M3FN, cutlass.Float16]:
|
|
return False
|
|
if acc_dtype != cutlass.Float32 or lse_dtype != cutlass.Float32:
|
|
return False
|
|
if is_cpasync:
|
|
if not use_page_table:
|
|
return False
|
|
if page_size & (page_size - 1) != 0:
|
|
return False
|
|
if page_size > mma_qk_tiler_mn[1]:
|
|
return False
|
|
else:
|
|
if use_page_table and page_size != mma_qk_tiler_mn[1]:
|
|
return False
|
|
if mma_qk_tiler_mn[0] != mma_pv_tiler_mn[0] or mma_qk_tiler_mn[0] != 128:
|
|
return False
|
|
if is_var_split_kv and (not use_page_table or not is_var_seq):
|
|
return False
|
|
if is_var_seq and not use_page_table:
|
|
return False
|
|
if not is_cpasync and (H > 128 or (H < 128 and split_kv != 1)):
|
|
return False
|
|
if is_cpasync and H != 128:
|
|
return False
|
|
if K <= 0:
|
|
return False
|
|
return True
|
|
|
|
|
|
def ceil_div(a: int, b: int) -> int:
|
|
return (a + b - 1) // b
|
|
|
|
|
|
def run(
|
|
batch_size: int,
|
|
seq_len: int,
|
|
num_heads: int,
|
|
latent_dim: int,
|
|
rope_dim: int,
|
|
in_dtype: Type[cutlass.Numeric],
|
|
out_dtype: Type[cutlass.Numeric],
|
|
acc_dtype: Type[cutlass.Numeric],
|
|
lse_dtype: Type[cutlass.Numeric],
|
|
mma_qk_tiler_mn: Tuple[int, int],
|
|
mma_pv_tiler_mn: Tuple[int, int],
|
|
split_kv: int,
|
|
is_persistent: bool,
|
|
is_cpasync: bool,
|
|
is_var_seq: bool,
|
|
is_var_split_kv: bool,
|
|
use_page_table: bool,
|
|
page_size: int,
|
|
softmax_scale: float,
|
|
output_scale: float,
|
|
tolerance: float,
|
|
warmup_iterations: int,
|
|
iterations: int,
|
|
skip_ref_check: bool,
|
|
use_cold_l2: bool,
|
|
**kwargs,
|
|
):
|
|
"""Execute Multi-Head Latent Attention (MLA) on Blackwell architecture and validate results.
|
|
|
|
This function creates random input tensors for query latent/rope, compressed latent/rope, and value,
|
|
then performs the complete MLA computation pipeline. It supports configurable data types, tiling parameters,
|
|
page table, variable sequence length, and variable split_kv. Results can be validated against a PyTorch reference
|
|
implementation or run multiple times for performance measurement.
|
|
|
|
:param batch_size: Batch size
|
|
:type batch_size: int
|
|
:param seq_len: Sequence length
|
|
:type seq_len: int
|
|
:param num_heads: Number of heads
|
|
:type num_heads: int
|
|
:param latent_dim: dimension of query/compressed latent
|
|
:type latent_dim: int
|
|
:param rope_dim: dimension of query/compressed rope
|
|
:type rope_dim: int
|
|
:param in_dtype: Input data type for query/compressed latent/rope tensors
|
|
:type in_dtype: Type[cutlass.Numeric]
|
|
:param out_dtype: Output data type for attention output
|
|
:type out_dtype: Type[cutlass.Numeric]
|
|
:param acc_dtype: Accumulator data type for query-key matrix multiplication
|
|
:type acc_dtype: Type[cutlass.Numeric]
|
|
:param lse_dtype: Accumulator data type for log-sum-exp
|
|
:type lse_dtype: Type[cutlass.Numeric]
|
|
:param mma_qk_tiler_mn: Matrix multiply accumulate tile shape (M, N) for query-key matrix multiplication
|
|
:type mma_qk_tiler_mn: Tuple[int, int]
|
|
:param mma_pv_tiler_mn: Matrix multiply accumulate tile shape (M, N) for probability-value matrix multiplication
|
|
:type mma_pv_tiler_mn: Tuple[int, int]
|
|
:param split_kv: Split key-value
|
|
:type split_kv: int
|
|
:param is_persistent: Whether to use persistent kernel optimization
|
|
:type is_persistent: bool
|
|
:param is_cpasync: Whether to use cpasync
|
|
:type is_cpasync: bool
|
|
:param is_var_seq: Whether to use variable sequence length
|
|
:type is_var_seq: bool
|
|
:param is_var_split_kv: Whether to use variable split_kv
|
|
:type is_var_split_kv: bool
|
|
:param use_page_table: Whether to use page table
|
|
:type use_page_table: bool
|
|
:param page_size: Page size of the page table
|
|
:type page_size: int
|
|
:param softmax_scale: Attention score scaling factor
|
|
:type softmax_scale: float
|
|
:param output_scale: Output scaling factor
|
|
:type output_scale: float
|
|
:param tolerance: Maximum acceptable error for validation
|
|
:type tolerance: float
|
|
:param warmup_iterations: Number of warmup iterations
|
|
:type warmup_iterations: int
|
|
:param iterations: Number of iterations to run for performance testing
|
|
:type iterations: int
|
|
:param skip_ref_check: Skip validation against reference implementation
|
|
:type skip_ref_check: bool
|
|
:param use_cold_l2: Whether to use cold L2 cache
|
|
:type use_cold_l2: bool
|
|
|
|
:raises ValueError: If input shapes are incompatible or head dimension is unsupported
|
|
:raises RuntimeError: If GPU is unavailable for computation
|
|
"""
|
|
|
|
print("Running Blackwell MLA test with:")
|
|
print(f" batch_size: {batch_size}")
|
|
print(f" seq_len: {seq_len}")
|
|
print(f" num_heads: {num_heads}")
|
|
print(f" latent_dim: {latent_dim}")
|
|
print(f" rope_dim: {rope_dim}")
|
|
print(f" in_dtype: {in_dtype}")
|
|
print(f" out_dtype: {out_dtype}")
|
|
print(f" acc_dtype: {acc_dtype}")
|
|
print(f" mma_qk_tiler_mn: {mma_qk_tiler_mn}")
|
|
print(f" mma_pv_tiler_mn: {mma_pv_tiler_mn}")
|
|
print(f" split_kv: {split_kv}")
|
|
print(f" is_persistent: {is_persistent}")
|
|
print(f" is_cpasync: {is_cpasync}")
|
|
print(f" is_var_seq: {is_var_seq}")
|
|
print(f" is_var_split_kv: {is_var_split_kv}")
|
|
print(f" use_page_table: {use_page_table}")
|
|
print(f" page_size: {page_size}")
|
|
print(f" softmax_scale: {softmax_scale}")
|
|
print(f" output_scale: {output_scale}")
|
|
print(f" tolerance: {tolerance}")
|
|
print(f" warmup_iterations: {warmup_iterations}")
|
|
print(f" iterations: {iterations}")
|
|
print(f" skip_ref_check: {skip_ref_check}")
|
|
print(f" use_cold_l2: {use_cold_l2}")
|
|
|
|
# Prepare pytorch tensors: Q, K, V (random from 0 to 2) and O (all zero)
|
|
if not torch.cuda.is_available():
|
|
raise RuntimeError("GPU is required to run this example!")
|
|
|
|
if not BlackwellMultiHeadLatentAttentionForward.can_implement(
|
|
batch_size,
|
|
seq_len,
|
|
num_heads,
|
|
latent_dim,
|
|
rope_dim,
|
|
in_dtype,
|
|
out_dtype,
|
|
acc_dtype,
|
|
lse_dtype,
|
|
mma_qk_tiler_mn,
|
|
mma_pv_tiler_mn,
|
|
split_kv,
|
|
is_persistent,
|
|
is_cpasync,
|
|
is_var_seq,
|
|
is_var_split_kv,
|
|
use_page_table,
|
|
page_size,
|
|
):
|
|
raise TypeError(
|
|
f"Unsupported testcase {in_dtype}, {out_dtype}, {acc_dtype}, {lse_dtype}, {mma_qk_tiler_mn}, {mma_pv_tiler_mn}, {split_kv}, {is_persistent}, {is_cpasync}, {is_var_seq}, {is_var_split_kv}, {use_page_table}, {page_size}"
|
|
)
|
|
|
|
torch.manual_seed(1111)
|
|
|
|
def create_data_tensor(
|
|
B,
|
|
HK,
|
|
D,
|
|
dtype,
|
|
is_dynamic_layout=True,
|
|
page_table=None,
|
|
cache_seqs=None,
|
|
is_lse=False,
|
|
):
|
|
shape = (B, HK, D)
|
|
if page_table is not None:
|
|
if cache_seqs is not None:
|
|
max_seq_len = torch.max(cache_seqs)
|
|
shape = (B * ceil_div(max_seq_len, page_size), page_size, D)
|
|
else:
|
|
shape = (B * ceil_div(HK, page_size), page_size, D)
|
|
|
|
permute_order = (1, 2, 0)
|
|
stride_order = (2, 0, 1)
|
|
leading_dim = 1
|
|
if is_lse:
|
|
shape = (B, HK)
|
|
permute_order = (1, 0)
|
|
stride_order = (1, 0)
|
|
leading_dim = 0
|
|
|
|
init_config = cutlass.torch.RandomInitConfig(min_val=-2, max_val=2)
|
|
|
|
torch_dtype = (
|
|
cutlass_torch.dtype(dtype) if dtype != cutlass.Float8E4M3FN else torch.int8
|
|
)
|
|
|
|
# Create dtype torch tensor (cpu)
|
|
torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor(
|
|
shape,
|
|
torch_dtype,
|
|
permute_order=permute_order,
|
|
init_type=cutlass.torch.TensorInitType.RANDOM,
|
|
init_config=init_config,
|
|
)
|
|
|
|
# Create dtype torch tensor (gpu)
|
|
torch_tensor_gpu = torch_tensor_cpu.cuda()
|
|
|
|
# Create f32 torch tensor (cpu)
|
|
f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32)
|
|
|
|
# Create dtype cute tensor (gpu)
|
|
# Set use_32bit_stride to True for small problem size(cosize(layout) <= Int32_max) for better performance.
|
|
cute_tensor = from_dlpack(
|
|
torch_tensor_gpu, assumed_align=16, use_32bit_stride=True
|
|
)
|
|
cute_tensor.element_type = dtype
|
|
if is_dynamic_layout:
|
|
cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=leading_dim)
|
|
if not is_lse:
|
|
cute_tensor = cute_tensor.mark_compact_shape_dynamic(
|
|
mode=leading_dim,
|
|
stride_order=stride_order,
|
|
divisibility=(128 // dtype.width),
|
|
)
|
|
|
|
cute_tensor = cutlass_torch.convert_cute_tensor(
|
|
f32_torch_tensor,
|
|
cute_tensor,
|
|
dtype,
|
|
is_dynamic_layout=is_dynamic_layout,
|
|
)
|
|
|
|
return f32_torch_tensor, cute_tensor, torch_tensor_gpu
|
|
|
|
def create_cache_seqs(batch_size, seq_len, is_var_seq):
|
|
cache_seqs_ref = torch.ones(batch_size, dtype=torch.int32) * seq_len
|
|
cache_seqs_gpu = cache_seqs_ref.cuda()
|
|
# Set use_32bit_stride to True for small problem size(cosize(layout) <= Int32_max) for better performance.
|
|
cache_seqs = from_dlpack(
|
|
cache_seqs_gpu, assumed_align=16, use_32bit_stride=True
|
|
).mark_layout_dynamic()
|
|
if is_var_seq:
|
|
max_seq_len = seq_len
|
|
min_seq_len = int(seq_len * 0.8)
|
|
cache_seqs_ref = cutlass_torch.create_and_permute_torch_tensor(
|
|
(batch_size,),
|
|
torch.int32,
|
|
init_type=cutlass.torch.TensorInitType.RANDOM,
|
|
init_config=cutlass.torch.RandomInitConfig(
|
|
min_val=min_seq_len, max_val=max_seq_len + 1
|
|
),
|
|
)
|
|
cache_seqs_gpu = cache_seqs_ref.cuda()
|
|
cache_seqs = from_dlpack(
|
|
# Set use_32bit_stride to True for small problem size(cosize(layout) <= Int32_max) for better performance.
|
|
cache_seqs_gpu,
|
|
assumed_align=16,
|
|
use_32bit_stride=True,
|
|
).mark_layout_dynamic()
|
|
return cache_seqs_ref, cache_seqs, cache_seqs_gpu
|
|
|
|
def create_page_table(batch_size, seq_len, is_var_seq, use_page_table, page_size):
|
|
page_table_ref, page_table, page_table_gpu = None, None, None
|
|
if use_page_table:
|
|
max_seq_len = seq_len if not is_var_seq else torch.max(cache_seqs_ref)
|
|
page_count = ceil_div(max_seq_len, page_size)
|
|
page_table_ref = torch.empty([batch_size, page_count], dtype=torch.int32)
|
|
# use transposed index for page table to make sure the value is in bound of `batch_size * seq_len_block`. In practice, the value could be any positive values. This setting is only for testing purpose.
|
|
for b in range(batch_size):
|
|
for j in range(page_count):
|
|
page_table_ref[b, j] = b + j * batch_size
|
|
page_table_gpu = page_table_ref.permute(1, 0).cuda()
|
|
# Set use_32bit_stride to True for small problem size(cosize(layout) <= Int32_max) for better performance.
|
|
page_table = from_dlpack(
|
|
page_table_gpu, assumed_align=16, use_32bit_stride=True
|
|
).mark_layout_dynamic(leading_dim=0)
|
|
return page_table_ref, page_table, page_table_gpu
|
|
|
|
def create_block_split_kvs(
|
|
batch_size,
|
|
split_kv,
|
|
cache_seqs_ref,
|
|
is_var_split_kv,
|
|
mma_qk_tiler_mn,
|
|
cluster_shape_mnk,
|
|
max_active_clusters,
|
|
):
|
|
block_split_kvs_ref, block_split_kvs, block_split_kvs_gpu = None, None, None
|
|
# check if split_kv is valid otherwise do auto setting of split_kv
|
|
if is_var_split_kv:
|
|
block_split_kvs_ref = torch.zeros([batch_size], dtype=torch.int32)
|
|
for b in range(batch_size):
|
|
block_split_kvs_ref[b] = (
|
|
BlackwellMultiHeadLatentAttentionForward.get_split_kv(
|
|
batch_size,
|
|
cache_seqs_ref[b].item(),
|
|
mma_qk_tiler_mn,
|
|
max_active_clusters * cluster_shape_mnk[0],
|
|
)
|
|
)
|
|
split_kv = torch.max(block_split_kvs_ref).item()
|
|
block_split_kvs_gpu = block_split_kvs_ref.cuda()
|
|
# Set use_32bit_stride to True for small problem size(cosize(layout) <= Int32_max) for better performance.
|
|
block_split_kvs = from_dlpack(
|
|
block_split_kvs_gpu, assumed_align=16, use_32bit_stride=True
|
|
).mark_layout_dynamic()
|
|
elif split_kv <= 0:
|
|
split_kv = BlackwellMultiHeadLatentAttentionForward.get_split_kv(
|
|
batch_size,
|
|
cache_seqs_ref[0].item(),
|
|
mma_qk_tiler_mn,
|
|
max_active_clusters * cluster_shape_mnk[0],
|
|
)
|
|
return split_kv, block_split_kvs_ref, block_split_kvs, block_split_kvs_gpu
|
|
|
|
def create_workspace(num_heads, latent_dim, batch_size, split_kv, acc_dtype):
|
|
workspace_size = BlackwellMultiHeadLatentAttentionForward.get_workspace_size(
|
|
num_heads,
|
|
latent_dim,
|
|
batch_size,
|
|
split_kv,
|
|
acc_dtype,
|
|
)
|
|
|
|
workspace, workspace_torch = None, None
|
|
if workspace_size > 0:
|
|
workspace_torch = torch.empty([workspace_size], dtype=torch.int8).cuda()
|
|
# Set use_32bit_stride to True for small problem size(cosize(layout) <= Int32_max) for better performance.
|
|
workspace = from_dlpack(
|
|
workspace_torch, assumed_align=16, use_32bit_stride=True
|
|
)
|
|
return workspace, workspace_torch
|
|
|
|
cache_seqs_ref, cache_seqs, cache_seqs_torch = create_cache_seqs(
|
|
batch_size, seq_len, is_var_seq
|
|
)
|
|
page_table_ref, page_table, page_table_torch = create_page_table(
|
|
batch_size, seq_len, is_var_seq, use_page_table, page_size
|
|
)
|
|
cluster_shape_mnk = (2, 1, 1)
|
|
hardware_info = utils.HardwareInfo()
|
|
max_active_clusters = hardware_info.get_max_active_clusters(
|
|
cluster_shape_mnk[0] * cluster_shape_mnk[1]
|
|
)
|
|
split_kv, block_split_kvs_ref, block_split_kvs, block_split_kvs_torch = (
|
|
create_block_split_kvs(
|
|
batch_size,
|
|
split_kv,
|
|
cache_seqs_ref,
|
|
is_var_split_kv,
|
|
mma_qk_tiler_mn,
|
|
cluster_shape_mnk,
|
|
max_active_clusters,
|
|
)
|
|
)
|
|
|
|
q_latent_ref, q_latent, q_latent_torch = create_data_tensor(
|
|
batch_size,
|
|
num_heads,
|
|
latent_dim,
|
|
in_dtype,
|
|
is_dynamic_layout=True,
|
|
)
|
|
q_rope_ref, q_rope, q_rope_torch = create_data_tensor(
|
|
batch_size,
|
|
num_heads,
|
|
rope_dim,
|
|
in_dtype,
|
|
is_dynamic_layout=True,
|
|
)
|
|
|
|
c_latent_ref, c_latent, c_latent_torch = create_data_tensor(
|
|
batch_size,
|
|
seq_len,
|
|
latent_dim,
|
|
in_dtype,
|
|
is_dynamic_layout=True,
|
|
page_table=page_table,
|
|
cache_seqs=cache_seqs_ref,
|
|
)
|
|
c_rope_ref, c_rope, c_rope_torch = create_data_tensor(
|
|
batch_size,
|
|
seq_len,
|
|
rope_dim,
|
|
in_dtype,
|
|
is_dynamic_layout=True,
|
|
page_table=page_table,
|
|
cache_seqs=cache_seqs_ref,
|
|
)
|
|
o_ref, o, o_torch = create_data_tensor(
|
|
batch_size, num_heads, latent_dim, out_dtype, is_dynamic_layout=True
|
|
)
|
|
lse_ref, lse, lse_torch = create_data_tensor(
|
|
batch_size, num_heads, 1, lse_dtype, is_dynamic_layout=True, is_lse=True
|
|
)
|
|
workspace, workspace_torch = create_workspace(
|
|
num_heads, latent_dim, batch_size, split_kv, acc_dtype
|
|
)
|
|
|
|
mla = BlackwellMultiHeadLatentAttentionForward(
|
|
acc_dtype,
|
|
lse_dtype,
|
|
mma_qk_tiler_mn,
|
|
mma_pv_tiler_mn,
|
|
max_active_clusters,
|
|
is_persistent,
|
|
is_cpasync,
|
|
use_page_table,
|
|
is_var_seq,
|
|
is_var_split_kv,
|
|
)
|
|
|
|
# Get current CUDA stream from PyTorch
|
|
torch_stream = torch.cuda.current_stream()
|
|
# Get the raw stream pointer as a CUstream
|
|
stream = cuda.CUstream(torch_stream.cuda_stream)
|
|
|
|
# compile mla kernel
|
|
compiled_mla = cute.compile(
|
|
mla,
|
|
q_latent,
|
|
q_rope,
|
|
c_latent,
|
|
c_rope,
|
|
page_table,
|
|
o,
|
|
lse,
|
|
workspace,
|
|
split_kv,
|
|
cache_seqs,
|
|
block_split_kvs,
|
|
softmax_scale,
|
|
output_scale,
|
|
stream,
|
|
)
|
|
|
|
def torch_reference_mla(
|
|
q_latent,
|
|
q_rope,
|
|
c_latent,
|
|
c_rope,
|
|
page_table,
|
|
cache_seqs,
|
|
softmax_scale=1.0,
|
|
output_scale=1.0,
|
|
):
|
|
# expand and concat q_latent and q_rope to have the dimension of sequence length for q
|
|
q_ref = torch.cat([q_latent, q_rope], dim=1).permute(2, 0, 1).unsqueeze(2)
|
|
# expand and concat c_latent and c_rope to have the dimension of num_heads for k and v
|
|
if use_page_table:
|
|
page_count = page_table_ref.shape[1]
|
|
k_ref_paged = (
|
|
torch.cat([c_latent, c_rope], dim=1)
|
|
.permute(2, 0, 1)
|
|
.reshape(batch_size * page_count, page_size, latent_dim + rope_dim)
|
|
)
|
|
v_ref_paged = c_latent.permute(2, 0, 1).reshape(
|
|
batch_size * page_count, page_size, latent_dim
|
|
)
|
|
|
|
if is_var_seq:
|
|
max_seq_len = torch.max(cache_seqs_ref)
|
|
else:
|
|
max_seq_len = seq_len
|
|
|
|
k_ref = torch.zeros([batch_size, 1, max_seq_len, latent_dim + rope_dim])
|
|
v_ref = torch.zeros([batch_size, 1, max_seq_len, latent_dim])
|
|
k_ref = torch.index_select(
|
|
k_ref_paged, 0, torch.flatten(page_table_ref)
|
|
).reshape(batch_size, 1, -1, latent_dim + rope_dim)[:, :, :max_seq_len, :]
|
|
v_ref = torch.index_select(
|
|
v_ref_paged, 0, torch.flatten(page_table_ref)
|
|
).reshape(batch_size, 1, -1, latent_dim)[:, :, :max_seq_len, :]
|
|
for b in range(batch_size):
|
|
k_ref[b, :, cache_seqs_ref[b] :, :] = 0
|
|
v_ref[b, :, cache_seqs_ref[b] :, :] = 0
|
|
else:
|
|
k_ref = torch.cat([c_latent, c_rope], dim=1).permute(2, 0, 1).unsqueeze(1)
|
|
v_ref = c_latent.permute(2, 0, 1).unsqueeze(1)
|
|
|
|
o_ref = F.scaled_dot_product_attention(
|
|
q_ref,
|
|
k_ref,
|
|
v_ref,
|
|
attn_mask=None,
|
|
dropout_p=0.0,
|
|
scale=softmax_scale,
|
|
is_causal=False,
|
|
)
|
|
s_ref = torch.einsum("bhld,bhsd->bhls", q_ref, k_ref)
|
|
s_ref_max = torch.max(s_ref, dim=-1, keepdim=True).values
|
|
softmax_scale_log2 = LOG2_E * softmax_scale
|
|
s_ref_sum = torch.sum(
|
|
torch.exp2((s_ref - s_ref_max) * softmax_scale_log2), dim=-1, keepdim=True
|
|
)
|
|
lse_ref = s_ref_max * softmax_scale_log2 + torch.log2(s_ref_sum)
|
|
lse_ref = lse_ref.squeeze(3).squeeze(2).permute(1, 0)
|
|
o_ref = o_ref * output_scale
|
|
o_ref = o_ref.squeeze(2).permute(1, 2, 0)
|
|
|
|
return o_ref, lse_ref
|
|
|
|
if not skip_ref_check:
|
|
# Execute kernel once for reference checking
|
|
compiled_mla(
|
|
q_latent,
|
|
q_rope,
|
|
c_latent,
|
|
c_rope,
|
|
page_table,
|
|
o,
|
|
lse,
|
|
workspace,
|
|
split_kv,
|
|
cache_seqs,
|
|
block_split_kvs,
|
|
softmax_scale,
|
|
output_scale,
|
|
stream,
|
|
)
|
|
torch.cuda.synchronize()
|
|
print("Verifying results...")
|
|
if in_dtype == cutlass.Float8E4M3FN:
|
|
tolerance = 0.13
|
|
o_ref, lse_ref = torch_reference_mla(
|
|
q_latent_ref,
|
|
q_rope_ref,
|
|
c_latent_ref,
|
|
c_rope_ref,
|
|
page_table,
|
|
cache_seqs,
|
|
softmax_scale,
|
|
output_scale,
|
|
)
|
|
|
|
if out_dtype in [cutlass.Float8E5M2, cutlass.Float8E4M3FN]:
|
|
# convert o back to f32 for comparison
|
|
o_fp32, o_fp32_torch = cutlass_torch.cute_tensor_like(
|
|
torch.empty(*o_torch.shape, dtype=torch.float32),
|
|
cutlass.Float32,
|
|
is_dynamic_layout=True,
|
|
assumed_align=16,
|
|
)
|
|
cute.testing.convert(o, o_fp32)
|
|
o = o_fp32_torch.cpu()
|
|
ref_fp8, _ = cutlass_torch.cute_tensor_like(
|
|
torch.empty(*o_ref.permute(2, 0, 1).shape, dtype=torch.uint8).permute(
|
|
1, 2, 0
|
|
),
|
|
out_dtype,
|
|
is_dynamic_layout=True,
|
|
assumed_align=16,
|
|
)
|
|
o_ref_gpu = o_ref.cuda()
|
|
# Set use_32bit_stride to True for small problem size(cosize(layout) <= Int32_max) for better performance.
|
|
o_ref_f32 = from_dlpack(
|
|
o_ref_gpu, use_32bit_stride=True
|
|
).mark_layout_dynamic(leading_dim=1)
|
|
|
|
# convert ref : f32 -> fp8 -> f32
|
|
cute.testing.convert(o_ref_f32, ref_fp8)
|
|
cute.testing.convert(ref_fp8, o_ref_f32)
|
|
|
|
o_ref = o_ref_gpu.cpu()
|
|
else:
|
|
o = o_torch.cpu().to(torch.float32)
|
|
lse = lse_torch.cpu()
|
|
lse_ref = lse_ref.to(cutlass.torch.dtype(lse_dtype))
|
|
# Assert close results
|
|
torch.testing.assert_close(o, o_ref, atol=tolerance, rtol=1e-05)
|
|
torch.testing.assert_close(lse, lse_ref, atol=tolerance, rtol=1e-05)
|
|
print("Results verified successfully!")
|
|
|
|
def generate_tensors():
|
|
_, cache_seqs, _ = create_cache_seqs(batch_size, seq_len, is_var_seq)
|
|
_, page_table, _ = create_page_table(
|
|
batch_size, seq_len, is_var_seq, use_page_table, page_size
|
|
)
|
|
_split_kv, _, block_split_kvs, _ = create_block_split_kvs(
|
|
batch_size,
|
|
split_kv,
|
|
cache_seqs_ref,
|
|
is_var_split_kv,
|
|
mma_qk_tiler_mn,
|
|
cluster_shape_mnk,
|
|
max_active_clusters,
|
|
)
|
|
|
|
_, q_latent, _ = create_data_tensor(
|
|
batch_size,
|
|
num_heads,
|
|
latent_dim,
|
|
in_dtype,
|
|
is_dynamic_layout=True,
|
|
)
|
|
_, q_rope, _ = create_data_tensor(
|
|
batch_size,
|
|
num_heads,
|
|
rope_dim,
|
|
in_dtype,
|
|
is_dynamic_layout=True,
|
|
)
|
|
|
|
_, c_latent, _ = create_data_tensor(
|
|
batch_size,
|
|
seq_len,
|
|
latent_dim,
|
|
in_dtype,
|
|
is_dynamic_layout=True,
|
|
page_table=page_table,
|
|
cache_seqs=cache_seqs_ref,
|
|
)
|
|
_, c_rope, _ = create_data_tensor(
|
|
batch_size,
|
|
seq_len,
|
|
rope_dim,
|
|
in_dtype,
|
|
is_dynamic_layout=True,
|
|
page_table=page_table,
|
|
cache_seqs=cache_seqs_ref,
|
|
)
|
|
_, o, _ = create_data_tensor(
|
|
batch_size, num_heads, latent_dim, out_dtype, is_dynamic_layout=True
|
|
)
|
|
_, lse, _ = create_data_tensor(
|
|
batch_size, num_heads, 1, lse_dtype, is_dynamic_layout=True, is_lse=True
|
|
)
|
|
workspace, workspace_torch = create_workspace(
|
|
num_heads, latent_dim, batch_size, _split_kv, acc_dtype
|
|
)
|
|
return testing.JitArguments(
|
|
q_latent,
|
|
q_rope,
|
|
c_latent,
|
|
c_rope,
|
|
page_table,
|
|
o,
|
|
lse,
|
|
workspace,
|
|
_split_kv,
|
|
cache_seqs,
|
|
block_split_kvs,
|
|
softmax_scale,
|
|
output_scale,
|
|
stream,
|
|
)
|
|
|
|
workspace_count = 1
|
|
if use_cold_l2:
|
|
one_workspace_bytes = (
|
|
q_latent_torch.numel() * q_latent_torch.element_size()
|
|
+ q_rope_torch.numel() * q_rope_torch.element_size()
|
|
+ c_latent_torch.numel() * c_latent_torch.element_size()
|
|
+ c_rope_torch.numel() * c_rope_torch.element_size()
|
|
+ o_torch.numel() * o_torch.element_size()
|
|
+ lse_torch.numel() * lse_torch.element_size()
|
|
+ cache_seqs_torch.numel() * cache_seqs_torch.element_size()
|
|
)
|
|
if use_page_table:
|
|
one_workspace_bytes += (
|
|
page_table_torch.numel() * page_table_torch.element_size()
|
|
)
|
|
if is_var_split_kv:
|
|
one_workspace_bytes += (
|
|
block_split_kvs_torch.numel() * block_split_kvs_torch.element_size()
|
|
)
|
|
if workspace_torch is not None:
|
|
one_workspace_bytes += (
|
|
workspace_torch.numel() * workspace_torch.element_size()
|
|
)
|
|
workspace_count = testing.get_workspace_count(
|
|
one_workspace_bytes, warmup_iterations, iterations
|
|
)
|
|
|
|
avg_time_us = testing.benchmark(
|
|
compiled_mla,
|
|
workspace_generator=generate_tensors,
|
|
workspace_count=workspace_count,
|
|
stream=stream,
|
|
warmup_iterations=warmup_iterations,
|
|
iterations=iterations,
|
|
)
|
|
|
|
return avg_time_us # Return execution time in microseconds
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
def parse_comma_separated_ints(s: str) -> Tuple[int, ...]:
|
|
try:
|
|
return tuple(int(x.strip()) for x in s.split(","))
|
|
except ValueError:
|
|
raise argparse.ArgumentTypeError(
|
|
"Invalid format. Expected comma-separated integers."
|
|
)
|
|
|
|
def parse_mma_tiler(s: str) -> Tuple[int, int, Tuple[int, int]]:
|
|
ret = parse_comma_separated_ints(s)
|
|
if len(ret) != 2:
|
|
raise argparse.ArgumentTypeError(
|
|
"Invalid format. Expected 2 comma-separated integers."
|
|
)
|
|
return (ret[0], ret[1])
|
|
|
|
parser = argparse.ArgumentParser(description="Example of MLA on Blackwell.")
|
|
|
|
parser.add_argument(
|
|
"--in_dtype",
|
|
type=cutlass.dtype,
|
|
default=cutlass.Float8E4M3FN,
|
|
help="Input data type",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--out_dtype",
|
|
type=cutlass.dtype,
|
|
default=cutlass.Float16,
|
|
help="Output data type",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--acc_dtype",
|
|
type=cutlass.dtype,
|
|
default=cutlass.Float32,
|
|
help="Accumulator data type",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--lse_dtype",
|
|
type=cutlass.dtype,
|
|
default=cutlass.Float32,
|
|
help="LSE data type",
|
|
)
|
|
parser.add_argument(
|
|
"--mma_qk_tiler_mn",
|
|
type=parse_mma_tiler,
|
|
default=(128, 128),
|
|
help="MMA tile shape (H, K)",
|
|
)
|
|
parser.add_argument(
|
|
"--mma_pv_tiler_mn",
|
|
type=parse_mma_tiler,
|
|
default=(128, 256),
|
|
help="MMA tile shape (H, D)",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--is_persistent",
|
|
action="store_true",
|
|
help="Is persistent",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--batch_size",
|
|
type=int,
|
|
default=1,
|
|
help="Batch size",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--seq_len",
|
|
type=int,
|
|
default=128,
|
|
help="Sequence length of K/V",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--num_heads",
|
|
type=int,
|
|
default=128,
|
|
help="Number of heads of Q",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--latent_dim",
|
|
type=int,
|
|
default=512,
|
|
help="Latent dimension of Q/C",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--rope_dim",
|
|
type=int,
|
|
default=64,
|
|
help="Rope dimension of Q/C",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--is_cpasync",
|
|
action="store_true",
|
|
help="Use cpasync for load or not",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--is_var_seq",
|
|
action="store_true",
|
|
help="Use variable length of sequence length or not",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--is_var_split_kv",
|
|
action="store_true",
|
|
help="Use variable length of split kv or not",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--use_page_table",
|
|
action="store_true",
|
|
help="Use page table or not, must be True when is_cpasync is True",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--page_size",
|
|
type=int,
|
|
default=128,
|
|
help="Page size of page table",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--split_kv",
|
|
type=int,
|
|
default=-1,
|
|
help="Split KV setting",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--softmax_scale",
|
|
type=float,
|
|
default=1.0,
|
|
help="Scaling factor to scale softmax",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--output_scale",
|
|
type=float,
|
|
default=1.0,
|
|
help="Scaling factor to scale output",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--tolerance", type=float, default=1e-02, help="Tolerance for validation"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--warmup_iterations",
|
|
type=int,
|
|
default=0,
|
|
help="Number of iterations for warmup",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--iterations",
|
|
type=int,
|
|
default=1,
|
|
help="Number of iterations after warmup",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--skip_ref_check",
|
|
action="store_true",
|
|
help="Skip reference check",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--use_cold_l2",
|
|
action="store_true",
|
|
help="Use cold L2 cache",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
run(
|
|
args.batch_size,
|
|
args.seq_len,
|
|
args.num_heads,
|
|
args.latent_dim,
|
|
args.rope_dim,
|
|
args.in_dtype,
|
|
args.out_dtype,
|
|
args.acc_dtype,
|
|
args.lse_dtype,
|
|
args.mma_qk_tiler_mn,
|
|
args.mma_pv_tiler_mn,
|
|
args.split_kv,
|
|
args.is_persistent,
|
|
args.is_cpasync,
|
|
args.is_var_seq,
|
|
args.is_var_split_kv,
|
|
args.use_page_table,
|
|
args.page_size,
|
|
args.softmax_scale,
|
|
args.output_scale,
|
|
args.tolerance,
|
|
args.warmup_iterations,
|
|
args.iterations,
|
|
args.skip_ref_check,
|
|
args.use_cold_l2,
|
|
)
|
|
|
|
print("PASS")
|