3779 lines
151 KiB
Python
3779 lines
151 KiB
Python
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
|
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
# list of conditions and the following disclaimer.
|
|
|
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
|
|
# 3. Neither the name of the copyright holder nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
|
|
import argparse
|
|
from typing import List, Type, Tuple, Optional
|
|
import cuda.bindings.driver as cuda
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
import cutlass.cute.testing as testing
|
|
import cutlass.utils as utils
|
|
import cutlass.pipeline as pipeline
|
|
from cutlass.cute.nvgpu import cpasync, tcgen05
|
|
import cutlass.torch as cutlass_torch
|
|
import cutlass.utils.blackwell_helpers as sm100_utils
|
|
from cutlass.cute.runtime import from_dlpack
|
|
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
sys.path.append(str(Path(__file__).resolve().parent))
|
|
from mamba2_ssd_reference import (
|
|
ssd_reference_fp32_all,
|
|
ssd_reference_lowprecision_intermediates,
|
|
analyze_relative_diffs,
|
|
)
|
|
from mamba2_ssd_tile_scheduler import (
|
|
Mamba2SSDTileSchedulerParams,
|
|
Mamba2SSDTileScheduler,
|
|
)
|
|
|
|
|
|
class SSDKernel:
|
|
def __init__(
|
|
self,
|
|
io_dtype: Type[cutlass.Numeric],
|
|
cumsum_delta_dtype: Type[cutlass.Numeric],
|
|
acc_dtype: Type[cutlass.Numeric],
|
|
L: int,
|
|
D: int,
|
|
N: int,
|
|
has_d: bool,
|
|
d_has_hdim: bool,
|
|
):
|
|
self.io_dtype: Type[cutlass.Numeric] = io_dtype
|
|
self.acc_dtype: Type[cutlass.Numeric] = acc_dtype
|
|
self.cumsum_delta_dtype: Type[cutlass.Numeric] = cumsum_delta_dtype
|
|
# has_d means epilog warp performs Y += X*D fusion
|
|
self.has_d: bool = has_d
|
|
# d_has_hdim = True means D is (D, EH) shape and loaded by TMA
|
|
# d_has_hdim = False means D is (1, EH) shape and loaded directly to register
|
|
self.d_has_hdim: bool = d_has_hdim
|
|
self.tile_shape = (L, D, N)
|
|
|
|
assert io_dtype in {
|
|
cutlass.Float16,
|
|
cutlass.BFloat16,
|
|
}, "Do not support other I/O types."
|
|
assert acc_dtype in {cutlass.Float32}, "Do not support other ACC types."
|
|
assert cumsum_delta_dtype in {
|
|
cutlass.Float32
|
|
}, "Do not support other cumsum types."
|
|
assert not (not has_d and d_has_hdim), "D cannot have Hdim if has_d is False"
|
|
|
|
# Hardcode default setting
|
|
self.use_2cta_instrs = False
|
|
self.cluster_shape_mnk = (1, 1, 1)
|
|
self.epi_tile = (128, 32)
|
|
|
|
# Setup mma tile shapes
|
|
self.tile_shape_mnk_intra1 = (L, L, N)
|
|
self.tile_shape_mnk_intra2 = (L, D, L)
|
|
self.tile_shape_mnk_inter1 = (N, D, L)
|
|
self.tile_shape_mnk_inter2 = (L, D, N)
|
|
|
|
self.cta_group = (
|
|
tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
|
|
)
|
|
|
|
# Launch config
|
|
self.occupancy = 1
|
|
self.mma_inter_warp_id = 0
|
|
self.mma_intra_warp_id = 1
|
|
self.tma_b_c_warp_id = 2
|
|
self.tma_deltas_x_d_warp_id = 3
|
|
self.pre_inter_warp_id = [4, 5, 6, 7]
|
|
self.pre_intra_warp_id = [8, 9, 10, 11]
|
|
self.epilog_warp_id = [12, 13, 14, 15]
|
|
self.threads_per_cta = 32 * len(
|
|
(
|
|
self.mma_inter_warp_id,
|
|
self.mma_intra_warp_id,
|
|
self.tma_b_c_warp_id,
|
|
self.tma_deltas_x_d_warp_id,
|
|
*self.pre_inter_warp_id,
|
|
*self.pre_intra_warp_id,
|
|
*self.epilog_warp_id,
|
|
)
|
|
)
|
|
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
|
|
|
|
# Named barriers
|
|
self.pre_inter_sync_bar_id = 1
|
|
self.epilog_sync_bar_id = 2
|
|
self.pre_intra_sync_bar_id = 3
|
|
self.tmem_dealloc_sync_bar_id = 4
|
|
|
|
# Number of registers used by each warp
|
|
self.num_regs_uniform_warps = 24
|
|
self.num_regs_pre_inter_warps = 168
|
|
self.num_regs_pre_intra_warps = 208
|
|
self.num_regs_epilogue_warps = 112
|
|
|
|
# Shared storage
|
|
self.shared_storage = None
|
|
|
|
# TMEM buffer offsets
|
|
self.tmem_intra1_acc_offset = 0
|
|
self.tmem_intra2_q_offset = 0
|
|
self.tmem_intra2_acc_offset = 0
|
|
self.tmem_inter1_acc_offset = 0
|
|
self.tmem_inter2_acc_offset = 0
|
|
self.num_tmem_cols_total = 0
|
|
|
|
def _setup_attributes(self):
|
|
(
|
|
tiled_mma_intra1,
|
|
tiled_mma_intra2,
|
|
tiled_mma_inter1,
|
|
tiled_mma_inter2,
|
|
) = self.make_tiled_mmas(
|
|
self.io_dtype,
|
|
self.acc_dtype,
|
|
self.cta_group,
|
|
self.tile_shape_mnk_intra1,
|
|
self.tile_shape_mnk_intra2,
|
|
self.tile_shape_mnk_inter1,
|
|
self.tile_shape_mnk_inter2,
|
|
)
|
|
|
|
self.cluster_layout_vmnk = cute.tiled_divide(
|
|
cute.make_layout(self.cluster_shape_mnk),
|
|
(tiled_mma_intra1.thr_id.shape,),
|
|
)
|
|
|
|
# Setup stages
|
|
(
|
|
self.input_stages,
|
|
self.output_stages,
|
|
self.internal_stages,
|
|
self.intra1_acc_stages,
|
|
) = self._compute_stages(
|
|
self.smem_capacity,
|
|
)
|
|
|
|
# Setup smem layouts
|
|
# X is B operand (from smem) of INTRA2_MMA and INTER1_MMA
|
|
self.x_smem_layout = sm100_utils.make_smem_layout_b(
|
|
tiled_mma_intra2,
|
|
self.tile_shape_mnk_intra2,
|
|
self.io_dtype,
|
|
self.input_stages,
|
|
)
|
|
self.num_x_load_bytes = cute.size_in_bytes(
|
|
self.io_dtype, cute.slice_(self.x_smem_layout, (None, None, None, 0))
|
|
)
|
|
|
|
# XT is same shape as ACC operand of INTER2_MMA, before postprocessing by EPILOG
|
|
self.xt_smem_layout = sm100_utils.make_smem_layout_epi(
|
|
self.io_dtype,
|
|
utils.LayoutEnum.COL_MAJOR,
|
|
self.tile_shape_mnk_intra2[:2],
|
|
self.input_stages,
|
|
)
|
|
|
|
# B is B operand (from smem) of INTRA1_MMA
|
|
self.b_smem_layout = sm100_utils.make_smem_layout_b(
|
|
tiled_mma_intra1,
|
|
self.tile_shape_mnk_intra1,
|
|
self.io_dtype,
|
|
self.input_stages,
|
|
)
|
|
self.num_b_load_bytes = cute.size_in_bytes(
|
|
self.io_dtype, cute.slice_(self.b_smem_layout, (None, None, None, 0))
|
|
)
|
|
|
|
# B_INTERNAL is also A operand (from smem) of INTER1_MMA, after preprocessed by PRE_INTER
|
|
self.bt_internal_smem_layout = sm100_utils.make_smem_layout_a(
|
|
tiled_mma_inter1,
|
|
self.tile_shape_mnk_inter1,
|
|
self.io_dtype,
|
|
self.internal_stages,
|
|
)
|
|
|
|
# B needs to be proprocessed to be used as A operand of INTER1_MMA
|
|
self.bt_smem_layout = cute.coalesce(
|
|
sm100_utils.make_smem_layout_epi(
|
|
self.io_dtype,
|
|
utils.LayoutEnum.ROW_MAJOR,
|
|
(self.tile_shape_mnk_inter1[0], self.tile_shape_mnk_inter1[2]),
|
|
self.input_stages,
|
|
),
|
|
target_profile=(1, 1, 1),
|
|
)
|
|
|
|
# C is A operand (from smem) of INTRA1_MMA and INTER2_MMA
|
|
self.c_smem_layout = sm100_utils.make_smem_layout_a(
|
|
tiled_mma_intra1,
|
|
self.tile_shape_mnk_intra1,
|
|
self.io_dtype,
|
|
self.input_stages,
|
|
)
|
|
self.num_c_load_bytes = cute.size_in_bytes(
|
|
self.io_dtype, cute.slice_(self.c_smem_layout, (None, None, None, 0))
|
|
)
|
|
|
|
# P is B operand (from smem) of INTER2_MMA, after preprocessed by PRE_INTER
|
|
self.p_smem_layout = sm100_utils.make_smem_layout_b(
|
|
tiled_mma_inter2,
|
|
self.tile_shape_mnk_inter2,
|
|
self.io_dtype,
|
|
self.internal_stages,
|
|
)
|
|
|
|
# PT is ACC operand (from tmem) of INTER1_MMA, after postprocessed by PRE_INTER
|
|
self.pt_smem_layout = sm100_utils.make_smem_layout_epi(
|
|
self.io_dtype,
|
|
utils.LayoutEnum.COL_MAJOR,
|
|
self.tile_shape_mnk_inter1[:2],
|
|
self.internal_stages,
|
|
)
|
|
|
|
# Q is A operand (from tmem) of INTRA2_MMA, after preprocessed by PRE_INTRA
|
|
self.q_tmem_layout = sm100_utils.make_smem_layout_a(
|
|
tiled_mma_intra2,
|
|
self.tile_shape_mnk_intra2,
|
|
self.io_dtype,
|
|
self.internal_stages,
|
|
)
|
|
|
|
# P is ACC operand (from tmem) of INTER1_MMA, to be TMA stored by PRE_INTER
|
|
self.p_smem_layout_store = sm100_utils.make_smem_layout_epi(
|
|
self.io_dtype,
|
|
utils.LayoutEnum.ROW_MAJOR,
|
|
self.tile_shape_mnk_inter2[1:],
|
|
self.internal_stages,
|
|
)
|
|
|
|
# Y is ACC operand (from smem) of INTER2_MMA and INTRA2_MMA, after postprocessed and TMA stored by EPILOG
|
|
self.y_smem_layout = sm100_utils.make_smem_layout_epi(
|
|
self.io_dtype,
|
|
utils.LayoutEnum.COL_MAJOR,
|
|
self.epi_tile,
|
|
self.output_stages,
|
|
)
|
|
|
|
# Delta is linear smem layouts for pre/post processing
|
|
self.delta_linear_smem_layout = cute.make_layout(
|
|
(self.tile_shape_mnk_inter1[2], self.input_stages)
|
|
)
|
|
self.num_delta_load_bytes = cute.size_in_bytes(
|
|
self.io_dtype, cute.slice_(self.delta_linear_smem_layout, (None, 0))
|
|
)
|
|
|
|
# Cumsum delta is linear smem layouts for pre/post processing
|
|
self.cumsum_delta_linear_smem_layout = cute.make_layout(
|
|
(self.tile_shape_mnk_inter1[2], self.input_stages)
|
|
)
|
|
self.num_cumsum_delta_load_bytes = cute.size_in_bytes(
|
|
self.cumsum_delta_dtype,
|
|
cute.slice_(self.cumsum_delta_linear_smem_layout, (None, 0)),
|
|
)
|
|
|
|
# D is linear smem layouts when d_has_hdim is True
|
|
self.d_linear_smem_layout = (
|
|
cute.make_layout((self.tile_shape_mnk_inter2[1], self.input_stages))
|
|
if self.d_has_hdim
|
|
else None
|
|
)
|
|
self.num_d_load_bytes = (
|
|
cute.size_in_bytes(
|
|
self.io_dtype,
|
|
cute.slice_(self.d_linear_smem_layout, (None, 0)),
|
|
)
|
|
if self.d_has_hdim
|
|
else 0
|
|
)
|
|
|
|
# Setup tmem offsets
|
|
(
|
|
self.tmem_intra1_acc_offset,
|
|
self.tmem_intra2_q_offset,
|
|
self.tmem_intra2_acc_offset,
|
|
self.tmem_inter1_acc_offset,
|
|
self.tmem_inter2_acc_offset,
|
|
self.num_tmem_cols_total,
|
|
) = self._plan_tmem_offsets(
|
|
tiled_mma_intra1,
|
|
self.tile_shape_mnk_intra1,
|
|
tiled_mma_intra2,
|
|
self.tile_shape_mnk_intra2,
|
|
tiled_mma_inter1,
|
|
self.tile_shape_mnk_inter1,
|
|
tiled_mma_inter2,
|
|
self.tile_shape_mnk_inter2,
|
|
self.internal_stages,
|
|
self.q_tmem_layout,
|
|
self.io_dtype,
|
|
self.internal_stages,
|
|
self.intra1_acc_stages,
|
|
)
|
|
|
|
return
|
|
|
|
@cute.jit
|
|
def __call__(
|
|
self,
|
|
x: cute.Tensor,
|
|
cumsum_delta: cute.Tensor,
|
|
delta: cute.Tensor,
|
|
b: cute.Tensor,
|
|
c: cute.Tensor,
|
|
y: cute.Tensor,
|
|
fstate: cute.Tensor,
|
|
d: cute.Tensor,
|
|
max_active_clusters: cutlass.Constexpr,
|
|
stream: cuda.CUstream,
|
|
):
|
|
self._setup_attributes()
|
|
(
|
|
tiled_mma_intra1,
|
|
tiled_mma_intra2,
|
|
tiled_mma_inter1,
|
|
tiled_mma_inter2,
|
|
) = self.make_tiled_mmas(
|
|
self.io_dtype,
|
|
self.acc_dtype,
|
|
self.cta_group,
|
|
self.tile_shape_mnk_intra1,
|
|
self.tile_shape_mnk_intra2,
|
|
self.tile_shape_mnk_inter1,
|
|
self.tile_shape_mnk_inter2,
|
|
)
|
|
|
|
# Setup TMA atoms and convert TMA tensors
|
|
# TMA load for A
|
|
x_op = sm100_utils.cluster_shape_to_tma_atom_B(
|
|
self.cluster_shape_mnk, tiled_mma_intra2.thr_id
|
|
)
|
|
tma_atom_x, tma_tensor_x = cute.nvgpu.make_tiled_tma_atom_B(
|
|
x_op,
|
|
x,
|
|
cute.slice_(self.x_smem_layout, (None, None, None, 0)),
|
|
self.tile_shape_mnk_intra2,
|
|
tiled_mma_intra2,
|
|
self.cluster_layout_vmnk.shape,
|
|
internal_type=(
|
|
cutlass.TFloat32 if x.element_type is cutlass.Float32 else None
|
|
),
|
|
)
|
|
|
|
# TMA load for B
|
|
b_op = sm100_utils.cluster_shape_to_tma_atom_B(
|
|
self.cluster_shape_mnk, tiled_mma_intra1.thr_id
|
|
)
|
|
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
|
|
b_op,
|
|
b,
|
|
cute.slice_(self.b_smem_layout, (None, None, None, 0)),
|
|
self.tile_shape_mnk_intra1,
|
|
tiled_mma_intra1,
|
|
self.cluster_layout_vmnk.shape,
|
|
internal_type=(
|
|
cutlass.TFloat32 if b.element_type is cutlass.Float32 else None
|
|
),
|
|
)
|
|
|
|
# TMA load for C
|
|
c_op = sm100_utils.cluster_shape_to_tma_atom_A(
|
|
self.cluster_shape_mnk, tiled_mma_intra1.thr_id
|
|
)
|
|
tma_atom_c, tma_tensor_c = cute.nvgpu.make_tiled_tma_atom_A(
|
|
c_op,
|
|
c,
|
|
cute.slice_(self.c_smem_layout, (None, None, None, 0)),
|
|
self.tile_shape_mnk_intra1,
|
|
tiled_mma_intra1,
|
|
self.cluster_layout_vmnk.shape,
|
|
internal_type=(
|
|
cutlass.TFloat32 if c.element_type is cutlass.Float32 else None
|
|
),
|
|
)
|
|
|
|
# TMA load for delta
|
|
# TODO: use bulkcp instead of tma
|
|
delta_cta_v_layout = cute.slice_(
|
|
cute.make_identity_layout(delta.shape), (None, 0, 0, 0)
|
|
)
|
|
delta_linear_smem_layout = cute.slice_(self.delta_linear_smem_layout, (None, 0))
|
|
tma_atom_delta, tma_tensor_delta = cpasync.make_tiled_tma_atom(
|
|
cpasync.CopyBulkTensorTileG2SOp(),
|
|
delta,
|
|
delta_linear_smem_layout,
|
|
delta_cta_v_layout,
|
|
)
|
|
|
|
# TMA load for cumsum_delta
|
|
cumsum_delta_cta_v_layout = cute.slice_(
|
|
cute.make_identity_layout(cumsum_delta.shape), (None, 0, 0, 0)
|
|
)
|
|
cumsum_delta_linear_smem_layout = cute.slice_(
|
|
self.cumsum_delta_linear_smem_layout, (None, 0)
|
|
)
|
|
(
|
|
tma_atom_cumsum_delta,
|
|
tma_tensor_cumsum_delta,
|
|
) = cpasync.make_tiled_tma_atom(
|
|
cpasync.CopyBulkTensorTileG2SOp(),
|
|
cumsum_delta,
|
|
cumsum_delta_linear_smem_layout,
|
|
cumsum_delta_cta_v_layout,
|
|
)
|
|
|
|
tma_atom_d = None
|
|
tma_tensor_d = d
|
|
# TMA load for D
|
|
if cutlass.const_expr(self.d_has_hdim):
|
|
d_cta_v_layout = cute.slice_(cute.make_identity_layout(d.shape), (None, 0))
|
|
d_linear_smem_layout = cute.slice_(self.d_linear_smem_layout, (None, 0))
|
|
(
|
|
tma_atom_d,
|
|
tma_tensor_d,
|
|
) = cpasync.make_tiled_tma_atom(
|
|
cpasync.CopyBulkTensorTileG2SOp(),
|
|
d,
|
|
d_linear_smem_layout,
|
|
d_cta_v_layout,
|
|
)
|
|
|
|
# TMA store for y
|
|
y_cta_v_layout = cute.composition(
|
|
cute.make_identity_layout(y.shape), self.epi_tile
|
|
)
|
|
y_smem_layout = cute.slice_(self.y_smem_layout, (None, None, 0))
|
|
tma_atom_y, tma_tensor_y = cpasync.make_tiled_tma_atom(
|
|
cpasync.CopyBulkTensorTileS2GOp(),
|
|
y,
|
|
y_smem_layout,
|
|
y_cta_v_layout,
|
|
)
|
|
|
|
# TMA store for fstate(p)
|
|
p_cta_v_layout = cute.slice_(
|
|
cute.make_identity_layout(fstate.shape), (None, None, 0, 0)
|
|
)
|
|
p_smem_layout_store = cute.slice_(self.p_smem_layout_store, (None, None, 0))
|
|
tma_atom_p, tma_tensor_p = cpasync.make_tiled_tma_atom(
|
|
cpasync.CopyBulkTensorTileS2GOp(),
|
|
fstate,
|
|
p_smem_layout_store,
|
|
p_cta_v_layout,
|
|
)
|
|
|
|
# Compute grid size
|
|
tile_sched_params, grid = self._compute_grid(y, b, max_active_clusters)
|
|
|
|
# Plan shared memory storage
|
|
swizzle_buffer_align_bytes = 1024
|
|
nonswizzle_buffer_align_bytes = 128
|
|
|
|
@cute.struct
|
|
class SharedStorage:
|
|
# Input stage barriers
|
|
x_full: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore
|
|
x_empty: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore
|
|
b_full: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore
|
|
b_empty: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore
|
|
c_full: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore
|
|
c_empty: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore
|
|
deltas_full: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore
|
|
deltas_empty: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore
|
|
d_full: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore
|
|
d_empty: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore
|
|
# Intra1 acc stage barriers
|
|
intra1_acc_full: cute.struct.MemRange[cutlass.Int64, self.intra1_acc_stages] # type: ignore
|
|
intra1_acc_empty: cute.struct.MemRange[cutlass.Int64, self.intra1_acc_stages] # type: ignore
|
|
# Internal stage barriers
|
|
intra2_q_full: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore
|
|
intra2_q_empty: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore
|
|
intra2_acc_full: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore
|
|
intra2_acc_empty: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore
|
|
inter1_b_full: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore
|
|
inter1_b_empty: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore
|
|
inter1_acc_full: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore
|
|
inter1_acc_empty: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore
|
|
inter2_p_full: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore
|
|
inter2_p_empty: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore
|
|
inter2_acc_full: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore
|
|
inter2_acc_empty: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore
|
|
# Tmem holding buffer
|
|
tmem_holding_buf: cutlass.Int32
|
|
# Smem tensors
|
|
smem_x: cute.struct.Align[
|
|
cute.struct.MemRange[self.io_dtype, cute.cosize(self.x_smem_layout)],
|
|
swizzle_buffer_align_bytes,
|
|
]
|
|
smem_b: cute.struct.Align[
|
|
cute.struct.MemRange[self.io_dtype, cute.cosize(self.b_smem_layout)],
|
|
swizzle_buffer_align_bytes,
|
|
]
|
|
smem_bt_internal: cute.struct.Align[
|
|
cute.struct.MemRange[
|
|
self.io_dtype, cute.cosize(self.bt_internal_smem_layout)
|
|
],
|
|
swizzle_buffer_align_bytes,
|
|
]
|
|
smem_c: cute.struct.Align[
|
|
cute.struct.MemRange[self.io_dtype, cute.cosize(self.c_smem_layout)],
|
|
swizzle_buffer_align_bytes,
|
|
]
|
|
smem_p: cute.struct.Align[
|
|
cute.struct.MemRange[self.io_dtype, cute.cosize(self.p_smem_layout)],
|
|
swizzle_buffer_align_bytes,
|
|
]
|
|
smem_y: cute.struct.Align[
|
|
cute.struct.MemRange[self.io_dtype, cute.cosize(self.y_smem_layout)],
|
|
swizzle_buffer_align_bytes,
|
|
]
|
|
smem_cumsum_delta: cute.struct.Align[
|
|
cute.struct.MemRange[
|
|
self.cumsum_delta_dtype,
|
|
cute.cosize(self.cumsum_delta_linear_smem_layout),
|
|
],
|
|
nonswizzle_buffer_align_bytes,
|
|
]
|
|
smem_delta: cute.struct.Align[
|
|
cute.struct.MemRange[
|
|
self.io_dtype, cute.cosize(self.delta_linear_smem_layout)
|
|
],
|
|
nonswizzle_buffer_align_bytes,
|
|
]
|
|
smem_d: cute.struct.Align[
|
|
cute.struct.MemRange[
|
|
self.io_dtype,
|
|
cute.cosize(self.d_linear_smem_layout) if self.d_has_hdim else 0,
|
|
],
|
|
nonswizzle_buffer_align_bytes,
|
|
]
|
|
|
|
self.shared_storage = SharedStorage
|
|
if cutlass.const_expr(self.shared_storage.size_in_bytes() > self.smem_capacity):
|
|
raise ValueError(
|
|
f"SharedStorage size {self.shared_storage.size_in_bytes()} exceeds smem_capacity {self.smem_capacity}"
|
|
)
|
|
|
|
# Launch the kernel synchronously
|
|
self.kernel(
|
|
tma_atom_x,
|
|
tma_tensor_x,
|
|
tma_atom_b,
|
|
tma_tensor_b,
|
|
tma_atom_c,
|
|
tma_tensor_c,
|
|
tma_atom_p,
|
|
tma_tensor_p,
|
|
tma_atom_y,
|
|
tma_tensor_y,
|
|
tma_atom_delta,
|
|
tma_tensor_delta,
|
|
tma_atom_cumsum_delta,
|
|
tma_tensor_cumsum_delta,
|
|
tma_atom_d,
|
|
tma_tensor_d,
|
|
self.cluster_layout_vmnk,
|
|
self.x_smem_layout,
|
|
self.xt_smem_layout,
|
|
self.b_smem_layout,
|
|
self.bt_smem_layout,
|
|
self.bt_internal_smem_layout,
|
|
self.c_smem_layout,
|
|
self.pt_smem_layout,
|
|
self.p_smem_layout,
|
|
self.q_tmem_layout,
|
|
self.p_smem_layout_store,
|
|
self.y_smem_layout,
|
|
self.delta_linear_smem_layout,
|
|
self.cumsum_delta_linear_smem_layout,
|
|
self.d_linear_smem_layout,
|
|
self.epi_tile,
|
|
tile_sched_params,
|
|
).launch(
|
|
grid=grid,
|
|
block=[self.threads_per_cta, 1, 1],
|
|
cluster=self.cluster_shape_mnk,
|
|
min_blocks_per_mp=1,
|
|
smem=self.shared_storage.size_in_bytes(),
|
|
stream=stream,
|
|
)
|
|
|
|
# GPU device kernel
|
|
@cute.kernel
|
|
def kernel(
|
|
self,
|
|
tma_atom_x: cute.CopyAtom,
|
|
tma_tensor_x: cute.Tensor,
|
|
tma_atom_b: cute.CopyAtom,
|
|
tma_tensor_b: cute.Tensor,
|
|
tma_atom_c: cute.CopyAtom,
|
|
tma_tensor_c: cute.Tensor,
|
|
tma_atom_p: cute.CopyAtom,
|
|
tma_tensor_p: cute.Tensor,
|
|
tma_atom_y: cute.CopyAtom,
|
|
tma_tensor_y: cute.Tensor,
|
|
tma_atom_delta: cute.CopyAtom,
|
|
tma_tensor_delta: cute.Tensor,
|
|
tma_atom_cumsum_delta: cute.CopyAtom,
|
|
tma_tensor_cumsum_delta: cute.Tensor,
|
|
tma_atom_d: Optional[cute.CopyAtom],
|
|
tma_tensor_d: cute.Tensor,
|
|
cluster_layout_vmnk: cute.Layout,
|
|
x_smem_layout: cute.ComposedLayout,
|
|
xt_smem_layout: cute.ComposedLayout,
|
|
b_smem_layout: cute.ComposedLayout,
|
|
bt_smem_layout: cute.ComposedLayout,
|
|
bt_internal_smem_layout: cute.ComposedLayout,
|
|
c_smem_layout: cute.ComposedLayout,
|
|
pt_smem_layout: cute.ComposedLayout,
|
|
p_smem_layout: cute.ComposedLayout,
|
|
q_tmem_layout: cute.ComposedLayout,
|
|
p_smem_layout_store: cute.ComposedLayout,
|
|
y_smem_layout: cute.ComposedLayout,
|
|
delta_linear_smem_layout: cute.Layout,
|
|
cumsum_delta_linear_smem_layout: cute.Layout,
|
|
d_linear_smem_layout: Optional[cute.Layout],
|
|
epi_tile: cute.Tile,
|
|
tile_sched_params: Mamba2SSDTileSchedulerParams,
|
|
):
|
|
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
|
|
|
# Prefetch tma descriptor
|
|
if warp_idx == 0:
|
|
tma_atoms = [
|
|
tma_atom_x,
|
|
tma_atom_b,
|
|
tma_atom_c,
|
|
tma_atom_p,
|
|
tma_atom_y,
|
|
tma_atom_delta,
|
|
tma_atom_cumsum_delta,
|
|
]
|
|
if cutlass.const_expr(self.d_has_hdim):
|
|
tma_atoms.append(tma_atom_d)
|
|
for tma_atom in tma_atoms:
|
|
cpasync.prefetch_descriptor(tma_atom)
|
|
|
|
# Static consts
|
|
D = cute.size(tma_tensor_x, mode=[0])
|
|
L = cute.size(tma_tensor_x, mode=[1])
|
|
N = cute.size(tma_tensor_b, mode=[1])
|
|
# Dynamic values
|
|
C = cute.size(tma_tensor_x, mode=[2])
|
|
EH = cute.size(tma_tensor_x, mode=[3])
|
|
B = cute.size(tma_tensor_x, mode=[4])
|
|
G = cute.size(tma_tensor_b, mode=[3])
|
|
NGROUP_RATIO = EH // G
|
|
|
|
# Make tiledMma
|
|
(
|
|
tiled_mma_intra1,
|
|
tiled_mma_intra2,
|
|
tiled_mma_inter1,
|
|
tiled_mma_inter2,
|
|
) = self.make_tiled_mmas(
|
|
self.io_dtype,
|
|
self.acc_dtype,
|
|
self.cta_group,
|
|
self.tile_shape_mnk_intra1,
|
|
self.tile_shape_mnk_intra2,
|
|
self.tile_shape_mnk_inter1,
|
|
self.tile_shape_mnk_inter2,
|
|
)
|
|
|
|
# Setup cta/thread coordinates
|
|
# Block coord
|
|
bidx, bidy, bidz = cute.arch.block_idx()
|
|
mma_tile_coord_v = bidx % cute.size(tiled_mma_intra1.thr_id.shape)
|
|
cta_rank_in_cluster = cute.arch.make_warp_uniform(
|
|
cute.arch.block_idx_in_cluster()
|
|
)
|
|
block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(
|
|
cta_rank_in_cluster
|
|
)
|
|
# Workload coord
|
|
tile_sched = Mamba2SSDTileScheduler.create(
|
|
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
|
|
)
|
|
work_tile = tile_sched.initial_work_tile_info()
|
|
|
|
# Thread/warp coord
|
|
tidx, _, _ = cute.arch.thread_idx()
|
|
# Thread coord inside specialized warps
|
|
local_tidx = tidx % 128
|
|
local_warp_idx = cute.arch.make_warp_uniform(local_tidx // 32)
|
|
|
|
# Alloc and init smem tensors and pipelines
|
|
smem = utils.SmemAllocator()
|
|
smem_storage = smem.allocate(self.shared_storage)
|
|
|
|
# Setup smem tensors
|
|
smem_x = smem_storage.smem_x.get_tensor(
|
|
x_smem_layout.outer, swizzle=x_smem_layout.inner
|
|
)
|
|
smem_xt = smem_storage.smem_x.get_tensor(
|
|
xt_smem_layout.outer, swizzle=xt_smem_layout.inner
|
|
)
|
|
smem_b = smem_storage.smem_b.get_tensor(
|
|
b_smem_layout.outer, swizzle=b_smem_layout.inner
|
|
)
|
|
smem_bt = smem_storage.smem_b.get_tensor(
|
|
bt_smem_layout.outer, swizzle=bt_smem_layout.inner
|
|
)
|
|
smem_bt_internal = smem_storage.smem_bt_internal.get_tensor(
|
|
bt_internal_smem_layout.outer, swizzle=bt_internal_smem_layout.inner
|
|
)
|
|
smem_c = smem_storage.smem_c.get_tensor(
|
|
c_smem_layout.outer, swizzle=c_smem_layout.inner
|
|
)
|
|
smem_p = smem_storage.smem_p.get_tensor(
|
|
p_smem_layout.outer, swizzle=p_smem_layout.inner
|
|
)
|
|
smem_pt = smem_storage.smem_p.get_tensor(
|
|
pt_smem_layout.outer, swizzle=pt_smem_layout.inner
|
|
)
|
|
smem_p_store = smem_storage.smem_p.get_tensor(
|
|
p_smem_layout_store.outer, swizzle=p_smem_layout_store.inner
|
|
)
|
|
smem_y = smem_storage.smem_y.get_tensor(
|
|
y_smem_layout.outer, swizzle=y_smem_layout.inner
|
|
)
|
|
smem_cumsum_delta = smem_storage.smem_cumsum_delta.get_tensor(
|
|
cumsum_delta_linear_smem_layout
|
|
)
|
|
smem_delta = smem_storage.smem_delta.get_tensor(delta_linear_smem_layout)
|
|
smem_d = None
|
|
if cutlass.const_expr(self.d_has_hdim):
|
|
smem_d = smem_storage.smem_d.get_tensor(d_linear_smem_layout)
|
|
|
|
# Init mbarrier for pipeline
|
|
x_pipeline = self.make_and_init_x_pipeline(smem_storage.x_full.data_ptr())
|
|
b_pipeline = self.make_and_init_b_pipeline(smem_storage.b_full.data_ptr())
|
|
c_pipeline = self.make_and_init_c_pipeline(smem_storage.c_full.data_ptr())
|
|
deltas_pipeline = self.make_and_init_deltas_pipeline(
|
|
smem_storage.deltas_full.data_ptr()
|
|
)
|
|
d_pipeline = self.make_and_init_d_pipeline(smem_storage.d_full.data_ptr())
|
|
intra1_acc_pipeline = self.make_and_init_intra1_acc_pipeline(
|
|
smem_storage.intra1_acc_full.data_ptr()
|
|
)
|
|
intra2_q_pipeline = self.make_and_init_intra2_q_pipeline(
|
|
smem_storage.intra2_q_full.data_ptr()
|
|
)
|
|
intra2_acc_pipeline = self.make_and_init_intra2_acc_pipeline(
|
|
smem_storage.intra2_acc_full.data_ptr()
|
|
)
|
|
inter1_b_pipeline = self.make_and_init_inter1_b_pipeline(
|
|
smem_storage.inter1_b_full.data_ptr()
|
|
)
|
|
inter1_acc_pipeline = self.make_and_init_inter1_acc_pipeline(
|
|
smem_storage.inter1_acc_full.data_ptr()
|
|
)
|
|
inter2_p_pipeline = self.make_and_init_inter2_p_pipeline(
|
|
smem_storage.inter2_p_full.data_ptr()
|
|
)
|
|
inter2_acc_pipeline = self.make_and_init_inter2_acc_pipeline(
|
|
smem_storage.inter2_acc_full.data_ptr()
|
|
)
|
|
|
|
# Cluster arrive after barrier init
|
|
if cute.size(self.cluster_shape_mnk) > 1:
|
|
cute.arch.cluster_arrive_relaxed()
|
|
|
|
# Cluster wait before tmem alloc
|
|
if cute.size(self.cluster_shape_mnk) > 1:
|
|
cute.arch.cluster_wait()
|
|
|
|
# Alloc tmem buffer
|
|
if warp_idx == self.epilog_warp_id[0]:
|
|
cute.arch.alloc_tmem(
|
|
self.num_tmem_cols_total,
|
|
smem_storage.tmem_holding_buf,
|
|
is_two_cta=self.use_2cta_instrs,
|
|
)
|
|
|
|
# Bar sync before retrieving tmem ptr from shared mem
|
|
cute.arch.barrier()
|
|
|
|
# Retrieve tmem ptr
|
|
tmem_ptr_base = cute.arch.retrieve_tmem_ptr(
|
|
self.acc_dtype,
|
|
alignment=16,
|
|
ptr_to_buffer_holding_addr=smem_storage.tmem_holding_buf,
|
|
)
|
|
|
|
# Specialized TMA load Delta/CumsumDelta/X warp
|
|
if warp_idx == self.tma_deltas_x_d_warp_id:
|
|
# Dealloc regs for pre-inter/pre-intra warps
|
|
cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps)
|
|
|
|
# ((ATOM_V, REST_V), INPUT_STAGE)
|
|
# ((ATOM_V, REST_V), 1, 1, C, EH, B)
|
|
tXsX, tXgX_pre_slice = self.tma_partition_for_mma_b_operand(
|
|
tma_atom_x,
|
|
tma_tensor_x,
|
|
smem_x,
|
|
tiled_mma_intra2,
|
|
cluster_layout_vmnk,
|
|
mma_tile_coord_v,
|
|
block_in_cluster_coord_vmnk,
|
|
)
|
|
|
|
# ((ATOM_V, REST_V), INPUT_STAGE)
|
|
# ((ATOM_V, REST_V), 1, C, EH, B)
|
|
tDeltasDelta, tDeltagDelta_pre_slice = self.tma_partition_with_shape(
|
|
tma_atom_delta,
|
|
tma_tensor_delta,
|
|
smem_delta,
|
|
(self.tile_shape_mnk_inter1[2],),
|
|
)
|
|
# ((ATOM_V, REST_V), INPUT_STAGE)
|
|
# ((ATOM_V, REST_V), 1, C, EH, B)
|
|
(
|
|
tDeltasCumsumDelta,
|
|
tDeltagCumsumDelta_pre_slice,
|
|
) = self.tma_partition_with_shape(
|
|
tma_atom_cumsum_delta,
|
|
tma_tensor_cumsum_delta,
|
|
smem_cumsum_delta,
|
|
(self.tile_shape_mnk_inter1[2],),
|
|
)
|
|
|
|
tDsD = None
|
|
tDgD_pre_slice = None
|
|
if cutlass.const_expr(self.d_has_hdim):
|
|
# Partition global/shared tensor for D
|
|
# ((ATOM_V, REST_V), INPUT_STAGE)
|
|
# ((ATOM_V, REST_V), 1, EH)
|
|
tDsD, tDgD_pre_slice = self.tma_partition_with_shape(
|
|
tma_atom_d, tma_tensor_d, smem_d, (self.tile_shape_mnk_inter2[1],)
|
|
)
|
|
|
|
# Pipeline X/Delta/CumsumDelta/D producer state
|
|
x_producer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Producer, self.input_stages
|
|
)
|
|
deltas_producer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Producer, self.input_stages
|
|
)
|
|
d_producer_state = None
|
|
if cutlass.const_expr(self.d_has_hdim):
|
|
# D is loaded by TMA only when d_has_hdim is True
|
|
d_producer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Producer, self.input_stages
|
|
)
|
|
|
|
while work_tile.is_valid_tile:
|
|
b_idx, eh_idx, g_idx = work_tile.tile_idx
|
|
|
|
# Slice global tensor to current tile idx
|
|
# ((ATOM_V, REST_V), C)
|
|
tXgX = tXgX_pre_slice[None, 0, 0, None, eh_idx, b_idx]
|
|
tDeltagDelta = tDeltagDelta_pre_slice[None, 0, None, eh_idx, b_idx]
|
|
tDeltagCumsumDelta = tDeltagCumsumDelta_pre_slice[
|
|
None, 0, None, eh_idx, b_idx
|
|
]
|
|
tDgD = None
|
|
if cutlass.const_expr(self.d_has_hdim):
|
|
# ((ATOM_V, REST_V))
|
|
tDgD = tDgD_pre_slice[None, 0, eh_idx]
|
|
|
|
# Reset count for pipeline state
|
|
x_producer_state.reset_count()
|
|
deltas_producer_state.reset_count()
|
|
if cutlass.const_expr(self.d_has_hdim):
|
|
d_producer_state.reset_count()
|
|
|
|
# Peek (try_wait) X/deltas buffer empty status
|
|
peek_x_empty_status = self.conditional_producer_try_acquire(
|
|
x_producer_state, x_pipeline, C
|
|
)
|
|
peek_deltas_empty_status = self.conditional_producer_try_acquire(
|
|
deltas_producer_state, deltas_pipeline, C
|
|
)
|
|
|
|
if cutlass.const_expr(self.d_has_hdim):
|
|
# Wait for D buffer empty
|
|
d_pipeline.producer_acquire(d_producer_state)
|
|
# TMA load D
|
|
cute.copy(
|
|
tma_atom_d,
|
|
tDgD,
|
|
tDsD[None, d_producer_state.index],
|
|
tma_bar_ptr=d_pipeline.producer_get_barrier(d_producer_state),
|
|
)
|
|
# Advance D producer state
|
|
d_producer_state.advance()
|
|
|
|
# Batched load over C dimension
|
|
for chunk_idx in cutlass.range(C, unroll=1):
|
|
# Conditionally wait for X buffer empty
|
|
x_pipeline.producer_acquire(x_producer_state, peek_x_empty_status)
|
|
|
|
# TMA load X
|
|
cute.copy(
|
|
tma_atom_x,
|
|
tXgX[None, x_producer_state.count],
|
|
tXsX[None, x_producer_state.index],
|
|
tma_bar_ptr=x_pipeline.producer_get_barrier(x_producer_state),
|
|
)
|
|
|
|
# Conditionally wait for deltas buffer empty
|
|
deltas_pipeline.producer_acquire(
|
|
deltas_producer_state, peek_deltas_empty_status
|
|
)
|
|
|
|
# TMA load Delta/CumsumDelta
|
|
cute.copy(
|
|
tma_atom_delta,
|
|
tDeltagDelta[None, deltas_producer_state.count],
|
|
tDeltasDelta[None, deltas_producer_state.index],
|
|
tma_bar_ptr=deltas_pipeline.producer_get_barrier(
|
|
deltas_producer_state
|
|
),
|
|
)
|
|
cute.copy(
|
|
tma_atom_cumsum_delta,
|
|
tDeltagCumsumDelta[None, deltas_producer_state.count],
|
|
tDeltasCumsumDelta[None, deltas_producer_state.index],
|
|
tma_bar_ptr=deltas_pipeline.producer_get_barrier(
|
|
deltas_producer_state
|
|
),
|
|
)
|
|
|
|
# Advance X/deltas producer state
|
|
x_producer_state.advance()
|
|
deltas_producer_state.advance()
|
|
|
|
# Peek (try_wait) X/deltas buffer empty status
|
|
peek_x_empty_status = self.conditional_producer_try_acquire(
|
|
x_producer_state, x_pipeline, C
|
|
)
|
|
peek_deltas_empty_status = self.conditional_producer_try_acquire(
|
|
deltas_producer_state, deltas_pipeline, C
|
|
)
|
|
# END of for chunk_idx in cutlass.range(C, unroll=1)
|
|
|
|
# Advance to next tile
|
|
tile_sched.advance_to_next_work()
|
|
work_tile = tile_sched.get_current_work()
|
|
# END of while work_tile.is_valid_tile
|
|
|
|
# Producer tail for X/Deltas/D
|
|
x_pipeline.producer_tail(x_producer_state)
|
|
deltas_pipeline.producer_tail(deltas_producer_state)
|
|
if cutlass.const_expr(self.d_has_hdim):
|
|
d_pipeline.producer_tail(d_producer_state)
|
|
# END of specialized tma load X/Deltas/D warp
|
|
|
|
# Specialized TMA load B/C warp
|
|
elif warp_idx == self.tma_b_c_warp_id:
|
|
# Dealloc regs for pre-inter/pre-intra warps
|
|
cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps)
|
|
|
|
# ((ATOM_V, REST_V), INPUT_STAGE)
|
|
# ((ATOM_V, REST_V), 1, 1, C, G, B)
|
|
tBsB, tBgB_pre_slice = self.tma_partition_for_mma_b_operand(
|
|
tma_atom_b,
|
|
tma_tensor_b,
|
|
smem_b,
|
|
tiled_mma_intra1,
|
|
cluster_layout_vmnk,
|
|
mma_tile_coord_v,
|
|
block_in_cluster_coord_vmnk,
|
|
)
|
|
|
|
# ((ATOM_V, REST_V), INPUT_STAGE)
|
|
# ((ATOM_V, REST_V), 1, 1, C, G, B)
|
|
tCsC, tCgC_pre_slice = self.tma_partition_for_mma_a_operand(
|
|
tma_atom_c,
|
|
tma_tensor_c,
|
|
smem_c,
|
|
tiled_mma_intra1,
|
|
cluster_layout_vmnk,
|
|
mma_tile_coord_v,
|
|
block_in_cluster_coord_vmnk,
|
|
)
|
|
|
|
# Pipeline B/C producer state
|
|
b_producer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Producer, self.input_stages
|
|
)
|
|
c_producer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Producer, self.input_stages
|
|
)
|
|
|
|
while work_tile.is_valid_tile:
|
|
b_idx, eh_idx, g_idx = work_tile.tile_idx
|
|
|
|
# Slice global tensor to current tile idx
|
|
# ((ATOM_V, REST_V), C)
|
|
tBgB = tBgB_pre_slice[None, 0, 0, None, g_idx, b_idx]
|
|
tCgC = tCgC_pre_slice[None, 0, 0, None, g_idx, b_idx]
|
|
|
|
# Reset count for pipeline state
|
|
b_producer_state.reset_count()
|
|
c_producer_state.reset_count()
|
|
|
|
# Peek (try_wait) B/C buffer empty status
|
|
peek_b_empty_status = self.conditional_producer_try_acquire(
|
|
b_producer_state, b_pipeline, C
|
|
)
|
|
peek_c_empty_status = self.conditional_producer_try_acquire(
|
|
c_producer_state, c_pipeline, C
|
|
)
|
|
|
|
# Batched load over C dimension
|
|
for chunk_idx in cutlass.range(C, unroll=1):
|
|
# Conditionally wait for B buffer empty
|
|
b_pipeline.producer_acquire(b_producer_state, peek_b_empty_status)
|
|
|
|
# TMA load B
|
|
cute.copy(
|
|
tma_atom_b,
|
|
tBgB[None, b_producer_state.count],
|
|
tBsB[None, b_producer_state.index],
|
|
tma_bar_ptr=b_pipeline.producer_get_barrier(b_producer_state),
|
|
)
|
|
|
|
# Conditionally wait for C buffer empty
|
|
c_pipeline.producer_acquire(c_producer_state, peek_c_empty_status)
|
|
|
|
# TMA load C
|
|
cute.copy(
|
|
tma_atom_c,
|
|
tCgC[None, c_producer_state.count],
|
|
tCsC[None, c_producer_state.index],
|
|
tma_bar_ptr=c_pipeline.producer_get_barrier(c_producer_state),
|
|
)
|
|
|
|
# Advance B/C producer state
|
|
b_producer_state.advance()
|
|
c_producer_state.advance()
|
|
|
|
# Peek (try_wait) B/C buffer empty status
|
|
peek_b_empty_status = self.conditional_producer_try_acquire(
|
|
b_producer_state, b_pipeline, C
|
|
)
|
|
peek_c_empty_status = self.conditional_producer_try_acquire(
|
|
c_producer_state, c_pipeline, C
|
|
)
|
|
# END of for chunk_idx in cutlass.range(C, unroll=1)
|
|
|
|
# Advance to next tile
|
|
tile_sched.advance_to_next_work()
|
|
work_tile = tile_sched.get_current_work()
|
|
# END of while work_tile.is_valid_tile
|
|
|
|
# Producer tail for B/C
|
|
b_pipeline.producer_tail(b_producer_state)
|
|
c_pipeline.producer_tail(c_producer_state)
|
|
# END of specialized tma load B/C warp
|
|
|
|
# Specialized MMA Intra warp
|
|
elif warp_idx == self.mma_intra_warp_id:
|
|
# Dealloc regs for pre-inter/pre-intra warps
|
|
cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps)
|
|
|
|
# Make shared/tmem fragments for INTRA_MMA1 B/C/ACC
|
|
# (MMA, MMA_N, MMA_K, INPUT_STAGE)
|
|
# (MMA, MMA_M, MMA_K, INPUT_STAGE)
|
|
# (MMA, MMA_M, MMA_N, INTRA1_ACC_STAGE)
|
|
tCrC, tCrB, tCtAccIntra1 = self.mma_partition_ss(
|
|
tiled_mma_intra1,
|
|
self.tile_shape_mnk_intra1,
|
|
smem_c,
|
|
smem_b,
|
|
tmem_ptr_base + self.tmem_intra1_acc_offset,
|
|
self.intra1_acc_stages,
|
|
)
|
|
|
|
# Make shared/tmem fragments for INTRA_MMA2 X/Q/ACC
|
|
# (MMA, MMA_M, MMA_K, INTERNAL_STAGE)
|
|
# (MMA, MMA_N, MMA_K, INPUT_STAGE)
|
|
# (MMA, MMA_M, MMA_N, INTERNAL_STAGE)
|
|
tCrQ, tCrX, tCtAccIntra2 = self.mma_partition_ts(
|
|
tiled_mma_intra2,
|
|
self.tile_shape_mnk_intra2,
|
|
q_tmem_layout,
|
|
smem_x,
|
|
tmem_ptr_base + self.tmem_intra2_q_offset,
|
|
tmem_ptr_base + self.tmem_intra2_acc_offset,
|
|
self.internal_stages,
|
|
)
|
|
|
|
# Pipeline B/C/X/INTRA2_Q consumer state
|
|
b_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.input_stages
|
|
)
|
|
c_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.input_stages
|
|
)
|
|
x_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.input_stages
|
|
)
|
|
intra2_q_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.internal_stages
|
|
)
|
|
|
|
# Pipeline INTRA1_ACC/INTRA2_ACC producer state
|
|
intra1_acc_producer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Producer, self.intra1_acc_stages
|
|
)
|
|
intra2_acc_producer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Producer, self.internal_stages
|
|
)
|
|
|
|
while work_tile.is_valid_tile:
|
|
# Reset count for pipeline state
|
|
b_consumer_state.reset_count()
|
|
c_consumer_state.reset_count()
|
|
intra1_acc_producer_state.reset_count()
|
|
x_consumer_state.reset_count()
|
|
intra2_q_consumer_state.reset_count()
|
|
intra2_acc_producer_state.reset_count()
|
|
|
|
# Peek (try_wait) B/C/X/INTRA1_ACC buffer full/full/full/empty status
|
|
peek_b_full_status = self.conditional_consumer_try_wait(
|
|
b_consumer_state, b_pipeline, C
|
|
)
|
|
peek_c_full_status = self.conditional_consumer_try_wait(
|
|
c_consumer_state, c_pipeline, C
|
|
)
|
|
peek_wr_intra1_acc_empty_status = self.conditional_producer_try_acquire(
|
|
intra1_acc_producer_state, intra1_acc_pipeline, C
|
|
)
|
|
peek_x_full_status = self.conditional_consumer_try_wait(
|
|
x_consumer_state, x_pipeline, C
|
|
)
|
|
|
|
# Manual pipeline: unrolled INTRA_MMA1 chunk_idx = 0 loop
|
|
# Conditionally wait for B/C/INTRA1_ACC buffer full/full/empty
|
|
b_pipeline.consumer_wait(b_consumer_state, peek_b_full_status)
|
|
c_pipeline.consumer_wait(c_consumer_state, peek_c_full_status)
|
|
intra1_acc_pipeline.producer_acquire(
|
|
intra1_acc_producer_state, peek_wr_intra1_acc_empty_status
|
|
)
|
|
|
|
# INTRA_MMA1
|
|
tiled_mma_intra1 = self.exec_mma(
|
|
tiled_mma_intra1,
|
|
tCtAccIntra1,
|
|
tCrC,
|
|
tCrB,
|
|
intra1_acc_producer_state,
|
|
c_consumer_state,
|
|
b_consumer_state,
|
|
)
|
|
|
|
# Async arrive B/C/INTRA1_ACC buffer empty/empty/full
|
|
b_pipeline.consumer_release(
|
|
b_consumer_state, pipeline.PipelineOp.TCGen05Mma
|
|
)
|
|
c_pipeline.consumer_release(c_consumer_state)
|
|
intra1_acc_pipeline.producer_commit(intra1_acc_producer_state)
|
|
|
|
# Advance B/C/INTRA1_ACC state
|
|
b_consumer_state.advance()
|
|
c_consumer_state.advance()
|
|
intra1_acc_producer_state.advance()
|
|
|
|
# Peek (try_wait) B/C/INTRA1_ACC buffer full/full/empty for chunk_idx = chunk_idx + 1
|
|
peek_b_full_status = self.conditional_consumer_try_wait(
|
|
b_consumer_state, b_pipeline, C
|
|
)
|
|
peek_c_full_status = self.conditional_consumer_try_wait(
|
|
c_consumer_state, c_pipeline, C
|
|
)
|
|
peek_wr_intra1_acc_empty_status = self.conditional_producer_try_acquire(
|
|
intra1_acc_producer_state, intra1_acc_pipeline, C
|
|
)
|
|
|
|
# Manual pipeline: batched gemm over C-1 dimension
|
|
for chunk_idx in cutlass.range(C - 1, unroll=1):
|
|
# Conditionally wait for B/C/INTRA1_ACC buffer full/full/empty
|
|
b_pipeline.consumer_wait(b_consumer_state, peek_b_full_status)
|
|
c_pipeline.consumer_wait(c_consumer_state, peek_c_full_status)
|
|
intra1_acc_pipeline.producer_acquire(
|
|
intra1_acc_producer_state, peek_wr_intra1_acc_empty_status
|
|
)
|
|
|
|
# INTRA_MMA1
|
|
tiled_mma_intra1 = self.exec_mma(
|
|
tiled_mma_intra1,
|
|
tCtAccIntra1,
|
|
tCrC,
|
|
tCrB,
|
|
intra1_acc_producer_state,
|
|
c_consumer_state,
|
|
b_consumer_state,
|
|
)
|
|
|
|
# Async arrive B/C/INTRA1_ACC buffer empty/empty/full
|
|
b_pipeline.consumer_release(
|
|
b_consumer_state, pipeline.PipelineOp.TCGen05Mma
|
|
)
|
|
c_pipeline.consumer_release(c_consumer_state)
|
|
intra1_acc_pipeline.producer_commit(intra1_acc_producer_state)
|
|
|
|
# Conditionally wait for X/INTRA2_Q/INTRA2_ACC buffer full/full/empty
|
|
x_pipeline.consumer_wait(x_consumer_state, peek_x_full_status)
|
|
intra2_q_pipeline.consumer_wait(intra2_q_consumer_state)
|
|
intra2_acc_pipeline.producer_acquire(intra2_acc_producer_state)
|
|
|
|
# INTRA_MMA2
|
|
tiled_mma_intra2 = self.exec_mma(
|
|
tiled_mma_intra2,
|
|
tCtAccIntra2,
|
|
tCrQ,
|
|
tCrX,
|
|
intra2_acc_producer_state,
|
|
intra2_q_consumer_state,
|
|
x_consumer_state,
|
|
)
|
|
|
|
# Async arrive X/INTRA2_Q/INTRA2_ACC buffer empty/empty/full
|
|
if cutlass.const_expr(self.has_d):
|
|
x_pipeline.consumer_release(
|
|
x_consumer_state, pipeline.PipelineOp.TCGen05Mma
|
|
)
|
|
else:
|
|
x_pipeline.consumer_release(x_consumer_state)
|
|
intra2_q_pipeline.consumer_release(intra2_q_consumer_state)
|
|
intra2_acc_pipeline.producer_commit(intra2_acc_producer_state)
|
|
|
|
# Advance B/C/INTRA1_ACC cstate
|
|
b_consumer_state.advance()
|
|
c_consumer_state.advance()
|
|
intra1_acc_producer_state.advance()
|
|
|
|
# Peek (try_wait) B/C/INTRA1_ACC buffer full/full/empty for chunk_idx = chunk_idx + 1
|
|
peek_b_full_status = self.conditional_consumer_try_wait(
|
|
b_consumer_state, b_pipeline, C
|
|
)
|
|
peek_c_full_status = self.conditional_consumer_try_wait(
|
|
c_consumer_state, c_pipeline, C
|
|
)
|
|
peek_wr_intra1_acc_empty_status = (
|
|
self.conditional_producer_try_acquire(
|
|
intra1_acc_producer_state, intra1_acc_pipeline, C
|
|
)
|
|
)
|
|
|
|
# Advance X/INTRA2_Q/INTRA2_ACC state
|
|
x_consumer_state.advance()
|
|
intra2_q_consumer_state.advance()
|
|
intra2_acc_producer_state.advance()
|
|
|
|
# Peek (try_wait) X buffer full for chunk_idx = chunk_idx + 1
|
|
peek_x_full_status = self.conditional_consumer_try_wait(
|
|
x_consumer_state, x_pipeline, C
|
|
)
|
|
# END of for chunk_idx in cutlass.range(C-1, unroll=1)
|
|
|
|
# Manual pipeline: unrolled INTRA_MMA2 chunk_idx = C-1 loop
|
|
# Conditionally wait for X/INTRA2_Q/INTRA2_ACC buffer full/full/empty
|
|
x_pipeline.consumer_wait(x_consumer_state, peek_x_full_status)
|
|
intra2_q_pipeline.consumer_wait(intra2_q_consumer_state)
|
|
intra2_acc_pipeline.producer_acquire(intra2_acc_producer_state)
|
|
|
|
# INTRA_MMA2
|
|
tiled_mma_intra2 = self.exec_mma(
|
|
tiled_mma_intra2,
|
|
tCtAccIntra2,
|
|
tCrQ,
|
|
tCrX,
|
|
intra2_acc_producer_state,
|
|
intra2_q_consumer_state,
|
|
x_consumer_state,
|
|
)
|
|
|
|
# Async arrive X/INTRA2_Q/INTRA2_ACC buffer empty/empty/full
|
|
if cutlass.const_expr(self.has_d):
|
|
x_pipeline.consumer_release(
|
|
x_consumer_state, pipeline.PipelineOp.TCGen05Mma
|
|
)
|
|
else:
|
|
x_pipeline.consumer_release(x_consumer_state)
|
|
intra2_q_pipeline.consumer_release(intra2_q_consumer_state)
|
|
intra2_acc_pipeline.producer_commit(intra2_acc_producer_state)
|
|
|
|
# Advance X/INTRA2_Q/INTRA2_ACC state
|
|
x_consumer_state.advance()
|
|
intra2_q_consumer_state.advance()
|
|
intra2_acc_producer_state.advance()
|
|
|
|
# Peek (try_wait) X buffer full for chunk_idx = chunk_idx + 1
|
|
peek_x_full_status = self.conditional_consumer_try_wait(
|
|
x_consumer_state, x_pipeline, C
|
|
)
|
|
|
|
# Advance to next tile
|
|
tile_sched.advance_to_next_work()
|
|
work_tile = tile_sched.get_current_work()
|
|
# END of while work_tile.is_valid_tile
|
|
|
|
# Producer tail for INTRA1_ACC/INTRA2_ACC
|
|
intra1_acc_pipeline.producer_tail(intra1_acc_producer_state)
|
|
intra2_acc_pipeline.producer_tail(intra2_acc_producer_state)
|
|
# END of specialized mma-intra warp
|
|
|
|
# Specialized MMA Inter warp
|
|
elif warp_idx == self.mma_inter_warp_id:
|
|
# Dealloc regs for pre-inter/pre-intra warps
|
|
cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps)
|
|
|
|
# Make shared/tmem fragments for INTER_MMA1 X/B/ACC
|
|
# (MMA, MMA_N, MMA_K, INPUT_STAGE)
|
|
# (MMA, MMA_M, MMA_K, INTERNAL_STAGE)
|
|
# (MMA, MMA_M, MMA_N, INTERNAL_STAGE)
|
|
tCrB, tCrX, tCtAccInter1 = self.mma_partition_ss(
|
|
tiled_mma_inter1,
|
|
self.tile_shape_mnk_inter1,
|
|
smem_bt_internal,
|
|
smem_x,
|
|
tmem_ptr_base + self.tmem_inter1_acc_offset,
|
|
self.internal_stages,
|
|
)
|
|
|
|
# Make shared/tmem fragments for INTER_MMA2 C/P/ACC
|
|
# (MMA, MMA_M, MMA_K, INPUT_STAGE)
|
|
# (MMA, MMA_N, MMA_K, INTERNAL_STAGE)
|
|
# (MMA, MMA_M, MMA_N, INTERNAL_STAGE)
|
|
tCrC, tCrP, tCtAccInter2 = self.mma_partition_ss(
|
|
tiled_mma_inter2,
|
|
self.tile_shape_mnk_inter2,
|
|
smem_c,
|
|
smem_p,
|
|
tmem_ptr_base + self.tmem_inter2_acc_offset,
|
|
self.internal_stages,
|
|
)
|
|
|
|
# Pipeline X/C/INTER1_B/INTER2_P consumer state
|
|
x_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.input_stages
|
|
)
|
|
c_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.input_stages
|
|
)
|
|
inter1_b_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.internal_stages
|
|
)
|
|
inter2_p_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.internal_stages
|
|
)
|
|
|
|
# Pipeline INTER1_ACC/INTER2_ACC producer state
|
|
inter1_acc_producer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Producer, self.internal_stages
|
|
)
|
|
inter2_acc_producer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Producer, self.internal_stages
|
|
)
|
|
|
|
while work_tile.is_valid_tile:
|
|
# Reset count for pipeline state
|
|
x_consumer_state.reset_count()
|
|
c_consumer_state.reset_count()
|
|
inter1_acc_producer_state.reset_count()
|
|
inter1_b_consumer_state.reset_count()
|
|
inter2_p_consumer_state.reset_count()
|
|
inter2_acc_producer_state.reset_count()
|
|
|
|
# Peek (try_wait) C/INTER2_P/INTER2_ACC buffer full/full/empty status
|
|
peek_c_full_status = self.conditional_consumer_try_wait(
|
|
c_consumer_state, c_pipeline, C
|
|
)
|
|
peek_inter2_p_full_status = self.conditional_consumer_try_wait(
|
|
inter2_p_consumer_state, inter2_p_pipeline, C
|
|
)
|
|
peek_inter2_acc_empty_status = self.conditional_producer_try_acquire(
|
|
inter2_acc_producer_state, inter2_acc_pipeline, C
|
|
)
|
|
|
|
# Batched gemm over C dimension
|
|
for chunk_idx in cutlass.range(C, unroll=1):
|
|
# Conditionally wait for C/INTER2_P/INTER2_ACC buffer full/full/empty
|
|
c_pipeline.consumer_wait(c_consumer_state, peek_c_full_status)
|
|
inter2_p_pipeline.consumer_wait(
|
|
inter2_p_consumer_state, peek_inter2_p_full_status
|
|
)
|
|
inter2_acc_pipeline.producer_acquire(
|
|
inter2_acc_producer_state, peek_inter2_acc_empty_status
|
|
)
|
|
|
|
# INTER MMA2
|
|
tiled_mma_inter2 = self.exec_mma(
|
|
tiled_mma_inter2,
|
|
tCtAccInter2,
|
|
tCrC,
|
|
tCrP,
|
|
inter2_acc_producer_state,
|
|
c_consumer_state,
|
|
inter2_p_consumer_state,
|
|
)
|
|
|
|
# Async arrive C/INTER2_P/INTER2_ACC buffer empty/empty/full
|
|
c_pipeline.consumer_release(c_consumer_state)
|
|
inter2_p_pipeline.consumer_release(inter2_p_consumer_state)
|
|
inter2_acc_pipeline.producer_commit(inter2_acc_producer_state)
|
|
|
|
# Wait for X/INTER1_B/INTER1_ACC buffer full/full/empty
|
|
x_pipeline.consumer_wait(x_consumer_state)
|
|
inter1_b_pipeline.consumer_wait(inter1_b_consumer_state)
|
|
inter1_acc_pipeline.producer_acquire(inter1_acc_producer_state)
|
|
|
|
# INTER MMA1
|
|
tiled_mma_inter1 = self.exec_mma(
|
|
tiled_mma_inter1,
|
|
tCtAccInter1,
|
|
tCrB,
|
|
tCrX,
|
|
inter1_acc_producer_state,
|
|
inter1_b_consumer_state,
|
|
x_consumer_state,
|
|
)
|
|
|
|
# Async arrive X/INTER1_B/INTER1_ACC buffer empty/empty/full
|
|
if cutlass.const_expr(self.has_d):
|
|
x_pipeline.consumer_release(
|
|
x_consumer_state, pipeline.PipelineOp.TCGen05Mma
|
|
)
|
|
else:
|
|
x_pipeline.consumer_release(x_consumer_state)
|
|
inter1_b_pipeline.consumer_release(inter1_b_consumer_state)
|
|
inter1_acc_pipeline.producer_commit(inter1_acc_producer_state)
|
|
|
|
# Advance X/C/INTER1_B/INTER1_ACC/INTER2_P/INTER2_ACC state
|
|
x_consumer_state.advance()
|
|
c_consumer_state.advance()
|
|
inter1_b_consumer_state.advance()
|
|
inter1_acc_producer_state.advance()
|
|
inter2_p_consumer_state.advance()
|
|
inter2_acc_producer_state.advance()
|
|
|
|
# Peek (try_wait) C/INTER2_P/INTER2_ACC buffer full/full/empty for chunk_idx = chunk_idx + 1
|
|
peek_c_full_status = self.conditional_consumer_try_wait(
|
|
c_consumer_state, c_pipeline, C
|
|
)
|
|
peek_inter2_p_full_status = self.conditional_consumer_try_wait(
|
|
inter2_p_consumer_state, inter2_p_pipeline, C
|
|
)
|
|
peek_inter2_acc_empty_status = (
|
|
self.conditional_producer_try_acquire(
|
|
inter2_acc_producer_state, inter2_acc_pipeline, C
|
|
)
|
|
)
|
|
|
|
# Advance to next tile
|
|
tile_sched.advance_to_next_work()
|
|
work_tile = tile_sched.get_current_work()
|
|
|
|
# Producer tail for INTER1_ACC/INTER2_ACC
|
|
inter1_acc_pipeline.producer_tail(inter1_acc_producer_state)
|
|
inter2_acc_pipeline.producer_tail(inter2_acc_producer_state)
|
|
|
|
# Specialized Pre-Inter warp
|
|
elif (
|
|
warp_idx == self.pre_inter_warp_id[0]
|
|
or warp_idx == self.pre_inter_warp_id[1]
|
|
or warp_idx == self.pre_inter_warp_id[2]
|
|
or warp_idx == self.pre_inter_warp_id[3]
|
|
):
|
|
# Alloc regs in pre_inter warps
|
|
cute.arch.warpgroup_reg_alloc(self.num_regs_pre_inter_warps)
|
|
|
|
# Make tiledCopy and partition smem/register tensor for smem load Bt
|
|
# ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N, INPUT_STAGE)
|
|
# ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N)
|
|
tiled_s2r_b, tBsB_s2r, tBrB_s2r = self.pre_inter_smem_load_and_partition_b(
|
|
local_tidx, smem_bt
|
|
)
|
|
|
|
# Partition shared tensor for smem store Bt
|
|
smem_bt_internal_ = cute.make_tensor(
|
|
smem_bt_internal.iterator, smem_bt.layout
|
|
)
|
|
# Make tiledCopy and partition register/smem tensor for smem store Bt
|
|
# ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N)
|
|
# ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N, INTERNAL_STAGE)
|
|
tiled_r2s_b, tBrB_r2s, tBsB_r2s = self.pre_inter_smem_store_and_partition_b(
|
|
local_tidx,
|
|
smem_bt_internal_,
|
|
tiled_s2r_b,
|
|
tBrB_s2r,
|
|
)
|
|
|
|
# (MMA, MMA_M, MMA_K, INPUT_STAGE)
|
|
sDelta = self.pre_inter_make_delta(smem_delta, smem_bt.layout)
|
|
sDeltaA = self.pre_inter_make_delta(smem_cumsum_delta, smem_bt.layout)
|
|
|
|
# Make copy_atom and partition register/smem tensor for smem load/store of Delta/DeltaA
|
|
# ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N, INPUT_STAGE)
|
|
# ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N)
|
|
(
|
|
s2r_atom_delta,
|
|
tBsDelta_s2r,
|
|
tBrDelta_s2r,
|
|
) = self.smem_load_and_partition_delta_d(
|
|
tiled_s2r_b, local_tidx, sDelta, (None, None, None, 0)
|
|
)
|
|
(
|
|
s2r_atom_cumsum,
|
|
tBsDeltaA_s2r,
|
|
tBrDeltaA_s2r,
|
|
) = self.smem_load_and_partition_delta_d(
|
|
tiled_s2r_b, local_tidx, sDeltaA, (None, None, None, 0)
|
|
)
|
|
|
|
# ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N)
|
|
thr_r2s_b = tiled_r2s_b.get_slice(local_tidx)
|
|
tBrDelta_r2s = thr_r2s_b.retile(tBrDelta_s2r)
|
|
tBrDeltaA_r2s = thr_r2s_b.retile(tBrDeltaA_s2r)
|
|
|
|
# Make tmem fragment for INTER1_ACC
|
|
# (MMA, MMA_M, MMA_N, INTERNAL_STAGE)
|
|
tCtAccInter1 = self.mma_partition_c(
|
|
tiled_mma_inter1,
|
|
self.tile_shape_mnk_inter1,
|
|
tmem_ptr_base + self.tmem_inter1_acc_offset,
|
|
self.internal_stages,
|
|
)
|
|
# (M_PER_MMA, N_PER_MMA, INTERNAL_STAGE)
|
|
tInter1 = tCtAccInter1[((None, None), 0, 0, None)]
|
|
|
|
# Make tiledCopy and partition tmem/register tensor for tmem load INTER1_ACC
|
|
# ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, INTERNAL_STAGE)
|
|
# ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N)
|
|
(
|
|
tiled_t2r_inter1,
|
|
tTR_tP,
|
|
tTR_rP,
|
|
) = self.pre_inter_tmem_load_and_partition_p(local_tidx, tInter1, smem_pt)
|
|
|
|
# Make fragment for register to hold P after post-processing (in acc dtype)
|
|
tState = cute.make_fragment(tTR_rP.shape, self.acc_dtype)
|
|
|
|
# Make tiledCopy and partition smem/register tensor for smem store INTER2_P
|
|
# ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N)
|
|
# ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N, INTERNAL_STAGE)
|
|
tiled_r2s_p, tRS_rP, tRS_sP = self.smem_store_and_partition_p_y(
|
|
local_tidx, smem_pt, tiled_t2r_inter1
|
|
)
|
|
|
|
# Partition global/shared tensor for P (State)
|
|
# ((ATOM_V, REST_V), INTERNAL_STAGE)
|
|
# ((ATOM_V, REST_V), 1, 1, EH, B)
|
|
bSG_sP, bSG_gP_pre_slice = self.tma_partition_with_shape(
|
|
tma_atom_p,
|
|
tma_tensor_p,
|
|
smem_p_store,
|
|
self.tile_shape_mnk_inter2[1:],
|
|
)
|
|
|
|
# Pipeline B/Delta/INTER1_ACC consumer state
|
|
b_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.input_stages
|
|
)
|
|
deltas_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.input_stages
|
|
)
|
|
inter1_acc_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.internal_stages
|
|
)
|
|
|
|
# Pipeline INTER1_B/INTER2_P producer state
|
|
inter1_b_producer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Producer, self.internal_stages
|
|
)
|
|
inter2_p_producer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Producer, self.internal_stages
|
|
)
|
|
|
|
# Pipeline TMA store P
|
|
tma_p_pipeline = pipeline.PipelineTmaStore.create(
|
|
num_stages=self.internal_stages,
|
|
producer_group=pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, 32 * len(self.pre_inter_warp_id), 128
|
|
),
|
|
)
|
|
|
|
while work_tile.is_valid_tile:
|
|
b_idx, eh_idx, g_idx = work_tile.tile_idx
|
|
|
|
# Slice global tensor to current tile idx
|
|
# ((ATOM_V, REST_V))
|
|
bSG_gP = bSG_gP_pre_slice[(None, 0, 0, eh_idx, b_idx)]
|
|
|
|
# Reset count for pipeline state
|
|
b_consumer_state.reset_count()
|
|
deltas_consumer_state.reset_count()
|
|
inter1_b_producer_state.reset_count()
|
|
inter1_acc_consumer_state.reset_count()
|
|
inter2_p_producer_state.reset_count()
|
|
|
|
# State (P) init
|
|
tState.fill(0.0)
|
|
|
|
# Peek (try_wait) B/Delta/INTER1_B buffer full/full/empty status
|
|
peek_b_full_status = self.conditional_consumer_try_wait(
|
|
b_consumer_state, b_pipeline, C
|
|
)
|
|
peek_deltas_full_status = self.conditional_consumer_try_wait(
|
|
deltas_consumer_state, deltas_pipeline, C
|
|
)
|
|
peek_wr_inter1_b_empty_status = self.conditional_producer_try_acquire(
|
|
inter1_b_producer_state, inter1_b_pipeline, C
|
|
)
|
|
|
|
# Prefill INTER2_P with 0
|
|
# Wait for INTER2_P buffer empty
|
|
inter2_p_pipeline.producer_acquire(inter2_p_producer_state)
|
|
|
|
tRS_rP.fill(0.0)
|
|
# Copy INTER2_P from register to smem
|
|
inter2_p_coord = (None, None, None, inter2_p_producer_state.index)
|
|
cute.copy(tiled_r2s_p, tRS_rP, tRS_sP[inter2_p_coord])
|
|
|
|
# Fence for shared memory
|
|
cute.arch.fence_proxy(
|
|
cute.arch.ProxyKind.async_shared,
|
|
space=cute.arch.SharedSpace.shared_cta,
|
|
)
|
|
# Async arrive INTER2_P buffer full
|
|
inter2_p_pipeline.producer_commit(inter2_p_producer_state)
|
|
# Advance INTER2_P producer state
|
|
inter2_p_producer_state.advance()
|
|
|
|
# Batched processing over C dimension
|
|
for chunk_idx in cutlass.range(C, unroll=1):
|
|
# Conditionally wait for B/Delta/B_TMEM buffer full/full/empty
|
|
b_pipeline.consumer_wait(b_consumer_state, peek_b_full_status)
|
|
deltas_pipeline.consumer_wait(
|
|
deltas_consumer_state, peek_deltas_full_status
|
|
)
|
|
inter1_b_pipeline.producer_acquire(
|
|
inter1_b_producer_state, peek_wr_inter1_b_empty_status
|
|
)
|
|
|
|
# Load B/Delta/DeltaA/last_column
|
|
b_coord = (None, None, None, b_consumer_state.index)
|
|
delta_coord = (None, None, None, deltas_consumer_state.index)
|
|
cute.copy(tiled_s2r_b, tBsB_s2r[b_coord], tBrB_s2r)
|
|
cute.copy(s2r_atom_delta, tBsDelta_s2r[delta_coord], tBrDelta_s2r)
|
|
cute.copy(
|
|
s2r_atom_cumsum, tBsDeltaA_s2r[delta_coord], tBrDeltaA_s2r
|
|
)
|
|
last_column = smem_cumsum_delta[
|
|
smem_cumsum_delta.shape[0] - 1, deltas_consumer_state.index
|
|
]
|
|
|
|
# Fence for shared memory
|
|
cute.arch.fence_proxy(
|
|
cute.arch.ProxyKind.async_shared,
|
|
space=cute.arch.SharedSpace.shared_cta,
|
|
)
|
|
|
|
# Combine B/Delta/DeltaA/last_column
|
|
tScaledB = self.pre_inter_scale_bt_with_delta(
|
|
tBrB_s2r, tBrDelta_r2s, tBrDeltaA_r2s, last_column
|
|
)
|
|
|
|
# Store scaled B to tBrB_r2s
|
|
for reg_idx in range(cute.size(tBrB_r2s)):
|
|
tBrB_r2s[reg_idx] = tScaledB[reg_idx].to(self.io_dtype)
|
|
|
|
# Store tBrB_r2s to bt_smem_internal
|
|
inter1_b_coord = (None, None, None, inter1_b_producer_state.index)
|
|
cute.copy(tiled_r2s_b, tBrB_r2s, tBsB_r2s[inter1_b_coord])
|
|
|
|
# Fence for shared memory
|
|
cute.arch.fence_proxy(
|
|
cute.arch.ProxyKind.async_shared,
|
|
space=cute.arch.SharedSpace.shared_cta,
|
|
)
|
|
|
|
# Async arrive B/Delta/B_TMEM buffer empty/empty/full
|
|
b_pipeline.consumer_release(
|
|
b_consumer_state, pipeline.PipelineOp.AsyncThread
|
|
)
|
|
deltas_pipeline.consumer_release(deltas_consumer_state)
|
|
inter1_b_pipeline.producer_commit(inter1_b_producer_state)
|
|
|
|
# Wait for INTER1_ACC/INTER2_P buffer full/empty
|
|
inter1_acc_pipeline.consumer_wait(inter1_acc_consumer_state)
|
|
inter2_p_pipeline.producer_acquire(inter2_p_producer_state)
|
|
|
|
# Load INTER1_ACC
|
|
inter1_acc_coord = (
|
|
None,
|
|
None,
|
|
None,
|
|
inter1_acc_consumer_state.index,
|
|
)
|
|
cute.copy(tiled_t2r_inter1, tTR_tP[inter1_acc_coord], tTR_rP)
|
|
|
|
# Fence for TMEM load
|
|
cute.arch.fence_view_async_tmem_load()
|
|
|
|
# Combine INTER1_ACC/last_column/State
|
|
exp_last_column = cute.arch.exp(last_column.ir_value())
|
|
for reg_idx in range(0, cute.size(tTR_rP), 2):
|
|
(
|
|
tTR_rP[reg_idx],
|
|
tTR_rP[reg_idx + 1],
|
|
) = cute.arch.fma_packed_f32x2(
|
|
(exp_last_column, exp_last_column),
|
|
(tState[reg_idx], tState[reg_idx + 1]),
|
|
(tTR_rP[reg_idx], tTR_rP[reg_idx + 1]),
|
|
)
|
|
|
|
# Store scaled P to tRS_rP
|
|
for reg_idx in range(cute.size(tTR_rP)):
|
|
tRS_rP[reg_idx] = tTR_rP[reg_idx].to(self.io_dtype)
|
|
|
|
# Update old state
|
|
tState.store(tTR_rP.load())
|
|
|
|
# Store INTER2_P
|
|
inter2_p_coord = (None, None, None, inter2_p_producer_state.index)
|
|
cute.copy(tiled_r2s_p, tRS_rP, tRS_sP[inter2_p_coord])
|
|
|
|
# Fence for shared memory
|
|
cute.arch.fence_proxy(
|
|
cute.arch.ProxyKind.async_shared,
|
|
space=cute.arch.SharedSpace.shared_cta,
|
|
)
|
|
|
|
# Async arrive INTER1_ACC/INTER2_P buffer empty/full
|
|
inter1_acc_pipeline.consumer_release(inter1_acc_consumer_state)
|
|
# Last iteration consumer is PRE_INTER warp itself, not MMA_INTER warp
|
|
if inter2_p_producer_state.count < C:
|
|
inter2_p_pipeline.producer_commit(inter2_p_producer_state)
|
|
|
|
# Advance B/Delta/INTER1_B/INTER1_ACC state
|
|
b_consumer_state.advance()
|
|
deltas_consumer_state.advance()
|
|
inter1_b_producer_state.advance()
|
|
inter1_acc_consumer_state.advance()
|
|
# Peek (try_wait) B/Delta/INTER1_B buffer full/full./empty for chunk_idx = chunk_idx + 1
|
|
peek_b_full_status = self.conditional_consumer_try_wait(
|
|
b_consumer_state, b_pipeline, C
|
|
)
|
|
peek_deltas_full_status = self.conditional_consumer_try_wait(
|
|
deltas_consumer_state, deltas_pipeline, C
|
|
)
|
|
peek_wr_inter1_b_empty_status = (
|
|
self.conditional_producer_try_acquire(
|
|
inter1_b_producer_state, inter1_b_pipeline, C
|
|
)
|
|
)
|
|
|
|
# Last iteration producer is PRE_INTER warp itself, not MMA_INTER warp
|
|
if inter2_p_producer_state.count < C:
|
|
# Advance INTER2_P producer state
|
|
inter2_p_producer_state.advance()
|
|
# END of for chunk_idx in cutlass.range(C, unroll=1)
|
|
|
|
# Store last INTER2_P (State) from smem to gmem
|
|
# Wait for all previous stores to smem to be done
|
|
cute.arch.fence_proxy(
|
|
cute.arch.ProxyKind.async_shared,
|
|
space=cute.arch.SharedSpace.shared_cta,
|
|
)
|
|
cute.arch.barrier(
|
|
barrier_id=self.pre_inter_sync_bar_id,
|
|
number_of_threads=len(self.pre_inter_warp_id) * 32,
|
|
)
|
|
|
|
if local_warp_idx == 0:
|
|
# TMA store P
|
|
cute.copy(
|
|
tma_atom_p,
|
|
bSG_sP[(None, inter2_p_producer_state.index)],
|
|
bSG_gP,
|
|
)
|
|
# Wait for TMA store done
|
|
tma_p_pipeline.producer_commit()
|
|
tma_p_pipeline.producer_acquire()
|
|
|
|
cute.arch.barrier(
|
|
barrier_id=self.pre_inter_sync_bar_id,
|
|
number_of_threads=len(self.pre_inter_warp_id) * 32,
|
|
)
|
|
tma_p_pipeline.producer_tail()
|
|
|
|
# Advance to next tile
|
|
tile_sched.advance_to_next_work()
|
|
work_tile = tile_sched.get_current_work()
|
|
# END of while work_tile.is_valid_tile
|
|
|
|
# Producer tail for INTER1_B/INTER2_P/TMA store P
|
|
inter1_b_pipeline.producer_tail(inter1_b_producer_state)
|
|
inter2_p_pipeline.producer_tail(inter2_p_producer_state)
|
|
# END of specialized pre-inter warp
|
|
|
|
# Specialized Pre-Intra warp
|
|
elif (
|
|
warp_idx == self.pre_intra_warp_id[0]
|
|
or warp_idx == self.pre_intra_warp_id[1]
|
|
or warp_idx == self.pre_intra_warp_id[2]
|
|
or warp_idx == self.pre_intra_warp_id[3]
|
|
):
|
|
# Alloc regs in pre_inter warps
|
|
cute.arch.warpgroup_reg_alloc(self.num_regs_pre_intra_warps)
|
|
|
|
# Make tmem fragment for INTRA1_ACC
|
|
# (MMA, MMA_M, MMA_N, INTRA1_ACC_STAGE)
|
|
tCtAccIntra1 = self.mma_partition_c(
|
|
tiled_mma_intra1,
|
|
self.tile_shape_mnk_intra1,
|
|
tmem_ptr_base + self.tmem_intra1_acc_offset,
|
|
self.intra1_acc_stages,
|
|
)
|
|
# (M_PER_MMA, N_PER_MMA, INTRA1_ACC_STAGE)
|
|
tIntra1 = tCtAccIntra1[((None, None), 0, 0, None)]
|
|
|
|
# Make tiledCopy and partition tmem/register tensor for tensor memory load INTRA1_ACC
|
|
# ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, INTERNAL_STAGE)
|
|
# ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N)
|
|
tiled_t2r_intra1, tTR_tQ, tTR_rQ = self.pre_intra_tmem_load_and_partition_q(
|
|
tIntra1, local_tidx
|
|
)
|
|
|
|
# Broadcast delta/delta_cumsum smem tensor from LxINPUT_STAGE to LxLxINPUT_STAGE
|
|
sDeltaA_Row = self.pre_intra_make_delta(smem_cumsum_delta, 0)
|
|
sDeltaA_Col = self.pre_intra_make_delta(smem_cumsum_delta, 1)
|
|
sDelta = self.pre_intra_make_delta(smem_delta, 0)
|
|
|
|
# Make tiledCopy and partition smem/register tensor for smem memory load delta/delta_cumsum
|
|
# ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, INPUT_STAGE)
|
|
# ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N)
|
|
(
|
|
s2r_atom_cumsum,
|
|
tQsDeltaA_Row,
|
|
tQrDeltaA_Row,
|
|
) = self.smem_load_and_partition_delta_d(
|
|
tiled_t2r_intra1, local_tidx, sDeltaA_Row, (None, None, None, 0)
|
|
)
|
|
(
|
|
s2r_atom_cumsum,
|
|
tQsDeltaA_Col,
|
|
tQrDeltaA_Col,
|
|
) = self.smem_load_and_partition_delta_d(
|
|
tiled_t2r_intra1, local_tidx, sDeltaA_Col, (None, None, None, 0)
|
|
)
|
|
(
|
|
s2r_atom_delta,
|
|
tQsDelta,
|
|
tQrDelta,
|
|
) = self.smem_load_and_partition_delta_d(
|
|
tiled_t2r_intra1, local_tidx, sDelta, (None, None, None, 0)
|
|
)
|
|
|
|
# Make and partition coord tensor for delta_cumsum load
|
|
# (L, L)
|
|
coord_tensor = cute.make_identity_tensor(
|
|
cute.dice(self.tile_shape_mnk_intra1, (1, 1, None))
|
|
)
|
|
thr_t2r_intra1 = tiled_t2r_intra1.get_slice(local_tidx)
|
|
# ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N)
|
|
tCoord = thr_t2r_intra1.partition_D(coord_tensor)
|
|
|
|
# Make tmem tensor for INTRA2_Q
|
|
# (MMA, MMA_M, MMA_K, INTERNAL_STAGE)
|
|
tCrQ = self.mma_partition_a_tmem(
|
|
tiled_mma_intra2,
|
|
q_tmem_layout,
|
|
tmem_ptr_base + self.tmem_intra2_q_offset,
|
|
)
|
|
|
|
# Make tiledCopy and partition tmem/register tensor for tensor memory store INTRA2_Q
|
|
# ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, ...)
|
|
# ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, ..., INTERNAL_STAGE)
|
|
tiled_r2t_q, tRT_rQ, tRT_tQ = self.pre_intra_tmem_store_and_partition_q(
|
|
local_tidx, tCrQ
|
|
)
|
|
|
|
# Pipeline DELTA/INTRA1_ACC consumer state
|
|
deltas_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.input_stages
|
|
)
|
|
intra1_acc_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.intra1_acc_stages
|
|
)
|
|
# Pipeline INTRA2_Q producer state
|
|
intra2_q_producer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Producer, self.internal_stages
|
|
)
|
|
|
|
while work_tile.is_valid_tile:
|
|
# Reset count for pipeline state
|
|
deltas_consumer_state.reset_count()
|
|
intra1_acc_consumer_state.reset_count()
|
|
intra2_q_producer_state.reset_count()
|
|
|
|
# Peek (try_wait) DELTA/INTRA1_ACC buffer full
|
|
peek_deltas_full_status = self.conditional_consumer_try_wait(
|
|
deltas_consumer_state, deltas_pipeline, C
|
|
)
|
|
peek_rd_intra1_acc_full_status = self.conditional_consumer_try_wait(
|
|
intra1_acc_consumer_state, intra1_acc_pipeline, C
|
|
)
|
|
|
|
# Batched processing over C dimension
|
|
for chunk_idx in cutlass.range(C, unroll=1):
|
|
# Conditionally wait for Delta/INTRA1_ACC buffer full
|
|
deltas_pipeline.consumer_wait(
|
|
deltas_consumer_state, peek_deltas_full_status
|
|
)
|
|
intra1_acc_pipeline.consumer_wait(
|
|
intra1_acc_consumer_state, peek_rd_intra1_acc_full_status
|
|
)
|
|
|
|
# Load Q from tmem
|
|
intra1_coord = (None, None, None, intra1_acc_consumer_state.index)
|
|
cute.copy(tiled_t2r_intra1, tTR_tQ[intra1_coord], tTR_rQ)
|
|
cute.arch.fence_view_async_tmem_load()
|
|
|
|
# Load tQsDeltaA_Row/tQsDeltaA_Col/tQsDelta from smem
|
|
delta_coord = (None, None, None, deltas_consumer_state.index)
|
|
cute.copy(
|
|
s2r_atom_cumsum, tQsDeltaA_Row[delta_coord], tQrDeltaA_Row
|
|
)
|
|
cute.copy(
|
|
s2r_atom_cumsum, tQsDeltaA_Col[delta_coord], tQrDeltaA_Col
|
|
)
|
|
cute.copy(s2r_atom_delta, tQsDelta[delta_coord], tQrDelta)
|
|
|
|
# SegSum
|
|
tRT_rQ = self.pre_intra_segsum(
|
|
tTR_rQ, tQrDeltaA_Row, tQrDeltaA_Col, tQrDelta, tCoord, tRT_rQ
|
|
)
|
|
|
|
# Wait for INTRA2_Q buffer empty
|
|
# Delay producer_acquire to right before data store
|
|
intra2_q_pipeline.producer_acquire(intra2_q_producer_state)
|
|
|
|
# Store Q from reg to tmem
|
|
q_coord = (None, None, None, None, intra2_q_producer_state.index)
|
|
cute.copy(tiled_r2t_q, tRT_rQ, tRT_tQ[q_coord])
|
|
|
|
# Async arrive Delta/INTRA1_ACC buffer empty
|
|
intra1_acc_pipeline.consumer_release(intra1_acc_consumer_state)
|
|
deltas_pipeline.consumer_release(deltas_consumer_state)
|
|
|
|
cute.arch.fence_view_async_tmem_store()
|
|
|
|
# Async arrive INTRA2_Q buffer full
|
|
intra2_q_pipeline.producer_commit(intra2_q_producer_state)
|
|
|
|
# Advance deltas/intra1_acc/intra2_q states
|
|
deltas_consumer_state.advance()
|
|
intra1_acc_consumer_state.advance()
|
|
intra2_q_producer_state.advance()
|
|
|
|
# Peek (try_wait) Delta/INTRA1_ACC buffer full for chunk_idx = chunk_idx + 1
|
|
peek_deltas_full_status = self.conditional_consumer_try_wait(
|
|
deltas_consumer_state, deltas_pipeline, C
|
|
)
|
|
peek_rd_intra1_acc_full_status = self.conditional_consumer_try_wait(
|
|
intra1_acc_consumer_state, intra1_acc_pipeline, C
|
|
)
|
|
# END of for chunk_idx in cutlass.range(C, unroll=1)
|
|
|
|
# Advance to next tile
|
|
tile_sched.advance_to_next_work()
|
|
work_tile = tile_sched.get_current_work()
|
|
# END of while work_tile.is_valid_tile
|
|
|
|
# Producer tail for INTRA2_Q
|
|
intra2_q_pipeline.producer_tail(intra2_q_producer_state)
|
|
# END of specialized pre-intra warp
|
|
|
|
# Specialized Epilogue warp
|
|
else:
|
|
# Dealloc regs for pre-inter/pre-intra warps
|
|
cute.arch.warpgroup_reg_dealloc(self.num_regs_epilogue_warps)
|
|
|
|
# (L, D, INPUT_STAGE)
|
|
sDeltaA = self.epilog_make_delta(smem_cumsum_delta)
|
|
|
|
# Make tmem tensor for INTRA2_ACC/INTER2_ACC
|
|
# (MMA, MMA_M, MMA_K, INTERNAL_STAGE)
|
|
tCtAccIntra2 = self.mma_partition_c(
|
|
tiled_mma_intra2,
|
|
self.tile_shape_mnk_intra2,
|
|
tmem_ptr_base + self.tmem_intra2_acc_offset,
|
|
self.internal_stages,
|
|
)
|
|
# (M_PER_MMA, N_PER_MMA, INTERNAL_STAGE)
|
|
tIntra2 = tCtAccIntra2[((None, None), 0, 0, None)]
|
|
# (MMA, MMA_M, MMA_K, INTERNAL_STAGE)
|
|
tCtAccInter2 = self.mma_partition_c(
|
|
tiled_mma_inter2,
|
|
self.tile_shape_mnk_inter2,
|
|
tmem_ptr_base + self.tmem_inter2_acc_offset,
|
|
self.internal_stages,
|
|
)
|
|
# (M_PER_MMA, N_PER_MMA, INTERNAL_STAGE)
|
|
tInter2 = tCtAccInter2[((None, None), 0, 0, None)]
|
|
|
|
# Subtiling INTRA2_ACC/INTER2_ACC/Delta/Y
|
|
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, INTERNAL_STAGE)
|
|
tIntra_epi = cute.flat_divide(tIntra2, epi_tile)
|
|
tInter_epi = cute.flat_divide(tInter2, epi_tile)
|
|
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, INPUT_STAGE)
|
|
sDeltaA_epi = cute.flat_divide(sDeltaA, epi_tile)
|
|
|
|
# Make tiled copy and partition tmem/reg tensor w.r.t tensor memory load
|
|
# ((T2R_ATOM_V, T2R_REST_V), REST_M, REST_N, EPI_M, EPI_N, INTERNAL_STAGE)
|
|
# ((T2R_ATOM_V, T2R_REST_V), REST_M, REST_N)
|
|
(
|
|
tiled_t2r_intra2,
|
|
tTR_tIntra,
|
|
tTR_rIntra,
|
|
) = self.epilog_tmem_load_and_partition_acc(local_tidx, tIntra_epi, smem_y)
|
|
(
|
|
tiled_t2r_inter2,
|
|
tTR_tInter2,
|
|
tTR_rInter,
|
|
) = self.epilog_tmem_load_and_partition_acc(local_tidx, tInter_epi, smem_y)
|
|
|
|
# Make tiled copy and partition smem/reg tensor w.r.t smem load Delta
|
|
# ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, EPI_M, EPI_N, INPUT_STAGE)
|
|
# ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N)
|
|
(
|
|
s2r_atom_delta,
|
|
tTR_sDeltaA,
|
|
tTR_rDeltaA,
|
|
) = self.smem_load_and_partition_delta_d(
|
|
tiled_t2r_inter2, local_tidx, sDeltaA_epi, (None, None, None, 0, 0, 0)
|
|
)
|
|
|
|
# Make tiled copy and Partition smem/register tensor w.r.t smem store Y
|
|
# ((R2S_ATOM_V, R2S_REST_V), REST_M, REST_N, OUTPUT_STAGE)
|
|
# ((R2S_ATOM_V, R2S_REST_V), REST_M, REST_N)
|
|
tiled_r2s_y, tRS_rY, tRS_sY = self.smem_store_and_partition_p_y(
|
|
local_tidx, smem_y, tiled_t2r_inter2
|
|
)
|
|
|
|
tRS_rCompute = cute.make_fragment(tRS_rY.shape, self.acc_dtype)
|
|
|
|
tiled_s2r_x = None
|
|
tSR_sX = None
|
|
tSR_rX = None
|
|
if cutlass.const_expr(self.has_d):
|
|
# Make TiledCopy/smem/register tensor for smem load X
|
|
# (R2S_ATOM, R2S_M, R2S_N, EPI_M, EPI_N, INPUT_STAGES)
|
|
# (R2S_ATOM, R2S_M, R2S_N)
|
|
tiled_s2r_x, tSR_sX, tSR_rX = self.epilog_smem_load_and_partition_x(
|
|
tiled_t2r_inter2, local_tidx, smem_xt, epi_tile
|
|
)
|
|
|
|
tRS_sD = None
|
|
tRS_rD = None
|
|
s2r_atom_d = None
|
|
if cutlass.const_expr(self.d_has_hdim):
|
|
# (L, D, INPUT_STAGE)
|
|
sD = self.epilog_make_d(smem_d)
|
|
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, INPUT_STAGE)
|
|
tD_sepi = cute.flat_divide(sD, epi_tile)
|
|
|
|
# Make tiled copy and partition smem/reg tensor w.r.t smem load D
|
|
# ((T2R_ATOM_V, T2R_REST_V), REST_M, REST_N, EPI_M, EPI_N, INPUT_STAGE)
|
|
# ((T2R_ATOM_V, T2R_REST_V), REST_M, REST_N)
|
|
s2r_atom_d, tRS_sD, tRS_rD = self.smem_load_and_partition_delta_d(
|
|
tiled_t2r_inter2, local_tidx, tD_sepi, (None, None, None, 0, 0, 0)
|
|
)
|
|
|
|
elif cutlass.const_expr(self.has_d):
|
|
tRS_rD = cutlass.Float32(0.0).to(self.io_dtype)
|
|
|
|
# Partition global/shared tensor for TMA store Y
|
|
# ((ATOM_V, REST_V), INPUT_STAGE)
|
|
# ((ATOM_V, REST_V), EPI_M, EPI_N, 1, 1, C, EH, B)
|
|
bSG_sY, bSG_gY_pre_slice = self.epilog_tma_partition_y(
|
|
tma_tensor_y, tma_atom_y, smem_y, epi_tile
|
|
)
|
|
|
|
# Make TMA store pipeline Y
|
|
tma_y_pipeline = pipeline.PipelineTmaStore.create(
|
|
num_stages=self.output_stages,
|
|
producer_group=pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, 32 * len(self.epilog_warp_id), 128
|
|
),
|
|
)
|
|
|
|
# Make consumer pipeline states for Delta/INTRA2_ACC/INTER2_ACC/X/D buffer
|
|
deltas_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.input_stages
|
|
)
|
|
intra2_acc_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.internal_stages
|
|
)
|
|
inter2_acc_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.internal_stages
|
|
)
|
|
x_consumer_state = None
|
|
if cutlass.const_expr(self.has_d):
|
|
x_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.input_stages
|
|
)
|
|
d_consumer_state = None
|
|
if cutlass.const_expr(self.d_has_hdim):
|
|
d_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.input_stages
|
|
)
|
|
|
|
while work_tile.is_valid_tile:
|
|
b_idx, eh_idx, g_idx = work_tile.tile_idx
|
|
|
|
# Slice global tensor to current tile idx
|
|
# ((ATOM_V, REST_V), EPI_M, EPI_N, C)
|
|
bSG_gY = bSG_gY_pre_slice[(None, None, None, 0, 0, None, eh_idx, b_idx)]
|
|
if cutlass.const_expr(self.has_d and not self.d_has_hdim):
|
|
tRS_rD = tma_tensor_d[0, eh_idx]
|
|
|
|
# Reset count for pipeline state
|
|
deltas_consumer_state.reset_count()
|
|
intra2_acc_consumer_state.reset_count()
|
|
inter2_acc_consumer_state.reset_count()
|
|
if cutlass.const_expr(self.has_d):
|
|
x_consumer_state.reset_count()
|
|
if cutlass.const_expr(self.d_has_hdim):
|
|
d_consumer_state.reset_count()
|
|
|
|
# Peek Delta/INTRA2_ACC/INTER2_ACC buffer status
|
|
peek_deltas_full_status = self.conditional_consumer_try_wait(
|
|
deltas_consumer_state, deltas_pipeline, C
|
|
)
|
|
peek_rd_intra2_acc_full_status = self.conditional_consumer_try_wait(
|
|
intra2_acc_consumer_state, intra2_acc_pipeline, C
|
|
)
|
|
peek_rd_inter2_acc_full_status = self.conditional_consumer_try_wait(
|
|
inter2_acc_consumer_state, inter2_acc_pipeline, C
|
|
)
|
|
peek_rd_x_full_status = None
|
|
if cutlass.const_expr(self.has_d):
|
|
peek_rd_x_full_status = self.conditional_consumer_try_wait(
|
|
x_consumer_state, x_pipeline, C
|
|
)
|
|
|
|
if cutlass.const_expr(self.d_has_hdim):
|
|
d_pipeline.consumer_wait(d_consumer_state)
|
|
|
|
# Batched processing over C dimension
|
|
for chunk_idx in cutlass.range(C, unroll=1):
|
|
# Conditionally wait for Delta/INTRA2_ACC/INTER2_ACC/X buffer full
|
|
deltas_pipeline.consumer_wait(
|
|
deltas_consumer_state, peek_deltas_full_status
|
|
)
|
|
intra2_acc_pipeline.consumer_wait(
|
|
intra2_acc_consumer_state, peek_rd_intra2_acc_full_status
|
|
)
|
|
inter2_acc_pipeline.consumer_wait(
|
|
inter2_acc_consumer_state, peek_rd_inter2_acc_full_status
|
|
)
|
|
if cutlass.const_expr(self.has_d):
|
|
x_pipeline.consumer_wait(
|
|
x_consumer_state, peek_rd_x_full_status
|
|
)
|
|
# Loop over EPI_M and EPI_N subtiles
|
|
for epi_n in range(cute.size(tTR_tIntra, mode=[4])):
|
|
for epi_m in range(cute.size(tTR_tIntra, mode=[3])):
|
|
epi_iter_cnt = (
|
|
epi_n * cute.size(tTR_tIntra, mode=[3]) + epi_m
|
|
)
|
|
epi_buffer_idx = epi_iter_cnt % self.output_stages
|
|
|
|
# Load INTRA2_ACC/INTER2_ACC from tmem
|
|
subtile_coord = (
|
|
None,
|
|
None,
|
|
None,
|
|
epi_m,
|
|
epi_n,
|
|
)
|
|
intra2_coord = subtile_coord + (
|
|
intra2_acc_consumer_state.index,
|
|
)
|
|
cute.copy(
|
|
tiled_t2r_intra2,
|
|
tTR_tIntra[intra2_coord],
|
|
tTR_rIntra,
|
|
)
|
|
inter2_coord = subtile_coord + (
|
|
inter2_acc_consumer_state.index,
|
|
)
|
|
cute.copy(
|
|
tiled_t2r_inter2,
|
|
tTR_tInter2[inter2_coord],
|
|
tTR_rInter,
|
|
)
|
|
# Fence for T2R load
|
|
cute.arch.fence_view_async_tmem_load()
|
|
|
|
# Load Delta from smem
|
|
delta_coord = subtile_coord + (deltas_consumer_state.index,)
|
|
cute.copy(
|
|
s2r_atom_delta, tTR_sDeltaA[delta_coord], tTR_rDeltaA
|
|
)
|
|
|
|
# Load X from smem
|
|
if cutlass.const_expr(self.has_d):
|
|
x_coord = subtile_coord + (x_consumer_state.index,)
|
|
cute.copy(tiled_s2r_x, tSR_sX[x_coord], tSR_rX)
|
|
|
|
# Load D from smem
|
|
if cutlass.const_expr(self.d_has_hdim):
|
|
# Load vector D from smem (d_has_hdim = True)
|
|
d_coord = subtile_coord + (d_consumer_state.index,)
|
|
cute.copy(s2r_atom_d, tRS_sD[d_coord], tRS_rD)
|
|
|
|
# Combine INTRA2_ACC/INTER2_ACC/Delta/X/D
|
|
for reg_idx in range(0, cute.size(tRS_rCompute), 2):
|
|
(
|
|
tRS_rCompute[reg_idx],
|
|
tRS_rCompute[reg_idx + 1],
|
|
) = cute.arch.fma_packed_f32x2(
|
|
(tTR_rInter[reg_idx], tTR_rInter[reg_idx + 1]),
|
|
(
|
|
cute.arch.exp(tTR_rDeltaA[reg_idx].ir_value()),
|
|
cute.arch.exp(
|
|
tTR_rDeltaA[reg_idx + 1].ir_value()
|
|
),
|
|
),
|
|
(tTR_rIntra[reg_idx], tTR_rIntra[reg_idx + 1]),
|
|
)
|
|
# Fuse Y += X * D
|
|
if cutlass.const_expr(self.d_has_hdim):
|
|
(
|
|
tRS_rCompute[reg_idx],
|
|
tRS_rCompute[reg_idx + 1],
|
|
) = cute.arch.fma_packed_f32x2(
|
|
(
|
|
tRS_rD[reg_idx].to(self.acc_dtype),
|
|
tRS_rD[reg_idx + 1].to(self.acc_dtype),
|
|
),
|
|
(
|
|
tSR_rX[reg_idx].to(self.acc_dtype),
|
|
tSR_rX[reg_idx + 1].to(self.acc_dtype),
|
|
),
|
|
(
|
|
tRS_rCompute[reg_idx],
|
|
tRS_rCompute[reg_idx + 1],
|
|
),
|
|
)
|
|
elif cutlass.const_expr(self.has_d):
|
|
(
|
|
tRS_rCompute[reg_idx],
|
|
tRS_rCompute[reg_idx + 1],
|
|
) = cute.arch.fma_packed_f32x2(
|
|
(
|
|
tRS_rD.to(self.acc_dtype),
|
|
tRS_rD.to(self.acc_dtype),
|
|
),
|
|
(
|
|
tSR_rX[reg_idx].to(self.acc_dtype),
|
|
tSR_rX[reg_idx + 1].to(self.acc_dtype),
|
|
),
|
|
(
|
|
tRS_rCompute[reg_idx],
|
|
tRS_rCompute[reg_idx + 1],
|
|
),
|
|
)
|
|
|
|
tRS_rY.store(tRS_rCompute.load().to(self.io_dtype))
|
|
|
|
# Store Y to smem
|
|
cute.copy(
|
|
tiled_r2s_y,
|
|
tRS_rY,
|
|
tRS_sY[None, None, None, epi_buffer_idx],
|
|
)
|
|
|
|
# Fence for R2S store
|
|
cute.arch.fence_proxy(
|
|
cute.arch.ProxyKind.async_shared,
|
|
space=cute.arch.SharedSpace.shared_cta,
|
|
)
|
|
# Sync before TMA store
|
|
cute.arch.barrier(
|
|
barrier_id=self.epilog_sync_bar_id,
|
|
number_of_threads=len(self.epilog_warp_id) * 32,
|
|
)
|
|
|
|
# Async arrive Delta/INTRA2_ACC/INTER2_ACC buffer empty
|
|
if (
|
|
epi_iter_cnt
|
|
== cute.size(tTR_tIntra, mode=[4])
|
|
* cute.size(tTR_tIntra, mode=[3])
|
|
- 1
|
|
):
|
|
deltas_pipeline.consumer_release(deltas_consumer_state)
|
|
intra2_acc_pipeline.consumer_release(
|
|
intra2_acc_consumer_state
|
|
)
|
|
inter2_acc_pipeline.consumer_release(
|
|
inter2_acc_consumer_state
|
|
)
|
|
if cutlass.const_expr(self.has_d):
|
|
x_pipeline.consumer_release(
|
|
x_consumer_state,
|
|
pipeline.PipelineOp.AsyncThread,
|
|
)
|
|
|
|
# TMA store Y to global memory
|
|
if local_warp_idx == 0:
|
|
cute.copy(
|
|
tma_atom_y,
|
|
bSG_sY[None, epi_buffer_idx],
|
|
bSG_gY[None, epi_m, epi_n, chunk_idx],
|
|
)
|
|
|
|
# Commit TMA store
|
|
tma_y_pipeline.producer_commit()
|
|
# Wait for TMA store
|
|
tma_y_pipeline.producer_acquire()
|
|
# Sync before smem store
|
|
cute.arch.barrier(
|
|
barrier_id=self.epilog_sync_bar_id,
|
|
number_of_threads=len(self.epilog_warp_id) * 32,
|
|
)
|
|
|
|
# Advance deltas/intra2_acc/inter2_acc consumer states
|
|
deltas_consumer_state.advance()
|
|
intra2_acc_consumer_state.advance()
|
|
inter2_acc_consumer_state.advance()
|
|
|
|
# Peek (try_wait) Delta/INTRA2_ACC/INTER2_ACC buffer full for chunk_idx = chunk_idx + 1
|
|
peek_deltas_full_status = self.conditional_consumer_try_wait(
|
|
deltas_consumer_state, deltas_pipeline, C
|
|
)
|
|
peek_rd_intra2_acc_full_status = self.conditional_consumer_try_wait(
|
|
intra2_acc_consumer_state, intra2_acc_pipeline, C
|
|
)
|
|
peek_rd_inter2_acc_full_status = self.conditional_consumer_try_wait(
|
|
inter2_acc_consumer_state, inter2_acc_pipeline, C
|
|
)
|
|
|
|
if cutlass.const_expr(self.has_d):
|
|
# Advance x consumer states
|
|
x_consumer_state.advance()
|
|
# Peek (try_wait) X buffer full for chunk_idx = chunk_idx + 1
|
|
peek_rd_x_full_status = self.conditional_consumer_try_wait(
|
|
x_consumer_state, x_pipeline, C
|
|
)
|
|
|
|
if cutlass.const_expr(self.d_has_hdim):
|
|
d_pipeline.consumer_release(d_consumer_state)
|
|
d_consumer_state.advance()
|
|
|
|
# Advance to next tile
|
|
tile_sched.advance_to_next_work()
|
|
work_tile = tile_sched.get_current_work()
|
|
|
|
# Producer tail for TMA store Y
|
|
tma_y_pipeline.producer_tail()
|
|
|
|
# Dealloc tmem buffer
|
|
if warp_idx == self.epilog_warp_id[0]:
|
|
cute.arch.barrier(
|
|
barrier_id=self.tmem_dealloc_sync_bar_id,
|
|
number_of_threads=self.threads_per_cta,
|
|
)
|
|
cute.arch.dealloc_tmem(
|
|
tmem_ptr_base,
|
|
self.num_tmem_cols_total,
|
|
is_two_cta=self.use_2cta_instrs,
|
|
)
|
|
else:
|
|
cute.arch.barrier_arrive(
|
|
barrier_id=self.tmem_dealloc_sync_bar_id,
|
|
number_of_threads=self.threads_per_cta,
|
|
)
|
|
|
|
return
|
|
|
|
@staticmethod
|
|
def _compute_stages(smem_capacity):
|
|
return 2, 2, 1, 2 # input, output, internal, intra1_acc
|
|
|
|
@staticmethod
|
|
def _compute_grid(y, b, max_active_clusters):
|
|
B = cute.size(y, mode=[4])
|
|
EH = cute.size(y, mode=[3])
|
|
G = cute.size(b, mode=[3])
|
|
NGROUP_RATIO = EH // G
|
|
num_blocks = B * EH
|
|
|
|
tile_sched_params = Mamba2SSDTileSchedulerParams(num_blocks, EH, NGROUP_RATIO)
|
|
grid = Mamba2SSDTileScheduler.get_grid_shape(
|
|
tile_sched_params, max_active_clusters
|
|
)
|
|
return tile_sched_params, grid
|
|
|
|
@staticmethod
|
|
def _plan_tmem_offsets(
|
|
tiled_mma_intra1,
|
|
tile_shape_mnk_intra1,
|
|
tiled_mma_intra2,
|
|
tile_shape_mnk_intra2,
|
|
tiled_mma_inter1,
|
|
tile_shape_mnk_inter1,
|
|
tiled_mma_inter2,
|
|
tile_shape_mnk_inter2,
|
|
acc_stages,
|
|
intra2_a_tmem_layout,
|
|
a_dtype,
|
|
internal_stages,
|
|
intra1_acc_stages,
|
|
):
|
|
SM100_TMEM_CAPACITY_COLUMNS = 512
|
|
BITS_PER_TMEM_COL = 32
|
|
# (MMA, MMA_M, MMA_N)
|
|
acc_shape_intra1 = tiled_mma_intra1.partition_shape_C(tile_shape_mnk_intra1[:2])
|
|
# (MMA, MMA_M, MMA_N)
|
|
tCtAccIntra1_fake = tiled_mma_intra1.make_fragment_C(
|
|
cute.append(acc_shape_intra1, intra1_acc_stages)
|
|
)
|
|
num_intra1_acc_cols = tcgen05.find_tmem_tensor_col_offset(tCtAccIntra1_fake)
|
|
assert tile_shape_mnk_intra1[1] * intra1_acc_stages == num_intra1_acc_cols
|
|
# (MMA, MMA_N, MMA_K, STAGE)
|
|
tCrQ_fake = tiled_mma_intra2.make_fragment_A(intra2_a_tmem_layout.outer.shape)
|
|
num_intra2_a_cols = tcgen05.find_tmem_tensor_col_offset(tCrQ_fake)
|
|
assert (
|
|
tile_shape_mnk_intra2[2]
|
|
* internal_stages
|
|
* a_dtype.width
|
|
// BITS_PER_TMEM_COL
|
|
== num_intra2_a_cols
|
|
)
|
|
# (MMA, MMA_M, MMA_N)
|
|
acc_shape_intra2 = tiled_mma_intra2.partition_shape_C(tile_shape_mnk_intra2[:2])
|
|
# (MMA, MMA_M, MMA_N)
|
|
tCtAccIntra2_fake = tiled_mma_intra2.make_fragment_C(
|
|
cute.append(acc_shape_intra2, acc_stages)
|
|
)
|
|
num_intra2_acc_cols = tcgen05.find_tmem_tensor_col_offset(tCtAccIntra2_fake)
|
|
assert tile_shape_mnk_intra2[1] * acc_stages == num_intra2_acc_cols
|
|
|
|
# (MMA, MMA_M, MMA_N)
|
|
acc_shape_inter1 = tiled_mma_inter1.partition_shape_C(tile_shape_mnk_inter1[:2])
|
|
# (MMA, MMA_M, MMA_N)
|
|
tCtAccInter1_fake = tiled_mma_inter1.make_fragment_C(
|
|
cute.append(acc_shape_inter1, acc_stages)
|
|
)
|
|
num_inter1_acc_cols = tcgen05.find_tmem_tensor_col_offset(tCtAccInter1_fake)
|
|
assert tile_shape_mnk_inter1[1] * acc_stages == num_inter1_acc_cols
|
|
|
|
# (MMA, MMA_M, MMA_N)
|
|
acc_shape_inter2 = tiled_mma_inter2.partition_shape_C(tile_shape_mnk_inter2[:2])
|
|
# (MMA, MMA_M, MMA_N)
|
|
tCtAccInter2_fake = tiled_mma_inter2.make_fragment_C(
|
|
cute.append(acc_shape_inter2, acc_stages)
|
|
)
|
|
num_inter2_acc_cols = tcgen05.find_tmem_tensor_col_offset(tCtAccInter2_fake)
|
|
assert tile_shape_mnk_inter2[1] * acc_stages == num_inter2_acc_cols
|
|
|
|
tmem_intra1_acc_offset = 0
|
|
tmem_intra2_q_offset = tmem_intra1_acc_offset + num_intra1_acc_cols
|
|
tmem_intra2_acc_offset = tmem_intra2_q_offset + num_intra2_a_cols
|
|
tmem_inter1_acc_offset = tmem_intra2_acc_offset + num_intra2_acc_cols
|
|
tmem_inter2_acc_offset = tmem_inter1_acc_offset + num_inter1_acc_cols
|
|
num_tmem_cols_total_tmp = tmem_inter2_acc_offset + num_inter2_acc_cols
|
|
# Turn num_tmem_cols_total to the nearest power of 2
|
|
num_tmem_cols_total = 1
|
|
while num_tmem_cols_total < num_tmem_cols_total_tmp:
|
|
num_tmem_cols_total *= 2
|
|
assert num_tmem_cols_total <= SM100_TMEM_CAPACITY_COLUMNS
|
|
|
|
return (
|
|
tmem_intra1_acc_offset,
|
|
tmem_intra2_q_offset,
|
|
tmem_intra2_acc_offset,
|
|
tmem_inter1_acc_offset,
|
|
tmem_inter2_acc_offset,
|
|
num_tmem_cols_total,
|
|
)
|
|
|
|
@staticmethod
|
|
def make_tiled_mmas(
|
|
io_dtype,
|
|
acc_dtype,
|
|
cta_group,
|
|
tile_shape_mnk_intra1,
|
|
tile_shape_mnk_intra2,
|
|
tile_shape_mnk_inter1,
|
|
tile_shape_mnk_inter2,
|
|
):
|
|
tiled_mma_intra1 = sm100_utils.make_trivial_tiled_mma(
|
|
io_dtype,
|
|
tcgen05.OperandMajorMode("mn"),
|
|
tcgen05.OperandMajorMode("mn"),
|
|
acc_dtype,
|
|
cta_group,
|
|
tile_shape_mnk_intra1[:2],
|
|
tcgen05.OperandSource.SMEM,
|
|
)
|
|
tiled_mma_intra2 = sm100_utils.make_trivial_tiled_mma(
|
|
io_dtype,
|
|
tcgen05.OperandMajorMode("k"),
|
|
tcgen05.OperandMajorMode("k"),
|
|
acc_dtype,
|
|
cta_group,
|
|
tile_shape_mnk_intra2[:2],
|
|
tcgen05.OperandSource.TMEM,
|
|
)
|
|
tiled_mma_inter1 = sm100_utils.make_trivial_tiled_mma(
|
|
io_dtype,
|
|
tcgen05.OperandMajorMode("k"),
|
|
tcgen05.OperandMajorMode("k"),
|
|
acc_dtype,
|
|
cta_group,
|
|
tile_shape_mnk_inter1[:2],
|
|
tcgen05.OperandSource.SMEM,
|
|
)
|
|
tiled_mma_inter2 = sm100_utils.make_trivial_tiled_mma(
|
|
io_dtype,
|
|
tcgen05.OperandMajorMode("mn"),
|
|
tcgen05.OperandMajorMode("k"),
|
|
acc_dtype,
|
|
cta_group,
|
|
tile_shape_mnk_inter2[:2],
|
|
tcgen05.OperandSource.SMEM,
|
|
)
|
|
return tiled_mma_intra1, tiled_mma_intra2, tiled_mma_inter1, tiled_mma_inter2
|
|
|
|
def make_and_init_x_pipeline(self, x_full_mbar_ptr):
|
|
x_producer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, len([self.tma_deltas_x_d_warp_id])
|
|
)
|
|
if not self.has_d:
|
|
x_consumer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread,
|
|
len([self.mma_intra_warp_id, self.mma_inter_warp_id]),
|
|
)
|
|
return pipeline.PipelineTmaUmma.create(
|
|
num_stages=self.input_stages,
|
|
producer_group=x_producer_group,
|
|
consumer_group=x_consumer_group,
|
|
tx_count=self.num_x_load_bytes,
|
|
barrier_storage=x_full_mbar_ptr,
|
|
)
|
|
else:
|
|
x_consumer_group_umma = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread,
|
|
len([self.mma_intra_warp_id, self.mma_inter_warp_id]),
|
|
)
|
|
x_consumer_group_async = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, 32 * len(self.epilog_warp_id), 128
|
|
)
|
|
return pipeline.PipelineTmaMultiConsumersAsync.create(
|
|
num_stages=self.input_stages,
|
|
producer_group=x_producer_group,
|
|
consumer_group_umma=x_consumer_group_umma,
|
|
consumer_group_async=x_consumer_group_async,
|
|
tx_count=self.num_x_load_bytes,
|
|
barrier_storage=x_full_mbar_ptr,
|
|
)
|
|
|
|
def make_and_init_b_pipeline(self, b_full_mbar_ptr):
|
|
b_producer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, len([self.tma_b_c_warp_id])
|
|
)
|
|
b_consumer_group_umma = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, len([self.mma_intra_warp_id])
|
|
)
|
|
b_consumer_group_async = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, 32 * len(self.pre_inter_warp_id), 128
|
|
)
|
|
return pipeline.PipelineTmaMultiConsumersAsync.create(
|
|
num_stages=self.input_stages,
|
|
producer_group=b_producer_group,
|
|
consumer_group_umma=b_consumer_group_umma,
|
|
consumer_group_async=b_consumer_group_async,
|
|
tx_count=self.num_b_load_bytes,
|
|
barrier_storage=b_full_mbar_ptr,
|
|
)
|
|
|
|
def make_and_init_c_pipeline(self, c_full_mbar_ptr):
|
|
c_producer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, len([self.tma_b_c_warp_id])
|
|
)
|
|
c_consumer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, len([self.mma_intra_warp_id, self.mma_inter_warp_id])
|
|
)
|
|
return pipeline.PipelineTmaUmma.create(
|
|
num_stages=self.input_stages,
|
|
producer_group=c_producer_group,
|
|
consumer_group=c_consumer_group,
|
|
tx_count=self.num_c_load_bytes,
|
|
barrier_storage=c_full_mbar_ptr,
|
|
)
|
|
|
|
def make_and_init_deltas_pipeline(self, deltas_full_mbar_ptr):
|
|
deltas_producer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, len([self.tma_deltas_x_d_warp_id])
|
|
)
|
|
deltas_consumer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread,
|
|
len(
|
|
[*self.pre_inter_warp_id, *self.pre_intra_warp_id, *self.epilog_warp_id]
|
|
),
|
|
len(
|
|
[*self.pre_inter_warp_id, *self.pre_intra_warp_id, *self.epilog_warp_id]
|
|
),
|
|
)
|
|
|
|
return pipeline.PipelineTmaAsync.create(
|
|
num_stages=self.input_stages,
|
|
producer_group=deltas_producer_group,
|
|
consumer_group=deltas_consumer_group,
|
|
tx_count=self.num_delta_load_bytes + self.num_cumsum_delta_load_bytes,
|
|
barrier_storage=deltas_full_mbar_ptr,
|
|
)
|
|
|
|
def make_and_init_d_pipeline(self, d_full_mbar_ptr):
|
|
if not self.d_has_hdim:
|
|
return None
|
|
else:
|
|
d_producer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, len([self.tma_deltas_x_d_warp_id])
|
|
)
|
|
d_consumer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread,
|
|
len(self.epilog_warp_id),
|
|
len(self.epilog_warp_id),
|
|
)
|
|
|
|
return pipeline.PipelineTmaAsync.create(
|
|
num_stages=self.input_stages,
|
|
producer_group=d_producer_group,
|
|
consumer_group=d_consumer_group,
|
|
tx_count=self.num_d_load_bytes,
|
|
barrier_storage=d_full_mbar_ptr,
|
|
)
|
|
|
|
def make_and_init_intra1_acc_pipeline(self, intra1_acc_full_mbar_ptr):
|
|
intra1_acc_producer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, len([self.mma_intra_warp_id])
|
|
)
|
|
intra1_acc_consumer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, 32 * len(self.pre_intra_warp_id), 128
|
|
)
|
|
return pipeline.PipelineUmmaAsync.create(
|
|
num_stages=self.intra1_acc_stages,
|
|
producer_group=intra1_acc_producer_group,
|
|
consumer_group=intra1_acc_consumer_group,
|
|
barrier_storage=intra1_acc_full_mbar_ptr,
|
|
)
|
|
|
|
def make_and_init_intra2_q_pipeline(self, intra2_q_full_mbar_ptr):
|
|
intra2_q_producer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, 32 * len(self.pre_intra_warp_id), 128
|
|
)
|
|
intra2_q_consumer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, len([self.mma_intra_warp_id])
|
|
)
|
|
return pipeline.PipelineAsyncUmma.create(
|
|
num_stages=self.internal_stages,
|
|
producer_group=intra2_q_producer_group,
|
|
consumer_group=intra2_q_consumer_group,
|
|
barrier_storage=intra2_q_full_mbar_ptr,
|
|
)
|
|
|
|
def make_and_init_intra2_acc_pipeline(self, intra2_acc_full_mbar_ptr):
|
|
intra2_acc_producer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, len([self.mma_intra_warp_id])
|
|
)
|
|
intra2_acc_consumer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, 32 * len(self.epilog_warp_id), 128
|
|
)
|
|
return pipeline.PipelineUmmaAsync.create(
|
|
num_stages=self.internal_stages,
|
|
producer_group=intra2_acc_producer_group,
|
|
consumer_group=intra2_acc_consumer_group,
|
|
barrier_storage=intra2_acc_full_mbar_ptr,
|
|
)
|
|
|
|
def make_and_init_inter1_b_pipeline(self, inter1_b_full_mbar_ptr):
|
|
inter1_b_producer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, 32 * len(self.pre_inter_warp_id), 128
|
|
)
|
|
inter1_b_consumer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, len([self.mma_inter_warp_id])
|
|
)
|
|
return pipeline.PipelineAsyncUmma.create(
|
|
num_stages=self.internal_stages,
|
|
producer_group=inter1_b_producer_group,
|
|
consumer_group=inter1_b_consumer_group,
|
|
barrier_storage=inter1_b_full_mbar_ptr,
|
|
)
|
|
|
|
def make_and_init_inter1_acc_pipeline(self, inter1_acc_full_mbar_ptr):
|
|
inter1_acc_producer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, len([self.mma_inter_warp_id])
|
|
)
|
|
inter1_acc_consumer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, 32 * len(self.pre_inter_warp_id), 128
|
|
)
|
|
return pipeline.PipelineUmmaAsync.create(
|
|
num_stages=self.internal_stages,
|
|
producer_group=inter1_acc_producer_group,
|
|
consumer_group=inter1_acc_consumer_group,
|
|
barrier_storage=inter1_acc_full_mbar_ptr,
|
|
)
|
|
|
|
def make_and_init_inter2_p_pipeline(self, inter2_p_full_mbar_ptr):
|
|
inter2_p_producer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, 32 * len(self.pre_inter_warp_id), 128
|
|
)
|
|
inter2_p_consumer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, len([self.mma_inter_warp_id])
|
|
)
|
|
return pipeline.PipelineAsyncUmma.create(
|
|
num_stages=self.internal_stages,
|
|
producer_group=inter2_p_producer_group,
|
|
consumer_group=inter2_p_consumer_group,
|
|
barrier_storage=inter2_p_full_mbar_ptr,
|
|
)
|
|
|
|
def make_and_init_inter2_acc_pipeline(self, inter2_acc_full_mbar_ptr):
|
|
inter2_acc_producer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, len([self.mma_inter_warp_id])
|
|
)
|
|
inter2_acc_consumer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, 32 * len(self.epilog_warp_id), 128
|
|
)
|
|
return pipeline.PipelineUmmaAsync.create(
|
|
num_stages=self.internal_stages,
|
|
producer_group=inter2_acc_producer_group,
|
|
consumer_group=inter2_acc_consumer_group,
|
|
barrier_storage=inter2_acc_full_mbar_ptr,
|
|
)
|
|
|
|
def tma_partition_for_mma_b_operand(
|
|
self,
|
|
tma_atom_x,
|
|
tma_tensor_x,
|
|
smem_x,
|
|
tiled_mma_intra2,
|
|
cluster_layout_vmnk,
|
|
mma_tile_coord_v,
|
|
block_in_cluster_coord_vmnk,
|
|
):
|
|
# Local_tile partition global tensors
|
|
# (D, L, 1, 1, C, EH, B)
|
|
gX = cute.local_tile(
|
|
tma_tensor_x,
|
|
self.tile_shape_mnk_intra2[1:],
|
|
(None, None, None, None, None),
|
|
)
|
|
# Partition global tensor with regard to TiledMMA
|
|
thr_mma_intra2 = tiled_mma_intra2.get_slice(mma_tile_coord_v)
|
|
# (MMA, MMA_N, MMA_K, 1, 1, C, EH, B)
|
|
tCgX = thr_mma_intra2.partition_B(gX)
|
|
|
|
# Partition global/shared tensor for X
|
|
x_cta_layout = cute.make_layout(
|
|
cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape
|
|
)
|
|
# ((ATOM_V, REST_V), INPUT_STAGE)
|
|
# ((ATOM_V, REST_V), 1, 1, C, EH, B)
|
|
tXsX, tXgX_pre_slice = cpasync.tma_partition(
|
|
tma_atom_x,
|
|
block_in_cluster_coord_vmnk[2],
|
|
x_cta_layout,
|
|
cute.group_modes(smem_x, 0, 3),
|
|
cute.group_modes(tCgX, 0, 3),
|
|
)
|
|
return tXsX, tXgX_pre_slice
|
|
|
|
def tma_partition_for_mma_a_operand(
|
|
self,
|
|
tma_atom_c,
|
|
tma_tensor_c,
|
|
smem_c,
|
|
tiled_mma_intra1,
|
|
cluster_layout_vmnk,
|
|
mma_tile_coord_v,
|
|
block_in_cluster_coord_vmnk,
|
|
):
|
|
# Local_tile partition global tensors
|
|
# (L, N, 1, 1, C, G, B)
|
|
gC = cute.local_tile(
|
|
tma_tensor_c,
|
|
cute.slice_(self.tile_shape_mnk_intra1, (None, 0, None)),
|
|
(None, None, None, None, None),
|
|
)
|
|
# Partition global tensor with regard to TiledMMA
|
|
thr_mma_intra1 = tiled_mma_intra1.get_slice(mma_tile_coord_v)
|
|
# (MMA, MMA_M/N, MMA_K, 1, 1, C, G, B)
|
|
tCgC = thr_mma_intra1.partition_A(gC)
|
|
|
|
# Partition global/shared tensor for TMA C
|
|
c_cta_layout = cute.make_layout(
|
|
cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape
|
|
)
|
|
# ((ATOM_V, REST_V), INPUT_STAGE)
|
|
# ((ATOM_V, REST_V), 1, 1, C, G, B)
|
|
tCsC, tCgC_pre_slice = cpasync.tma_partition(
|
|
tma_atom_c,
|
|
block_in_cluster_coord_vmnk[1],
|
|
c_cta_layout,
|
|
cute.group_modes(smem_c, 0, 3),
|
|
cute.group_modes(tCgC, 0, 3),
|
|
)
|
|
return tCsC, tCgC_pre_slice
|
|
|
|
def tma_partition_with_shape(
|
|
self, tma_atom_delta, tma_tensor_delta, smem_delta, shape
|
|
):
|
|
# Local_tile partition global tensors
|
|
# (L, 1, C, EH, B)
|
|
gDelta = cute.local_tile(
|
|
tma_tensor_delta,
|
|
shape,
|
|
(None,) * cute.rank(tma_tensor_delta),
|
|
)
|
|
# Partition global/shared tensor for DELTA
|
|
# ((ATOM_V, REST_V), INPUT_STAGE)
|
|
# ((ATOM_V, REST_V), 1, C, EH, B)
|
|
tDeltasDelta, tDeltagDelta_pre_slice = cpasync.tma_partition(
|
|
tma_atom_delta,
|
|
0,
|
|
cute.make_layout(1),
|
|
cute.group_modes(smem_delta, 0, cute.rank(shape)),
|
|
cute.group_modes(gDelta, 0, cute.rank(shape)),
|
|
)
|
|
|
|
return tDeltasDelta, tDeltagDelta_pre_slice
|
|
|
|
def mma_partition_ss(
|
|
self,
|
|
tiled_mma,
|
|
tile_shape_mnk,
|
|
smem_a,
|
|
smem_b,
|
|
tmem_acc_ptr,
|
|
acc_stages,
|
|
):
|
|
# (MMA, MMA_M, MMA_K, INPUT_STAGE)
|
|
tCrA = tiled_mma.make_fragment_A(smem_a)
|
|
# (MMA, MMA_N, MMA_K, INPUT_STAGE)
|
|
tCrB = tiled_mma.make_fragment_B(smem_b)
|
|
# (MMA, MMA_M, MMA_N, ACC_STAGE)
|
|
tCtAcc = self.mma_partition_c(
|
|
tiled_mma, tile_shape_mnk, tmem_acc_ptr, acc_stages
|
|
)
|
|
return tCrA, tCrB, tCtAcc
|
|
|
|
def mma_partition_ts(
|
|
self,
|
|
tiled_mma,
|
|
tile_shape_mnk,
|
|
a_tmem_layout,
|
|
smem_b,
|
|
tmem_a_ptr,
|
|
tmem_acc_ptr,
|
|
acc_stages,
|
|
):
|
|
# (MMA, MMA_M, MMA_K, INTERNAL_STAGE)
|
|
tCrA = self.mma_partition_a_tmem(tiled_mma, a_tmem_layout, tmem_a_ptr)
|
|
# (MMA, MMA_N, MMA_K, INPUT_STAGE)
|
|
tCrB = tiled_mma.make_fragment_B(smem_b)
|
|
# (MMA, MMA_M, MMA_N, INTERNAL_STAGE)
|
|
tCtAcc = self.mma_partition_c(
|
|
tiled_mma, tile_shape_mnk, tmem_acc_ptr, acc_stages
|
|
)
|
|
return tCrA, tCrB, tCtAcc
|
|
|
|
def mma_partition_a_tmem(self, tiled_mma, a_tmem_layout, tmem_a_ptr):
|
|
tCrA_fake = tiled_mma.make_fragment_A(a_tmem_layout.outer.shape)
|
|
tCrA = cute.make_tensor(
|
|
cute.recast_ptr(
|
|
tmem_a_ptr,
|
|
dtype=tCrA_fake.element_type,
|
|
),
|
|
tCrA_fake.layout,
|
|
)
|
|
return tCrA
|
|
|
|
def mma_partition_c(self, tiled_mma, tile_shape_mnk, tmem_acc_ptr, acc_stages):
|
|
acc_shape = tiled_mma.partition_shape_C(tile_shape_mnk[:2])
|
|
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, acc_stages))
|
|
# (MMA, MMA_M, MMA_N, INTERNAL_STAGE)
|
|
tCtAcc = cute.make_tensor(tmem_acc_ptr, tCtAcc_fake.layout)
|
|
return tCtAcc
|
|
|
|
@cute.jit
|
|
def exec_mma(
|
|
self,
|
|
tiled_mma,
|
|
tCtAcc,
|
|
tCrA,
|
|
tCrB,
|
|
acc_producer_state,
|
|
a_consumer_state,
|
|
b_consumer_state,
|
|
):
|
|
for kphase_idx in cutlass.range(cute.size(tCrB, mode=[2]), unroll_full=True):
|
|
# set accu = 1
|
|
tiled_mma.set(
|
|
tcgen05.Field.ACCUMULATE,
|
|
cutlass.Boolean(kphase_idx != 0),
|
|
)
|
|
cute.gemm(
|
|
tiled_mma,
|
|
tCtAcc[None, None, None, acc_producer_state.index],
|
|
tCrA[None, None, kphase_idx, a_consumer_state.index],
|
|
tCrB[None, None, kphase_idx, b_consumer_state.index],
|
|
tCtAcc[None, None, None, acc_producer_state.index],
|
|
)
|
|
return tiled_mma
|
|
|
|
@cute.jit
|
|
def conditional_consumer_try_wait(self, b_consumer_state, b_pipeline, C):
|
|
peek_b_full_status = cutlass.Boolean(1)
|
|
if b_consumer_state.count < C:
|
|
peek_b_full_status = b_pipeline.consumer_try_wait(b_consumer_state)
|
|
return peek_b_full_status
|
|
|
|
@cute.jit
|
|
def conditional_producer_try_acquire(
|
|
self, intra1_acc_producer_state, intra1_acc_pipeline, C
|
|
):
|
|
peek_wr_intra1_acc_empty_status = cutlass.Boolean(1)
|
|
if intra1_acc_producer_state.count < C:
|
|
peek_wr_intra1_acc_empty_status = intra1_acc_pipeline.producer_try_acquire(
|
|
intra1_acc_producer_state
|
|
)
|
|
return peek_wr_intra1_acc_empty_status
|
|
|
|
def pre_intra_tmem_load_and_partition_q(self, tIntra1, local_tidx):
|
|
copy_atom_t2r_intra1 = cute.make_copy_atom(
|
|
tcgen05.Ld16x256bOp(tcgen05.Repetition(16), tcgen05.Pack.NONE),
|
|
self.acc_dtype,
|
|
)
|
|
# (L, L)
|
|
fake_sQ = cute.make_tensor(
|
|
cute.make_ptr(self.io_dtype, 0, cute.AddressSpace.smem),
|
|
cute.dice(self.tile_shape_mnk_intra1, (1, 1, None)),
|
|
)
|
|
return self.make_tmem_load_and_partition(
|
|
copy_atom_t2r_intra1, tIntra1, (None, None, 0), local_tidx, fake_sQ
|
|
)
|
|
|
|
def pre_intra_make_delta(self, smem_delta, extend_on_row_or_col):
|
|
smem_iterator = smem_delta.iterator
|
|
delta_linear_smem_layout = smem_delta.layout
|
|
# extend L linear layout to LxL
|
|
extend_layout = cute.make_layout(delta_linear_smem_layout.shape[0], stride=0)
|
|
if extend_on_row_or_col == 0:
|
|
# (L, L, INPUT_STAGE):(0, 1, L)
|
|
sDelta = cute.make_tensor(
|
|
smem_iterator,
|
|
cute.prepend(
|
|
delta_linear_smem_layout,
|
|
extend_layout,
|
|
),
|
|
)
|
|
else:
|
|
# (L, L, INPUT_STAGE):(1, 0, L)
|
|
sDelta = cute.make_tensor(
|
|
smem_iterator,
|
|
cute.append(
|
|
cute.append(
|
|
cute.get(delta_linear_smem_layout, mode=[0]),
|
|
extend_layout,
|
|
),
|
|
cute.get(delta_linear_smem_layout, mode=[1]),
|
|
),
|
|
)
|
|
return sDelta
|
|
|
|
def pre_intra_tmem_store_and_partition_q(self, local_tidx, tCrQ):
|
|
dtype = tCrQ.element_type
|
|
# Make tiledCopy for tensor memory store INTRA2_Q
|
|
copy_atom_r2t_q = cute.make_copy_atom(
|
|
tcgen05.St16x128bOp(tcgen05.Repetition(16), tcgen05.Unpack.NONE),
|
|
dtype,
|
|
)
|
|
tiled_r2t_q = tcgen05.make_tmem_copy(copy_atom_r2t_q, tCrQ)
|
|
thr_r2t_q = tiled_r2t_q.get_slice(local_tidx)
|
|
|
|
# Partition tmem/register tensor for tensor memory store INTRA2_Q
|
|
# ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, ...)
|
|
tRT_rQ = cute.make_fragment(
|
|
cute.slice_(thr_r2t_q.partition_S(tCrQ).shape, (None, None, None, None, 0)),
|
|
dtype,
|
|
)
|
|
# ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, ..., INTERNAL_STAGE)
|
|
tRT_tQ = thr_r2t_q.partition_D(tCrQ)
|
|
|
|
return tiled_r2t_q, tRT_rQ, tRT_tQ
|
|
|
|
@cute.jit
|
|
def pre_intra_segsum(
|
|
self, tTR_rQ, tQrDeltaA_Row, tQrDeltaA_Col, tQrDelta, tCoord, tRT_rQ
|
|
):
|
|
# Make tmp acc type fragments
|
|
tCrDeltaA_Row = cute.make_fragment(tQrDeltaA_Row.shape, self.acc_dtype)
|
|
tCrDeltaA_Col = cute.make_fragment(tQrDeltaA_Col.shape, self.acc_dtype)
|
|
tCrDelta = cute.make_fragment(tQrDelta.shape, self.acc_dtype)
|
|
tCompute = cute.make_fragment(tRT_rQ.shape, self.acc_dtype)
|
|
|
|
# Combine tTR_rQ/tCrDeltaA_Row/tCrDeltaA_Col/tCrDelta
|
|
tCrDeltaA_Row.store(tQrDeltaA_Row.load().to(self.acc_dtype))
|
|
tCrDeltaA_Col.store(tQrDeltaA_Col.load().to(self.acc_dtype))
|
|
tCrDelta.store(tQrDelta.load().to(self.acc_dtype))
|
|
|
|
# SegSum
|
|
# fadd2 + fsel + fmul2/mufu + fmul2
|
|
for subtile_idx in cutlass.range(0, cute.size(tTR_rQ), 2, unroll_full=True):
|
|
(
|
|
tCompute[subtile_idx],
|
|
tCompute[subtile_idx + 1],
|
|
) = cute.arch.add_packed_f32x2(
|
|
(tCrDeltaA_Col[subtile_idx], tCrDeltaA_Col[subtile_idx + 1]),
|
|
(-tCrDeltaA_Row[subtile_idx], -tCrDeltaA_Row[subtile_idx + 1]),
|
|
)
|
|
for subtile_idx in cutlass.range(cute.size(tTR_rQ), unroll_full=True):
|
|
m, n = tCoord[subtile_idx]
|
|
if m < n:
|
|
tCompute[subtile_idx] = cutlass.Float32(-float("inf"))
|
|
for subtile_idx in cutlass.range(0, cute.size(tTR_rQ), 2, unroll_full=True):
|
|
# TODO: use math.exp directly
|
|
(
|
|
tCompute[subtile_idx],
|
|
tCompute[subtile_idx + 1],
|
|
) = cute.arch.mul_packed_f32x2(
|
|
cute.arch.exp_packed_f32x2(
|
|
(tCompute[subtile_idx], tCompute[subtile_idx + 1])
|
|
),
|
|
(tCrDelta[subtile_idx], tCrDelta[subtile_idx + 1]),
|
|
)
|
|
(
|
|
tCompute[subtile_idx],
|
|
tCompute[subtile_idx + 1],
|
|
) = cute.arch.mul_packed_f32x2(
|
|
(tCompute[subtile_idx], tCompute[subtile_idx + 1]),
|
|
(tTR_rQ[subtile_idx], tTR_rQ[subtile_idx + 1]),
|
|
)
|
|
|
|
tRT_rQ.store(tCompute.load().to(self.io_dtype))
|
|
return tRT_rQ
|
|
|
|
def pre_inter_smem_load_and_partition_b(self, local_tidx, smem_bt):
|
|
dtype = smem_bt.element_type
|
|
copy_atom_s2r_b = cute.make_copy_atom(
|
|
cute.nvgpu.CopyUniversalOp(),
|
|
dtype,
|
|
num_bits_per_copy=128,
|
|
)
|
|
num_elements_per_thread = 128 // dtype.width
|
|
num_threads_per_row = self.tile_shape_mnk_inter1[2] // num_elements_per_thread
|
|
num_threads_per_col = 128 // num_threads_per_row
|
|
thread_layout = cute.make_layout(
|
|
(num_threads_per_col, num_threads_per_row),
|
|
stride=(num_threads_per_row, 1),
|
|
)
|
|
val_layout = cute.make_layout((1, num_elements_per_thread))
|
|
tiled_s2r_b = cute.make_tiled_copy_tv(
|
|
copy_atom_s2r_b,
|
|
thread_layout,
|
|
val_layout,
|
|
)
|
|
thr_s2r_b = tiled_s2r_b.get_slice(local_tidx)
|
|
|
|
# Partition shared tensor for smem load Bt
|
|
# ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N, INPUT_STAGE)
|
|
tBsB_s2r = thr_s2r_b.partition_S(smem_bt)
|
|
|
|
# ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N)
|
|
tBrB_s2r = cute.make_fragment(
|
|
cute.slice_(tBsB_s2r.shape, (None, None, None, 0)),
|
|
dtype,
|
|
)
|
|
return tiled_s2r_b, tBsB_s2r, tBrB_s2r
|
|
|
|
def pre_inter_smem_store_and_partition_b(
|
|
self, local_tidx, smem_bt_internal, tiled_s2r_b, tBrB_s2r
|
|
):
|
|
dtype = smem_bt_internal.element_type
|
|
# Make tiledCopy from register to smem store Bt
|
|
copy_atom_r2s_b = cute.make_copy_atom(
|
|
cute.nvgpu.CopyUniversalOp(),
|
|
dtype,
|
|
num_bits_per_copy=128,
|
|
)
|
|
tiled_r2s_b = cute.make_tiled_copy_S(copy_atom_r2s_b, tiled_s2r_b)
|
|
thr_r2s_b = tiled_r2s_b.get_slice(local_tidx)
|
|
|
|
# Partition shared tensor for smem store Bt
|
|
# ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N, INTERNAL_STAGE)
|
|
tBsB_r2s = thr_r2s_b.partition_D(smem_bt_internal)
|
|
|
|
# Make register fragments for smem load/store Bt
|
|
# ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N)
|
|
tBrB_r2s = thr_r2s_b.retile(tBrB_s2r)
|
|
return tiled_r2s_b, tBrB_r2s, tBsB_r2s
|
|
|
|
def smem_load_and_partition_delta_d(
|
|
self, tiled_s2r_b, local_tidx, smem_delta, smem_tile_coord
|
|
):
|
|
dtype = smem_delta.element_type
|
|
s2r_atom_delta = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), dtype)
|
|
|
|
thr_s2r_b = tiled_s2r_b.get_slice(local_tidx)
|
|
# ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N, INPUT_STAGE)
|
|
tBsDelta_s2r = thr_s2r_b.partition_D(smem_delta)
|
|
|
|
# Make register fragments for smem load/store of Delta/DeltaA
|
|
# ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N)
|
|
tBrDelta_s2r = cute.make_fragment(tBsDelta_s2r[smem_tile_coord].shape, dtype)
|
|
return s2r_atom_delta, tBsDelta_s2r, tBrDelta_s2r
|
|
|
|
def pre_inter_tmem_load_and_partition_p(self, local_tidx, tInter1, smem_pt):
|
|
copy_atom_t2r_inter1 = cute.make_copy_atom(
|
|
tcgen05.Ld16x256bOp(tcgen05.Repetition(8), tcgen05.Pack.NONE),
|
|
self.acc_dtype,
|
|
)
|
|
return self.make_tmem_load_and_partition(
|
|
copy_atom_t2r_inter1,
|
|
tInter1,
|
|
(None, None, 0),
|
|
local_tidx,
|
|
smem_pt[None, None, 0],
|
|
)
|
|
|
|
def make_tmem_load_and_partition(
|
|
self, copy_atom_t2r, tmem_tensor, tmem_tile_coord, local_tidx, smem_tensor
|
|
):
|
|
dtype = tmem_tensor.element_type
|
|
tiled_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tmem_tensor[tmem_tile_coord])
|
|
thr_t2r = tiled_t2r.get_slice(local_tidx)
|
|
# Partition tmem/shared tensor for tmem load INTER1_ACC
|
|
# ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N)
|
|
tTR_t = thr_t2r.partition_S(tmem_tensor)
|
|
tTR_s = thr_t2r.partition_D(smem_tensor)
|
|
# Make register fragments for tmem load INTER1_ACC
|
|
# ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N)
|
|
tTR_r = cute.make_fragment(
|
|
tTR_s.shape,
|
|
dtype,
|
|
)
|
|
return tiled_t2r, tTR_t, tTR_r
|
|
|
|
def smem_store_and_partition_p_y(self, local_tidx, smem_pt, tiled_t2r_inter1):
|
|
dtype = smem_pt.element_type
|
|
copy_atom_r2s_p = cute.make_copy_atom(
|
|
cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=True, num_matrices=4),
|
|
dtype,
|
|
)
|
|
tiled_r2s_p = cute.make_tiled_copy_D(copy_atom_r2s_p, tiled_t2r_inter1)
|
|
thr_r2s_p = tiled_r2s_p.get_slice(local_tidx)
|
|
# Partition smem/register tensor for smem store INTER2_P
|
|
# ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N, INTERNAL_STAGE)
|
|
tRS_sP = thr_r2s_p.partition_D(smem_pt)
|
|
# ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N)
|
|
tRS_rP = cute.make_fragment(
|
|
cute.slice_(tRS_sP.shape, (None, None, None, 0)), self.io_dtype
|
|
)
|
|
return tiled_r2s_p, tRS_rP, tRS_sP
|
|
|
|
def pre_inter_make_delta(self, smem_delta, smem_bt_layout):
|
|
# Broadcast Delta/DeltaA to Bt shape on M dimension
|
|
# before: (128,(64,2),2):(64,(1,8192),16384)
|
|
# after : (128,(64,2),2):(0,(1,64),128)
|
|
# (MMA, MMA_M, MMA_K, INPUT_STAGE)
|
|
sDeltaA = cute.make_tensor(
|
|
smem_delta.iterator,
|
|
cute.make_layout(
|
|
smem_bt_layout.shape,
|
|
stride=(
|
|
0,
|
|
(1, cute.get(smem_bt_layout.shape, mode=[1, 0])),
|
|
smem_delta.layout.stride[1],
|
|
),
|
|
),
|
|
)
|
|
return sDeltaA
|
|
|
|
def pre_inter_scale_bt_with_delta(
|
|
self, tBrB_s2r, tBrDelta_s2r, tBrDeltaA_s2r, last_column
|
|
):
|
|
tCompute = cute.make_fragment(tBrB_s2r.shape, self.acc_dtype)
|
|
tBrB_Compute = cute.make_fragment(tBrB_s2r.shape, self.acc_dtype)
|
|
tBrDelta_Compute = cute.make_fragment(tBrDelta_s2r.shape, self.acc_dtype)
|
|
tBrDeltaA_Compute = cute.make_fragment(tBrDeltaA_s2r.shape, self.acc_dtype)
|
|
|
|
tBrB_Compute.store(tBrB_s2r.load().to(self.acc_dtype))
|
|
tBrDelta_Compute.store(tBrDelta_s2r.load().to(self.acc_dtype))
|
|
tBrDeltaA_Compute.store(tBrDeltaA_s2r.load().to(self.acc_dtype))
|
|
|
|
for reg_idx in range(0, cute.size(tBrB_Compute), 2):
|
|
tCompute[reg_idx], tCompute[reg_idx + 1] = cute.arch.mul_packed_f32x2(
|
|
(
|
|
cute.arch.exp(
|
|
(last_column - tBrDeltaA_Compute[reg_idx]).ir_value()
|
|
),
|
|
cute.arch.exp(
|
|
(last_column - tBrDeltaA_Compute[reg_idx + 1]).ir_value()
|
|
),
|
|
),
|
|
(tBrDelta_Compute[reg_idx], tBrDelta_Compute[reg_idx + 1]),
|
|
)
|
|
tCompute[reg_idx], tCompute[reg_idx + 1] = cute.arch.mul_packed_f32x2(
|
|
(tCompute[reg_idx], tCompute[reg_idx + 1]),
|
|
(tBrB_Compute[reg_idx], tBrB_Compute[reg_idx + 1]),
|
|
)
|
|
return tCompute
|
|
|
|
def epilog_make_delta(self, smem_cumsum_delta):
|
|
# Broadcast cumsum delta from LxINPUT_STAGE to LxDxINPUT_STAGE
|
|
sDeltaA = cute.make_tensor(
|
|
smem_cumsum_delta.iterator,
|
|
cute.make_layout(
|
|
(*self.tile_shape_mnk_inter2[:2], self.input_stages),
|
|
stride=(1, 0, smem_cumsum_delta.layout.shape[0]),
|
|
),
|
|
)
|
|
return sDeltaA
|
|
|
|
def epilog_make_d(self, smem_d):
|
|
# Broadcast d from DxINPUT_STAGE to LxDxINPUT_STAGE
|
|
sD = cute.make_tensor(
|
|
smem_d.iterator,
|
|
cute.make_layout(
|
|
(*self.tile_shape_mnk_inter2[:2], self.input_stages),
|
|
stride=(0, 1, smem_d.layout.shape[0]),
|
|
),
|
|
)
|
|
return sD
|
|
|
|
def epilog_tma_partition_y(self, tma_tensor_y, tma_atom_y, smem_y, epi_tile):
|
|
# Local_tile partition global tensors
|
|
# (L, D, 1, 1, C, EH, B)
|
|
gY = cute.local_tile(
|
|
tma_tensor_y,
|
|
cute.slice_(self.tile_shape_mnk_inter2, (None, None, 0)),
|
|
(None, None, None, None, None),
|
|
)
|
|
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, 1, 1, C, EH, B)
|
|
gY_epi = cute.flat_divide(gY, epi_tile)
|
|
# ((ATOM_V, REST_V), INPUT_STAGE)
|
|
# ((ATOM_V, REST_V), EPI_M, EPI_N, 1, 1, C, EH, B)
|
|
bSG_sY, bSG_gY_pre_slice = cpasync.tma_partition(
|
|
tma_atom_y,
|
|
0,
|
|
cute.make_layout(1),
|
|
cute.group_modes(smem_y, 0, 2),
|
|
cute.group_modes(gY_epi, 0, 2),
|
|
)
|
|
return bSG_sY, bSG_gY_pre_slice
|
|
|
|
def epilog_smem_load_and_partition_x(
|
|
self, tiled_t2r_inter2_intra2, local_tidx, smem_xt, epi_tile
|
|
):
|
|
dtype = smem_xt.element_type
|
|
copy_atom_s2r_x = cute.make_copy_atom(
|
|
cute.nvgpu.warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4),
|
|
dtype,
|
|
)
|
|
tiled_s2r_x = cute.make_tiled_copy_D(copy_atom_s2r_x, tiled_t2r_inter2_intra2)
|
|
thr_s2r_x = tiled_s2r_x.get_slice(local_tidx)
|
|
# Partition smem/register tensor for smem store INTER2_P
|
|
# (R2S_ATOM, R2S_M, R2S_N, EPI_M, EPI_N, INPUT_STAGES)
|
|
tSR_sX = thr_s2r_x.partition_S(cute.flat_divide(smem_xt, epi_tile))
|
|
# (R2S_ATOM, R2S_M, R2S_N)
|
|
tSR_rX = cute.make_fragment(
|
|
cute.slice_(tSR_sX.shape, (None, None, None, 0, 0, 0)), dtype
|
|
)
|
|
return tiled_s2r_x, tSR_sX, tSR_rX
|
|
|
|
def epilog_tmem_load_and_partition_acc(self, local_tidx, tIntra, smem_y):
|
|
copy_atom_t2r_inter2_intra2 = cute.make_copy_atom(
|
|
tcgen05.Ld16x256bOp(tcgen05.Repetition(4), tcgen05.Pack.NONE),
|
|
self.acc_dtype,
|
|
)
|
|
return self.make_tmem_load_and_partition(
|
|
copy_atom_t2r_inter2_intra2,
|
|
tIntra,
|
|
(None, None, 0, 0, 0),
|
|
local_tidx,
|
|
smem_y[None, None, 0],
|
|
)
|
|
|
|
|
|
def run(
|
|
gbehcdln: Tuple[int, int, int, int, int, int, int, int],
|
|
io_dtype: Type[cutlass.Numeric],
|
|
cumsum_delta_dtype: Type[cutlass.Numeric],
|
|
acc_dtype: Type[cutlass.Numeric],
|
|
fuse_scale_d: str,
|
|
tolerance: float,
|
|
print_rtol_stats: bool,
|
|
ref_lower_precision: bool,
|
|
warmup_iterations: int,
|
|
iterations: int,
|
|
skip_ref_check: bool,
|
|
use_cold_l2: bool = False,
|
|
**kwargs,
|
|
):
|
|
has_d = fuse_scale_d != "none"
|
|
d_has_hdim = fuse_scale_d == "vector"
|
|
|
|
print(f"Running B100 Mamba2 SSD with:")
|
|
print(f"GBEHCDLN: {gbehcdln}")
|
|
print(
|
|
f"Input/Output dtype: {io_dtype}, Intermediate delta dtype: {cumsum_delta_dtype}, Acc dtype: {acc_dtype}"
|
|
)
|
|
print(
|
|
f"Has D (True means fuse Y+=X*D): {has_d}, D has Hdim (True means D.shape DxEH, False means 1xEH): {d_has_hdim}"
|
|
)
|
|
print(f"Tolerance: {tolerance}")
|
|
print(f"Warmup iterations: {warmup_iterations}")
|
|
print(f"Iterations: {iterations}")
|
|
print(f"Skip reference checking: {skip_ref_check}")
|
|
print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}")
|
|
|
|
# Unpack parameters
|
|
G, B, E, H, C, D, L, N = gbehcdln
|
|
EH = E * H
|
|
|
|
if not torch.cuda.is_available():
|
|
raise RuntimeError("GPU is required to run this example!")
|
|
|
|
# Match same seed in ssd_reference.py for reference check
|
|
torch.manual_seed(42)
|
|
|
|
# Create and permute tensor A/B/C
|
|
def create_and_permute_tensor(
|
|
shape, permute_order, dtype, dt_or_a=0, dynamic_modes=None, ref_tensor=None
|
|
):
|
|
# Build fp32 reference torch tensor
|
|
if ref_tensor is None:
|
|
ref_tensor = (
|
|
torch.empty(*shape, dtype=torch.float32)
|
|
# .random_(-1, 1)
|
|
.normal_(0, 0.5)
|
|
# .uniform_(-1,1)
|
|
.permute(permute_order)
|
|
)
|
|
if dt_or_a == 1: # dt:
|
|
ref_tensor = F.softplus(ref_tensor - 4)
|
|
elif dt_or_a == 2: # A:
|
|
ref_tensor = -torch.exp(ref_tensor)
|
|
|
|
# Build torch_dtype torch tensor
|
|
torch_dtype = cutlass_torch.dtype(dtype)
|
|
|
|
dst_tensor = ref_tensor.to(torch_dtype).cuda()
|
|
cute_tensor = from_dlpack(dst_tensor, assumed_align=16)
|
|
for mode in dynamic_modes:
|
|
cute_tensor = cute_tensor.mark_compact_shape_dynamic(
|
|
mode=mode, stride_order=dst_tensor.dim_order()
|
|
)
|
|
|
|
return ref_tensor, cute_tensor, dst_tensor
|
|
|
|
# INPUT tensors
|
|
# x: (D, L, C, EH, B):(C*L, 1, L, D*C*L, EH*D*C*L)
|
|
x_ref, x_tensor, x_torch = create_and_permute_tensor(
|
|
[B, EH, D, C, L], [2, 4, 3, 1, 0], io_dtype, dynamic_modes=[2, 3, 4]
|
|
)
|
|
# delta/delta_a/cumsum_delta: (L, C, EH, B):(1, L, C*L, EH*C*L)
|
|
delta_ref, delta_tensor, delta_torch = create_and_permute_tensor(
|
|
[B, EH, C, L], [3, 2, 1, 0], io_dtype, dt_or_a=1, dynamic_modes=[1, 2, 3]
|
|
)
|
|
# a: (EH):(1)
|
|
a_ref, a_tensor, a_torch = create_and_permute_tensor(
|
|
[EH], [0], io_dtype, dt_or_a=2, dynamic_modes=[0]
|
|
)
|
|
|
|
if has_d:
|
|
# d: (D, EH):(1, D) or (1, EH):(0, 1)
|
|
d_ref, d_tensor, d_torch = create_and_permute_tensor(
|
|
[EH, D if d_has_hdim else 1], [1, 0], io_dtype, dynamic_modes=[1]
|
|
)
|
|
else:
|
|
d_ref = None
|
|
d_tensor = None
|
|
|
|
# b/c: (L, N, C, G, B):(1, C*L, L, N*C*L, G*N*C*L)
|
|
b_ref, b_tensor, b_torch = create_and_permute_tensor(
|
|
[B, G, N, C, L], [4, 2, 3, 1, 0], io_dtype, dynamic_modes=[2, 3, 4]
|
|
)
|
|
c_ref, c_tensor, c_torch = create_and_permute_tensor(
|
|
[B, G, N, C, L], [4, 2, 3, 1, 0], io_dtype, dynamic_modes=[2, 3, 4]
|
|
)
|
|
|
|
# OUTPUT tensors
|
|
# y: (L, D, C, EH, B):(1, C*L, L, D*C*L, EH*D*C*L)
|
|
y_ref, y_tensor, y_torch = create_and_permute_tensor(
|
|
[B, EH, D, C, L], [4, 2, 3, 1, 0], io_dtype, dynamic_modes=[2, 3, 4]
|
|
)
|
|
# fstate: (D, N, EH, B):(N, 1, D*N, EH*D*N)
|
|
fstate_ref, fstate_tensor, fstate_torch = create_and_permute_tensor(
|
|
[B, EH, D, N], [2, 3, 1, 0], io_dtype, dynamic_modes=[2, 3]
|
|
)
|
|
|
|
# Call pytorch reference on cpu
|
|
if not ref_lower_precision:
|
|
ssd_reference_fp32_all(
|
|
x_ref,
|
|
a_ref,
|
|
delta_ref,
|
|
b_ref,
|
|
c_ref,
|
|
y_ref,
|
|
fstate_ref,
|
|
d_ref,
|
|
has_d,
|
|
d_has_hdim,
|
|
)
|
|
else:
|
|
ssd_reference_lowprecision_intermediates(
|
|
x_ref,
|
|
a_ref,
|
|
delta_ref,
|
|
b_ref,
|
|
c_ref,
|
|
y_ref,
|
|
fstate_ref,
|
|
cutlass_torch.dtype(io_dtype),
|
|
d_ref,
|
|
has_d,
|
|
d_has_hdim,
|
|
)
|
|
|
|
# Compute cumsum with pytorch on cpu
|
|
delta_a_ref = delta_ref * a_ref.view(1, 1, -1, 1)
|
|
cumsum_delta_ref = torch.empty([B, EH, C, L], dtype=torch.float32).permute(
|
|
[3, 2, 1, 0]
|
|
)
|
|
cumsum_delta_ref.copy_(torch.cumsum(delta_a_ref, dim=0).permute([0, 1, 2, 3]))
|
|
# Copy cumsum_delta_ref to cumsum_delta_tensor
|
|
(
|
|
cumsum_delta_ref,
|
|
cumsum_delta_tensor,
|
|
cumsum_delta_torch,
|
|
) = create_and_permute_tensor(
|
|
[B, EH, C, L],
|
|
[3, 2, 1, 0],
|
|
cumsum_delta_dtype,
|
|
ref_tensor=cumsum_delta_ref,
|
|
dynamic_modes=[1, 2, 3],
|
|
)
|
|
|
|
# Call fused ssd kernel
|
|
ssd = SSDKernel(
|
|
io_dtype,
|
|
cumsum_delta_dtype,
|
|
acc_dtype,
|
|
L,
|
|
D,
|
|
N,
|
|
has_d,
|
|
d_has_hdim,
|
|
)
|
|
|
|
# Compute max active clusters on current device
|
|
hardware_info = cutlass.utils.HardwareInfo()
|
|
max_active_clusters = hardware_info.get_max_active_clusters(1)
|
|
|
|
stream = cutlass.cuda.default_stream()
|
|
# Compile ssd kernel
|
|
compiled_ssd = cute.compile(
|
|
ssd,
|
|
x_tensor,
|
|
cumsum_delta_tensor,
|
|
delta_tensor,
|
|
b_tensor,
|
|
c_tensor,
|
|
y_tensor,
|
|
fstate_tensor,
|
|
d_tensor,
|
|
max_active_clusters,
|
|
stream,
|
|
)
|
|
|
|
# Launch compiled ssd kernel for reference check
|
|
if not skip_ref_check:
|
|
compiled_ssd(
|
|
x_tensor,
|
|
cumsum_delta_tensor,
|
|
delta_tensor,
|
|
b_tensor,
|
|
c_tensor,
|
|
y_tensor,
|
|
fstate_tensor,
|
|
d_tensor,
|
|
stream,
|
|
)
|
|
|
|
# Reference check
|
|
if print_rtol_stats:
|
|
print("\nY's Relative diffs:")
|
|
analyze_relative_diffs(
|
|
y_torch.cpu(), y_ref.to(cutlass_torch.dtype(io_dtype))
|
|
)
|
|
print("\nFstate's Relative diffs:")
|
|
analyze_relative_diffs(
|
|
fstate_torch.cpu(), fstate_ref.to(cutlass_torch.dtype(io_dtype))
|
|
)
|
|
torch.testing.assert_close(
|
|
y_torch.cpu(),
|
|
y_ref.to(cutlass_torch.dtype(io_dtype)),
|
|
atol=tolerance,
|
|
rtol=1e-02,
|
|
)
|
|
torch.testing.assert_close(
|
|
fstate_torch.cpu(),
|
|
fstate_ref.to(cutlass_torch.dtype(io_dtype)),
|
|
atol=tolerance,
|
|
rtol=1e-05,
|
|
)
|
|
|
|
def generate_tensors():
|
|
# Reuse existing CPU reference tensors and create new GPU tensors from them
|
|
_, x_tensor_new, _ = create_and_permute_tensor(
|
|
[B, EH, D, C, L],
|
|
[2, 4, 3, 1, 0],
|
|
io_dtype,
|
|
ref_tensor=x_ref,
|
|
dynamic_modes=[2, 3, 4],
|
|
)
|
|
_, cumsum_delta_tensor_new, _ = create_and_permute_tensor(
|
|
[B, EH, C, L],
|
|
[3, 2, 1, 0],
|
|
cumsum_delta_dtype,
|
|
ref_tensor=cumsum_delta_ref,
|
|
dynamic_modes=[1, 2, 3],
|
|
)
|
|
_, delta_tensor_new, _ = create_and_permute_tensor(
|
|
[B, EH, C, L],
|
|
[3, 2, 1, 0],
|
|
io_dtype,
|
|
ref_tensor=delta_ref,
|
|
dynamic_modes=[1, 2, 3],
|
|
)
|
|
_, b_tensor_new, _ = create_and_permute_tensor(
|
|
[B, G, N, C, L],
|
|
[4, 2, 3, 1, 0],
|
|
io_dtype,
|
|
ref_tensor=b_ref,
|
|
dynamic_modes=[2, 3, 4],
|
|
)
|
|
_, c_tensor_new, _ = create_and_permute_tensor(
|
|
[B, G, N, C, L],
|
|
[4, 2, 3, 1, 0],
|
|
io_dtype,
|
|
ref_tensor=c_ref,
|
|
dynamic_modes=[2, 3, 4],
|
|
)
|
|
_, y_tensor_new, _ = create_and_permute_tensor(
|
|
[B, EH, D, C, L],
|
|
[4, 2, 3, 1, 0],
|
|
io_dtype,
|
|
ref_tensor=y_ref,
|
|
dynamic_modes=[2, 3, 4],
|
|
)
|
|
_, fstate_tensor_new, _ = create_and_permute_tensor(
|
|
[B, EH, D, N],
|
|
[2, 3, 1, 0],
|
|
io_dtype,
|
|
ref_tensor=fstate_ref,
|
|
dynamic_modes=[2, 3],
|
|
)
|
|
|
|
if has_d:
|
|
_, d_tensor_new, _ = create_and_permute_tensor(
|
|
[EH, D if d_has_hdim else 1],
|
|
[1, 0],
|
|
io_dtype,
|
|
ref_tensor=d_ref,
|
|
dynamic_modes=[1],
|
|
)
|
|
else:
|
|
d_tensor_new = d_tensor
|
|
|
|
return testing.JitArguments(
|
|
x_tensor_new,
|
|
cumsum_delta_tensor_new,
|
|
delta_tensor_new,
|
|
b_tensor_new,
|
|
c_tensor_new,
|
|
y_tensor_new,
|
|
fstate_tensor_new,
|
|
d_tensor_new,
|
|
stream,
|
|
)
|
|
|
|
workspace_count = 1
|
|
if use_cold_l2:
|
|
one_workspace_bytes = (
|
|
x_torch.numel() * x_torch.element_size()
|
|
+ cumsum_delta_torch.numel() * cumsum_delta_torch.element_size()
|
|
+ delta_torch.numel() * delta_torch.element_size()
|
|
+ b_torch.numel() * b_torch.element_size()
|
|
+ c_torch.numel() * c_torch.element_size()
|
|
+ y_torch.numel() * y_torch.element_size()
|
|
+ fstate_torch.numel() * fstate_torch.element_size()
|
|
)
|
|
if has_d:
|
|
one_workspace_bytes += d_torch.numel() * d_torch.element_size()
|
|
|
|
workspace_count = testing.get_workspace_count(
|
|
one_workspace_bytes, warmup_iterations, iterations
|
|
)
|
|
|
|
exec_time = testing.benchmark(
|
|
compiled_ssd,
|
|
workspace_generator=generate_tensors,
|
|
workspace_count=workspace_count,
|
|
stream=stream,
|
|
warmup_iterations=warmup_iterations,
|
|
iterations=iterations,
|
|
)
|
|
|
|
return exec_time # Return execution time in microseconds
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
def parse_comma_separated_ints(s: str) -> List[int]:
|
|
try:
|
|
return [int(x.strip()) for x in s.split(",")]
|
|
except ValueError:
|
|
raise argparse.ArgumentTypeError(
|
|
"Invalid format. Expected comma-separated integers."
|
|
)
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description="Example of MxNxKxL GEMM on Blackwell."
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--gbehcdln",
|
|
type=parse_comma_separated_ints,
|
|
default=[2, 4, 2, 40, 32, 64, 128, 128],
|
|
# default=[2, 3, 2, 2, 8, 64, 128, 128],
|
|
# default=[1, 2, 1, 4, 8, 64, 128, 128],
|
|
help="gbehcdln dimensions (comma-separated)",
|
|
)
|
|
parser.add_argument("--io_dtype", type=cutlass.dtype, default=cutlass.BFloat16)
|
|
parser.add_argument(
|
|
"--cumsum_delta_dtype", type=cutlass.dtype, default=cutlass.Float32
|
|
)
|
|
parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32)
|
|
parser.add_argument(
|
|
"--fuse_scale_d",
|
|
type=str,
|
|
choices=["none", "scalar", "vector"],
|
|
default="vector",
|
|
help="Fuse scale type: none (no Y+=X*D fusion), scalar (Y+=X*D fusion with D.shape=1xEH), or vector (Y+=X*D fusion with D.shape=DxEH)",
|
|
)
|
|
parser.add_argument(
|
|
"--ref_lower_precision",
|
|
action="store_true",
|
|
default=True,
|
|
help="Use lower precision for reference check",
|
|
)
|
|
parser.add_argument(
|
|
"--no-ref_lower_precision",
|
|
action="store_false",
|
|
dest="ref_lower_precision",
|
|
default=False,
|
|
help="Disable lower precision for reference check",
|
|
)
|
|
parser.add_argument(
|
|
"--tolerance", type=float, default=5e-02, help="Tolerance for validation"
|
|
)
|
|
parser.add_argument(
|
|
"--print_rtol_stats",
|
|
action="store_true",
|
|
default=True,
|
|
help="Enable print rtol stats",
|
|
)
|
|
parser.add_argument(
|
|
"--no-print_rtol_stats",
|
|
action="store_false",
|
|
dest="print_rtol_stats",
|
|
default=False,
|
|
help="Disable print rtol stats",
|
|
)
|
|
parser.add_argument(
|
|
"--warmup_iterations",
|
|
type=int,
|
|
default=0,
|
|
help="Number of warmup iterations",
|
|
)
|
|
parser.add_argument(
|
|
"--iterations",
|
|
type=int,
|
|
default=1,
|
|
help="Number of iterations",
|
|
)
|
|
parser.add_argument(
|
|
"--skip_ref_check", action="store_true", help="Skip reference checking"
|
|
)
|
|
parser.add_argument(
|
|
"--use_cold_l2",
|
|
action="store_true",
|
|
default=False,
|
|
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if len(args.gbehcdln) != 8:
|
|
parser.error("--gbehcdln must contain exactly 8 values")
|
|
|
|
run(
|
|
args.gbehcdln,
|
|
args.io_dtype,
|
|
args.cumsum_delta_dtype,
|
|
args.acc_dtype,
|
|
args.fuse_scale_d,
|
|
args.tolerance,
|
|
args.print_rtol_stats,
|
|
args.ref_lower_precision,
|
|
args.warmup_iterations,
|
|
args.iterations,
|
|
args.skip_ref_check,
|
|
args.use_cold_l2,
|
|
)
|
|
print("PASS")
|