Files
cutlass/python/CuTeDSL/cutlass/utils/hopper_helpers.py
2025-07-03 08:07:53 -04:00

200 lines
6.4 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from typing import Type, Tuple
from enum import Enum
from cutlass.utils.layout import LayoutEnum
from cutlass.cutlass_dsl import (
Float16,
BFloat16,
Float8E5M2,
Float8E4M3FN,
Numeric,
NumericMeta,
dsl_user_op,
)
import cutlass
import cutlass.cute as cute
from cutlass.cute.nvgpu.common import CopyUniversalOp
from cutlass.cute.nvgpu.warp import StMatrix8x8x16bOp
from cutlass.cute.nvgpu.warpgroup import (
MmaF16BF16Op,
MmaF8Op,
OperandMajorMode,
OperandSource,
)
@dsl_user_op
def sm90_get_smem_store_op(
layout_d: LayoutEnum,
elem_ty_d: Type[Numeric],
elem_ty_acc: Type[Numeric],
*,
loc=None,
ip=None,
) -> cute.CopyAtom:
"""
Selects the largest vectorized smem store atom available subject to constraint of gmem layout.
Parameters:
-----------
layout_d : LayoutEnum
The layout enum of the output tensor D.
elem_ty_d : Type[Numeric]
The element type for output tensor D.
elem_ty_acc : Type[Numeric]
The element type for accumulator.
Returns:
--------
Either SmemStoreMatrix or SimtSyncCopy, based on the input parameters.
"""
def validate_type(ty, ty_name):
if not isinstance(ty, NumericMeta):
raise TypeError(f"{ty_name} must be a Numeric, but got {ty}")
validate_type(elem_ty_d, "elem_ty_d")
validate_type(elem_ty_acc, "elem_ty_acc")
is_m_major = layout_d.is_m_major_c()
if elem_ty_d.width == 16:
return cute.make_copy_atom(
StMatrix8x8x16bOp(is_m_major, 4), elem_ty_d, loc=loc, ip=ip
)
else:
return cute.make_copy_atom(CopyUniversalOp(), elem_ty_d, loc=loc, ip=ip)
class SmemCapacity(Enum):
SM90_SMEM_CAPACITY_BYTES = (228 - 1) * 1024
# Dictionary to map compute capability to SMEM capacity
SMEM_CAPACITY = {
"sm90": SmemCapacity.SM90_SMEM_CAPACITY_BYTES.value,
}
def make_trivial_tiled_mma(
a_dtype: Type[Numeric],
b_dtype: Type[Numeric],
a_leading_mode: OperandMajorMode,
b_leading_mode: OperandMajorMode,
acc_dtype: Type[Numeric],
atom_layout_mnk: Tuple[int, int, int],
tiler_mn: Tuple[int, int],
a_source: OperandSource = OperandSource.SMEM,
*,
loc=None,
ip=None,
) -> cute.TiledMma:
"""Make a tiled MMA atom with given data type, leading dimension, cta group and mma tile shape.
By default, the MMA atom is created with SMEM operand source for A.
:param a_dtype: Data type of operand A.
:type a_dtype: type[Numeric]
:param b_dtype: Data type of operand B.
:type b_dtype: type[Numeric]
:param a_leading_mode: Leading dimension of operand A (1 for K, 0 for M/N).
:type a_leading_mode: warpgroup.OperandMajorMode
:param b_leading_mode: Leading dimension of operand B (1 for K, 0 for M/N).
:type b_leading_mode: warpgroup.OperandMajorMode
:param acc_dtype: Data type of the accumulator.
:type acc_dtype: type[Numeric]
:param atom_layout_mnk: A integer tuple describing the tiling of Atom across threads.
:type atom_layout_mnk: Tuple[int, int, int]
:param tiler_mn: The shape (M, N) of the cta tiler.
:type tiler_mn: Tuple[int, int]
:return: A tiled MMA atom.
:rtype: cute.TiledMma
:raises TypeError: If the data type is not supported.
"""
if a_dtype in {Float16, BFloat16}:
if cutlass.const_expr(a_dtype != b_dtype):
raise TypeError(f"Type mismatch: {a_dtype} != {b_dtype}")
if cutlass.const_expr(a_dtype.width != b_dtype.width):
raise TypeError(f"Type width mismatch: {a_dtype.width} != {b_dtype.width}")
mma_op = MmaF16BF16Op(
a_dtype,
acc_dtype,
(*tiler_mn, 16),
a_source,
a_leading_mode,
b_leading_mode,
)
elif a_dtype in {Float8E4M3FN, Float8E5M2} and b_dtype in {
Float8E4M3FN,
Float8E5M2,
}:
mma_op = MmaF8Op(
a_dtype,
b_dtype,
acc_dtype,
(*tiler_mn, 32),
a_source,
a_leading_mode,
b_leading_mode,
)
else:
raise TypeError(f"unsupported a_dtype and b_dtype, got {a_dtype} and {b_dtype}")
return cute.make_tiled_mma(cute.make_mma_atom(mma_op), atom_layout_mnk)
def get_smem_layout_atom(
layout: LayoutEnum,
element_type: Type[Numeric],
major_mode_size: int,
*,
loc=None,
ip=None,
):
"""Select the optimal shared memory layout atom based on parameters.
:param layout: Layout enum of the tensor
:type layout: LayoutEnum
:param element_type: Data type of the elements
:type element_type: type[cutlass.Numeric]
:param major_mode_size: Size of the major mode dimension
:type major_mode_size: int
:return: Selected shared memory layout atom kind
:rtype: cute.nvgpu.warpgroup.SmemLayoutAtomKind
"""
assert major_mode_size % 8 == 0
sw128_num_contiguous_bits = 1024
sw64_num_contiguous_bits = 512
sw32_num_contiguous_bits = 256
major_mode_size_bits = major_mode_size * element_type.width
if layout.sm90_mma_major_mode() == OperandMajorMode.MN:
if major_mode_size_bits % sw128_num_contiguous_bits == 0:
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW128
if major_mode_size_bits % sw64_num_contiguous_bits == 0:
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW64
if major_mode_size_bits % sw32_num_contiguous_bits == 0:
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW32
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_INTER
if major_mode_size_bits % sw128_num_contiguous_bits == 0:
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW128
if major_mode_size_bits % sw64_num_contiguous_bits == 0:
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW64
if major_mode_size_bits % sw32_num_contiguous_bits == 0:
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW32
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_INTER