Rename python/cutlass to python/cutlass_cppgen (#2652)

This commit is contained in:
Jack Kosaian
2025-09-18 13:26:57 -05:00
committed by Haicheng Wu
parent 4260d4aef9
commit 177a82e251
71 changed files with 1 additions and 1 deletions

View File

@ -0,0 +1,213 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import logging
import os
import sys
import cutlass_library
def _cuda_install_path_from_nvcc() -> str:
import subprocess
# Attempt to detect CUDA_INSTALL_PATH based on location of NVCC
result = subprocess.run(['/usr/bin/which', 'nvcc'], capture_output=True)
if result.returncode != 0:
raise Exception(f'Unable to find nvcc via `which` utility.')
cuda_install_path = result.stdout.decode('utf-8').split('/bin/nvcc')[0]
if not os.path.isdir(cuda_install_path):
raise Exception(f'Environment variable "CUDA_INSTALL_PATH" is not defined, '
f'and default path of {cuda_install_path} does not exist.')
return cuda_install_path
CUTLASS_PATH = os.getenv("CUTLASS_PATH", cutlass_library.source_path)
# Alias CUTLASS_PATH as source_path
source_path = CUTLASS_PATH
_NVCC_VERSION = None
def nvcc_version():
global _NVCC_VERSION
if _NVCC_VERSION is None:
import subprocess
# Attempt to get NVCC version
result = subprocess.run(['nvcc', '--version'], capture_output=True)
if result.returncode != 0:
raise Exception('Unable to run `nvcc --version')
_NVCC_VERSION = str(result.stdout).split(" release ")[-1].split(",")[0]
return _NVCC_VERSION
_CUDA_INSTALL_PATH = None
def cuda_install_path():
"""
Helper method for on-demand fetching of the CUDA installation path. This allows
the import of CUTLASS to proceed even if NVCC is not available, preferring to
raise this error only when an operation that needs NVCC is being performed.
"""
global _CUDA_INSTALL_PATH
if _CUDA_INSTALL_PATH is None:
_CUDA_INSTALL_PATH = os.getenv("CUDA_INSTALL_PATH", _cuda_install_path_from_nvcc())
return _CUDA_INSTALL_PATH
CACHE_FILE = "compiled_cache.db"
from cutlass_library import (
DataType,
EpilogueScheduleType,
KernelScheduleType,
MathOperation,
LayoutType,
OpcodeClass,
TileDescription,
TileSchedulerType,
)
this = sys.modules[__name__]
this.logger = logging.getLogger(__name__)
# RMM is only supported for Python 3.9+
if (sys.version_info.major == 3 and sys.version_info.minor > 8) or sys.version_info.major > 3:
try:
import rmm
this.use_rmm = True
except ImportError:
this.use_rmm = False
else:
this.use_rmm = False
def set_log_level(level: int):
"""
Sets the log level
:param log_level: severity of logging level to use. See https://docs.python.org/3/library/logging.html#logging-levels for options
:type log_level: int
"""
this.logger.setLevel(level)
set_log_level(logging.ERROR)
from cutlass_cppgen.library_defaults import OptionRegistry
from cutlass_cppgen.backend.utils.device import device_cc
this._option_registry = None
def get_option_registry():
"""
Helper method for on-demand initialization of the options registry. This avoids building
the registry when CUTLASS is imported.
"""
if this._option_registry is None:
this.logger.info("Initializing option registry")
this._option_registry = OptionRegistry(device_cc())
return this._option_registry
this.__version__ = '4.2.1'
from cutlass_cppgen.backend import create_memory_pool
from cutlass_cppgen.emit.pytorch import pytorch
from cutlass_cppgen.op.gemm import Gemm
from cutlass_cppgen.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad
from cutlass_cppgen.op.gemm_grouped import GroupedGemm
from cutlass_cppgen.op.op import OperationBase
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
from cutlass_cppgen.utils.lazy_import import lazy_import
this.memory_pool = None
def get_memory_pool():
""""
Helper method for on-demand memory pool. This avoids allocating the memory pool unnecessarily
whe CUTLASS is imported.
"""
if this.use_rmm and this.memory_pool is None:
this.memory_pool = create_memory_pool(init_pool_size=2 ** 30, max_pool_size=2 ** 32)
return this.memory_pool
base_cuda = lazy_import("cuda")
cuda = lazy_import("cuda.cuda")
cudart = lazy_import("cuda.cudart")
this._device_id = None
this._nvcc_version = None
def check_cuda_versions():
# Strip any additional information from the CUDA version
_cuda_version = base_cuda.__version__.split("rc")[0]
# Check that Python CUDA version exceeds NVCC version
this._nvcc_version = nvcc_version()
_cuda_list = _cuda_version.split('.')
_nvcc_list = this._nvcc_version.split('.')
for val_cuda, val_nvcc in zip(_cuda_list, _nvcc_list):
if int(val_cuda) < int(val_nvcc):
raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {this._nvcc_version}")
if len(_nvcc_list) > len(_cuda_list):
if len(_nvcc_list) != len(_cuda_list) + 1:
raise Exception(f"Malformatted NVCC version of {this._nvcc_version}")
if _nvcc_list[:-1] == _cuda_list and int(_nvcc_list[-1]) != 0:
raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {this._nvcc_version}")
def initialize_cuda_context():
check_cuda_versions()
if this._device_id is not None:
return
if this.use_rmm:
# This also covers initializing the CUDA context
get_memory_pool()
device_id = os.getenv("CUTLASS_CUDA_DEVICE_ID")
if device_id is None:
if not this.use_rmm:
# Manually call cuInit() and create context by making a runtime API call
err, = cudart.cudaFree(0)
if err != cudart.cudaError_t.cudaSuccess:
raise RuntimeError(f"cudaFree failed with error {err}")
err, device_count = cuda.cuDeviceGetCount()
if err != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"cuDeviceGetCount failed with error {err}")
if device_count <= 0:
raise Exception("No CUDA devices found")
device_id = 0
this._device_id = int(device_id)
def device_id() -> int:
initialize_cuda_context()
return this._device_id

View File

@ -0,0 +1,48 @@
#################################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_cppgen.backend.arguments import *
from cutlass_cppgen.backend.c_types import *
from cutlass_cppgen.backend.compiler import ArtifactManager
from cutlass_cppgen.backend.conv2d_operation import *
from cutlass_cppgen.backend.epilogue import *
from cutlass_cppgen.backend.frontend import *
from cutlass_cppgen.backend.gemm_operation import *
from cutlass_cppgen.backend.library import *
from cutlass_cppgen.backend.memory_manager import PoolMemoryManager, create_memory_pool
from cutlass_cppgen.backend.operation import *
from cutlass_cppgen.backend.reduction_operation import *
from cutlass_cppgen.backend.type_hint import *
from cutlass_cppgen.backend.utils import *
from cutlass_cppgen.backend.utils.device import device_cc
compiler = ArtifactManager()

View File

@ -0,0 +1,136 @@
#################################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from math import prod
from typing import Union
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
cudart = lazy_import("cuda.cudart")
import numpy as np
import cutlass_cppgen
from cutlass_cppgen.backend.frontend import CupyFrontend, NumpyFrontend, TorchFrontend
from cutlass_cppgen.backend.memory_manager import DevicePtrWrapper
from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor
class ArgumentBase:
"""
Base class for operation arguments
"""
def __init__(
self,
A: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]",
B: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]",
C: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]",
D: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]",
**kwargs,
) -> None:
# tensor_C can be interpreted as the bias with bias=True in keyword args
self.bias = kwargs.get("bias", False)
self.stream = kwargs.get("stream", cuda.CUstream(0))
# RMM buffers used to track tensor lifetime
self.buffers = {}
# Host tensor to copy the computed result back
self.host_tensors = {}
self.ptr_A = self.tensor_to_ptr(A, "A")
self.ptr_B = self.tensor_to_ptr(B, "B")
self.ptr_C = self.tensor_to_ptr(C, "C")
self.ptr_D = self.tensor_to_ptr(D, "D", is_output=True)
if C is not None:
if not isinstance(C, cuda.CUdeviceptr):
self.tensor_c_numel = prod(C.shape)
def tensor_to_ptr(self, tensor, name, is_output=False):
"""
Convert and remember the input tensor to cuda.CUdeviceptr used by cuda python
For numpy.ndarray, it also remembers the host buffer for synchronization
"""
if tensor is None:
return cuda.CUdeviceptr(0)
if is_numpy_tensor(tensor):
if is_output:
assert name
self.buffers[name] = NumpyFrontend.argument(tensor, is_output)
if is_output:
self.host_tensors[name] = tensor
return self.buffers[name].ptr
elif is_torch_tensor(tensor):
return TorchFrontend.argument(tensor)
elif isinstance(tensor, cuda.CUdeviceptr):
return tensor
elif is_cupy_tensor(tensor):
return CupyFrontend.argument(tensor)
else:
raise TypeError("Unsupported Frontend. Only support numpy and torch")
def sync(self, stream_sync=True):
if stream_sync:
(err,) = cudart.cudaDeviceSynchronize()
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
for key in self.host_tensors.keys():
host_tensor = self.host_tensors[key]
(err,) = cuda.cuMemcpyDtoH(
host_tensor,
self.buffers[key].ptr,
host_tensor.size * host_tensor.itemsize,
)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
self.free()
def free(self):
"""
Frees allocated device-side memory
"""
# Free any device memory allocated manually
if not cutlass_cppgen.use_rmm:
for name, buf in self.buffers.items():
if isinstance(buf, DevicePtrWrapper):
err, = cudart.cudaFree(buf.ptr)
if err != cudart.cudaError_t.cudaSuccess:
raise RuntimeError(f"cudaFree failed with error {err}")
if hasattr(self, "workspace_buffer") and isinstance(self.workspace_buffer, DevicePtrWrapper):
err, = cudart.cudaFree(self.workspace_buffer.ptr)
if err != cudart.cudaError_t.cudaSuccess:
raise RuntimeError(f"cudaFree failed with error {err}")
del self.workspace_buffer

View File

@ -0,0 +1,625 @@
#################################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import ctypes
from cutlass_library import (
DataType,
KernelScheduleType,
TileSchedulerType
)
from cutlass_cppgen.backend.library import DataTypeSizeBytes
class GemmCoord_(ctypes.Structure):
_fields_ = [
("m", ctypes.c_int),
("n", ctypes.c_int),
("k", ctypes.c_int)
]
def __init__(self, m, n, k) -> None:
self.m = m
self.n = n
self.k = k
class GemmCoordBatched_(ctypes.Structure):
"""
Wrapper around a GemmCoord that also contains batch count. This is used for encoding
batched GEMM inputs to CUTLASS 3 GEMMs.
"""
_fields_ = [
("m", ctypes.c_int),
("n", ctypes.c_int),
("k", ctypes.c_int),
("batch_count", ctypes.c_int)
]
def __init__(self, gemm_coord, batch_count) -> None:
self.m = gemm_coord.m
self.n = gemm_coord.n
self.k = gemm_coord.k
self.batch_count = batch_count
class MatrixCoord_(ctypes.Structure):
_fields_ = [
("row", ctypes.c_int),
("column", ctypes.c_int)
]
class dim3_(ctypes.Structure):
_fields_ = [
("x", ctypes.c_int),
("y", ctypes.c_int),
("z", ctypes.c_int)
]
class StrideBatched_(ctypes.Structure):
"""
CUTLASS 3.0 strides for operands contain one static dimension and two variable dimensions. The
variable dimensions represent the stride along non-unit-stride dimension of the row/column major
layout, and the batch stride. This structure encodes the two variable dimensions.
"""
_fields_ = [
("major_stride", ctypes.c_int64),
("batch_stride", ctypes.c_int64)
]
class GenericMainloopArguments3x_(ctypes.Structure):
"""
Structure representing the superset of possible mainloop arguments.
This structure should not be passed to kernels directly, but, rather,
be used as an input to one of the more specific schedule arguments, which
will each select those arguments relevant to the particular schedule.
"""
_fields_ = [
("ptr_A", ctypes.c_void_p),
("stride_A", StrideBatched_),
("ptr_B", ctypes.c_void_p),
("stride_B", StrideBatched_),
("mma_promotion_interval", ctypes.c_int)
]
class _PersistentTileSchedulerArguments(ctypes.Structure):
_fields_ = [
("max_swizzle_size", ctypes.c_int),
("raster_order_option", ctypes.c_int),
]
class _PersistentTileSchedulerStreamKArguments(ctypes.Structure):
_fields_ = [
("splits", ctypes.c_int),
("max_swizzle_size", ctypes.c_int),
("raster_order_option", ctypes.c_int),
("reduction_mode", ctypes.c_int),
("decomposition_mode", ctypes.c_int),
]
def get_tile_scheduler_arguments_3x(
tile_scheduler: TileSchedulerType,
splits: int = 1):
max_swizzle_size = 1
raster_order_option = 0 # Heuristic
if tile_scheduler in [TileSchedulerType.Default, TileSchedulerType.Persistent]:
return _PersistentTileSchedulerArguments(
max_swizzle_size,
raster_order_option,
)
elif tile_scheduler == TileSchedulerType.StreamK:
reduction_mode = 0 # Deterministic
decomposition_mode = 0 # Heuristic
return _PersistentTileSchedulerStreamKArguments(
splits,
max_swizzle_size,
raster_order_option,
reduction_mode,
decomposition_mode,
)
def get_mainloop_arguments_3x(
kernel_schedule: KernelScheduleType,
element_A,
element_B,
alignment_A: int,
alignment_B: int) -> ctypes.Structure:
"""
Returns the ctypes structure to be used for the 3.x kernel's mainloop parameters.
:param kernel_schedule: type of kernel schedule to be used in the mainloop
:type kernel_schedule: cutlass_library.KernelScheduleType
:param element_A: data type of operand A
:param element_B: data type of operand B
:param alignment_A: alignment of operand A
:type alignment_A: int
:param alignment_B: alignment of operand B
:type alignment_B: int
:returns: ctypes structure to be used for the 3.x kernel's mainloop parameters
:rtype: ctypes.Structure
"""
class _MainloopArgumentsTma(ctypes.Structure):
_fields_ = [
("ptr_A", ctypes.c_void_p),
("stride_A", StrideBatched_),
("ptr_B", ctypes.c_void_p),
("stride_B", StrideBatched_),
("mma_promotion_interval", ctypes.c_int)
]
@staticmethod
def from_generic_mainloop_args(args: GenericMainloopArguments3x_):
return _MainloopArgumentsTma(
args.ptr_A, args.stride_A, args.ptr_B, args.stride_B,
args.mma_promotion_interval
)
class _MainloopArgumentsMultistage(ctypes.Structure):
_fields_ = [
("ptr_A", ctypes.c_void_p),
("stride_A", StrideBatched_),
("ptr_B", ctypes.c_void_p),
("stride_B", StrideBatched_),
]
@staticmethod
def from_generic_mainloop_args(args: GenericMainloopArguments3x_):
return _MainloopArgumentsMultistage(
args.ptr_A, args.stride_A, args.ptr_B, args.stride_B,
)
# Currently all 3.x kernels (CpAsync and Tma) have the same argument structure.
# Should that become not the case, this is the place to return custom ctypes
# structures based on selected kernel schedule.
return _MainloopArgumentsTma
def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor, scheduler_args, default_epilogue):
if not default_epilogue and hasattr(epilogue_functor, "epilogue_type_evt"):
_EpilogueOutputOpParams = epilogue_functor.epilogue_type_evt
else:
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
if hasattr(epilogue_functor, "visitor"):
class _EpilogueArguments(ctypes.Structure):
_fields_ = [
("epilogue", _EpilogueOutputOpParams),
("arg_C", epilogue_functor.arg_c_type),
("arg_D", epilogue_functor.arg_d_type)
]
def __init__(self, output_op, ptr_c, stride_c, ptr_d, stride_d) -> None:
self.epilogue = output_op
self.arg_C = epilogue_functor.arg_c_type(ptr_c)
self.arg_D = epilogue_functor.arg_d_type(ptr_d)
else:
class _EpilogueArguments(ctypes.Structure):
_fields_ = [
("epilogue", _EpilogueOutputOpParams),
("ptr_C", ctypes.c_void_p),
("stride_C", StrideBatched_),
("ptr_D", ctypes.c_void_p),
("stride_D", StrideBatched_),
]
class _HardwareInfo(ctypes.Structure):
_fields_ = [
("device_id", ctypes.c_int),
("sm_count", ctypes.c_int),
("max_active_clusters", ctypes.c_int),
("cluster_shape", dim3_),
("cluster_shape_fallback", dim3_),
]
class _GemmArguments(ctypes.Structure):
_fields_ = [
("mode", ctypes.c_int),
("problem_size", GemmCoordBatched_),
("mainloop", mainloop_arguments),
("epilogue", _EpilogueArguments),
("hw_info", _HardwareInfo),
("scheduler", type(scheduler_args)),
]
return _GemmArguments, _EpilogueArguments, _EpilogueOutputOpParams, _HardwareInfo
def get_gemm_arguments(epilogue_functor):
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
class _GemmArguments(ctypes.Structure):
_fields_ = [
# Arguments from UniversalArgumentsBase
("mode", ctypes.c_int),
("problem_size", GemmCoord_),
("batch_count", ctypes.c_int),
("batch_stride_D", ctypes.c_longlong),
# Remaining arguments
("epilogue", _EpilogueOutputOpParams),
("ptr_A", ctypes.c_void_p),
("ptr_B", ctypes.c_void_p),
("ptr_C", ctypes.c_void_p),
("ptr_D", ctypes.c_void_p),
("batch_stride_A", ctypes.c_longlong),
("batch_stride_B", ctypes.c_longlong),
("batch_stride_C", ctypes.c_longlong),
("stride_a", ctypes.c_longlong),
("stride_b", ctypes.c_longlong),
("stride_c", ctypes.c_longlong),
("stride_d", ctypes.c_longlong),
("lda", ctypes.c_longlong),
("ldb", ctypes.c_longlong),
("ldc", ctypes.c_longlong),
("ldd", ctypes.c_longlong),
("ptr_gather_A_indices", ctypes.c_void_p),
("ptr_gather_B_indices", ctypes.c_void_p),
("ptr_scatter_D_indices", ctypes.c_void_p)
]
return _GemmArguments, _EpilogueOutputOpParams
def get_gemm_arguments_streamk(epilogue_functor):
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
class _GemmArguments(ctypes.Structure):
_fields_ = [
("mode", ctypes.c_int),
("problem_size", GemmCoord_),
("batch_count", ctypes.c_int),
("epilogue", _EpilogueOutputOpParams),
("ptr_A", ctypes.c_void_p),
("ptr_B", ctypes.c_void_p),
("ptr_C", ctypes.c_void_p),
("ptr_D", ctypes.c_void_p),
("batch_stride_A", ctypes.c_longlong),
("batch_stride_B", ctypes.c_longlong),
("batch_stride_C", ctypes.c_longlong),
("batch_stride_D", ctypes.c_longlong),
("stride_a", ctypes.c_longlong),
("stride_b", ctypes.c_longlong),
("stride_c", ctypes.c_longlong),
("stride_d", ctypes.c_longlong),
("lda", ctypes.c_longlong),
("ldb", ctypes.c_longlong),
("ldc", ctypes.c_longlong),
("ldd", ctypes.c_longlong),
("avail_sms", ctypes.c_int)
]
return _GemmArguments, _EpilogueOutputOpParams
###########################################################################################
# GEMM Grouped
###########################################################################################
def get_gemm_grouped_arguments(epilogue_functor):
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
class _GEMMGroupedArguments(ctypes.Structure):
_fields_ = [
("problem_sizes", ctypes.c_void_p),
("problem_count", ctypes.c_int),
("threadblock_count", ctypes.c_int),
("output_op", _EpilogueOutputOpParams),
("ptr_A", ctypes.c_void_p),
("ptr_B", ctypes.c_void_p),
("ptr_C", ctypes.c_void_p),
("ptr_D", ctypes.c_void_p),
("lda", ctypes.c_void_p),
("ldb", ctypes.c_void_p),
("ldc", ctypes.c_void_p),
("ldd", ctypes.c_void_p),
("host_problem_sizes", ctypes.c_void_p)
]
return _GEMMGroupedArguments, _EpilogueOutputOpParams
############################################################################################
# Convolution2D
############################################################################################
class Conv2DProblemSize_(ctypes.Structure):
_fields_ = [
("N", ctypes.c_int),
("H", ctypes.c_int),
("W", ctypes.c_int),
("C", ctypes.c_int),
("P", ctypes.c_int),
("Q", ctypes.c_int),
("K", ctypes.c_int),
("R", ctypes.c_int),
("S", ctypes.c_int),
("pad_h", ctypes.c_int),
("pad_w", ctypes.c_int),
("stride_h", ctypes.c_int),
("stride_w", ctypes.c_int),
("dilation_h", ctypes.c_int),
("dilation_w", ctypes.c_int),
("mode", ctypes.c_int), # kCrossCorrelation: 0, kConvolution: 1
("split_k_slices", ctypes.c_int),
("groups", ctypes.c_int)
]
def __init__(self, problem_size) -> None:
for field_name, _ in self._fields_:
setattr(self, field_name, getattr(problem_size, field_name))
class Layout4D(ctypes.Structure):
_fields_ = [("stride", ctypes.c_int * 3)]
def __init__(self, tensor_ref):
stride = tensor_ref.stride()
setattr(self, "stride", (stride.at(0), stride.at(1), stride.at(2)))
class TensorRef_(ctypes.Structure):
_fields_ = [
("ptr", ctypes.c_void_p),
("layout", Layout4D)
]
def __init__(self, tensor_ref):
setattr(self, "ptr", tensor_ref.data())
setattr(self, "layout", Layout4D(tensor_ref.layout()))
class TensorRef2D_(ctypes.Structure):
_fields_ = [
("ptr", ctypes.c_void_p),
("stride", ctypes.c_int)
]
def get_conv2d_arguments(epilogue_functor):
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
class _Conv2dArguments(ctypes.Structure):
_fields_ = [
("conv_kind", ctypes.c_int),
("problem_size", Conv2DProblemSize_),
("ptr_A", ctypes.c_void_p),
("ptr_B", ctypes.c_void_p),
("ptr_C", ctypes.c_void_p),
("ptr_D", ctypes.c_void_p),
("tensor_C_numel", ctypes.c_int),
("output_op", _EpilogueOutputOpParams),
("split_k_mode", ctypes.c_int)
]
return _Conv2dArguments, _EpilogueOutputOpParams
############################################################################################
# Reduction
############################################################################################
def get_reduction_params(epilogue_functor):
_EpilogueOutputParams = epilogue_functor.epilogue_type
class _ReductionParams(ctypes.Structure):
_fields_ = [
("problem_size", MatrixCoord_),
("partitions", ctypes.c_int),
("partition_stride", ctypes.c_longlong),
("workspace", TensorRef2D_),
("destination", TensorRef2D_),
("source", TensorRef2D_),
("output_op", _EpilogueOutputParams),
]
return _ReductionParams, _EpilogueOutputParams
###########################################################################################
# Epilogue Visitor Type Factory
###########################################################################################
class Empty(ctypes.Structure):
_fields_ = []
def __init__(self, *arg) -> None:
pass
class EmptyByte(ctypes.Structure):
_fields_ = [
("byte", ctypes.c_byte)
]
def __init__(self, *arg) -> None:
pass
class EBO:
def __init__(self, index: int, type) -> None:
self.index = index
self.type = type
def __eq__(self, other) -> bool:
if isinstance(other, EBO):
return self.index == other.index and self.type == other.type
return False
def __hash__(self) -> int:
return hash((self.index, self.type))
def __ne__(self, other):
return not self.__eq__(other)
def __str__(self) -> str:
return f"<{self.index}, {self.type}>"
def tuple_factory_(input_tuple, dtype, constants=[0,1]):
"""
The factory function generating cute::Tuple with input tuple
:param input_tuple: the input tuple
:type input_tuple: tuple
:param dtype: the data type for non-constant values
:type dtype: str, "int32_t", "int", "int64_t"
:param constant: the values that will be treated as constants
:type constant: list[int]
:return: ctype structure representing the cute::Tuple
:return: the empty base classes of the tuple
"""
# The empty base classes of the current tuple
empty_bases = []
# The first non empty base class
first_non_empty_base = None
# The ctype fields of the current tuple
ctype_fields = []
for idx, entry in enumerate(input_tuple):
# For nested tuples
if isinstance(entry, tuple):
sub_tuple_ctype, sub_empty_bases = tuple_factory_(entry, dtype, constants)
if ctypes.sizeof(sub_tuple_ctype) == 0:
# The empty tuple base class is also an empty EBO
empty_bases.append(EBO(idx, entry))
else:
if first_non_empty_base is None:
first_non_empty_base = sub_empty_bases
ctype_fields.append((f"entry_{idx}", sub_tuple_ctype))
else:
if entry in constants:
empty_bases.append(EBO(idx, entry))
ctype_fields.append((f"entry_{idx}", Empty))
else:
ctype_fields.append((f"entry_{idx}", dtype))
if first_non_empty_base is None:
first_non_empty_base = []
# Create the ctype tuple
class TupleType(ctypes.Structure):
_fields_ = ctype_fields
def __init__(self, args) -> None:
fields = self._fields_
assert len(fields) == len(args)
for field, arg in zip(fields, args):
name = field[0]
field_type = field[1]
setattr(self, name, field_type(arg))
return TupleType, empty_bases
def tuple_factory(input_tuple, dtype: str, constants=[0,1]):
"""
The factory function generating cute::Tuple with input tuple
:param input_tuple: the input tuple
:type input_tuple: tuple
:param dtype: the data type for non-constant values
:type dtype: str, "int32_t", "int", "int64_t"
:param constant: the values that will be treated as constants
:type constant: list[int]
:return: ctype structure representing the cute::Tuple
:return: the empty base classes of the tuple
"""
# Step 1: convert the dtype
if dtype == "int64_t":
dtype = ctypes.c_longlong
elif dtype in ["int", "int32_t"]:
dtype = ctypes.c_int32
else:
raise NotImplementedError(f"Type {dtype} is not supported")
tuple_type, _ = tuple_factory_(input_tuple, dtype, constants)
if ctypes.sizeof(tuple_type) == 0:
return EmptyByte
return tuple_type
def visitor_factory(node_types, node_names):
"""
Creates the argument type of epilogue visitor type
:param node_types: list of argument types under ctypes
:param node_names: list of argument names under str
:return: tuple type in ctypes.Structure
"""
ctypes_field = []
# Struct is used when number of nodes < 4
# Because the Sm90VisitorImplBase has specification up to 4 nodes
# in `include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp`
if len(node_types) <= 4:
for idx, node_type in enumerate(node_types):
if ctypes.sizeof(node_type) == 0:
# Special case for empty struct
# 1 byte placeholder is used for correct alignment
ctypes_field.append((node_names[idx], ctypes.c_byte))
else:
ctypes_field.append((node_names[idx], node_type))
class VisitorType(ctypes.Structure):
_fields_ = ctypes_field
def __init__(self, kwargs) -> None:
for field in self._fields_:
fname, ftype = field
if ftype != ctypes.c_byte:
setattr(self, fname, ftype(kwargs))
# For cases with more than 4 nodes, tuple is used
else:
for idx, node_type in enumerate(node_types):
ctypes_field.append((node_names[idx], node_type))
class VisitorType(ctypes.Structure):
_fields_ = ctypes_field
def __init__(self, kwargs) -> None:
for field in self._fields_:
fname, ftype = field
setattr(self, fname, ftype(kwargs))
return VisitorType

View File

@ -0,0 +1,462 @@
#################################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import ctypes
import json
import os
import sqlite3
import subprocess
import tempfile
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
cudart = lazy_import("cuda.cudart")
nvrtc = lazy_import("cuda.nvrtc")
from cutlass_library import SubstituteTemplate
import cutlass_cppgen
from cutlass_cppgen import CACHE_FILE, CUTLASS_PATH, cuda_install_path, logger
from cutlass_cppgen.backend.gemm_operation import GemmOperationUniversal
from cutlass_cppgen.backend.library import ApiVersion
from cutlass_cppgen.backend.utils.device import device_cc
IncludeTemplate = r"""#include "${include}"
"""
def compile_with_nvcc(cmd, source, error_file):
succeed = True
try:
subprocess.check_output(cmd, stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as e:
error_message = e.output.decode()
with open(error_file, "w") as error_out:
error_log = "Compilation error for the following kernel: \n"
error_log += source
error_log += "\nError Message:\n"
error_log += error_message
error_out.write(error_log)
succeed = False
if not succeed:
# Print the error log to stdout if log level is set to warning or higher
# verbosity. Otherwise, simply point to the error log file.
logger.warning(error_log)
raise Exception(f"Invalid Kernel. See '{error_file}' for details.")
class CompilationOptions:
"""
Compilation options.
"""
def __init__(self, flags, arch, include_paths=[]):
self.includes = []
self.include_paths = include_paths
self.flags = flags
self.arch = arch
def get_str(self):
opts = []
for flag in self.flags:
opts.append(flag)
for incl in self.include_paths:
opts.append(f"--include-path={incl}")
arch_flag = f"-arch=sm_{self.arch}"
if self.arch in [90, 100, 101, 103, 120, 121] and int(cutlass_cppgen.nvcc_version().split('.')[0]) >= 12:
arch_flag += "a"
opts.append(arch_flag)
return " ".join(opts)
def get(self):
options = []
for flag in self.flags:
options.append(bytes(str.encode(flag)))
for incl in self.include_paths:
options.append(bytes(str.encode(f" --include-path={incl}")))
arch_flag = f" -arch=sm_{self.arch}"
if self.arch in [90, 100, 101, 103, 120, 121]:
arch_flag += "a"
options.append(bytes(str.encode(arch_flag)))
return options
def convertToBinaryData(filename):
with open(filename, "rb") as file:
blobData = file.read()
return blobData
def CDLLBin(host_binary):
tempfile.tempdir = "./"
temp_so = tempfile.NamedTemporaryFile(prefix="host_func", suffix=".so", delete=True)
with open(temp_so.name, "wb") as file:
file.write(host_binary)
host_lib = ctypes.CDLL(temp_so.name)
return host_lib
class ArtifactManager:
"""
Artifact manager
"""
def __init__(self) -> None:
connection = sqlite3.connect(CACHE_FILE)
cursor = connection.cursor()
# Create the table if it does not already exist
sqlite_create_table_query = """
CREATE TABLE IF NOT EXISTS compiled_operations(op_key TEXT NOT NULL UNIQUE,
cubin BLOB NOT NULL,
hostbin BLOB NOT NULL,
op_name TEXT NOT NULL,
op_attrs TEXT NOT NULL)
"""
cursor.execute(sqlite_create_table_query)
connection.commit()
cursor.close()
self._nvrtc_compile_options = ["-std=c++17", "-default-device"]
self._nvcc_compile_options = [
"-std=c++17",
"--expt-relaxed-constexpr",
"-Xcudafe --diag_suppress=esa_on_defaulted_function_ignored",
]
self.nvcc()
self.compiled_cache_device = {}
self.compiled_cache_host = {}
def nvrtc(self):
self.backend = "nvrtc"
self.default_compile_options = self._nvrtc_compile_options
def nvcc(self):
self.backend = "nvcc"
self.default_compile_options = self._nvcc_compile_options
def insert_operation(self, op_key, cubin, hostfile, op_name, op_attrs):
connection = sqlite3.connect(CACHE_FILE)
cursor = connection.cursor()
sqlite_insert_blob_query = """ INSERT OR IGNORE INTO compiled_operations (op_key, cubin, hostbin, op_name, op_attrs) VALUES (?, ?, ?, ?, ?)"""
hostbin = convertToBinaryData(hostfile)
data_tuple = (op_key, cubin, hostbin, op_name, json.dumps(op_attrs))
cursor.execute(sqlite_insert_blob_query, data_tuple)
connection.commit()
cursor.close()
def load_operation(self, op_key, extra_funcs):
connection = sqlite3.connect(CACHE_FILE)
cursor = connection.cursor()
sqlite_fetch_blob_query = """SELECT * from compiled_operations where op_key = ?"""
cursor.execute(sqlite_fetch_blob_query, (op_key,))
record = cursor.fetchall()
if len(record) == 0:
return False
for row in record:
key, cubin_image, host_binary, operation_name, op_attr = row
op_attr = json.loads(op_attr)
err, module = cuda.cuModuleLoadData(cubin_image)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("Cuda Error: {}".format(err))
err, kernel = cuda.cuModuleGetFunction(module, bytes(str.encode(operation_name)))
self.compiled_cache_device[key] = kernel
compiled_host_fns = {}
host_lib = CDLLBin(host_binary)
func_name = operation_name + "_get_params"
func = getattr(host_lib, func_name)
func.restype = ctypes.POINTER(ctypes.c_char * op_attr[0])
compiled_host_fns["get_args"] = func
func_name = operation_name + "_shared_memory_size"
func = getattr(host_lib, func_name)
compiled_host_fns["shared_memory_capacity"] = func()
for attr in op_attr:
if isinstance(attr, str):
func_name = operation_name + "_" + attr
func = getattr(host_lib, func_name)
# Set the return type of the function
if attr in extra_funcs and extra_funcs[attr] != None:
func.restype = extra_funcs[attr]
compiled_host_fns[attr] = func
self.compiled_cache_host[key] = compiled_host_fns
return True
def emit_compile_(self, operation_list, compilation_options, host_compilation_options):
"""
Compile a list of kernels and store them into database
"""
source_buffer_device = ""
source_buffer_host = ""
# 1. include
includes = []
for operation in operation_list:
for incl in operation.emitter.includes:
if incl not in includes:
includes.append(incl)
includes_host = ["builtin_types.h", "device_launch_parameters.h", "cstddef"] + includes
for incl in includes:
source_buffer_device += SubstituteTemplate(
IncludeTemplate,
{"include": incl},
)
for incl in includes_host:
source_buffer_host += SubstituteTemplate(
IncludeTemplate,
{"include": incl},
)
# 2. Operations
for operation in operation_list:
source_buffer_device += operation.emit()
source_buffer_host += operation.emit()
values = {
"operation_name": operation.name(),
"operation_suffix": operation.emitter.operation_suffix,
}
source_buffer_device += SubstituteTemplate(
operation.KernelTemplate,
values,
)
source_buffer_host += SubstituteTemplate(operation.HostTemplate, values)
if self.backend == "nvrtc":
# 3. compile
err, program = nvrtc.nvrtcCreateProgram(
str.encode(source_buffer_device),
bytes(str.encode("module.cu")),
0, [], [])
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError("NVRTC Error: {}".format(err))
# Compile program
options = compilation_options.get()
err, = nvrtc.nvrtcCompileProgram(program, len(options), options)
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
error_string = "NVRTC Error: {}\n".format(err)
# Get log from compilation
err, logSize = nvrtc.nvrtcGetProgramLogSize(program)
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError("NVRTC Error: {}".format(err))
log = b" " * logSize
err, = nvrtc.nvrtcGetProgramLog(program, log)
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError("NVRTC Error: {}".format(err))
raise RuntimeError(error_string + log.decode() + source_buffer_device)
# Get data from compilation
err, dataSize = nvrtc.nvrtcGetCUBINSize(program)
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError("NVRTC Error: {}".format(err))
cubin_image = b" " * dataSize
(err,) = nvrtc.nvrtcGetCUBIN(program, cubin_image)
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError("NVRTC Error: {}".format(err))
else: # with nvcc backend
# emit code
tempfile.tempdir = "./"
temp_cu = tempfile.NamedTemporaryFile(
prefix="kernel", suffix=".cu", delete=True)
temp_cubin = tempfile.NamedTemporaryFile(
prefix="kernel", suffix=".cubin", delete=True)
with open(temp_cu.name, "w") as file:
file.write(source_buffer_device)
# compile with nvcc
cmd_template = "${cuda_install_path}/bin/nvcc ${options} -cubin ${srcfile} -o ${tarfile}"
values = {
"cuda_install_path": cuda_install_path(),
"options": compilation_options.get_str(),
"srcfile": temp_cu.name,
"tarfile": temp_cubin.name,
}
cmd = SubstituteTemplate(cmd_template, values)
compile_with_nvcc(cmd.split(" "), source_buffer_device, "./cutlass_python_compilation_device_error.txt")
# load the cubin image
with open(temp_cubin.name, "rb") as file:
cubin_image = file.read()
tempfile.tempdir = "./"
temp_src = tempfile.NamedTemporaryFile(
prefix="host_src", suffix=".cu", delete=True)
# Write the host source
with open(temp_src.name, "w") as outfile:
outfile.write(source_buffer_host)
temp_dst = tempfile.NamedTemporaryFile(
prefix="host_func", suffix=".so", delete=True)
# Set up host compilation arguments
cmd = []
cmd.append(f"{cuda_install_path()}/bin/nvcc")
cmd.extend(["-x", "cu", "-Xcompiler=-fpermissive", "-Xcompiler=-w", "-Xcompiler=-fPIC"])
cmd.extend(host_compilation_options.get_str().split(" "))
cmd.extend(["-shared", "-o", temp_dst.name, temp_src.name, "-lcudart", "-lcuda"])
# Comile and load the library
compile_with_nvcc( cmd, source_buffer_host, error_file="./cutlass_python_compilation_host_error.txt")
host_lib = ctypes.CDLL(temp_dst.name)
return cubin_image, host_lib, temp_dst
def add_module(self, operations, compile_options=None, bypass_cache=False):
"""
Insert a new compiled device module
"""
include_paths = [
cuda_install_path() + "/include",
CUTLASS_PATH + "/include",
CUTLASS_PATH + "/tools/util/include",
CUTLASS_PATH + "/python/cutlass/cpp/include",
]
cutlass_cppgen.initialize_cuda_context()
arch = device_cc()
host_compile_options = CompilationOptions(
self._nvcc_compile_options, arch, include_paths)
if compile_options is None:
compile_options = CompilationOptions(
self.default_compile_options, arch, include_paths)
# save the cubin
operation_key = []
operation_list = []
for operation in operations:
# step 1: get kernel string as key
key = operation.rt_module.emit() + operation.procedural_name() + self.backend
# step 1: check if the operation is in cache
compiled_kernel = self.compiled_cache_device.get(key)
if compiled_kernel is None and not bypass_cache:
hit = self.load_operation(key, getattr( operation.rt_module, "extra_funcs", {}))
if hit:
compiled_kernel = self.compiled_cache_device.get(key)
assert compiled_kernel is not None
if compiled_kernel is not None:
operation.rt_module.kernel = compiled_kernel
compiled_host_fns = self.compiled_cache_host.get(key)
assert compiled_host_fns is not None
for key in compiled_host_fns.keys():
setattr(operation.rt_module, key, compiled_host_fns[key])
operation.rt_module.initialize()
else:
operation_list.append(operation.rt_module)
operation_key.append(key)
if len(operation_list) > 0:
cubin_image, host_lib, host_file = self.emit_compile_(
operation_list, compile_options, host_compile_options)
err, module = cuda.cuModuleLoadData(cubin_image)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("Cuda Error: {}".format(err))
operation_name = []
operation_attr = []
for operation, key in zip(operation_list, operation_key):
# get device kernels
err, operation.kernel = cuda.cuModuleGetFunction(
module,
bytes(str.encode(operation.name()))
)
operation_name.append(operation.name())
self.compiled_cache_device[key] = operation.kernel
# get host functions
compiled_host_fns = {}
op_attr = []
# get param size
func_name = operation.name() + "_get_param_size"
func = getattr(host_lib, func_name)
param_size = func()
func_name = operation.name() + "_get_params"
func = getattr(host_lib, func_name)
func.argtype = operation.argtype
func.restype = ctypes.POINTER(ctypes.c_char * param_size)
setattr(operation, "get_args", func)
compiled_host_fns["get_args"] = func
# set shared memory size
func_name = operation.name() + "_shared_memory_size"
func = getattr(host_lib, func_name)
setattr(operation, "shared_memory_capacity", func())
compiled_host_fns["shared_memory_capacity"] = func()
# set the maximum dynamic shared size
operation.initialize()
# get extra functions
op_attr.append(param_size)
if hasattr(operation, "extra_funcs"):
for suffix, ret_type in operation.extra_funcs.items():
func_name = operation.name() + "_" + suffix
func = getattr(host_lib, func_name)
if ret_type is not None:
func.restype = ret_type
setattr(operation, suffix, func)
compiled_host_fns[suffix] = func
op_attr.append(suffix)
operation_attr.append(op_attr)
self.compiled_cache_host[key] = compiled_host_fns
for (key, operation_name, operation_attr,) in zip(operation_key, operation_name, operation_attr):
self.insert_operation(
key, cubin_image, host_file.name, operation_name, operation_attr)

View File

@ -0,0 +1,700 @@
#################################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from __future__ import annotations
import ctypes
from typing import Union
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
from cutlass_library import SubstituteTemplate
import numpy as np
from cutlass_library import (
ConvKindNames,
ConvKindTag,
DataTypeNames,
DataTypeSize,
DataTypeTag,
IteratorAlgorithmNames,
IteratorAlgorithmTag,
LayoutTag,
LayoutType,
MathOperation,
MathOperationTag,
OpcodeClass,
OpcodeClassNames,
OpcodeClassTag,
OperationKind,
ShortDataTypeNames,
ShortLayoutTypeNames,
SplitKMode,
StrideSupport,
StrideSupportTag,
SwizzlingFunctor,
SwizzlingFunctorTag,
get_complex_from_real,
)
from cutlass_cppgen.backend.arguments import ArgumentBase
from cutlass_cppgen.backend.c_types import dim3_, get_conv2d_arguments
from cutlass_cppgen.backend.library import (
EmissionType,
TensorDescription,
TileDescription,
)
from cutlass_cppgen.backend.memory_manager import device_mem_alloc
from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration
from cutlass_cppgen.backend.utils.device import to_device_ptr
from cutlass_cppgen.shape import GemmCoord
class Conv2dArguments(ArgumentBase):
"""
Argument wrapper for Conv2d. It encodes problem information and
user-provide tensors into the kernel's argument.
:param operation: the Conv2d operation to take the argument
:type operation: :class:`cutlass_cppgen.backend.Conv2dOperation`
:param problem_size: the Conv2d problem size
:type problem_size: :class:`cutlass_cppgen.shape.Conv2dProblemSize`
:param A: tensor A
:type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
:param B: tensor B
:type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
:param C: tensor C
:type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
:param D: tensor D
:type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
:param split_k_mode: conv2d split K mode, defaults to cutlass_library.library.SplitKMode.Serial
:type split_k_mode: cutlass_library.library.SplitKMode, optional
:param output_op: output operator, optional
:type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
:type stream: :class:`cuda.cuda.CUstream`
"""
def __init__(self, operation, problem_size, A, B, C, D,
split_k_mode=SplitKMode.Serial, **kwargs, ) -> None:
self.operation = operation
self.conv_kind = operation.conv_kind
self.layout_A = operation.A.layout
self.layout_B = operation.B.layout
self.layout_C = operation.C.layout
self.element_A = operation.A.element
self.element_B = operation.B.element
self.element_C = operation.C.element
if self.layout_C == LayoutType.TensorNC32HW32:
raise Exception("Layout type TensorNC32HW32 is not currently supported")
super().__init__(A, B, C, D, **kwargs)
if "split_k_slices" in kwargs.keys() and kwargs["split_k_slices"] > 1:
self.split_k_mode = split_k_mode
self.split_k_slices = kwargs["split_k_slices"]
else:
self.split_k_mode = SplitKMode.Serial
self.split_k_slices = 1
if "output_op" in kwargs.keys() and self.split_k_mode != SplitKMode.Parallel:
self.output_op = kwargs["output_op"]
else:
self.output_op = self.operation.epilogue_type(1.0, 0.0)
self.problem_size = problem_size
self.problem_size.split_k_slices = self.split_k_slices
self.initialize()
def get_arguments(self):
tc_numel = -1
if hasattr(self, "tensor_c_numel"):
tc_numel = self.tensor_c_numel
self.c_arguments = self.operation.argument_type(
int(self.conv_kind),
self.problem_size.ctype,
int(to_device_ptr(self.ptr_A)),
int(to_device_ptr(self.ptr_B)),
int(to_device_ptr(self.ptr_C)),
int(to_device_ptr(self.ptr_D)),
tc_numel,
self.output_op,
int(self.split_k_mode)
)
def initialize(self):
self.launch_config = self.operation.rt_module.plan(self)
self.get_arguments()
# Allocate and initialize device workspace
device_workspace_size = self.operation.rt_module.get_workspace_size(self.c_arguments)
if device_workspace_size > 0:
self.workspace_buffer = device_mem_alloc(device_workspace_size)
workspace_ptr = self.workspace_buffer.ptr
err, = cuda.cuMemsetD32(
workspace_ptr, 0, device_workspace_size // 4)
else:
workspace_ptr = None
self.semaphore = 0
if workspace_ptr is not None and self.split_k_mode == SplitKMode.Parallel:
self.ptr_D = workspace_ptr
# Reset arguments now that ptr_D has been updated
self.get_arguments()
elif workspace_ptr is not None and self.split_k_mode == SplitKMode.Serial:
self.semaphore = workspace_ptr
params_ = self.operation.rt_module.get_args(
self.c_arguments, ctypes.c_void_p(int(self.semaphore)))
self.host_workspace = bytearray(params_.contents)
self.device_workspace = None
def sync(self):
"""
Synchronize the arguments. If the input tensor is in host,
copy it from device to host.
"""
return super().sync()
class Conv2dRT(ExecutableOperation):
"""
Conv2dRT manages the CUTLASS runtime components
"""
KernelTemplate = r"""
extern "C"
__global__ void
${operation_name}(${operation_name}${operation_suffix}::Params params) {
// Dynamic shared memory base pointer
extern __shared__ int SharedStorageBase[];
// Declare pointer to dynamic shared memory.
${operation_name}${operation_suffix}::SharedStorage *shared_storage =
reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase);
${operation_name}${operation_suffix} op;
op(params, *shared_storage);
}
"""
HostTemplate = r"""
extern "C" {
// Get the size of params in bytes
int ${operation_name}_get_param_size(){
return sizeof(${operation_name}${operation_suffix}::Params);
}
// Get the size of dynamic shared memory in bytes
int ${operation_name}_shared_memory_size() {
return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
}
using ElementA = typename ${operation_name}_base::ElementA;
using ElementB = typename ${operation_name}_base::ElementB;
using ElementC = typename ${operation_name}_base::ElementC;
using LayoutA = typename ${operation_name}_base::LayoutA;
using LayoutB = typename ${operation_name}_base::LayoutB;
using LayoutC = typename ${operation_name}_base::LayoutC;
using EpilogueOutputOp = typename ${operation_name}_base::EpilogueOutputOp;
struct ${operation_name}_TemporaryArgs {
int conv_kind;
cutlass::conv::Conv2dProblemSize problem_size;
ElementA* ptr_A;
ElementB* ptr_B;
ElementC* ptr_C;
ElementC* ptr_D;
int tensor_c_numel;
typename EpilogueOutputOp::Params epilogue_params;
int split_k_mode;
};
typename ${operation_name}${operation_suffix}::Arguments
construct_arguments(${operation_name}_TemporaryArgs args) {
cutlass::conv::Operator conv_operator = static_cast<cutlass::conv::Operator>(args.conv_kind);
auto tc_A = cutlass::conv::implicit_gemm_tensor_a_extent(conv_operator, args.problem_size);
auto tc_B = cutlass::conv::implicit_gemm_tensor_b_extent(conv_operator, args.problem_size);
auto tc_C = cutlass::conv::implicit_gemm_tensor_c_extent(conv_operator, args.problem_size);
auto tc_D = cutlass::conv::implicit_gemm_tensor_c_extent(conv_operator, args.problem_size);
auto size_C = tc_C.at(0) * tc_C.at(1) * tc_C.at(2) * tc_C.at(3);
if (args.tensor_c_numel >= 0 && args.tensor_c_numel == tc_C.at(3) && args.tensor_c_numel < size_C) {
// C is interpreted as bias
tc_C = {0, 0, 0, 0};
}
cutlass::TensorRef<ElementA, LayoutA> tref_A(args.ptr_A, LayoutA::packed(tc_A));
cutlass::TensorRef<ElementB, LayoutA> tref_B(args.ptr_B, LayoutB::packed(tc_B));
cutlass::TensorRef<ElementC, LayoutA> tref_C(args.ptr_C, LayoutC::packed(tc_C));
cutlass::TensorRef<ElementC, LayoutA> tref_D(args.ptr_D, LayoutC::packed(tc_D));
return {
args.problem_size,
tref_A,
tref_B,
tref_C,
tref_D,
args.epilogue_params,
static_cast<cutlass::conv::SplitKMode>(args.split_k_mode)
};
}
// Get the params as byte array
char* ${operation_name}_get_params(${operation_name}_TemporaryArgs args, int *semaphore=nullptr) {
auto arguments = construct_arguments(args);
typename ${operation_name}${operation_suffix}::Params* params;
params = new ${operation_name}${operation_suffix}::Params(arguments, semaphore);
char *bytes = ((char*)(params));
char *output = new char[sizeof(${operation_name}${operation_suffix}::Params)];
for (unsigned int i = 0; i < sizeof(${operation_name}${operation_suffix}::Params); i ++)
output[i] = bytes[i];
return output;
}
dim3 ${operation_name}_get_grid_shape(
int conv_kind,
cutlass::conv::Conv2dProblemSize problem_size,
cutlass::gemm::GemmCoord tile_size,
int split_k_slices
) {
using Swizzle = typename ${operation_name}_base::ThreadblockSwizzle;
auto tiled_shape = Swizzle::get_tiled_shape(
static_cast<cutlass::conv::Operator>(conv_kind),
problem_size,
tile_size,
split_k_slices);
return Swizzle::get_grid_shape(tiled_shape);
}
size_t ${operation_name}_get_workspace_size(${operation_name}_TemporaryArgs args) {
auto arguments = construct_arguments(args);
// Temporarily define device::-level Conv2d so that we can call get_workspace_size
using DeviceConv = cutlass::conv::device::ImplicitGemmConvolution<${operation_name}_base>;
return DeviceConv::get_workspace_size(arguments);
}
}
"""
def __init__(self, operation: "Conv2dOperation"):
super().__init__(operation)
self.extra_funcs = {
"get_grid_shape": dim3_,
"get_workspace_size": ctypes.c_uint64
}
self.argument_type, self.epilogue_type = get_conv2d_arguments(operation.epilogue_functor)
self.argtype = [ctypes.POINTER(self.argument_type), ctypes.c_void_p]
self.conv_kind = operation.conv_kind
self.operation: Conv2dOperation = operation
self.emitter = EmitConv2dInstance("_type")
self.threads = operation.tile_description.num_threads
self.swizzle_functor = operation.swizzling_functor
def emit(self):
return self.emitter.emit(self.operation)
def plan(self, arguments: Conv2dArguments):
tile_size = GemmCoord(
self.operation.tile_description.threadblock_shape[0],
self.operation.tile_description.threadblock_shape[1],
self.operation.tile_description.threadblock_shape[2],
)
grid = self.get_grid_shape(
int(self.conv_kind),
arguments.problem_size.ctype,
tile_size.ctype,
arguments.split_k_slices
)
return LaunchConfiguration(
[grid.x, grid.y, grid.z], [self.threads, 1, 1],
self.shared_memory_capacity)
def initialize(self):
err, = cuda.cuFuncSetAttribute(
self.kernel,
attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
value=self.shared_memory_capacity)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"CUDA Error: {err}")
class Conv2dOperation:
"""
CUTLASS Conv2d operation description.
:param conv_kind: convolution operator
:type conv_kind: :class:`cutlass_library.library.ConvKind`
:param iterator_algorithm: Selects among several implementation
variants trading off performance with simplicity
:type iterator_algorithm: :class:`cutlass_library.library.IteratorAlgorithm`
:param arch: GPU compute capability (sm_xx)
:type arch: int
:param tile_description: tile description
:type tile_description: :class:`cutlass_cppgen.backend.TileDescription`
:param A: tensor A description
:type A: :class:`cutlass_cppgen.backend.TensorDescription`
:param B: tensor B description
:type B: :class:`cutlass_cppgen.backend.TensorDescription`
:param C: tensor C description
:type C: :class:`cutlass_cppgen.backend.TensorDescription`
:param D: tensor D description
:type D: :class:`cutlass_cppgen.backend.TensorDescription`
:param element_epilogue: element type for computation in epilogue \
:type element_epilogue: cutlass_library.library.DataType
:param stride_support: distinguish among partial specializations that \
accelerate certain problems where convolution stride is unit \
:type stride_support: :class:`cutlass_library.library.StrideSupport`
:param epilogue_functor: convolution epilogue functor
:type epilogue_functor: :class:`EpilogueFunctor`
:param swizzling_functor: threadblock swizzling functor
"""
def __init__(
self,
conv_kind,
iterator_algorithm,
arch: int,
tile_description: TileDescription,
A: TensorDescription,
B: TensorDescription,
C: TensorDescription,
stride_support,
epilogue_functor,
swizzling_functor=SwizzlingFunctor.Identity1,
emission_type=EmissionType.Kernel,
**kwargs
):
self.operation_kind: OperationKind = OperationKind.Conv2d
self.arch: int = arch
self.tile_description: TileDescription = tile_description
self.conv_kind = conv_kind
self.A: TensorDescription = A
self.B: TensorDescription = B
self.C: TensorDescription = C
self.epilogue_functor = epilogue_functor
self.iterator_algorithm = iterator_algorithm
self.stride_support = stride_support
self.swizzling_functor = swizzling_functor
self.emission_type = emission_type
self.rt_module: Conv2dRT = Conv2dRT(self)
self.argument_type = self.rt_module.argument_type
self.epilogue_type = self.rt_module.epilogue_type
def run(self, arguments: Conv2dArguments) -> cuda.CUresult:
"""
Launch the cuda kernel with input arguments
:param arguments: conv2d arguments
:type arguments: :class:`cutlass_cppgen.backend.Conv2dArguments`
"""
# launch the kernel
err = self.rt_module.run(
arguments.host_workspace,
arguments.device_workspace,
arguments.launch_config,
arguments.stream
)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"CUDA Error {err}")
return err
#
# Get function name
#
def procedural_name(self):
"""The full procedural name indicates architecture, extended name, tile size, and layout."""
return self.configuration_name()
def configuration_name(self):
"""The full procedural name indicates architecture, extended name, tile size, and layout."""
opcode_class_name = OpcodeClassNames[
self.tile_description.math_instruction.opcode_class
]
threadblock = "%dx%d_%dx%d" % (
self.tile_description.threadblock_shape[0],
self.tile_description.threadblock_shape[1],
self.tile_description.threadblock_shape[2],
self.tile_description.stages,
)
if self.stride_support == StrideSupport.Unity:
configuration_name = "cutlass_sm${arch}_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_align${alignment}"
else:
configuration_name = "cutlass_sm${arch}_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}"
return SubstituteTemplate(
configuration_name,
{
"arch": str(self.arch),
"opcode_class": opcode_class_name,
"extended_name": self.extended_name(),
"threadblock": threadblock,
"layout": self.layout_name(),
"alignment": "%d" % self.A.alignment
},
)
def extended_name(self):
"""Append data types if they differ from compute type."""
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
self.A.element != self.tile_description.math_instruction.element_accumulator:
extended_name = "${element_c}_${core_name}_${element_a}"
elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
self.A.element != self.tile_description.math_instruction.element_accumulator:
extended_name = "${core_name}_${element_a}"
else:
extended_name = "${core_name}"
extended_name = SubstituteTemplate(extended_name, {
"element_a": DataTypeNames[self.A.element],
"element_c": DataTypeNames[self.C.element],
"core_name": self.core_name(),
})
return extended_name
def layout_name(self):
return "%s" % (ShortLayoutTypeNames[self.A.layout])
def core_name(self):
"""The basic operation kind is prefixed with a letter indicating the accumulation type."""
intermediate_type = ""
if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp:
inst_shape = "%dx%dx%d" % tuple(
self.tile_description.math_instruction.instruction_shape)
if self.tile_description.math_instruction.element_a != self.A.element and \
self.tile_description.math_instruction.element_a != self.accumulator_type():
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
else:
inst_shape = ""
return "%s%s%s%s_%s" % (
ShortDataTypeNames[self.accumulator_type()],
inst_shape,
intermediate_type,
ConvKindNames[self.conv_kind],
IteratorAlgorithmNames[self.iterator_algorithm]
)
def is_complex(self):
complex_operators = [
MathOperation.multiply_add_complex,
MathOperation.multiply_add_complex_gaussian,
]
return self.tile_description.math_instruction.math_operation in complex_operators
def accumulator_type(self):
accum = self.tile_description.math_instruction.element_accumulator
if self.is_complex():
return get_complex_from_real(accum)
return accum
def device_op(self):
"""
Returns a new Conv2dOperation object that is constructed with emission type
``EmissionType.Device``.
:return: operation ready for device-level code emission
:rtype: Conv2dOperation
"""
return Conv2dOperation(
self.conv_kind, self.iterator_algorithm, self.arch, self.tile_description,
self.A, self.B, self.C, self.stride_support, self.epilogue_functor, self.swizzling_functor,
emission_type=EmissionType.Device)
###################################################################################################
#
# Emits single instances of a CUTLASS device-wide operator
#
###################################################################################################
class EmitConv2dInstance:
def __init__(self, operation_suffix=""):
self.operation_suffix = operation_suffix
self.includes = [
"cutlass/cutlass.h",
"cutlass/conv/kernel/default_conv2d_fprop.h",
"cutlass/conv/kernel/default_conv2d_dgrad.h",
"cutlass/conv/kernel/default_conv2d_wgrad.h",
"cutlass/conv/device/implicit_gemm_convolution.h"
]
self.template = """
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
using ${operation_name}_base =
typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
${element_a},
${layout_a},
${element_b},
${layout_b},
${element_c},
${layout_c},
${element_accumulator},
${opcode_class},
${arch},
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
${epilogue_functor},
${swizzling_functor},
${stages},
${math_operator},
${iterator_algorithm},
${stride_support},
${align_a},
${align_b}
>::Kernel;
struct ${operation_name}${operation_suffix}:
public ${operation_name}_base { };
"""
self.template_device = """
// Conv2d operation ${operation_name}
using Conv2d${conv_kind_name}Kernel = typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
${element_a},
${layout_a},
${element_b},
${layout_b},
${element_c},
${layout_c},
${element_accumulator},
${opcode_class},
${arch},
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
${epilogue_functor},
${swizzling_functor},
${stages},
${math_operator},
${iterator_algorithm},
${stride_support},
${align_a},
${align_b}
>::Kernel;
using DeviceKernel =
typename cutlass::conv::device::ImplicitGemmConvolution<Conv2d${conv_kind_name}Kernel>;
"""
def emit(self, operation):
warp_shape = [int(operation.tile_description.threadblock_shape[idx] /
operation.tile_description.warp_count[idx]) for idx in range(3)]
epilogue_vector_length = int(min(
operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
values = {
"operation_name": operation.procedural_name(),
"operation_suffix": self.operation_suffix,
"conv_kind": ConvKindTag[operation.conv_kind],
"conv_kind_name": ConvKindNames[operation.conv_kind].capitalize(),
"element_a": DataTypeTag[operation.A.element],
"layout_a": LayoutTag[operation.A.layout],
"element_b": DataTypeTag[operation.B.element],
"layout_b": LayoutTag[operation.B.layout],
"element_c": DataTypeTag[operation.C.element],
"layout_c": LayoutTag[operation.C.layout],
"element_accumulator": DataTypeTag[operation.accumulator_type()],
"opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
"arch": "cutlass::arch::Sm%d" % operation.arch,
"threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
"threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
"threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
"warp_shape_m": str(warp_shape[0]),
"warp_shape_n": str(warp_shape[1]),
"warp_shape_k": str(warp_shape[2]),
"instruction_shape_m": str(operation.tile_description.math_instruction.instruction_shape[0]),
"instruction_shape_n": str(operation.tile_description.math_instruction.instruction_shape[1]),
"instruction_shape_k": str(operation.tile_description.math_instruction.instruction_shape[2]),
"epilogue_vector_length": str(epilogue_vector_length),
"epilogue_functor": operation.epilogue_functor.emit(),
"swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
"stages": str(operation.tile_description.stages),
"iterator_algorithm": IteratorAlgorithmTag[operation.iterator_algorithm],
"iterator_algorithm_name": IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(),
"stride_support": StrideSupportTag[operation.stride_support],
"math_operator": "cutlass::arch::OpMultiplyAddComplex" if operation.is_complex() else MathOperationTag[operation.tile_description.math_instruction.math_operation],
"align_a": str(operation.A.alignment),
"align_b": str(operation.B.alignment),
}
if operation.emission_type == EmissionType.Kernel:
conv2d_template = self.template
else:
conv2d_template = self.template_device
return SubstituteTemplate(conv2d_template, values)

View File

@ -0,0 +1,541 @@
#################################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import ctypes
from cutlass_library import SubstituteTemplate
import numpy as np
from cutlass_library import DataType, DataTypeTag
from cutlass_cppgen.backend.c_types import MatrixCoord_, tuple_factory
from cutlass_cppgen.backend.frontend import NumpyFrontend
from cutlass_cppgen.backend.library import ActivationOp, ActivationOpTag
from cutlass_cppgen.utils.datatypes import is_numpy_tensor, is_torch_available, is_torch_tensor
dtype2ctype = {
DataType.f16: ctypes.c_uint16,
DataType.bf16: ctypes.c_uint16,
DataType.f32: ctypes.c_float,
DataType.f64: ctypes.c_double,
DataType.s8: ctypes.c_int8,
DataType.s32: ctypes.c_int32
}
if is_torch_available():
import torch
import torch.nn.functional as F
def get_scalar(value):
"""
Returns a scalar value from a container (e.g., np.ndarray)
"""
if is_numpy_tensor(value):
if value.size != 1:
raise Exception("Scalars used in epilogue must be of size 1")
return value.reshape(-1)[0]
elif is_torch_tensor(value):
if value.size != 1:
raise Exception("Scalars used in epilogue must be of size 1")
return value.reshape(-1)[0]
else:
return value
def to_ctype_value(value, dtype):
"""
Converts ``value`` to the corresponding storage needed for the ctype that
will store ``value``.
"""
scalar = get_scalar(value)
if dtype == DataType.f16:
# Convert f16 value into an integer
return int.from_bytes(np.float16(scalar).tobytes(), "little")
else:
return scalar
#################################################################################################
#
# Epilogue Functors
#
#################################################################################################
class EpilogueFunctorBase:
"""
Base class for thread-level epilogue functors
"""
def __init__(self) -> None:
pass
def emit(self, tag, template_argument):
template = """${tag}<${arguments}>"""
arguments = ""
for idx, arg in enumerate(template_argument):
arguments += arg
if idx < len(template_argument) - 1:
arguments += ", "
values = {
"tag": tag,
"arguments": arguments,
}
return SubstituteTemplate(template, values)
class LinearCombination(EpilogueFunctorBase):
"""
Apply a linear combination operator to an array of elements
D = alpha * accumulator + beta * source
:param element_output: data type used to load and store tensors
:param epilogue_vector_length: number of elements computed per operation.
Usually it is 128/sizeof_bits_v<ElementOutput_>, but we use 64 and 32 sometimes
when there are not enough data to store
:param element_accumulator: Accumulator data type
:param element_epilogue: data type used to compute linear combination
"""
tag = "cutlass::epilogue::thread::LinearCombination"
def __init__(
self, element_output, epilogue_vector_length,
element_accumulator=None, element_epilogue=None) -> None:
super().__init__()
if element_accumulator is None:
element_accumulator = element_output
if element_epilogue is None:
element_epilogue = element_output
self.element_output = element_output
self.element_accumulator = element_accumulator
self.element_epilogue = element_epilogue
self.epilogue_vector_length = epilogue_vector_length
self.template_arguments = [
DataTypeTag[element_output],
str(epilogue_vector_length),
DataTypeTag[element_accumulator],
DataTypeTag[element_epilogue],
]
c_element_epilogue = dtype2ctype[self.element_epilogue]
element_epilogue = self.element_epilogue
class _EpilogueOutputOpParamsEVT(ctypes.Structure):
"""
Epilogue params when using the default linear combination of EVT, which
does not currently use {alpha,beta}_ptr_array
"""
stride_type = tuple_factory((0,0,1), "int64_t", [0])
_fields_ = [
("alpha", c_element_epilogue),
("beta", c_element_epilogue),
("alpha_ptr", ctypes.c_void_p),
("beta_ptr", ctypes.c_void_p),
("dalpha", stride_type),
("dbeta", stride_type),
]
def __init__(self, alpha, beta, *args) -> None:
self.alpha = to_ctype_value(alpha, element_epilogue)
self.beta = to_ctype_value(beta, element_epilogue)
class _EpilogueOutputOpParams(ctypes.Structure):
_fields_ = [
("alpha", c_element_epilogue),
("beta", c_element_epilogue),
("alpha_ptr", ctypes.c_void_p),
("beta_ptr", ctypes.c_void_p),
("alpha_ptr_array", ctypes.c_void_p),
("beta_ptr_array", ctypes.c_void_p),
]
def __init__(self, alpha, beta, *args) -> None:
self.alpha = to_ctype_value(alpha, element_epilogue)
self.beta = to_ctype_value(beta, element_epilogue)
def to_evt_params(self) -> _EpilogueOutputOpParamsEVT:
return _EpilogueOutputOpParamsEVT(self.alpha, self.beta)
self.epilogue_type = _EpilogueOutputOpParams
self.epilogue_type_evt = _EpilogueOutputOpParamsEVT
def emit(self):
return super().emit(self.tag, self.template_arguments)
class LinearCombinationClamp(LinearCombination):
"""
Applies a linear combination operator to an array of elements then clamps
the output before converting to the output element type.
D = alpha * accumulator + beta * source + uniform
:param element_output: data type used to load and store tensors
:param epilogue_vector_length: number of elements computed per operation.
Usually it is 128/sizeof_bits_v<ElementOutput_>, but we use 64 and 32 sometimes
when there are not enough data to store
:param element_accumulator: Accumulator data type
:param element_epilogue: data type used to compute linear combination
"""
tag = "cutlass::epilogue::thread::LinearCombinationClamp"
def __init__(
self, element_output, epilogue_vector_length,
element_accumulator=None, element_epilogue=None) -> None:
# Base constructor
super().__init__(
element_output,
epilogue_vector_length,
element_accumulator,
element_epilogue,
)
c_element_epilogue = dtype2ctype[self.element_epilogue]
element_epilogue = self.element_epilogue
class _EpilogueOutputOpParams(ctypes.Structure):
_fields_ = [
("alpha", c_element_epilogue),
("beta", c_element_epilogue),
("alpha_ptr", ctypes.c_void_p),
("beta_ptr", ctypes.c_void_p),
]
def __init__(self, alpha, beta, *args) -> None:
self.alpha = to_ctype_value(alpha, element_epilogue)
self.beta = to_ctype_value(beta, element_epilogue)
self.epilogue_type = _EpilogueOutputOpParams
class FastLinearCombinationClamp(EpilogueFunctorBase):
"""
Applies a linear combination operator to an array of elements then clamps
the output before converting to the output element type.
D = alpha * accumulator + beta * source
Note: The below method only when problem_size_K <= 256 for signed int8 gemm
or problem_size_K <= 128 for unsigned int8 gemm. The default approach is
above.
:param element_output: data type used to load and store tensors
:param epilogue_vector_length: number of elements computed per operation.
Usually it is 128/sizeof_bits_v<ElementOutput_>, but we use 64 and 32 sometimes
when there are not enough data to store
"""
tag = "cutlass::epilogue::thread::FastLinearCombinationClamp"
def __init__(self, element_output, epilogue_vector_length, *args) -> None:
super().__init__()
self.template_arguments = [
DataTypeTag[element_output], str(epilogue_vector_length)
]
self.element_accumulator = DataType.s32
self.element_epilogue = DataType.f32
# get epilogue output op
c_element_epilogue = dtype2ctype[self.element_epilogue]
element_epilogue = self.element_epilogue
class _EpilogueOutputOpParams(ctypes.Structure):
_fields_ = [
("alpha", c_element_epilogue),
("beta", c_element_epilogue),
("alpha_ptr", ctypes.c_void_p),
("beta_ptr", ctypes.c_void_p),
]
def __init__(self, alpha, beta, *args) -> None:
self.alpha = to_ctype_value(alpha, element_epilogue)
self.beta = to_ctype_value(beta, element_epilogue)
self.epilogue_type = _EpilogueOutputOpParams
def emit(self):
return super().emit(self.tag, self.template_arguments)
class LinearCombinationGeneric(LinearCombination):
"""
Applies a linear combination operator followed by an activation function
to an array of elements.
D = activation(alpha * accumulator + beta * source)
:param activation_functor: input activation functor
:param element_output: data type used to load and store tensors
:param epilogue_vector_length: number of elements computed per operation.
Usually it is 128/sizeof_bits_v<ElementOutput_>, but we use 64 and 32 sometimes
when there are not enough data to store
:param element_accumulator: Accumulator data type
:param element_epilogue: data type used to compute linear combination
"""
tag = "cutlass::epilogue::thread::LinearCombinationGeneric"
def __init__(
self, activation_functor,
element_output, epilogue_vector_length,
element_accumulator=None, element_epilogue=None) -> None:
super().__init__(
element_output,
epilogue_vector_length,
element_accumulator,
element_epilogue,
)
self.template_arguments = [
activation_functor.emit()] + self.template_arguments
self.activation_functor = activation_functor
self.element_epilogue = element_epilogue
# get epilogue output op
self.epilogue_type = self.activation_functor.epilogue_output_op(self.element_epilogue)
class ActivationFunctor:
"""
Base class for frequently used activation functions
"""
@staticmethod
def numpy(x: np.ndarray):
raise NotImplementedError()
@classmethod
def emit(cls):
return ActivationOpTag[cls.binding_type]
@staticmethod
def epilogue_output_op(element_epilogue):
c_element_epilogue = dtype2ctype[element_epilogue]
class _EpilogueOutputOpParams(ctypes.Structure):
_fields_ = [
("alpha", c_element_epilogue),
("beta", c_element_epilogue),
("alpha_ptr", ctypes.c_void_p),
("beta_ptr", ctypes.c_void_p),
]
def __init__(self, alpha, beta, *args) -> None:
self.alpha = to_ctype_value(alpha, element_epilogue)
self.beta = to_ctype_value(beta, element_epilogue)
return _EpilogueOutputOpParams
class ActivationMeta(type):
@classmethod
def __call__(cls, x, *args):
if is_numpy_tensor(x):
return cls.numpy(x, *args)
elif is_torch_tensor(x):
return cls.torch(x, *args)
else:
raise NotImplementedError("Unsupported tensor type")
@classmethod
def numpy(cls, *args):
raise NotImplementedError(f"Numpy reference for {cls.__name__[:-4]} is not implemented.")
@classmethod
def torch(cls, *args):
raise NotImplementedError(f"PyTorch reference for {cls.__name__[:-4]} is not implemented.")
##############################################################################
# identity operator
class identityMeta(ActivationMeta):
@classmethod
def numpy(cls, x):
return x
@classmethod
def torch(cls, x):
return x
class identity(ActivationFunctor, metaclass=identityMeta):
binding_type = ActivationOp.Identity
##############################################################################
# ReLu operator
class reluMeta(ActivationMeta):
@classmethod
def numpy(cls, x):
return np.where(x > 0, x, 0)
@classmethod
def torch(cls, x):
return F.relu(x)
class relu(ActivationFunctor, metaclass=reluMeta):
binding_type = ActivationOp.ReLU
##############################################################################
# Leaky ReLu operator
class leakyReLUMeta(ActivationMeta):
@classmethod
def numpy(cls, x, leaky_alpha):
return np.maximum(x, 0) + np.minimum(x, 0) * leaky_alpha
@classmethod
def torch(cls, x, leaky_alpha):
return F.leaky_relu(x, leaky_alpha)
class leaky_relu(ActivationFunctor, metaclass=leakyReLUMeta):
binding_type = ActivationOp.LeakyReLU
@staticmethod
def epilogue_output_op(element_epilogue):
c_element_epilogue = dtype2ctype[element_epilogue]
class _EpilogueOutputOpParams(ctypes.Structure):
_fields_ = [
("alpha", c_element_epilogue),
("beta", c_element_epilogue),
("alpha_ptr", ctypes.c_void_p),
("beta_ptr", ctypes.c_void_p),
("leaky_alpha", c_element_epilogue)
]
def __init__(self, alpha, beta, leaky_alpha=0.2, *args) -> None:
self.alpha = to_ctype_value(alpha, element_epilogue)
self.beta = to_ctype_value(beta, element_epilogue)
self.alpha_ptr = 0
self.beta_ptr = 0
self.leaky_alpha = to_ctype_value(leaky_alpha, element_epilogue)
return _EpilogueOutputOpParams
##############################################################################
# Tanh operator
class tanhMeta(ActivationMeta):
@classmethod
def numpy(cls, x):
return np.tanh(x)
@classmethod
def torch(cls, x):
return torch.tanh(x)
class tanh(ActivationFunctor, metaclass=tanhMeta):
binding_type = ActivationOp.Tanh
##############################################################################
# Sigmoid operator
class sigmoidMeta(ActivationMeta):
@classmethod
def numpy(cls, x):
return 1.0 / (1.0 + np.exp(-x))
@classmethod
def torch(cls, x):
return F.sigmoid(x)
class sigmoid(ActivationFunctor, metaclass=sigmoidMeta):
binding_type = ActivationOp.Sigmoid
##############################################################################
# SiLu operator
class siluMeta(ActivationMeta):
@classmethod
def numpy(cls, x):
return x * sigmoidMeta.numpy()
@classmethod
def silu(cls, x):
return F.silu(x)
class silu(ActivationFunctor, metaclass=siluMeta):
binding_type = ActivationOp.SiLU
##############################################################################
# Hardswish operator
class hardswishMeta(ActivationMeta):
@classmethod
def numpy(cls, x):
relu6 = np.minimum(np.maximum(x + 3.0, 0), 6.0)
return x * relu6 / 6.0
@classmethod
def torch(cls, x):
return F.hardswish(x)
class hardswish(ActivationFunctor, metaclass=hardswishMeta):
binding_type = ActivationOp.HardSwish
##############################################################################
# GELU operator
class geluMeta(ActivationMeta):
@classmethod
def numpy(cls, x):
from scipy.special import erf
return 0.5 * x * (1 + erf(x / np.sqrt(2.0)))
@classmethod
def torch(cls, x):
return F.gelu(x)
class gelu(ActivationFunctor, metaclass=geluMeta):
binding_type = ActivationOp.Gelu

View File

@ -0,0 +1,34 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_cppgen.backend.evt.epilogue import EpilogueFunctorVisitor
from cutlass_cppgen.backend.evt.frontend import PythonASTFrontend

View File

@ -0,0 +1,38 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_cppgen.backend.evt.backend.sm80_emitter import Sm80Emitter
import cutlass_cppgen.backend.evt.backend.sm80_nodes as sm80_nodes
from cutlass_cppgen.backend.evt.backend.sm90_emitter import Sm90Emitter
import cutlass_cppgen.backend.evt.backend.sm90_nodes as sm90_nodes
from cutlass_cppgen.backend.evt.backend.sm100_emitter import Sm100Emitter
import cutlass_cppgen.backend.evt.backend.sm100_nodes as sm100_nodes

View File

@ -0,0 +1,159 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Base class for Epilogue Visitor Emitter
"""
from cutlass_library import DataTypeTag
from cutlass_cppgen.backend.evt.ir import TopoVisitorNode, DAGIR
class FusionCallbacks:
def __init__(self, dag_ir: DAGIR, cc: int, emit_CD=True) -> None:
"""
Emit the EVT fusion callbacks
:param dag_ir: the DAG IR holding the epilogue visitor
:param cc: compute capability
:param emit_CD: whether to emit nodes C & D as a part of the fusion callbacks
For Sm90, set emit_CD=False, as Tensor C & D are hardcoded in the collective API
so that their shared memory can be explicitly reused
For Sm89, set emit_CD=True as they are treated as normal AuxLoad & AuxStore nodes.
"""
self.dag_ir = dag_ir
self.emit_CD = emit_CD
self.cc = cc
self.evt_cc = 90 if cc >= 90 else cc
if self.cc < 90:
self.namespace = "threadblock"
else:
self.namespace = "fusion"
#
# Helper functions
#
def get_visitor_name(self, node: str):
"""
Get the visitor name
"""
meta = self.dag_ir.get_node_meta(node)
if not isinstance(meta, TopoVisitorNode) and self.dag_ir.in_degree(node) > 0:
return f"EVT{meta.name_camel}"
else:
return meta.name_camel
def emit(self):
node_metas = self.dag_ir.node_metas_topological_order()
epilogue_str = ""
# Step 1: emit individual node type decl
# emit the EVT & DAG connector
for meta in node_metas:
if not meta.disabled:
epilogue_str += self.emit_node(meta)
if not self.emit_CD and meta.name == "D":
continue
if isinstance(meta, TopoVisitorNode):
epilogue_str += self.emit_dag(meta)
else:
epilogue_str += self.emit_evt(meta)
# Step 2: post-processing & get callback name
if not self.emit_CD:
if not self.dag_ir.has_node("C"):
epilogue_str += "using ElementC = void;\nusing StrideC = StrideD;\n"
output_node = self.dag_ir.get_all_inputs("D")[0]
# The callback is the src of node D
callback_name = self.get_visitor_name(output_node)
else:
# The callback is the last node in the topological order
callback_name = self.get_visitor_name(node_metas[-1].name)
return epilogue_str, callback_name
def emit_evt(self, node):
if self.dag_ir.in_degree(node.name) == 0:
return ""
evt_tmp = f"""
using EVT{node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.evt_cc}EVT<
{node.name_camel},
"""
sorted_children = self.dag_ir.get_all_inputs(node.name)
evt_node_strs = [f" {self.get_visitor_name(child_name)}" for child_name in sorted_children]
evt_tmp += ",\n".join(evt_node_strs) + ">;\n"
return evt_tmp
def emit_dag(self, node):
subgraph = node.subgraph
subgraph_nodes = subgraph.nodes_topological_order()
# Emit the Edge Tuple
edge_tuples = "cute::tuple<\n"
for n in subgraph_nodes[:-1]:
in_edges = subgraph.in_edges(n)
edge_weights = [subgraph.get_edge_weight(edge[0], edge[1]) for edge in in_edges]
sorted_children = [edge[0] for _, edge in sorted(zip(edge_weights, in_edges))]
edge_tuple = " cute::seq<"
edge_str = [str(subgraph_nodes.index(child)) for child in sorted_children]
edge_tuple += ", ".join(edge_str) + ">,\n"
edge_tuples += edge_tuple
edge_tuples += " >"
# Emit the node list
dag_nodes = ""
dag_node_strs = []
for n in subgraph_nodes[:-1]:
n_meta = subgraph.get_node_meta(n)
if n_meta.disabled:
dag_node_strs.append(f" {self.get_visitor_name(n)}")
else:
dag_node_strs.append(f" {n_meta.name_camel}")
dag_nodes = ",\n".join(dag_node_strs)
return f"""
using {node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.evt_cc}TopologicalVisitor<
{DataTypeTag[node.subgraph.element_compute]},
{edge_tuples},
{dag_nodes}
>;
"""
def emit_node(self, node):
if isinstance(node, TopoVisitorNode):
emission = ""
for node in node.subgraph.node_metas_topological_order():
if not node.disabled:
emission += self.emit_node(node)
return emission
else:
return node.underlying_impl.type_decl

View File

@ -0,0 +1,116 @@
#################################################################################################
#
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Emitter for Sm100 Epilogue Visitor
"""
from cutlass_library import DataType, DataTypeTag, EpilogueScheduleTag, OpcodeClassTag
from cutlass_cppgen.backend.library import to_blackwell_threadblock_shape
from cutlass_cppgen.backend import GemmOperationUniversal
from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks
from cutlass_cppgen.backend.evt.ir.node import TupleEmitter
class Sm100CollectiveEpilogue:
def __init__(self, tile_description,
kernel_schedule,
epilogue_schedule,
element_accumulator,
element_d,
fusion_callbacks) -> None:
self.cta_tile_mnk, _ = to_blackwell_threadblock_shape(tile_description, tile_description.cluster_shape, kernel_schedule)
self.element_accumulator = element_accumulator
if fusion_callbacks.dag_ir.has_node("C"):
self.element_c = fusion_callbacks.dag_ir.get_node_meta("C").element
else:
self.element_c = DataType.void
self.element_d = element_d
self.schedule = epilogue_schedule
self.fusion_callbacks = fusion_callbacks
self.opclass = tile_description.math_instruction.opcode_class
@property
def CtaTileMNK(self) -> str:
"""
The threadblock shape
"""
return f"cute::Shape<_{self.cta_tile_mnk[0]}, _{self.cta_tile_mnk[1]}, _{self.cta_tile_mnk[2]}>"
@property
def EpilogueTileType(self) -> str:
"""
The epilogue tile type
"""
return "cutlass::epilogue::collective::EpilogueTileAuto"
@property
def Schedule(self) -> str:
return EpilogueScheduleTag[self.schedule]
def emit(self):
tuple_emitter = TupleEmitter("int64_t")
stride_D_str = self.fusion_callbacks.dag_ir.get_node_meta("D").underlying_impl.stride_mnl
stride_C_str = stride_D_str
if self.fusion_callbacks.dag_ir.has_node("C"):
stride_C_str = self.fusion_callbacks.dag_ir.get_node_meta("C").underlying_impl.stride_mnl
callback_decl, callback_name = self.fusion_callbacks.emit()
return callback_name, f"""
using EpilogueDescriptor = cutlass::epilogue::collective::detail::Sm100EpilogueDescriptor<
{OpcodeClassTag[self.opclass]},
{self.CtaTileMNK}, {self.EpilogueTileType},
{DataTypeTag[self.element_accumulator]}, {DataTypeTag[self.element_c]}, {DataTypeTag[self.element_d]},
{self.Schedule}, {stride_C_str}, {stride_D_str},
false /* IsPerColScaleSupported */,
false /* IsBlockScaleSupported */
>;
{callback_decl}
"""
class Sm100Emitter:
def __init__(self, operation: GemmOperationUniversal, graph) -> None:
fusion_callbacks = FusionCallbacks(graph, cc=100, emit_CD=False)
self.collective_epilogue = Sm100CollectiveEpilogue(
tile_description=operation.tile_description,
kernel_schedule=operation.tile_description.kernel_schedule,
epilogue_schedule=operation.tile_description.epilogue_schedule,
element_accumulator=operation.tile_description.math_instruction.element_accumulator,
element_d=fusion_callbacks.dag_ir.get_node_meta("D").element,
fusion_callbacks=fusion_callbacks
)
def emit(self):
return self.collective_epilogue.emit()

View File

@ -0,0 +1,134 @@
#################################################################################################
#
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from pycute import product
from cutlass_library import DataTypeSize, DataTypeTag
from cutlass_cppgen.backend.evt.ir import AuxLoadImpl, AuxStoreImpl
import cutlass_cppgen.backend.evt.backend.sm90_nodes as sm90_nodes
from cutlass_cppgen.backend.library import FloatRoundStyleTag
Sm100AccumulatorImpl = sm90_nodes.Sm90AccumulatorImpl
Sm100LoadSrcImpl = sm90_nodes.Sm90LoadSrcImpl
Sm100ScalarBroadcastImpl = sm90_nodes.Sm90ScalarBroadcastImpl
Sm100RowBroadcastImpl = sm90_nodes.Sm90RowBroadcastImpl
Sm100ColumnBroadcastImpl = sm90_nodes.Sm90ColumnBroadcastImpl
Sm100ComputeImpl = sm90_nodes.Sm90ComputeImpl
Sm100StoreDImpl = sm90_nodes.Sm90StoreDImpl
Sm100ColumnReductionImpl = sm90_nodes.Sm90ColumnReductionImpl
Sm100RowReductionImpl = sm90_nodes.Sm90RowReductionImpl
Sm100ScalarReductionImpl = sm90_nodes.Sm90ScalarReductionImpl
class Sm100AuxLoadImpl(AuxLoadImpl):
@property
def descriptor(self) -> str:
"""
Descriptor for Aux Load
"""
return f"{self.name_camel}Descriptor"
def decl_descriptor(self) -> str:
"""
Declare the descriptor type
"""
return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::Sm100AuxLoadDescriptor<EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}>;\n"
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = self.decl_descriptor()
self._type_decl += f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxLoad<
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
{self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, typename {self.descriptor}::CopyOpS2R
>;
"""
return self._type_decl
def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
"""
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
"""
return (DataTypeSize[self.element] * stages_c * product(epilogue_tile_mn) // 8, 128)
class Sm100AuxStoreImpl(AuxStoreImpl):
@property
def descriptor(self) -> str:
"""
Descriptor for Aux Load
"""
return f"{self.name_camel}Descriptor"
def decl_descriptor(self) -> str:
"""
Declare the descriptor type
"""
return f"""
using {self.descriptor} = cutlass::epilogue::collective::detail::Sm100AuxStoreDescriptor<
EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}
>;
"""
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = self.decl_descriptor()
self._type_decl += f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxStore<
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
{FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom,
typename {self.descriptor}::CopyOpR2S
>;
"""
return self._type_decl
def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
"""
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
"""
return (DataTypeSize[self.element] * stages_d * product(epilogue_tile_mn) // 8, 128)

View File

@ -0,0 +1,47 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Emitter for Sm80 Epilogue Visitor
"""
from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks
from cutlass_cppgen.backend import GemmOperationUniversal
class Sm80Emitter:
def __init__(self, operation: GemmOperationUniversal, graph) -> None:
self.fusion_callbacks = FusionCallbacks(graph, cc=80)
def emit(self):
callback_decl, callback_name = self.fusion_callbacks.emit()
return callback_name, callback_decl

View File

@ -0,0 +1,258 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_library import DataTypeSize, DataTypeTag
from cutlass_cppgen.backend.evt.ir import (
# Load Node
AccumulatorImpl,
AuxLoadImpl,
ColumnBroadcastImpl,
LoadNode,
LoadSrcImpl,
RowBroadcastImpl,
ScalarBroadcastImpl,
# Compute Node
ComputeImpl,
# Store Node
AuxStoreImpl,
ColumnReductionImpl,
RowReductionImpl,
ScalarReductionImpl
)
from cutlass_cppgen.backend.library import (
FloatRoundStyleTag,
FunctionalOp,
op_tag,
)
class Sm80AccumulatorImpl(AccumulatorImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""\nusing {self.name_camel} = cutlass::epilogue::threadblock::VisitorAccFetch;\n"""
return self._type_decl
class Sm80AuxLoadImpl(AuxLoadImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorAuxLoad<
OutputTileThreadMap, {DataTypeTag[self.element]}, {self.stride_mnl}
>;
"""
return self._type_decl
class Sm80LoadSrcImpl(Sm80AuxLoadImpl):
pass
class Sm80ScalarBroadcastImpl(ScalarBroadcastImpl):
def __init__(self, node: LoadNode) -> None:
super().__init__(node)
self.broadcast_count = 1
self.reduction_fn = FunctionalOp.Multiplies
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorScalarBroadcast<
{DataTypeTag[self.element]}, {self.stride_mnl}, {self.broadcast_count}, {op_tag(self.reduction_fn)}
>;
"""
return self._type_decl
class Sm80RowBroadcastImpl(RowBroadcastImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, {DataTypeTag[self.element]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm80ColumnBroadcastImpl(ColumnBroadcastImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorColBroadcast<
OutputTileThreadMap, {DataTypeTag[self.element]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm80ComputeImpl(ComputeImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorCompute<
{op_tag(self.fn)}, {DataTypeTag[self.element_output]}, {DataTypeTag[self.element_compute]},
{FloatRoundStyleTag[self.round_style]}
>;
"""
return self._type_decl
class Sm80AuxStoreImpl(AuxStoreImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, {DataTypeTag[self.element]}, {FloatRoundStyleTag[self.round_style]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm80StoreDImpl(Sm80AuxStoreImpl):
pass
class Sm80ColumnReductionImpl(ColumnReductionImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorColReduction<
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
OutputTileThreadMap, {DataTypeTag[self.element]},
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm80RowReductionImpl(RowReductionImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorRowReduction<
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
OutputTileThreadMap, {DataTypeTag[self.element]},
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm80ScalarReductionImpl(ScalarReductionImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorScalarReduction<
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
OutputTileThreadMap, {DataTypeTag[self.element]},
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
{self.stride_mnl}
>;
"""
return self._type_decl

View File

@ -0,0 +1,98 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Emitter for Sm90 Epilogue Visitor
"""
from cutlass_library import DataTypeTag, EpilogueScheduleTag
from cutlass_cppgen.backend import GemmOperationUniversal
from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks
class CollectiveEpilogue:
def __init__(self, tile_description,
schedule,
element_c,
element_d,
fusion_callbacks) -> None:
self.cta_tile_mnk = tile_description.threadblock_shape
self.element_c = element_c
self.element_d = element_d
self.schedule = schedule
self.fusion_callbacks = fusion_callbacks
@property
def CtaTileMNK(self) -> str:
"""
The threadblock shape
"""
return f"cute::Shape<_{self.cta_tile_mnk[0]}, _{self.cta_tile_mnk[1]}, _{self.cta_tile_mnk[2]}>"
@property
def EpilogueTileType(self) -> str:
"""
The epilogue tile type
"""
return "cutlass::epilogue::collective::EpilogueTileAuto"
@property
def Schedule(self) -> str:
return EpilogueScheduleTag[self.schedule]
def emit(self):
callback_decl, callback_name = self.fusion_callbacks.emit()
return callback_name, f"""
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor<
{self.CtaTileMNK}, {self.EpilogueTileType},
{DataTypeTag[self.element_c]}, {DataTypeTag[self.element_d]},
{self.Schedule}
>;
{callback_decl}
"""
class Sm90Emitter:
def __init__(self, operation: GemmOperationUniversal, graph) -> None:
fusion_callbacks = FusionCallbacks(graph, cc=90, emit_CD=False)
self.collective_epilogue = CollectiveEpilogue(
tile_description=operation.tile_description,
schedule=operation.tile_description.epilogue_schedule,
element_c=operation.C.element,
element_d=operation.C.element,
fusion_callbacks=fusion_callbacks
)
def emit(self):
return self.collective_epilogue.emit()

View File

@ -0,0 +1,329 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from pycute import product
from cutlass_library import DataTypeSize, DataTypeTag
from cutlass_cppgen.backend.evt.ir import (
# Load Node
AccumulatorImpl,
AuxLoadImpl,
ColumnBroadcastImpl,
LoadNode,
LoadSrcImpl,
RowBroadcastImpl,
ScalarBroadcastImpl,
# Compute Node
ComputeImpl,
ComputeNode,
# Store Node
AuxStoreImpl,
ColumnReductionImpl,
RowReductionImpl,
ScalarReductionImpl,
StoreNode,
StoreDImpl,
)
from cutlass_cppgen.backend.library import (
FloatRoundStyleTag,
FunctionalOp,
op_tag,
)
class Sm90AccumulatorImpl(AccumulatorImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""\nusing {self.name_camel} = cutlass::epilogue::fusion::Sm90AccFetch;\n"""
return self._type_decl
class Sm90LoadSrcImpl(LoadSrcImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using ElementC = {DataTypeTag[self.element]};
using StrideC = {self.stride_mnl};
using {self.name_camel} = cutlass::epilogue::fusion::Sm90SrcFetch<{DataTypeTag[self.element]}>;
"""
return self._type_decl
class Sm90AuxLoadImpl(AuxLoadImpl):
@property
def descriptor(self) -> str:
"""
Descriptor for Aux Load
"""
return f"{self.name_camel}Descriptor"
def decl_descriptor(self) -> str:
"""
Declare the descriptor type
"""
return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::AuxLoadDescriptor<EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}>;\n"
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = self.decl_descriptor()
self._type_decl += f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxLoad<
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
{self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, typename {self.descriptor}::CopyOpS2R
>;
"""
return self._type_decl
def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
"""
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
"""
return (DataTypeSize[self.element] * stages_c * product(epilogue_tile_mn) // 8, 128)
class Sm90ScalarBroadcastImpl(ScalarBroadcastImpl):
def __init__(self, node: LoadNode) -> None:
super().__init__(node)
self.broadcast_count = 1
self.reduction_fn = FunctionalOp.Multiplies
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ScalarBroadcast<
{DataTypeTag[self.element]}, {self.stride_mnl}, {self.broadcast_count}, {op_tag(self.reduction_fn)}
>;
"""
return self._type_decl
class Sm90RowBroadcastImpl(RowBroadcastImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, {DataTypeTag[self.element_output]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm90ColumnBroadcastImpl(ColumnBroadcastImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ColBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, {DataTypeTag[self.element_output]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm90ComputeImpl(ComputeImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90Compute<
{op_tag(self.fn)}, {DataTypeTag[self.element_output]}, {DataTypeTag[self.element_compute]},
{FloatRoundStyleTag[self.round_style]}
>;
"""
return self._type_decl
class Sm90AuxStoreImpl(AuxStoreImpl):
@property
def descriptor(self) -> str:
"""
Descriptor for Aux Load
"""
return f"{self.name_camel}Descriptor"
def decl_descriptor(self) -> str:
"""
Declare the descriptor type
"""
return f"""
using {self.descriptor} = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}
>;
"""
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = self.decl_descriptor()
self._type_decl += f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxStore<
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
{FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom,
typename {self.descriptor}::CopyOpR2S
>;
"""
return self._type_decl
def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
"""
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
"""
return (DataTypeSize[self.element] * stages_d * product(epilogue_tile_mn) // 8, 128)
class Sm90StoreDImpl(StoreDImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
return f"""
using ElementD = {DataTypeTag[self.element]};
using StrideD = {self.stride_mnl};
"""
class Sm90ColumnReductionImpl(ColumnReductionImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ColReduction<
{op_tag(self.reg_reduce_fn)}, {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, 0,
typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]},
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm90RowReductionImpl(RowReductionImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowReduction<
{op_tag(self.reg_reduce_fn)}, {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, 0 /* Stages */,
typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]},
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm90ScalarReductionImpl(ScalarReductionImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ScalarReduction<
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
{DataTypeTag[self.element]}, {DataTypeTag[self.element_compute]},
{FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}
>;
"""
return self._type_decl

View File

@ -0,0 +1,168 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Epilogue Visitor interface for compiling, and running visitor-based epilogue.
"""
import ctypes
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
from cutlass_library import DataType
import numpy as np
from cutlass_cppgen.backend.epilogue import EpilogueFunctorBase
import cutlass_cppgen.backend.evt.backend
from cutlass_cppgen.backend.frontend import TensorFrontend
from cutlass_cppgen.utils.datatypes import is_numpy_tensor
from cutlass_cppgen.backend.evt.passes.util import cc_map
class EpilogueFunctorVisitor(EpilogueFunctorBase):
"""
Apply an epilogue functor described by the epilogue EVT
:param cc: compute capability
:param visitor_frontend: user-provide visitor frontend
"""
def __init__(self, cc: int, visitor, element_compute=DataType.f32) -> None:
# Type of Emitter based on CC
self.emit_cls = getattr(cutlass_cppgen.backend.evt.backend, f"Sm{cc_map[cc]}Emitter")
# Visitor Types
self.visitor = visitor
self.graph = visitor.dag_ir
# Data types
self.element_epilogue = element_compute # element compute
self.element_output = self.graph.get_node_meta('D').underlying_impl.element
# Epilogue Thread Type
epilogue_thread_type = self.visitor.epilogue_thread_type
if cc_map[cc] in [90, 100]:
self.arg_c_type = self.visitor.arg_c_type
self.arg_d_type = self.visitor.arg_d_type
output_names = self.visitor.return_names
reduction_names = self.visitor.reduction_names
# Epilogue stages specialized for sm80 kernel
if cc == 80:
if hasattr(self.visitor, "epilogue_stages"):
self.epilogue_stages = self.visitor.epilogue_stages
assert self.epilogue_stages <= 2, "Only supports Stages <=2 in SM80 Epilogue"
# Epilogue Argument Type
class _Arguments(ctypes.Structure):
"""
Concepts:
class _EpilogueArguments(ctypes.Structure):
_fields_ = [
("epilogue", _Arguments), <- this class
("ptr_C", ctypes.c_void_p),
("stride_C", StrideBatched_),
("ptr_D", ctypes.c_void_p),
("stride_D", StrideBatched_)
]
"""
_fields_ = [
("output_op", epilogue_thread_type)
]
def __init__(self, kwargs: dict) -> None:
# The user-input kwargs is a dict of (name: tensors)
# We first convert all of them to device pointers
ptr_kwargs = {}
for key in kwargs.keys():
is_output = key in output_names and key not in reduction_names
ptr_kwargs[key] = self.get_tensor_ptr(key, kwargs, is_output)
# Initialize the thread arguments
self.output_op = epilogue_thread_type(ptr_kwargs)
def get_tensor_ptr(self, tensor_name, kwargs, is_output=False):
"""
Helper function for extracting device pointer
"""
# Skip the special tensors
if cc in [90, 100]:
if tensor_name in ["C", "D"]:
return 0
if tensor_name not in kwargs.keys():
raise ValueError(f"Tensor {tensor_name} is not provided.")
tensor = kwargs[tensor_name]
# For float scalar constant, directly return the value
if isinstance(tensor, float):
return tensor
# The tensor frontend returns a device buffer for np.ndarray
# and device ptr for other frontends
buffer_or_ptr = TensorFrontend.argument(tensor, is_output)
if is_numpy_tensor(tensor):
# Remember the host tensor for later synchronization
setattr(self, f"{tensor_name}_buffer", buffer_or_ptr)
setattr(self, f"{tensor_name}_host", tensor)
return int(buffer_or_ptr.ptr)
else:
return int(buffer_or_ptr)
def sync(self):
"""
Synchronize the results from device to host
"""
for name in output_names:
if hasattr(self, f"{name}_host"):
host_tensor = getattr(self, f"{name}_host")
tensor_ptr = getattr(self, f"{name}_buffer").ptr
(err,) = cuda.cuMemcpyDtoH(
host_tensor,
tensor_ptr,
host_tensor.size * host_tensor.itemsize,
)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
self.epilogue_type = _Arguments
def emit(self, operation):
"""
Emit the C++ code
"""
emitter = self.emit_cls(operation, self.graph)
return emitter.emit()
def get_smem_size(self, tile_description):
"""
Get the shared memory size in bytes
"""
return self.visitor.get_smem_size(tile_description)

View File

@ -0,0 +1,33 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_cppgen.backend.evt.frontend.python_ast import PythonASTFrontend

View File

@ -0,0 +1,272 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Base class for Python EVT Frontend
"""
from typing import Union
from cutlass_library import DataType
from cutlass_cppgen.backend.evt.ir import (
ComputeNode,
DAGIR,
LayoutNode,
LoadNode,
StoreNode,
)
from cutlass_cppgen.backend.evt.passes import (
EVTGraphDrawer,
EVTPassManager,
GetSmemSize,
PassDAG2Tree,
PassGetArgumentType,
PassGetImpl,
PassFixElementD,
PassLayoutManipulateElimination,
PassPreprocessRed,
PassShapeTypePropagation,
)
from cutlass_cppgen.backend.evt.passes.util import cc_map
from cutlass_cppgen.backend.utils import device_cc
from cutlass_cppgen.epilogue.evt_ops import permute, reshape
from cutlass_cppgen.utils.datatypes import library_type
class EVTFrontendBase:
layout_fns = {
"permute": permute,
"reshape": reshape
}
def __init__(self, cc, element_compute=DataType.f32, additional_passes=[], **kwargs) -> None:
self.cc = cc
self.element_compute = library_type(element_compute)
self.dag_ir = DAGIR(self.cc, self.element_compute)
self.compute_cnt = 0
self.layout_cnt = 0
self.imm_cnt = 0
self.pass_manager = EVTPassManager(
self.dag_ir,
[
PassPreprocessRed,
PassGetArgumentType,
PassShapeTypePropagation,
PassLayoutManipulateElimination,
PassGetImpl,
PassDAG2Tree,
PassFixElementD
] + additional_passes)
if self.cc == 80:
self._epilogue_stages = 1
else:
self._epilogue_stages = None
@property
def epilogue_stages(self):
return self._epilogue_stages
@epilogue_stages.setter
def epilogue_stages(self, stages):
self._epilogue_stages = stages
def parse(self, *args, **kwargs):
raise NotImplementedError(f"The 'parse' function must be overloaded in frontend class")
def trace(self, *args, **kwargs):
# Parse the input
self.parse(*args, **kwargs)
# Verify the DAG IR to ensure that "D" is the output node with out_degree = 0
if (self.cc >= 90):
if (self.dag_ir.out_degree("D") != 0):
raise RuntimeError(
f"On SM90 or higher, D is expected to be a output node with 0 users to "
f"enable smem reuse between C and D, but got {self.dag_ir.out_degree('D')}")
# Run the passes
self.pass_manager()
# Set the epilogue type
self.epilogue_thread_type = self.dag_ir.epilogue_thread_type
if cc_map[self.cc] in [90, 100]:
self.arg_c_type = self.dag_ir.arg_c_type
self.arg_d_type = self.dag_ir.arg_d_type
self.reduction_names = self.dag_ir.reduction_names
#
# Helper functions for DAG IR manipulation
#
def add_node(self, node):
self.dag_ir.add_node(node)
def add_edge(self, src, tgt, weight=0):
self.dag_ir.add_edge(src, tgt, weight=weight)
def set_tensor(self, node_name, example):
"""
Add an example tensor to node {node_name} in the DAG IR
"""
meta = self.dag_ir.get_node_meta(node_name)
meta.tensor = {"tensor": example}
def set_store_tensor(self, node_name, example):
"""
Add an example tensor to node {node_name} in the DAG IR
"""
meta = self.dag_ir.get_node_meta(node_name)
meta.store_tensor = {"tensor": example}
def mark_output(self, node_name):
"""
Mark a store node as output
"""
meta = self.dag_ir.get_node_meta(node_name)
if not isinstance(meta, StoreNode):
raise ValueError(
f"Only StoreNodes can be marked as output. "
f"Got {type(meta).__name__}: {node_name}")
meta.is_output = True
# Add node with specific type
def add_load_node(self, name, example):
"""
Add a Load node to DAG IR
:param name: name of the loaded variable
:type name: str
:param example: example input
:type example: np.ndarray|torch.Tensor|cupy.ndarray|float
"""
if name is None:
raise ValueError(f"Name is not provided.")
if example is None:
raise ValueError(f"Example input for {name} is not provided.")
load_node = LoadNode(name)
load_node.tensor = {"tensor": example}
# Special logics for accumulator
if name == "accum":
if load_node.tensor.rank == 2:
new_shape = tuple([1, ] + list(load_node.tensor.shape))
load_node.tensor.broadcast(new_shape)
elif load_node.tensor.rank < 2 or load_node.tensor.rank > 3:
raise ValueError(f"Expect example inputs for 'accum' be a rank-2 or rank-3 tensor. Got {load_node.tensor.shape}.")
self.add_node(load_node)
def add_imm(self, value: Union[float,int]):
"""
Add an immediate scalar value to DAG IR
:param value: the value of the immediate scalar
:type value: float
"""
try:
value = float(value)
except:
raise ValueError(f"{type(value).__name__} cannot be converted to float.")
name = f"imm_{value}_k{self.imm_cnt}".replace('.', '_')
self.imm_cnt += 1
load_node = LoadNode(name)
load_node.tensor = {"tensor": value, "is_constant": True}
self.add_node(load_node)
return name
def add_compute_node(self, op, name=None):
"""
Add a compute node.
:param op: the computation op
:param name: the node name (optional)
:type name: str
:return: the name of the compute node
"""
if name is None:
name = f"compute_{self.compute_cnt}"
self.compute_cnt += 1
compute_node = ComputeNode(
name=name, fn=op,
element_output=self.element_compute,
element_compute=self.element_compute)
self.add_node(compute_node)
return compute_node.name
def add_layout_node(self, op, kwargs, name=None):
"""
Add a layout node.
:param op: the layout op
:type op: evt_ops
:param name: the node name (optional)
:type name: str
:return: the name of the layout node
"""
if name is None:
name = f"layout_{self.layout_cnt}"
self.layout_cnt += 1
layout_node = LayoutNode(name=name, fn=op, kwargs=kwargs)
self.add_node(layout_node)
return layout_node.name
def add_store_node(self, name):
store_node = StoreNode(name)
self.add_node(store_node)
#
# Visualization The DAG IR
#
def visualize(self, name="dag_ir"):
"""
Visualize the dag ir with svg file
:param name: the name of the graph
"""
drawer = EVTGraphDrawer(self.dag_ir, name)
try:
for name, graph in drawer.get_dot_graph():
graph.write_svg(f"./{name}.svg")
except:
raise RuntimeError(
"'dot' is not found in path. GraphDrawer is disabled. "
"Please install it with 'sudo apt-get install graphviz'."
)
#
# Get shared memory size
#
def get_smem_size(self, tile_description):
"""
Get the shared memory size of the epilogue
"""
smem_size = GetSmemSize(self.dag_ir)(tile_description)
return smem_size

View File

@ -0,0 +1,194 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Python AST frontend that parses input into DAG IR
"""
import ast
import inspect
import textwrap
from cutlass_library import DataType
import cutlass_cppgen
from cutlass_cppgen.backend.evt.frontend.frontend_base import EVTFrontendBase
from cutlass_cppgen.backend.epilogue import identity, relu, tanh, sigmoid, silu, hardswish, gelu
from cutlass_cppgen.backend.library import FunctionalOp
class PythonASTFrontend(EVTFrontendBase, ast.NodeVisitor):
def __init__(self, cc, element_compute=DataType.f32, **kwargs):
super().__init__(cc, element_compute, **kwargs)
# Flags
# If this state is True, visit_Constant returns values without creating imm node
self.no_imm = False
self.visiting_return = False
def parse(self, example_inputs):
self.example_inputs = example_inputs
self.source = textwrap.dedent(inspect.getsource(self.__call__))
self.ast = ast.parse(self.source)
self.visit(self.ast)
#
# Helper functions
#
@staticmethod
def ast_op_to_bindings(op):
mapping = {
ast.Add: FunctionalOp.Plus,
ast.Sub: FunctionalOp.Minus,
ast.Mult: FunctionalOp.Multiplies,
ast.Div: FunctionalOp.Divides,
"maximum": FunctionalOp.Maximum,
"minimum": FunctionalOp.Minimum,
"identity": identity.binding_type,
"relu": relu.binding_type,
"tanh": tanh.binding_type,
"sigmoid": sigmoid.binding_type,
"silu": silu.binding_type,
"hardswish": hardswish.binding_type,
"gelu": gelu.binding_type,
"multiply_add": FunctionalOp.MultiplyAdd,
"sum": (FunctionalOp.Plus, FunctionalOp.AtomicAdd),
"max": (FunctionalOp.Maximum, FunctionalOp.AtomicMaximum),
"exp": FunctionalOp.Exp
}
return mapping[op]
#
# Visiting different node types
#
def visit_FunctionDef(self, node: ast.FunctionDef):
# Visit args and register load nodes
for arg in node.args.args:
self.visit(arg)
for expr in node.body:
self.visit(expr)
def visit_arg(self, node: ast.arg):
# Name of the argument
name = node.arg
try:
example_tensor = self.example_inputs[name]
except:
raise RuntimeError(f"Example input for {name} is not provided.")
self.add_load_node(name, example_tensor)
def visit_Name(self, node: ast.Name):
return node.id
def visit_Constant(self, node: ast.Constant):
if self.no_imm:
return node.value
else:
name = self.add_imm(node.value)
return name
def visit_Tuple(self, node: ast.Tuple):
results = []
for elt in node.elts:
results.append(self.visit(elt))
return tuple(results)
def visit_keyword(self, node: ast.keyword):
return {node.arg: self.visit(node.value)}
def visit_BinOp(self, node: ast.BinOp):
if self.visiting_return:
raise SyntaxError("Return value cannot be an expression")
lhs = self.visit(node.left)
rhs = self.visit(node.right)
op = self.ast_op_to_bindings(type(node.op))
name = self.add_compute_node(op)
# Add edges
# The edge weights are used to sort the input args
self.add_edge(lhs, name, weight=0)
self.add_edge(rhs, name, weight=1)
return name
def visit_Assign(self, node: ast.BinOp):
target = self.visit(node.targets[0])
value = self.visit(node.value)
# Create the assign node
self.add_store_node(target)
# Add edges
self.add_edge(value, target)
return target
def visit_Call(self, node: ast.Call):
if self.visiting_return:
raise SyntaxError("Return value cannot be an expression")
func = self.visit(node.func)
args = [self.visit(arg) for arg in node.args]
if func in self.layout_fns.keys():
# Parse kwargs
# By default, visiting imm automatically creates a load node
# However, in function call, keyword args are used to set
# specific function attributes such as indices for permute
# So no_imm is set to True temporarily
self.no_imm = True
kwargs = {}
for kw in node.keywords:
kwargs.update(self.visit(kw))
self.no_imm = False
op = self.layout_fns[func]
name = self.add_layout_node(op, kwargs)
else:
op = self.ast_op_to_bindings(func)
name = self.add_compute_node(op)
# Add edges
for idx, arg in enumerate(args):
self.add_edge(arg, name, weight=idx)
return name
def visit_Return(self, node: ast.Return):
self.visiting_return = True
results = self.visit(node.value)
self.visiting_return = False
self.return_names = results
if not isinstance(results, tuple):
results = (results,)
for rst in results:
try:
example_tensor = self.example_inputs[rst]
except:
raise RuntimeError(f"Example input for {rst} is not provided.")
self.set_store_tensor(rst, example_tensor)
self.mark_output(rst)

View File

@ -0,0 +1,53 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_cppgen.backend.evt.ir.compute_nodes import ComputeNode, ComputeImpl
from cutlass_cppgen.backend.evt.ir.dag_ir import DAGIR
from cutlass_cppgen.backend.evt.ir.layout_nodes import LayoutNode
from cutlass_cppgen.backend.evt.ir.load_nodes import (
LoadNode,
AccumulatorImpl,
LoadSrcImpl,
AuxLoadImpl,
RowBroadcastImpl,
ColumnBroadcastImpl,
ScalarBroadcastImpl
)
from cutlass_cppgen.backend.evt.ir.node import TopoVisitorNode, NoOpImpl
from cutlass_cppgen.backend.evt.ir.store_nodes import (
StoreNode,
StoreDImpl,
AuxStoreImpl,
ColumnReductionImpl,
RowReductionImpl,
ScalarReductionImpl
)

View File

@ -0,0 +1,91 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Python registration for compute nodes in EVT
"""
from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase
from cutlass_cppgen.backend.library import FloatRoundStyle
class ComputeImplBase(ImplBase):
"""
Base class for compute implementation
"""
def __init__(self, node) -> None:
super().__init__(node)
class ComputeImpl(ComputeImplBase):
"""
Implementation for Compute Node
"""
def __init__(self, node) -> None:
super().__init__(node)
self.fn = node.fn
self.element_output = node.element_output
self.element_compute = node.element_compute
self.round_style = node.round_style
@staticmethod
def match(node, problem_size: tuple):
return True
class ComputeNode(NodeBase):
"""
Compute Node in DAG IR
"""
possible_impls = [
ComputeImpl
]
def __init__(
self, name: str, fn, element_output,
element_compute,
round_style=FloatRoundStyle.ToNearest) -> None:
super().__init__(name)
self.op = "compute"
self.fn = fn
self.element_compute = element_compute
self.round_style = round_style
def type_propagation(self, *args, **kwargs):
"""
Load node loads tensor under type `tensor.element` and returns an array of type `tensor.element`.
"""
self.element = self.element_compute
# In general, the compute nodes have element_output = element_compute
# In certain cases like producer of D it is overwritten by other passes
if not hasattr(self, "element_output"):
self.element_output = self.element

View File

@ -0,0 +1,254 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
DAG IR used by Python EVT
"""
import networkx as nx
from cutlass_library import DataType
from cutlass_cppgen.backend.evt.ir.compute_nodes import ComputeNode
from cutlass_cppgen.backend.evt.ir.node import NodeBase
from cutlass_cppgen.backend.library import ActivationOp
from cutlass_cppgen.backend.utils import device_cc
class DAGIR:
"""
``DAGIR`` is the main data structure used in the EVT Intermediate Representation.
It consists of a series of ``Node`` s, each representing epilogue visitor nodes.
In the DAGIR, ``node`` is an string of its name. ``node_meta`` is the underlying class of the node
"""
def __init__(self, cc, element_compute=DataType.f32) -> None:
# The EVT DAGIR is managed through the nextworkX Digraph class
self._graph = nx.DiGraph()
self.element_compute = element_compute
self.reduction_names = []
self.cc = cc
self.identity_counter = 0
#
# IR manipulator
#
def add_node(self, meta: NodeBase):
"""
Add a node to dag ir
"""
if self.has_node(meta.name):
raise SyntaxError(f"Variable '{meta.name}' cannot be defined twice.")
self._graph.add_node(meta.name, meta=meta)
def add_edge(self, src: str, dst: str, weight: int=0):
"""
Add an edge src -> dst to dag ir with weight
"""
if not self.has_node(src):
raise SyntaxError(f"Variable '{src}' is undefined.")
if not self.has_node(dst):
raise SyntaxError(f"Variable '{dst}' is undefined.")
if self._graph.has_edge(src, dst):
# The DiGraph doesn't support multiple edges between two nodes
# We insert an identity node in such case as a workaround
identity_name = f"autogen_identity_{self.identity_counter}"
self.identity_counter += 1
compute_node = ComputeNode(
name=identity_name, fn=ActivationOp.Identity,
element_output=self.element_compute,
element_compute=self.element_compute)
self.add_node(compute_node)
self.add_edge(src, identity_name, 0)
self.add_edge(identity_name, dst, weight)
else:
self._graph.add_edge(src, dst, weight=weight)
def remove_node(self, node: str):
"""
Remove node from dag ir
"""
self._graph.remove_node(node)
def remove_edge(self, src: str, dst: str):
"""
Remove edge src -> dst
"""
self._graph.remove_edge(src, dst)
#
# Helper functions for getting attrs
#
def has_node(self, node: str) -> bool:
"""
Check if the node is in the graph
"""
return self._graph.has_node(node)
def in_degree(self, node: str):
"""
Get the input degree of node
"""
return self._graph.in_degree(node)
def in_edges(self, node: str):
"""
Get the input edges of node
"""
return [edge for edge in self._graph.in_edges(node)]
def out_degree(self, node: str):
"""
Get the output degree of node
"""
return self._graph.out_degree(node)
def out_edges(self, node: str):
"""
Get the output edges of node
"""
return [edge for edge in self._graph.out_edges(node)]
def get_node_meta(self, node: str):
"""
Get the meta data of the node
"""
return self._graph.nodes[node]["meta"]
def get_edge_weight(self, src, dst):
"""
Get the edge weight of edge src->dst
"""
return self._graph.get_edge_data(src, dst)["weight"]
#
# High-level helper functions
#
def all_reachable_nodes(self, node: str):
"""
Get all the nodes reachable from the current node (exclude)
"""
return list(nx.dfs_preorder_nodes(self._graph, source=node))
def get_users(self, node: str):
"""
Get all users of the current node
"""
return [edge[1] for edge in self.out_edges(node)]
def get_all_inputs(self, node: str):
"""
Get all the input nodes sorted by edge weight
"""
in_edges = self.in_edges(node)
edge_weights = [self.get_edge_weight(*edge) for edge in in_edges]
return [edge[0] for _, edge in sorted(zip(edge_weights, in_edges))]
def get_all_inputs_meta(self, node: str):
"""
Get all the input node metas sorted by edge weight
"""
return [self.get_node_meta(input_node) for input_node in self.get_all_inputs(node)]
def replace_all_uses_with(self, node1, node2):
"""
Replace all uses of node1 with node2
"""
for edge in self.out_edges(node1):
weight = self.get_edge_weight(*edge)
user = edge[1]
self.add_edge(node2, user, weight)
self.remove_edge(node1, user)
self.remove_node(node1)
#
# Node accessor
#
def nodes_topological_order(self):
"""
Get the nodes in the unique lexicographical topological order
It generates a unique ordering of nodes by first sorting topologically
and then additionally by sorting lexicographically.
Although topological_sort alone also works, this generates a unique key
for each epilogue visitor pattern and ensures the compilation cache can be reused.
:return: list[str]
"""
return list(nx.lexicographical_topological_sort(self._graph))
def node_metas_topological_order(self):
"""
Get the node metas in topological order
:return: list[NodeBase]
"""
return [self.get_node_meta(node) for node in self.nodes_topological_order()]
@property
def nodes(self):
"""
Get all nodes
:return: list[str]
"""
return list(self._graph.nodes)
@property
def nodes_meta(self):
"""
Get all node metas
:return: list[NodeBase]
"""
return [data[1]['meta'] for data in self._graph.nodes.data()]
@property
def edges(self):
"""
Get all edges
:return: list[(str, str)]
"""
return list(self._graph.edges)
#
# Path
#
def has_path(self, src: str, target: str) -> bool:
"""
Return True is a path exists from src to target
"""
return nx.has_path(self._graph, src, target)

View File

@ -0,0 +1,324 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Layout algebras
"""
from pycute import Layout, composition, make_layout, flatten, product
def _infer_split(old_shape, new_shape):
old_shape = _tuple_to_list(old_shape)
new_shape = _tuple_to_list(new_shape)
if len(old_shape) == 0 and len(new_shape) == 0:
return []
if len(old_shape) == 0:
if product(tuple(new_shape)) != 1:
raise ValueError("Invalid reshape size")
else:
return new_shape
if len(new_shape) == 0:
if product(tuple(old_shape)) != 1:
raise ValueError("Invalid reshape size")
else:
return old_shape
# This is done recursively by only process the last dimension at each time
old_dim = old_shape[-1]
new_dim = new_shape[-1]
# Exact match
if old_dim == new_dim:
return _infer_split(old_shape[:-1], new_shape[:-1]) + [new_dim,]
# Needs split
if old_dim > new_dim and old_dim % new_dim == 0:
residual = old_dim // new_dim
return _infer_split(old_shape[:-1] + [residual,], new_shape[:-1]) + [new_dim,]
# Needs merge
if old_dim < new_dim and new_dim % old_dim == 0:
residual = new_dim // old_dim
return _infer_split(old_shape[:-1], new_shape[:-1] + [residual,]) + [old_dim,]
raise NotImplementedError(f"Unsupported split: {old_shape} -> {new_shape}")
def _infer_merge(flatten_shape, shape):
flatten_shape = _tuple_to_list(flatten_shape)
shape = _tuple_to_list(shape)
idx_flat = 0
merged_shape = []
for dim in shape:
# Exact match
if dim == flatten_shape[idx_flat]:
merged_shape.append(dim)
idx_flat += 1
# Need group
elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0:
residual = dim
group = []
while(residual > 1):
group.append(flatten_shape[idx_flat])
residual = residual // flatten_shape[idx_flat]
idx_flat += 1
merged_shape.append(group)
else:
raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}")
return merged_shape
def _list_to_tuple(nested_list):
if isinstance(nested_list, list) or isinstance(nested_list, tuple):
return tuple(_list_to_tuple(item) for item in nested_list)
return nested_list
def _tuple_to_list(nested_tuple):
if isinstance(nested_tuple, list) or isinstance(nested_tuple, tuple):
return list(_tuple_to_list(item) for item in nested_tuple)
return nested_tuple
def _reverse_tuple(nested_tuple: tuple):
if isinstance(nested_tuple, tuple):
return tuple([_reverse_tuple(item) for item in nested_tuple][::-1])
return nested_tuple
def _get_first_lhs_nonzero_stride(stride_list, idx):
for i in reversed(range(idx)):
if stride_list[i] != 0:
return i
else:
return None
def _get_first_rhs_nonzero_stride(stride_list, idx):
for i in range(idx+1, len(stride_list)):
if stride_list[i] != 0:
return i
else:
return None
def reshape(layout, new_shape):
"""
General reshape of input layout.
It takes two steps:
1. split the dimensions of the old layout
2. merge the splitted dimensions according to the new shape
"""
#
# Step 1: Split the dimensions of the old layout
#
# 1.1 Flat old and new shape
old_flatten_shape = list(flatten(layout.shape))
new_flatten_shape = list(flatten(new_shape))
# 1.2 Infer the flatten splitted shape
splitted_flatten_shape = _infer_split(old_flatten_shape, new_flatten_shape)
# 1.3 Unflat the splitted shape based on the old shape
splited_shape = _infer_merge(splitted_flatten_shape, old_flatten_shape)
# 1.4 Infer the type of each split
# If the split type is in row-major (R), the dimension list is reversed because
# the cute::composition only support column-major split
split_type = [] # the type of each split (ColumnMajor or RowMajor)
permuted_splitted_shape = []
old_flatten_stride = list(flatten(layout.stride))
for idx, dim in enumerate(splited_shape):
if not isinstance(dim, list):
permuted_splitted_shape.append(dim)
split_type.append("C")
else:
lhs_stride = _get_first_lhs_nonzero_stride(old_flatten_stride, idx)
rhs_stride = _get_first_rhs_nonzero_stride(old_flatten_stride, idx)
# Special case for single tuple
# Use column-major by default
if lhs_stride is None and rhs_stride is None:
permuted_splitted_shape.append(dim)
split_type.append("C")
else:
if lhs_stride is not None and rhs_stride is not None:
# We consider shape[idx]:stride[idx]
# Case 1: stride[idx - 1] <= stride[idx] <= stride[idx + 1]: column major
if lhs_stride <= old_flatten_stride[idx] and old_flatten_stride[idx] <= rhs_stride:
permuted_splitted_shape.append(dim)
split_type.append("C")
# Case 2: stride[idx - 1] > stride[idx] > stride[idx + 1]: row major
elif lhs_stride > old_flatten_stride[idx] and old_flatten_stride[idx] > rhs_stride:
permuted_splitted_shape.append([d for d in reversed(dim)])
split_type.append("R")
# Case 3: stride[idx - 1] <= stride[idx] > stride[idx + 1]: concave
elif lhs_stride <= old_flatten_stride[idx] and old_flatten_stride[idx] > rhs_stride:
if lhs_stride >= rhs_stride:
permuted_splitted_shape.append(dim)
split_type.append("C")
else:
permuted_splitted_shape.append([d for d in reversed(dim)])
split_type.append("R")
# Case 4: stride[idx - 1] > stride[idx] <= stride[idx + 1]: concave
elif lhs_stride > old_flatten_stride[idx] and old_flatten_stride[idx] <= rhs_stride:
if lhs_stride >= rhs_stride:
permuted_splitted_shape.append(dim)
split_type.append("C")
else:
permuted_splitted_shape.append([d for d in reversed(dim)])
split_type.append("R")
else:
raise NotImplementedError()
elif lhs_stride is None:
# Case 1: dim's stride < dim+1's stride, expand in column major
if old_flatten_stride[idx] > rhs_stride:
permuted_splitted_shape.append([d for d in reversed(dim)])
split_type.append("R")
else:
permuted_splitted_shape.append(dim)
split_type.append("C")
else:
# Case 1: dim's stride > dim-1's stride
if old_flatten_stride[idx] < lhs_stride:
permuted_splitted_shape.append([d for d in reversed(dim)])
split_type.append("R")
else:
permuted_splitted_shape.append(dim)
split_type.append("C")
# 1.4 Generate the splitted layout
permuted_splitted_layout = composition(layout, Layout(_list_to_tuple(permuted_splitted_shape)))
# 1.5 Reverse the permutation in 1.4 before merge
splitted_shape = []
splitted_stride = []
for shape_dim, stride_dim, type in zip(
permuted_splitted_layout.shape,
permuted_splitted_layout.stride,
split_type):
if type == "C":
splitted_shape.append(shape_dim)
splitted_stride.append(stride_dim)
else:
splitted_shape.append(tuple([d for d in reversed(shape_dim)]))
splitted_stride.append(tuple([d for d in reversed(stride_dim)]))
splitted_layout = Layout(tuple(splitted_shape), tuple(splitted_stride))
#
# Step 2: Merge the splitted dimensions according to the new shape
#
# 2.1 Merge layout
merged_layout = composition(splitted_layout, Layout(new_shape))
# 2.2 Cleaning up
output_layout = composition(merged_layout, Layout(new_shape))
return output_layout
def permutation(layout, permutation):
"""
Permute the layout
"""
new_shape = tuple([layout.shape[idx] for idx in permutation])
new_stride = tuple([layout.stride[idx] for idx in permutation])
return Layout(new_shape, new_stride)
def _broadcast(layout, new_shape):
if len(layout) == 1 and isinstance(new_shape, int):
old_dim = layout.shape
old_stride = layout.stride
new_dim = new_shape
if old_dim == new_dim:
return Layout(old_dim, old_stride)
elif old_dim == 1:
return Layout(new_dim, 0)
else:
raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {new_dim}")
# Align the dimensions
old_shape = layout.shape
if isinstance(old_shape, int):
old_shape = (old_shape,)
sub_layouts = [layout,]
else:
sub_layouts = [sub_layout for sub_layout in layout]
rhs_broadcast_layouts = [Layout(1, 0)] * (len(new_shape) - len(old_shape))
# Get the broadcasted layout
broadcast_layouts = []
try:
layout = make_layout(*sub_layouts, *rhs_broadcast_layouts)
broadcast_layouts = []
for idx, sub_layout in enumerate(layout):
broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx]))
except NotImplementedError:
layout = make_layout(*rhs_broadcast_layouts, *sub_layouts)
for idx, sub_layout in enumerate(layout):
broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx]))
return make_layout(*broadcast_layouts)
def broadcast(layout, new_shape):
"""
Broadcast the new layout based on the input shape
The broadcasted shape equals to the new shape
The stride of broadcasted dimensions are 0
"""
return _broadcast(layout, new_shape)
def debroadcast(layout, dims):
"""
Squeeze the 0-stride
"""
for dim in dims:
if layout.stride[dim] != 0:
raise ValueError(f"Dim{dim} cannot be debroadcasted as it has stride {layout.stride[dim]}")
new_shape = tuple([s for idx, s in enumerate(layout.shape) if idx not in dims])
new_stride = tuple([s for idx, s in enumerate(layout.stride) if idx not in dims])
return Layout(new_shape, new_stride)
def canonicalization_(shapes, strides):
if isinstance(shapes, tuple):
c_shapes = []
c_strides = []
for shape, stride in zip(shapes, strides):
c_shape, c_stride = canonicalization_(shape, stride)
c_shapes.append(c_shape)
c_strides.append(c_stride)
return tuple(c_shapes), tuple(c_strides)
else:
if shapes == 1:
return 1, 0
else:
return shapes, strides
def canonicalization(layout):
"""
Canonicalize the input layout
1. set the stride of shape "1" to 0
"""
new_shape, new_stride = canonicalization_(layout.shape, layout.stride)
return Layout(new_shape, new_stride)

View File

@ -0,0 +1,336 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Layout manipulation nodes and implementations
The layout Nodes change the layout of intermediate nodes in epilogue visitor graph
"""
from copy import deepcopy
from cutlass_library import LayoutType
from pycute import product, flatten
import cutlass_cppgen
from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _tuple_to_list
from cutlass_cppgen.backend.evt.ir.node import NodeBase
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
class PermutationImpl:
"""
Detailed implementation and helper functions for permutation
"""
def __init__(self, node) -> None:
assert "indices" in node.kwargs.keys()
self.indices = list(node.kwargs["indices"])
self.inverse_indices = self.get_inverse_indices(self.indices)
def get_inverse_impl(self):
inverse_impl = deepcopy(self)
inverse_impl.indices = self.inverse_indices
inverse_impl.inverse_indices = self.indices
return inverse_impl
def update(self, shape):
num_dim = len(shape)
indices = self.indices
num_old_dim = len(indices)
# Add offset
for i, idx in enumerate(indices):
indices[i] = idx + num_dim - num_old_dim
# Add broadcast dims
for i in range(num_dim - num_old_dim):
indices = [i,] + indices
self.indices = indices
self.inverse_indices = self.get_inverse_indices(self.indices)
def get_inverse_indices(self, indices):
"""
Get the indices for inverse permutation
"""
num_dim = len(indices)
inverse_indices = [0] * num_dim
for i in range(num_dim):
inverse_indices[indices[i]] = i
return inverse_indices
def shape_propagation(self, input_node_meta):
input_shape = input_node_meta.tensor.shape
output_shape = tuple([input_shape[idx] for idx in self.indices])
return output_shape
def broadcast(self, shape, node_meta: NodeBase):
"""
Broadcast the inputs based on current shape
"""
self.update(shape)
inverse_shape = tuple([shape[idx] for idx in self.inverse_indices])
node_meta.tensor.broadcast(inverse_shape)
def apply_to_user(self, usr_meta: NodeBase):
"""
Propagate the permutation to the users of the current nodes
"""
usr_meta.tensor.permute(self.inverse_indices)
if hasattr(usr_meta, "store_tensor"):
if usr_meta.store_tensor is not None:
usr_meta.store_tensor.permute(self.inverse_indices)
def apply_to_input(self, input_meta: NodeBase):
"""
Propagate the permutation to inputs of the current nodes
"""
input_meta.tensor.permute(self.indices)
if hasattr(input_meta, "store_tensor"):
if input_meta.store_tensor is not None:
input_meta.store_tensor.permute(self.indices)
class ReshapeImpl:
"""
Detailed implementation and helper functions for reshape
"""
def __init__(self, node) -> None:
self.node = node
assert "new_shape" in node.kwargs.keys()
self.output_shape = _list_to_tuple(node.kwargs["new_shape"])
def get_inverse_impl(self):
inverse_impl = deepcopy(self)
inverse_impl.output_shape = self.input_shape
inverse_impl.input_shape = self.output_shape
return inverse_impl
def shape_propagation(self, input_node_meta):
self.input_shape = input_node_meta.tensor.shape
return _list_to_tuple(self.output_shape)
def broadcast(self, shape, node_meta: NodeBase):
"""
Broadcast the inputs based on current shape.
"""
# Step 1: infer split
flatten_split_shape = self.infer_split(flatten(self.input_shape), flatten(self.output_shape))
split_input_shape = self.infer_merge(flatten_split_shape, self.input_shape)
split_output_shape = self.infer_merge(flatten_split_shape, self.output_shape)
# broadcast shape -> split_output_shape -> flatten_split_shape
if len(shape) - len(split_output_shape) > 0:
for _ in range(len(shape) - len(split_output_shape)):
split_output_shape = [1,] + split_output_shape
flatten_split_shape = [1,] + flatten_split_shape
split_input_shape = [1,] + split_input_shape
broadcast_factor = []
for dim, old_dim in zip(shape, split_output_shape):
if not isinstance(dim, list):
dim = [dim,]
if not isinstance(old_dim, list):
old_dim = [old_dim,]
if product(tuple(dim)) == product(tuple(old_dim)):
broadcast_factor += [1] * len(old_dim)
elif product(tuple(old_dim)) == 1:
assert len(dim) == 1
broadcast_factor.append(dim[0])
else:
raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {dim}")
# flatten_split_shape -> split_input_shape
factor_idx = 0
broadcast_split_input_shape = []
for dim in split_input_shape:
if isinstance(dim, list):
new_dim = []
for d in dim:
new_dim.append(d * broadcast_factor[factor_idx])
factor_idx += 1
broadcast_split_input_shape.append(new_dim)
else:
broadcast_split_input_shape.append(dim * broadcast_factor[factor_idx])
factor_idx += 1
broadcast_split_input_shape = _list_to_tuple(broadcast_split_input_shape)
node_meta.tensor.reshape(_list_to_tuple(split_input_shape))
node_meta.tensor.broadcast(broadcast_split_input_shape)
# Last reshape op to clean up
broadcast_input_shape = tuple([product(dim) for dim in broadcast_split_input_shape])
node_meta.tensor.reshape(broadcast_input_shape)
# Update the input shape and output shape
self.input_shape = _list_to_tuple(node_meta.tensor.shape)
self.output_shape = _list_to_tuple(shape)
def apply_to_user(self, user_meta: NodeBase):
"""
Propagate the reshape to user nodes
"""
user_meta.tensor.reshape(tuple(self.input_shape))
if hasattr(user_meta, "store_tensor"):
if user_meta.store_tensor is not None:
user_meta.store_tensor.reshape(tuple(self.input_shape))
def apply_to_input(self, input_meta: NodeBase):
"""
Propagate the reshape to input nodes
"""
input_meta.tensor.reshape(tuple(self.output_shape))
if hasattr(input_meta, "store_tensor"):
if input_meta.store_tensor is not None:
input_meta.store_tensor.reshape(tuple(self.output_shape))
#
# Helper functions
#
def infer_split(self, input_shape, output_shape):
"""
Infer the flatten splitted shape that can be merged to both input_shape and output_shape
"""
input_shape = _tuple_to_list(input_shape)
output_shape = _tuple_to_list(output_shape)
if len(input_shape) == 0 and len(output_shape) == 0:
return []
if len(input_shape) == 0:
if product(tuple(output_shape)) != 1:
raise ValueError("Invalid reshape size")
else:
return output_shape
if len(output_shape) == 0:
if product(tuple(input_shape)) != 1:
raise ValueError("Invalid reshape size")
else:
return input_shape
# This is done recursively by only process the last dimension at each time
old_dim = input_shape[-1]
new_dim = output_shape[-1]
# Exact match
if old_dim == new_dim:
return self.infer_split(input_shape[:-1], output_shape[:-1]) + [new_dim,]
# Needs split
if old_dim > new_dim and old_dim % new_dim == 0:
residual = old_dim // new_dim
return self.infer_split(input_shape[:-1] + [residual,], output_shape[:-1]) + [new_dim,]
# Needs merge
if old_dim < new_dim and new_dim % old_dim == 0:
residual = new_dim // old_dim
return self.infer_split(input_shape[:-1], output_shape[:-1] + [residual,]) + [old_dim,]
raise NotImplementedError(f"Unsupported split: {input_shape} -> {output_shape}")
def infer_merge(self, flatten_shape, shape):
flatten_shape = _tuple_to_list(flatten_shape)
shape = _tuple_to_list(shape)
idx_flat = len(flatten_shape) - 1
merged_shape = []
for dim in reversed(shape):
# Exact match
if dim == flatten_shape[idx_flat]:
merged_shape.append(dim)
idx_flat -= 1
# need group
elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0:
residual = dim
group = []
while(residual > 1):
group.append(flatten_shape[idx_flat])
residual = residual // flatten_shape[idx_flat]
idx_flat -= 1
merged_shape.append(group[::-1])
else:
raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}")
return merged_shape[::-1]
class LayoutNode(NodeBase):
"""
Layout manipulation nodes
"""
fn_to_impl = {
"permute": PermutationImpl,
"reshape": ReshapeImpl
}
def __init__(self, name: str, fn, kwargs: dict) -> None:
super().__init__(name)
self.op = "layout"
self.fn = fn
self.kwargs = kwargs
self.underlying_impl = self.fn_to_impl[self.fn.__name__](self)
def get_inverse_node(self):
inverse_node = deepcopy(self)
inverse_node.underlying_impl = self.underlying_impl.get_inverse_impl()
return inverse_node
def shape_propagation(self, input_node_metas):
if self._tensor is not None:
return
assert len(input_node_metas) == 1, "Layout node can only have one input node"
output_shape = self.underlying_impl.shape_propagation(input_node_metas[0])
self._tensor = Tensor(
element=self.element_output,
shape=output_shape, layout_tag=LayoutType.RowMajor
)
return super().shape_propagation(input_node_metas)
def type_propagation(self, input_node_metas: 'list[NodeBase]'):
"""
The store nodes has element_output = element_input
"""
assert len(input_node_metas) == 1, "Layout node can only have one input node"
self.element_output = input_node_metas[0].element_output
def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
"""
Propagate the broadcast in the reversed topological order
"""
if self.tensor is None:
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
shape = self.tensor.shape
for child in input_node_metas:
self.underlying_impl.broadcast(shape, child)
def apply_to_user(self, usr_meta: NodeBase):
"""
Propagate the permutation to user nodes
"""
self.underlying_impl.apply_to_user(usr_meta)
def apply_to_input(self, input_meta: NodeBase):
"""
Propagate the permutation to input nodes
"""
self.underlying_impl.apply_to_input(input_meta)

View File

@ -0,0 +1,294 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Load nodes and implementations
"""
import ctypes
from cutlass_cppgen.backend.c_types import tuple_factory
from cutlass_cppgen.backend.epilogue import dtype2ctype, to_ctype_value
from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase
class LoadImplBase(ImplBase):
"""
Base class for load node implementations
"""
reserved_names = ["accum", "C"]
def __init__(self, node) -> None:
super().__init__(node)
self.element = node.element
self.element_output = node.element_output
self.stride = node.tensor.stride
class AccumulatorImpl(LoadImplBase):
"""
Accumulator node implementation
"""
@staticmethod
def match(node, problem_size: tuple):
return node.name == "accum" and node.tensor.shape == problem_size
class LoadSrcImpl(LoadImplBase):
"""
Load C implementation
"""
@property
def name_camel(self) -> str:
return "TensorC"
@property
def argument_type_c(self):
stride_mnl = self.get_stride_mnl()
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
class _Argument(ctypes.Structure):
_fields_ = [
("ptr_C", ctypes.c_void_p),
("stride_C", tuple_type)
]
def __init__(self, ptr) -> None:
self.ptr_C = ptr
self.stride_C = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
return node.name == "C" and node.tensor.shape == problem_size
class AuxLoadImpl(LoadImplBase):
"""
Load arbitrary tensor
"""
@property
def argument_type(self):
stride_mnl = self.get_stride_mnl()
name = self.name
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
element_type = self.element
class _Argument(ctypes.Structure):
_fields_ = [
("ptr_aux", ctypes.c_void_p),
("null_default", dtype2ctype[element_type]),
("dAux", tuple_type)
]
def __init__(self, kwargs) -> None:
ptr = kwargs[name]
self.ptr_aux = ptr
self.null_default = to_ctype_value(0, element_type)
self.dAux = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
if node.name in LoadImplBase.reserved_names:
return False
strideMN = node.tensor.stride[-2:]
if (strideMN[0] == 1 and strideMN[1] != 0 or
strideMN[0] != 0 and strideMN[1] == 1 ):
return True
else:
return False
class RowBroadcastImpl(LoadImplBase):
"""
Broadcast a row vector
"""
def __init__(self, node) -> None:
super().__init__(node)
self.stride_dtype = "int"
@property
def argument_type(self):
stride_mnl = self.get_stride_mnl()
name = self.name
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
element_type = self.element
class _Argument(ctypes.Structure):
_fields_ = [
("ptr_row", ctypes.c_void_p),
("null_default", dtype2ctype[element_type]),
("dRow", tuple_type)
]
def __init__(self, kwargs) -> None:
ptr = kwargs[name]
self.ptr_row = ptr
self.null_default = to_ctype_value(0, element_type)
self.dRow = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
if node.name in LoadImplBase.reserved_names:
return False
strideMN = node.tensor.stride[-2:]
if strideMN == (0, 1):
return True
else:
return False
class ColumnBroadcastImpl(LoadImplBase):
"""
Broadcast a column vector
"""
def __init__(self, node) -> None:
super().__init__(node)
self.stride_dtype = "int"
@property
def argument_type(self):
stride_mnl = self.get_stride_mnl()
name = self.name
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
element_type = self.element
class _Argument(ctypes.Structure):
_fields_ = [
("ptr_col", ctypes.c_void_p),
("null_default", dtype2ctype[element_type]),
("dCol", tuple_type)
]
def __init__(self, kwargs) -> None:
ptr = kwargs[name]
self.ptr_col = int(ptr)
self.null_default = to_ctype_value(0, element_type)
self.dCol = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
if node.name in LoadImplBase.reserved_names:
return False
strideMN = node.tensor.stride[-2:]
if strideMN == (1, 0):
return True
else:
return False
class ScalarBroadcastImpl(LoadImplBase):
"""
Broadcast a scalar
"""
def __init__(self, node) -> None:
super().__init__(node)
self.stride_dtype = "int"
@property
def argument_type(self):
stride_mnl = self.get_stride_mnl()
name = self.name
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
element_type = self.element
if self.tensor.is_constant:
value = self.tensor.value
class _Argument(ctypes.Structure):
_fields_ = [
("scalars", dtype2ctype[element_type]),
("scalar_ptrs", ctypes.c_void_p),
("dScalar", tuple_type)
]
def __init__(self, kwargs) -> None:
self.scalars = to_ctype_value(value, element_type)
self.scalar_ptrs = 0
self.dScalar = tuple_type(stride_mnl)
else:
class _Argument(ctypes.Structure):
_fields_ = [
("scalars", dtype2ctype[element_type]),
("scalar_ptrs", ctypes.c_void_p),
("dScalar", tuple_type)
]
def __init__(self, kwargs) -> None:
scalar_or_ptr = kwargs[name]
if isinstance(scalar_or_ptr, float):
self.scalars = to_ctype_value(scalar_or_ptr, element_type)
self.scalar_ptrs = 0
else:
self.scalar_ptrs = int(scalar_or_ptr)
self.dScalar = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
if node.name in LoadImplBase.reserved_names:
return False
strideMN = node.tensor.stride[-2:]
if strideMN == (0, 0):
return True
else:
return False
class LoadNode(NodeBase):
"""
Load Node
"""
cnt = 0
possible_impls = [
AccumulatorImpl, LoadSrcImpl, AuxLoadImpl,
RowBroadcastImpl, ColumnBroadcastImpl,
ScalarBroadcastImpl
]
def __init__(self, name: str) -> None:
if name is None:
name = f"load{LoadNode.cnt}"
LoadNode.cnt += 1
super().__init__(name)
self.op = "load"
def type_propagation(self, *args, **kwargs):
"""
Load node loads tensor under type `tensor.element` and returns an array of type `tensor.element`.
"""
if self.tensor is None:
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
self.element = self.tensor.element
self.element_output = self.tensor.element

View File

@ -0,0 +1,306 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Base & visitor classes of DAGIR Nodes
"""
import ctypes
from re import sub
from cutlass_library import LayoutType
from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _reverse_tuple
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
class TupleEmitter:
"""
Emit the cute tuple to C++ code
"""
def __init__(self, stride_dtype):
self.stride_dtype = stride_dtype
def emit(self, py_tuple):
if isinstance(py_tuple, int):
if py_tuple in [0, 1]:
return f"cute::Int<{py_tuple}>"
else:
return f"{self.stride_dtype}"
elif isinstance(py_tuple, tuple):
decl = "cute::Stride<"
for item in py_tuple:
decl += self.emit(item) + ", "
return decl[:-2] + ">"
else:
raise ValueError(f"TupleEmitter.emit only accepts tuple or int, got {type(py_tuple).__name__}")
class ImplBase:
"""
Base class for Node Implementation
"""
def __init__(self, node) -> None:
self.node = node
self.name = node.name
self.tensor = node.tensor
self._type_decl = None
self.tuple_emitter = TupleEmitter("int64_t")
@property
def stride_dtype(self):
return self.tuple_emitter.stride_dtype
@stride_dtype.setter
def stride_dtype(self, stride_dtype):
self.tuple_emitter.stride_dtype = stride_dtype
@staticmethod
def match(node, problem_size: tuple):
"""
Match function used in get_underlying_impl
"""
raise NotImplementedError(f"The `match` function is not defined.")
@property
def argument_type(self):
"""
Default class for Argument Type
"""
class _Argument(ctypes.Structure):
_fields_ = []
def __init__(self, *args, **kwargs) -> None:
pass
return _Argument
@property
def name_camel(self) -> str:
"""
Return the CamelCase name.
"""
return sub(r"(_|-)+", " ", self.name).title().replace(" ", "")
@property
def stride_mnl(self):
"""
Typename StrideMNL
"""
stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2]))))
return self.tuple_emitter.emit(stride)
def get_non_constant_stride(self, py_tuple):
if isinstance(py_tuple, int):
if py_tuple not in [0, 1]:
return py_tuple
else:
return None
non_constant_stride = []
for item in py_tuple:
item_out = self.get_non_constant_stride(item)
if item_out:
non_constant_stride.append(item_out)
return tuple(non_constant_stride)
def get_stride_mnl(self):
"""
Get the non-zero stride mnl. This is used in argument construction
"""
stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2]))))
return stride
def get_smem_size(self, *args, **kwargs):
"""
Get the shared memory size and alignment of current node
"""
return (0, 1)
class NoOpImpl(ImplBase):
"""
The NoOpImpl does nothing but forward its input to users
"""
def __init__(self, node) -> None:
super().__init__(node)
@staticmethod
def match(node, problem_size: tuple):
if node.op == "store":
# Store that is not output is a No OP
return not node.is_output
class NodeBase:
"""
Base class of DAG Node
"""
def __init__(self, name: str) -> None:
self.name = name
self.underlying_impl = None
self._tensor = None
# Whether the node is disabled for emit
self.disabled = False
@property
def name_camel(self) -> str:
"""
Return the CamelCase name.
"""
return self.underlying_impl.name_camel
@property
def tensor(self) -> Tensor:
"""
Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor)
"""
return self._tensor
@tensor.setter
def tensor(self, kwargs):
"""
Setting the tensor
"""
self._tensor = Tensor(**kwargs)
#
# Helper functions for type/shape propagation
#
def shape_propagation(self, input_node_metas):
"""
Infer shape from input nodes
General Broadcasting Rules from NumPy
When operating on two arrays, we compare their shapes element-wise.
It starts with the trailing (i.e. rightmost) dimension and works its
way left. Two dimensions are compatible when
1. they are equal
2. one of them is 1
"""
if self._tensor is not None:
return
shape = None
for src in input_node_metas:
src_shape = src.tensor.shape
if shape is None:
shape = src_shape
else:
len_difference = len(shape) - len(src_shape)
if len_difference > 0:
for _ in range(len_difference):
src_shape = [1, ] + list(src_shape)
elif len_difference < 0:
for _ in range(-len_difference):
shape = [1, ] + list(shape)
broadcasted_shape = []
# Infer broadcast shape
for shape_dim, src_dim in zip(reversed(shape), reversed(src_shape)):
if shape_dim == 1:
broadcasted_shape = [src_dim, ] + list(broadcasted_shape)
elif src_dim == 1:
broadcasted_shape = [shape_dim, ] + list(broadcasted_shape)
elif shape_dim == src_dim:
broadcasted_shape = [shape_dim, ] + list(broadcasted_shape)
else:
error_msg = "Dimension mismatch between "
for src_ in input_node_metas:
error_msg += f"{src_.name}{src_.tensor.shape}, "
error_msg = error_msg[:-2] + "."
raise RuntimeError(error_msg)
shape = tuple(broadcasted_shape)
self._tensor = Tensor(element=self.element_output, shape=shape, layout_tag=LayoutType.RowMajor)
def type_propagation(self, *args, **kwargs):
"""
Each node is associated with two data types: `element` and `element_output`.
The `element_output` is the type of return array of the node. The `element`
has specific meaning for different node types.
* Load Node: data type of tensor in gmem
* Compute Node: element compute
* Store Node: data type of tensor in gmem
This function must be overloaded in the derived classes
"""
raise NotImplementedError(f"Function `type_propagation` is not overloaded in {self.__class__.__name__}")
def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
"""
Propagate the broadcast in the reversed topological order.
For example:
C[l, m, n] = A[m, 1] + B[l, m, n]
After the broadcast propagation, it will be come
C[l, m, n] = A[l, m, n] + B[l, m, n]
and each tensor will have a proper stride accessing the underlying tensor
"""
if self.tensor is None:
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
for child in input_node_metas:
child.tensor.broadcast(self.tensor.shape)
def get_underlying_impl(self, problem_size: tuple):
"""
Get the underlying implementation of the current node.
"""
if self.tensor is None:
raise RuntimeError(f"The Layout of node {self.name} is unknown. Please call PassShapeTypePropagation first.")
for impl in self.possible_impls:
if impl.match(self, problem_size):
self.underlying_impl = impl(self)
break
if self.underlying_impl is None:
raise NotImplementedError(f"No matching op for node {self.name} with stride {self.tensor.stride}.")
#
# Visitor Nodes & Impls
#
class TopoVisitorImpl(ImplBase):
"""
Impl for topological visitor
"""
def __init__(self, node) -> None:
super().__init__(node.output_node)
self.name = node.name
self.element_output = node.output_node.element_output
class TopoVisitorNode(NodeBase):
def __init__(self, name: str, subgraph, output_node) -> None:
super().__init__(name)
self.subgraph = subgraph
self.output_node = output_node
self.op = "dag"
self.underlying_impl = TopoVisitorImpl(self)

View File

@ -0,0 +1,277 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Store node and implementations
"""
import ctypes
from cutlass_library import DataType
from cutlass_cppgen.backend.c_types import tuple_factory
from cutlass_cppgen.backend.epilogue import dtype2ctype, to_ctype_value
from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase, NoOpImpl
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
from cutlass_cppgen.backend.library import FloatRoundStyle, FunctionalOp
class StoreImplBase(ImplBase):
"""
Base class for store node implementation
"""
reserved_names = ["D"]
def __init__(self, node) -> None:
super().__init__(node)
self.element = node.element
self.element_output = node.element_output
self.stride = node.store_tensor.stride
class StoreDImpl(StoreImplBase):
"""
Store D implementation
"""
@property
def argument_type_d(self):
stride_mnl = self.get_stride_mnl()
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
class _Argument(ctypes.Structure):
_fields_ = [
("ptr_D", ctypes.c_void_p),
("stride_D", tuple_type)
]
def __init__(self, ptr: int) -> None:
self.ptr_D = ptr
self.stride_D = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
if node.name == "D" and node.store_tensor.shape == problem_size:
return True
return False
class AuxStoreImpl(StoreImplBase):
def __init__(self, node) -> None:
super().__init__(node)
self.round_style = FloatRoundStyle.ToNearest
@property
def argument_type(self):
stride_mnl = self.get_stride_mnl()
name = self.name
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
class _Argument(ctypes.Structure):
_fields_ = [
("ptr_aux", ctypes.c_void_p),
("dAux", tuple_type)
]
def __init__(self, kwargs) -> None:
ptr = kwargs[name]
self.ptr_aux = ptr
self.dAux = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
if not node.is_output:
return False
if node.name in StoreImplBase.reserved_names:
return False
strideMN = node.store_tensor.stride[-2:]
if (strideMN[0] == 1 and strideMN[1] != 0 or
strideMN[0] != 0 and strideMN[1] == 1 ):
return True
else:
return False
class ReductionImplBase(StoreImplBase):
def __init__(self, node) -> None:
super().__init__(node)
self.element = node.store_tensor.element
self.element_compute = node.element_compute
self.reg_reduce_fn = self.node.reg_reduce_fn
self.gmem_reduce_fn = self.node.gmem_reduce_fn
self.round_style = node.round_style
self.stride_dtype = "int"
def get_reduce_identity(self):
"""
Return the reduction identity of the current reduce_fn
"""
maxes = {
DataType.f32: (2 ** 31) - 1,
DataType.f16: (2 ** 15),
DataType.s32: (2 ** 31) - 1,
DataType.s8: (2 ** 7) - 1
}
mins = {
DataType.f32: -maxes[DataType.f32],
DataType.f16: -maxes[DataType.f16],
DataType.s32: -maxes[DataType.s32],
DataType.s8: -maxes[DataType.s8]
}
if self.reg_reduce_fn == FunctionalOp.Maximum:
if self.element_compute not in mins:
raise Exception(f"No min entry for data type {self.element_compute}")
return to_ctype_value(mins[self.element_compute], self.element_compute)
elif self.reg_reduce_fn == FunctionalOp.Multiplies:
return to_ctype_value(1., self.element_compute)
elif self.reg_reduce_fn == FunctionalOp.Minimum:
if self.element_compute not in maxes:
raise Exception(f"No max entry for data type {self.element_compute}")
return to_ctype_value(maxes[self.element_compute], self.element_compute)
else:
return to_ctype_value(0., self.element_compute)
@property
def argument_type(self):
self.get_reduce_identity()
stride_mnl = self.get_stride_mnl()
name = self.name
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
element_compute = self.element_compute
reduce_identity = self.get_reduce_identity()
class _Argument(ctypes.Structure):
_fields_ = [
("ptr", ctypes.c_void_p),
("reduce_identity", dtype2ctype[element_compute]),
("dMNL", tuple_type)
]
def __init__(self, kwargs) -> None:
ptr = kwargs[name]
self.ptr = ptr
self.reduce_identity = reduce_identity
self.dMNL = tuple_type(stride_mnl)
return _Argument
class ColumnReductionImpl(ReductionImplBase):
@staticmethod
def match(node, problem_size: tuple):
if not node.is_output:
return False
if node.name in StoreImplBase.reserved_names:
return False
strideMN = node.store_tensor.stride[-2:]
if strideMN == (1, 0):
return True
else:
return False
class RowReductionImpl(ReductionImplBase):
@staticmethod
def match(node, problem_size: tuple):
if not node.is_output:
return False
if node.name in StoreImplBase.reserved_names:
return False
strideMN = node.store_tensor.stride[-2:]
if strideMN == (0, 1):
return True
else:
return False
class ScalarReductionImpl(ReductionImplBase):
@staticmethod
def match(node, problem_size: tuple):
if not node.is_output:
return False
if node.name in StoreImplBase.reserved_names:
return False
strideMN = node.store_tensor.stride[-2:]
if strideMN == (0, 0):
return True
else:
return False
class StoreNode(NodeBase):
"""
Store node
"""
possible_impls = [
AuxStoreImpl, RowReductionImpl,
ColumnReductionImpl, ScalarReductionImpl,
NoOpImpl, StoreDImpl
]
def __init__(self, name: str) -> None:
super().__init__(name)
self.op = "store"
self.is_output = False
self._store_tensor = None
@property
def store_tensor(self) -> Tensor:
"""
Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor)
"""
return self._store_tensor
@store_tensor.setter
def store_tensor(self, kwargs):
"""
Setting the tensor
"""
self._store_tensor = Tensor(**kwargs)
def type_propagation(self, input_node_metas: 'list[NodeBase]'):
"""
The store nodes has element_output = element_input
"""
if self.is_output:
if self.store_tensor is None:
raise RuntimeError(f"The store tensor of node {self.name} is unknown.")
self.element = self.store_tensor.element
assert len(input_node_metas) == 1, "Store node can only have one input node"
self.element_output = input_node_metas[0].element_output
def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
super().broadcast_propagation(input_node_metas)
if self.is_output:
self._store_tensor.broadcast(self.tensor.shape)

View File

@ -0,0 +1,137 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
High-level class for tensor
"""
from cutlass_library import LayoutType
from cutlass_cppgen.backend.evt.ir.layout_algorithm import (
Layout,
broadcast,
canonicalization,
permutation,
reshape,
_reverse_tuple
)
from cutlass_cppgen.utils.datatypes import get_datatype_and_layout, get_tensor_shape, library_type
class Tensor:
"""
The tensor abstracts the data type
"""
def __init__(self, tensor=None, element=None, shape=None, stride=None,layout_tag=None, is_constant=False) -> None:
if element is not None and tensor is not None:
raise Exception(f"Must not specify both element and tensor")
elif shape is not None and tensor is not None:
raise Exception(f"Must not specify both shape and tensor")
elif layout_tag is not None and tensor is not None:
raise Exception(f"Must not specify both layout_tag and tensor")
elif (element is None or (layout_tag is None and stride is None) or shape is None) and (tensor is None) :
raise Exception(f"Must specify one of (element, shape, layout/stride) or (tensor)")
elif stride is not None and tensor is not None:
raise Exception(f"Must not specify both stride and tensor")
elif stride is not None and layout_tag is not None:
raise Exception(f"Must not specify layout_tag when stride is provided")
if isinstance(tensor, Tensor):
# Directly copy all the attributes
self.__dict__.update(vars(tensor))
else:
if tensor is None:
self.element = library_type(element)
else:
self.element, layout_tag = get_datatype_and_layout(tensor)
shape = get_tensor_shape(tensor)
if stride is not None:
self.layout = Layout(shape[::-1], stride[::-1])
else:
if layout_tag == LayoutType.RowMajor:
self.layout = Layout(shape[::-1])
elif layout_tag == LayoutType.ColumnMajor:
self.layout = permutation(Layout(shape), [idx for idx in reversed(range(len(shape)))])
self.layout = canonicalization(self.layout)
self.is_constant = is_constant
# Save the tensor value if it is constant
if is_constant and tensor is not None:
self.value = tensor
@property
def shape(self):
"""
Returns the RowMajor layout shape
"""
return _reverse_tuple(self.layout.shape)
@property
def stride(self):
"""
Returns the RowMajor layout stride
"""
return _reverse_tuple(self.layout.stride)
@property
def rank(self):
"""
Returns the rank of the tensor
"""
return len(self.shape)
#
# Layout Algorithms
#
def broadcast(self, shape):
"""
Broadcast self.layout to shape
"""
assert isinstance(shape, tuple)
self.layout = broadcast(self.layout, _reverse_tuple(shape))
def reshape(self, shape):
"""
Reshape self.layout to shape
"""
assert isinstance(shape, tuple)
reverse_shape = _reverse_tuple(shape)
self.layout = reshape(self.layout, reverse_shape)
def permute(self, indices):
"""
Permute self.layout according to indices
"""
length = len(indices)
indices = [length - idx - 1 for idx in indices]
self.layout = permutation(self.layout, indices[::-1])

View File

@ -0,0 +1,42 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_cppgen.backend.evt.passes.graph_drawer import EVTGraphDrawer
from cutlass_cppgen.backend.evt.passes.pass_argument_type import PassGetArgumentType
from cutlass_cppgen.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree
from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl
from cutlass_cppgen.backend.evt.passes.pass_fix_element_d import PassFixElementD
from cutlass_cppgen.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassManager
from cutlass_cppgen.backend.evt.passes.pass_preprocess_red import PassPreprocessRed
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
from cutlass_cppgen.backend.evt.passes.smem_size_calculator import GetSmemSize

View File

@ -0,0 +1,143 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from __future__ import annotations
import subprocess
from cutlass_library import DataTypeTag
from cutlass_cppgen.backend.evt.ir.dag_ir import DAGIR
_COLOR_MAP = {
"load": '"AliceBlue"',
"compute": "LemonChiffon1",
"accumulator": "LightGrey",
"store": "PowderBlue",
"layout": "lightseagreen",
"dag": "darkorange"
}
class EVTGraphDrawer:
"""
Visualize a EVT DAGIR with graphviz
"""
def __init__(
self,
graph: DAGIR,
name: str
):
self._name = name
self._dot_graphs = {}
self._dot_graphs[name] = self._to_dot(graph, name)
def _get_node_style(self, node):
template = {
"shape": "record",
"fillcolor": "#CAFFE3",
"style": '"filled,rounded"',
"fontcolor": "#000000",
}
if node.op in _COLOR_MAP:
template["fillcolor"] = _COLOR_MAP[node.op]
else:
raise NotImplementedError("unknown node op")
if node.disabled:
template["fontcolor"] = "grey"
template["fillcolor"] = "white"
return template
def _get_node_label(self, node):
label = "{" + f"name={node.name}|op={node.op}"
if node.op == "layout":
label += f"|fn={node.fn.__name__}"
for key in node.kwargs:
label += f"|{key}={node.kwargs[key]}"
if node.underlying_impl is not None:
label += f"|impl={type(node.underlying_impl).__name__}"
if node.op == "load":
label += f"|element_output={DataTypeTag[node.underlying_impl.element]}"
elif node.op == "compute":
label += f"|element_compute={DataTypeTag[node.underlying_impl.element_compute]}|element_output={DataTypeTag[node.underlying_impl.element_output]}"
elif node.op == "store":
label += f"|element_store={DataTypeTag[node.underlying_impl.element]}|element_output={DataTypeTag[node.underlying_impl.element_output]}"
elif node.op == "dag":
label += f"|element_output={DataTypeTag[node.underlying_impl.element_output]}"
if node.tensor is not None:
shape = node.tensor.shape
stride = node.tensor.stride
label += f"|shape={shape}|stride={stride}"
if hasattr(node, "store_tensor"):
if node.store_tensor is not None:
store_shape = node.store_tensor.shape
store_stride = node.store_tensor.stride
label += f"|store_shape={store_shape}|stride_stride={store_stride}"
label += "}"
return label
def _to_dot(
self,
graph: DAGIR,
name: str
):
import pydot
dot_graph = pydot.Dot(name, randir="TB")
for node in graph.nodes_meta:
style = self._get_node_style(node)
label = self._get_node_label(node)
dot_node = pydot.Node(
node.name, label=label, **style
)
dot_graph.add_node(dot_node)
if node.op == "dag":
dot_subgraph = self._to_dot(node.subgraph, name=node.name)
self._dot_graphs[node.name] = dot_subgraph
# Add edges
for src, dst in graph.edges:
weight = graph.get_edge_weight(src, dst)
dot_graph.add_edge(pydot.Edge(src, dst, label=weight))
return dot_graph
def get_dot_graph(self) -> pydot.Dot:
return [(key, self.get_dot_graph_by_name(key)) for key in self._dot_graphs.keys()]
def get_dot_graph_by_name(self, name) -> pydot.Dot:
return self._dot_graphs[name]
def get_main_dot_graph(self) -> pydot.Dot:
return self._dot_graphs[self._name]

View File

@ -0,0 +1,120 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Construct the epilogue visitor argument type
"""
from cutlass_cppgen.backend.c_types import visitor_factory
from cutlass_cppgen.backend.evt.ir import TopoVisitorNode
from cutlass_cppgen.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree
from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
from cutlass_cppgen.backend.evt.passes.util import cc_map
class PassGetArgumentType(EVTPassBase):
"""
Construct the epilogue visitor argument type
"""
dependencies = [
PassShapeTypePropagation, # The Layout of all nodes must be set
PassDAG2Tree, # The type of each node must be set
PassGetImpl # The DAG subgraphs must be set
]
def requires(self) -> None:
# Check "D" is in the node list
if cc_map[self.cc] in [90, 100] and (not self.dag_ir.has_node("D")):
raise SyntaxError(
"Sm90+ EVT requires the epilogue to have a returned tensor D, "
"but the variable 'D' is not found in the return values.")
def call(self):
nodes = self.dag_ir.nodes_topological_order()
self.argument_types = {}
for node in nodes:
meta = self.dag_ir.get_node_meta(node)
if not meta.disabled:
self.argument_types[node] = meta.underlying_impl.argument_type
if node == "D" and cc_map[self.cc] in [90, 100]:
continue
if isinstance(meta, TopoVisitorNode):
self.get_dag_argument_type(node)
else:
self.get_evt_argument_type(node)
self.cc_specific_method(self.set_argument_type)()
def get_evt_argument_type(self, node):
# Sort the input nodes by edge weight
input_types = [self.argument_types[child] for child in self.dag_ir.get_all_inputs(node)]
if len(input_types) > 0:
self.argument_types[node] = visitor_factory(
input_types + [self.argument_types[node],], self.dag_ir.get_all_inputs(node) + [node,])
def get_dag_argument_type(self, node):
meta = self.dag_ir.get_node_meta(node)
subgraph = meta.subgraph
subgraph_nodes = subgraph.nodes_topological_order()
# Visit the unvisited nodes in subgraph
for n in subgraph_nodes:
m = subgraph.get_node_meta(n)
if m.disabled:
continue
else:
self.argument_types[n] = m.underlying_impl.argument_type
input_types = [self.argument_types[child] for child in subgraph_nodes[:-1]]
if len(input_types) > 0:
self.argument_types[node] = visitor_factory(input_types, subgraph_nodes[:-1])
def set_argument_type(self):
pass
def sm90_set_argument_type(self):
self.dag_ir.epilogue_thread_type = self.argument_types[self.dag_ir.get_all_inputs("D")[0]]
# Get the tensorD argument type
self.dag_ir.arg_d_type = self.dag_ir.get_node_meta("D").underlying_impl.argument_type_d
# Get the tensorC argument type
if self.dag_ir.has_node("C"):
self.dag_ir.arg_c_type = self.dag_ir.get_node_meta("C").underlying_impl.argument_type_c
else:
self.dag_ir.arg_c_type = self.dag_ir.arg_d_type
def sm100_set_argument_type(self):
self.sm90_set_argument_type()
def sm80_set_argument_type(self):
nodes = self.dag_ir.nodes_topological_order()
self.dag_ir.epilogue_thread_type = self.argument_types[nodes[-1]]

View File

@ -0,0 +1,169 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Merge non-tree sub-graphs of the DAG IR into a single DAG. The fused DAG will be implemented
by the topological visitor, while the rest of the graph will be implemented with the tree visitor.
"""
from copy import deepcopy
from cutlass_cppgen.backend.evt.ir import DAGIR, TopoVisitorNode
from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
class PassDAG2Tree(EVTPassBase):
"""
Convert the DAG IR to Tree by fusing subgraphs
"""
dependencies = [
PassShapeTypePropagation,
PassGetImpl
]
def call(self):
# Step 1: find the nodes that have multiple parents
multi_parent_nodes = []
for node in self.dag_ir.nodes_topological_order():
if self.dag_ir.out_degree(node) > 1:
multi_parent_nodes.append(node)
# Step 2: find the lowest common ancestor (LCA) of all its parents
for node in multi_parent_nodes:
# A multi-parent node could be already fused by the previous node
if not self.dag_ir.has_node(node):
continue
# A node uncovered by the previous fusions can have out degree change
# Case 1: it has <= 1 edges to the previously fused subgraph, no degree change
# Case 2: it has more than one edges to the previously fused subgraph, degree drops
if self.dag_ir.out_degree(node) <= 1:
continue
# Otherwise, the node still
reachable_nodes = []
# Complexity: O(Dout*N)
for parent in self.dag_ir.get_users(node):
reachable_nodes.append(set(self.dag_ir.all_reachable_nodes(parent)))
# get the common reachable objects
common_items = set.intersection(*reachable_nodes)
node_to_fuse = set.union(*reachable_nodes).difference(common_items)
lca = None
# If common ancestor exists, find the lowest one
if len(common_items) > 0:
topo_order = self.dag_ir.nodes_topological_order()
topo_idx = -1
for item in common_items:
if lca is None:
lca = item
topo_idx = topo_order.index(item)
else:
if topo_idx > topo_order.index(item):
lca = item
topo_idx = topo_order.index(item)
else:
# there is no common ancestor for all the parents, we pack all the reachable
# nodes into a single DAG node as a fallback. The lca should be the input node of
# one of the output nodes with out_degree = 0
potential_output_nodes = []
for node in node_to_fuse:
if self.dag_ir.out_degree(node) == 0:
potential_output_nodes.append(node)
if len(potential_output_nodes) == 0:
raise RuntimeError(f"No output node with out degree = 0 found.")
output_node = None
if (self.dag_ir.cc >= 90):
# For SM90+, the lca should be the input node of D
if (not self.dag_ir.has_node("D")):
raise RuntimeError(f"D is not a node in the DAG IR.")
output_node = "D"
else:
output_node = potential_output_nodes[0]
if (output_node is None):
raise RuntimeError(f"No output node found.")
lca = self.dag_ir.get_all_inputs(output_node)[0]
node_to_fuse.remove(output_node)
# The lca is the output node of the DAG node
# Get the nodes to be fused
node_to_fuse.add(lca)
# Get all the input nodes
all_input_nodes = []
all_output_nodes = []
for node in node_to_fuse:
all_input_nodes.append(set(self.dag_ir.get_all_inputs(node)))
all_output_nodes.append(set(self.dag_ir.get_users(node)))
all_input_nodes = set.union(*all_input_nodes)
all_output_nodes = set.union(*all_output_nodes)
new_subgraph_nodes = set.union(node_to_fuse, all_input_nodes, all_output_nodes)
# Create the subgraph
subgraph_ = self.dag_ir._graph.subgraph(new_subgraph_nodes)
subgraph = DAGIR(self.dag_ir.cc)
for node in subgraph_.nodes:
meta = deepcopy(self.dag_ir.get_node_meta(node))
if node not in node_to_fuse:
meta.disabled = True
subgraph.add_node(meta)
for edge in subgraph_.edges:
subgraph.add_edge(edge[0], edge[1], self.dag_ir.get_edge_weight(edge[0], edge[1]))
# Create the fused node
dag_node = TopoVisitorNode(
name=f"dag_{lca}", subgraph=subgraph,
output_node=self.dag_ir.get_node_meta(lca))
self.dag_ir.add_node(dag_node)
# Add input edges
for idx, node in enumerate(all_input_nodes):
self.dag_ir.add_edge(node, dag_node.name, weight=idx)
# Replace all uses with DAG node (only 1 output node)
self.dag_ir.replace_all_uses_with(lca, dag_node.name)
# Remove all fused nodes
node_to_fuse.remove(lca)
for node in node_to_fuse:
self.dag_ir.remove_node(node)
def ensures(self) -> None:
# Ensure that after the pass, the resulting DAG becomes a tree
for node in self.dag_ir.nodes:
out_degree = self.dag_ir.out_degree(node)
if out_degree > 1:
raise RuntimeError(f"PassDAG2Tree failed. Node {node} still have outdegree = {out_degree}")

View File

@ -0,0 +1,64 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Fix the element_output of producer of D.
In Sm90 epilogue visitor, the node writing D to gmem does not have internal
element converter, so the compute node producing D must have element_output = type(D).
"""
from cutlass_cppgen.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
class PassFixElementD(EVTPassBase):
"""
In Sm90 epilogue visitor, the node writing D to gmem does not have internal
element converter, so the compute node producing D must have
element_output = type(D)
"""
dependencies = [
PassLayoutManipulateElimination
]
def get_producer(self, node, element_D):
node_meta = self.dag_ir.get_node_meta(node)
if node_meta.op == "compute":
node_meta.element_output = element_D
elif node_meta.op == "store":
self.get_producer(self.dag_ir.get_all_inputs(node)[0], element_D)
def call(self):
if self.dag_ir.has_node("D"):
node_d_meta = self.dag_ir.get_node_meta("D")
element_D = node_d_meta.store_tensor.element
self.get_producer("D", element_D)

View File

@ -0,0 +1,90 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Infer the underlying implement of each node.
While the frontend only distinguish between Load/Store/Compute Node,
each of these nodes can have different underlying implementation based
on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadcast, etc.
This pass infers the underlying impl of each node
"""
import cutlass_cppgen.backend.evt.backend as evt_backend
from cutlass_cppgen.backend.evt.ir import DAGIR, LoadNode
from cutlass_cppgen.backend.evt.passes.pass_fix_element_d import PassFixElementD
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
from cutlass_cppgen.backend.evt.passes.pass_no_op_elimination import PassNoOpElimination
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
from cutlass_cppgen.backend.evt.passes.util import cc_map
class PassGetImpl(EVTPassBase):
"""
While the frontend only distinguish between Load/Store/Compute Node,
each of these nodes can have different underlying implementation based
on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadcast, etc.
This pass infers the underlying impl of each node
"""
dependencies = [
PassShapeTypePropagation, # The shape and type info are required for inference
PassFixElementD
]
def __init__(self, dag_ir: DAGIR) -> None:
super().__init__(dag_ir)
self.no_op_elimination = PassNoOpElimination(dag_ir)
def requires(self) -> None:
# Verify "accum" is in the arg list
if not self.dag_ir.has_node("accum"):
raise SyntaxError("Cannot find 'accum' in the argument list.")
def call(self):
# The loop structure of the epilogue is determined by the
# accumulator shape
accumulator: LoadNode = self.dag_ir.get_node_meta("accum")
problem_size = accumulator.tensor.shape
for node_meta in self.dag_ir.node_metas_topological_order():
node_meta.get_underlying_impl(problem_size)
def ensures(self) -> None:
# Some nodes will be lowered to NoOp, eliminate them
self.no_op_elimination()
# Lower to cc-specific impl
for node_meta in self.dag_ir.nodes_meta:
node_impl_ccs = getattr(evt_backend, f"sm{cc_map[self.cc]}_nodes")
node_meta.underlying_impl = getattr(
node_impl_ccs,
f"Sm{cc_map[self.cc]}" + node_meta.underlying_impl.__class__.__name__
)(node_meta)

View File

@ -0,0 +1,217 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Eliminate layout manipulation nodes
"""
from copy import deepcopy
from cutlass_cppgen.backend.evt.ir import DAGIR, LayoutNode
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
class PassLayoutManipulateElimination(EVTPassBase):
"""
Eliminate layout manipulation nodes
"""
dependencies = [PassShapeTypePropagation]
def __init__(self, dag_ir: DAGIR) -> None:
super().__init__(dag_ir)
self.copy_cnt = 0
def call(self):
self.layout_nodes_worklist = self.get_all_layout_nodes()
# Run while loop utill all layout nodes are eliminated
while(len(self.layout_nodes_worklist) > 0):
node = self.layout_nodes_worklist.pop(0)
# for node in layout_nodes:
# Step 1: get the propagation direction
direction = self.get_propagation_direction(node)
self.visited = []
getattr(self, f"propagate_to_{direction}")(self.dag_ir.get_node_meta(node), node)
# Eliminate the current node
input_node = self.dag_ir.get_all_inputs(node)[0]
self.dag_ir.replace_all_uses_with(node, input_node)
# layout_nodes = self.get_all_layout_nodes()
def get_all_layout_nodes(self):
layout_nodes = []
for node_meta in reversed(self.dag_ir.node_metas_topological_order()):
if isinstance(node_meta, LayoutNode):
layout_nodes.append(node_meta.name)
return layout_nodes
def get_propagation_direction(self, node: str):
"""
The logic is propagating all layout nodes away from the accumulator node.
"""
self.visited = []
self.get_influenced_users(node)
nodes_influenced_dir_users = self.visited
self.visited = []
self.get_influenced_inputs(node)
nodes_influenced_dir_inputs = self.visited
if "accum" in nodes_influenced_dir_users and "accum" not in nodes_influenced_dir_inputs:
return "inputs"
elif "accum" not in nodes_influenced_dir_users and "accum" in nodes_influenced_dir_inputs:
return "users"
else:
raise RuntimeError("Unsolved propagation direction")
# Get all influenced nodes if we propagate along the user direction
def get_influenced_users(self, node: str):
if node in self.visited:
return
self.visited.append(node)
users = self.dag_ir.get_users(node)
for user in users:
self.get_influenced_users(user)
user_inputs = []
for user in users:
user_inputs.append(set(self.dag_ir.get_all_inputs(user)))
if len(user_inputs) > 0:
user_inputs = set.union(*user_inputs)
user_inputs.remove(node)
for input in user_inputs:
self.get_influenced_inputs(input)
# Get all influenced nodes if we propagate along the input direction
def get_influenced_inputs(self, node: str):
if node in self.visited:
return
self.visited.append(node)
inputs = self.dag_ir.get_all_inputs(node)
for input in inputs:
self.get_influenced_inputs(input)
input_users = []
for input in inputs:
input_users.append(set(self.dag_ir.get_users(input)))
if len(input_users) > 0:
input_users = set.union(*input_users)
input_users.remove(node)
for user in input_users:
self.get_influenced_users(user)
def add_copy_before(self, layout_node_meta: LayoutNode, target: str):
copied_node_meta = deepcopy(layout_node_meta)
copied_node = f"{copied_node_meta.name}_copy{self.copy_cnt}"
self.copy_cnt += 1
copied_node_meta.name = copied_node
self.dag_ir.add_node(copied_node_meta)
# Add edges
target_inputs = self.dag_ir.get_all_inputs(target)
for src in target_inputs:
self.dag_ir.remove_edge(src, target)
self.dag_ir.add_edge(src, copied_node)
self.dag_ir.add_edge(copied_node, target)
self.layout_nodes_worklist.append(copied_node)
def add_copy_after(self, layout_node_meta: LayoutNode, target: str):
copied_node_meta = deepcopy(layout_node_meta)
copied_node = f"{copied_node_meta.name}_copy{self.copy_cnt}"
self.copy_cnt += 1
copied_node_meta.name = copied_node
self.dag_ir.add_node(copied_node_meta)
# Add edges
users = self.dag_ir.get_users(target)
for user in users:
self.dag_ir.remove_edge(target, user)
self.dag_ir.add_edge(copied_node, user)
self.dag_ir.add_edge(target, copied_node)
self.layout_nodes_worklist.append(copied_node)
# Propagate the layout `node` along the user direction
def propagate_to_users(self, layout_node_meta: LayoutNode, node: str):
"""
Propagate layout node to users
"""
if node in self.visited:
# Avoid applying twice
return
self.visited.append(node)
node_meta = self.dag_ir.get_node_meta(node)
if layout_node_meta.name != node:
if isinstance(node_meta, LayoutNode):
# Layout node is not transparent with layout node
self.add_copy_before(layout_node_meta, node)
return
else:
layout_node_meta.apply_to_user(node_meta)
users = self.dag_ir.get_users(node)
user_inputs = []
for user in users:
user_inputs.append(set(self.dag_ir.get_all_inputs(user)))
for user in users:
self.propagate_to_users(layout_node_meta, user)
if len(user_inputs) > 0:
user_inputs = set.union(*user_inputs)
user_inputs.remove(node)
for input in user_inputs:
self.propagate_to_inputs(layout_node_meta.get_inverse_node(), input)
# Propagate the layout `node` along the input direction
def propagate_to_inputs(self, layout_node_meta: LayoutNode, node: str):
"""
Propagate layout node to inputs
"""
if node in self.visited:
# Avoid applying twice
return
self.visited.append(node)
node_meta = self.dag_ir.get_node_meta(node)
if layout_node_meta.name != node:
if isinstance(node_meta, LayoutNode):
# Layout node is not transparent with layout node
self.add_copy_after(layout_node_meta, node)
return
else:
layout_node_meta.apply_to_input(node_meta)
inputs = self.dag_ir.get_all_inputs(node)
input_users = []
for input in inputs:
input_users.append(set(self.dag_ir.get_users(input)))
for input in inputs:
self.propagate_to_inputs(layout_node_meta, input)
if len(input_users) > 0:
input_users = set.union(*input_users)
input_users.remove(node)
for user in input_users:
self.propagate_to_users(layout_node_meta.get_inverse_node(), user)

View File

@ -0,0 +1,164 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Pass manager for DAG IR.
"""
from typing import Any
import networkx as nx
from cutlass_cppgen.backend.evt.ir import DAGIR
from cutlass_cppgen.backend.evt.passes.util import cc_map
class EVTPassBase:
"""
Base class for EVT Passes
"""
dependencies = []
def __init__(self, dag_ir: DAGIR) -> None:
self.dag_ir = dag_ir
self.cc = self.dag_ir.cc
def requires(self) -> None:
"""
This function will be called before the pass is run.
"""
pass
def call(self) -> None:
"""
The pass that is run through the self.dag_ir
"""
raise NotImplementedError(
f"__call__ is not overwritten in Pass {self.__class__.__name__}")
def ensures(self) -> None:
"""
This function will be called after the pass is run.
"""
pass
def __call__(self) -> Any:
self.requires()
self.call()
self.ensures()
def cc_specific_method(self, func):
"""
This enables defining function that behaves differently under different cc
The simplest example of using this function is the following
.. highlight:: python
.. code-block:: python
class ExamplePass(EVTPassBase):
def call(sekf):
# This automatically select the smXX_func based on current cc
self.cc_specific_method(self.func)()
# Interface func, can be empty
def func(self):
pass
# Sm90 specific func
def sm90_func(self):
// sm90 specific method
return
# Sm80 specific func
def sm80_func(self):
// sm80 specific method
return
"""
func_name = f"sm{cc_map[self.cc]}_{func.__name__}"
if hasattr(self, func_name):
return getattr(self, func_name)
else:
raise NotImplementedError(f"func {func.__name__} is not overwritten for Sm{self.cc}")
class EVTPassManager(nx.DiGraph):
"""
Topological-based Pass Manager.
Each registered pass has a list of dependencies. The pass manager organizes
the passes as a DAG and launch the compiler passes under topological order.
"""
def __init__(self, dag_ir: DAGIR, pass_list):
super().__init__()
self.dag_ir = dag_ir
for pass_cls in pass_list:
self.add_pass(pass_cls)
self.sorted_passes = self.schedule()
def get_callable(self, pass_name):
"""
Return the callable of the pass
"""
return self.nodes[pass_name]["callable"]
def add_pass(self, pass_cls):
"""
Add a pass to the pass manager
:param pass_cls: the class of pass
:type pass_cls: derived class of EVTPassBase
"""
name = pass_cls.__name__
pass_callable = pass_cls(self.dag_ir)
self.add_node(name, callable=pass_callable)
def schedule(self):
"""
Schedule the added passes under topological order
"""
# Add edges
for pass_name in self.nodes:
callable = self.get_callable(pass_name)
for dependency_cls in callable.dependencies:
self.add_edge(
dependency_cls.__name__,
type(callable).__name__)
# Topological sort
return list(nx.topological_sort(self))
def __call__(self) -> Any:
"""
Launch the registered passes
"""
for pass_name in self.sorted_passes:
callable = self.get_callable(pass_name)
callable()

View File

@ -0,0 +1,53 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
No op elimination node
"""
from typing import Any
from cutlass_cppgen.backend.evt.ir import NoOpImpl
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
class PassNoOpElimination(EVTPassBase):
"""
The dead node elimination pass removes nodes with NoOpImpl in DAG IR
"""
dependencies = []
def call(self) -> Any:
for node in self.dag_ir.nodes_topological_order():
node_meta = self.dag_ir.get_node_meta(node)
if isinstance(node_meta.underlying_impl, NoOpImpl):
self.dag_ir.replace_all_uses_with(node, self.dag_ir.get_all_inputs(node)[0])

View File

@ -0,0 +1,97 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Preprocess the reduction nodes.
The parser treats reduction as Compute(op=(reg_reduce_fn, gmem_reduce_fn)) - Store()
This pass fuses these into a single store node, and then replaces all uses of the
current node with the new store node.
"""
from cutlass_cppgen.backend.evt.ir import ComputeNode, StoreNode
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
class PassPreprocessRed(EVTPassBase):
"""
Preprocess red nodes
"""
def call(self):
# Step 1: find the compute nodes with op=red
red_compute_nodes = []
for node_meta in self.dag_ir.nodes_meta:
if isinstance(node_meta, ComputeNode):
if type(node_meta.fn) == tuple:
# To keep the frontend simple, the reduction nodes
# are parsed into compute nodes by default
# The simple heuristic to distinguish between compute
# and reduction node is that compute node is a single function,
# while the reduction node is a tuple of functions for
# in-register reduction and atomic global memory reduction
red_compute_nodes.append(node_meta.name)
# Step 2: for each compute, merge it with the succeeding store
for node in red_compute_nodes:
# Verify
users = self.dag_ir.get_users(node)
inputs = self.dag_ir.get_all_inputs(node)
# Has a single user
assert len(users) == 1
assert len(inputs) == 1
user = users[0]
input = inputs[0]
user_meta = self.dag_ir.get_node_meta(user)
# Must be a store node
assert isinstance(user_meta, StoreNode)
# With output degree == 0
assert self.dag_ir.out_degree(user) == 0
# Register the reduce op
node_meta = self.dag_ir.get_node_meta(node)
user_meta.reg_reduce_fn, user_meta.gmem_reduce_fn = node_meta.fn
user_meta.element_compute = node_meta.element_compute
user_meta.round_style = node_meta.round_style
# Replace all uses
self.dag_ir.remove_edge(input, node)
input_users = self.dag_ir.get_users(input)
for iu in input_users:
weight = self.dag_ir.get_edge_weight(input, iu)
self.dag_ir.add_edge(user, iu, weight)
self.dag_ir.remove_edge(input, iu)
self.dag_ir.add_edge(input, user)
self.dag_ir.remove_node(node)
# Register the reduction name
self.dag_ir.reduction_names.append(user)

View File

@ -0,0 +1,59 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Shape and type propagation pass
"""
from cutlass_cppgen.backend.evt.ir.node import NodeBase
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
from cutlass_cppgen.backend.evt.passes.pass_preprocess_red import PassPreprocessRed
class PassShapeTypePropagation(EVTPassBase):
"""
Propagate the shape and type of all nodes
"""
dependencies = [PassPreprocessRed]
def call(self):
# Propagate the node shape and type
for node in self.dag_ir.nodes_topological_order():
node_meta: NodeBase = self.dag_ir.get_node_meta(node)
input_node_metas = self.dag_ir.get_all_inputs_meta(node)
node_meta.type_propagation(input_node_metas)
node_meta.shape_propagation(input_node_metas)
for node in reversed(self.dag_ir.nodes_topological_order()):
node_meta: NodeBase = self.dag_ir.get_node_meta(node)
input_node_metas = self.dag_ir.get_all_inputs_meta(node)
node_meta.broadcast_propagation(input_node_metas)

View File

@ -0,0 +1,319 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Compute the shared memory size in bytes
"""
from math import gcd
import cutlass_library
from pycute import flatten, shape_div, product
import cutlass_cppgen
from cutlass_cppgen.backend.evt.ir import TopoVisitorNode, DAGIR
from cutlass_cppgen.backend.library import DataType, DataTypeSize
class GetSmemSize:
"""
Get the size in byte of shared memory used by the kernel
"""
def __init__(self, dag_ir: DAGIR) -> None:
self.dag_ir = dag_ir
self.cc = self.dag_ir.cc
#
# Sm90 epilogue specific
#
def sm90_epilogue_tile(self, tile_description):
# Get the epilogue tile size
schedule = tile_description.epilogue_schedule
if schedule == cutlass_library.EpilogueScheduleType.TmaWarpSpecialized:
element_d = self.dag_ir.get_node_meta("D").element
nperf = 64 if (DataTypeSize[element_d] == 8 and tile_description.threadblock_shape[1] % 64 == 0) else 32
epi_tile_m = min(64, tile_description.threadblock_shape[0])
epi_tile_n = gcd(min(nperf, tile_description.threadblock_shape[1]), tile_description.threadblock_shape[1])
epilogue_tile_mn = (epi_tile_m, epi_tile_n)
elif schedule == cutlass_library.EpilogueScheduleType.TmaWarpSpecializedCooperative:
epi_tile_m = min(128, tile_description.threadblock_shape[0])
epi_tile_n = gcd(min(32, tile_description.threadblock_shape[1]), tile_description.threadblock_shape[1])
epilogue_tile_mn = (epi_tile_m, epi_tile_n)
else:
raise NotImplementedError(f"Unsupported schedule: {schedule}")
# Get the pipeline stages
stages_d = 2
epi_tiles = product(shape_div(tuple(tile_description.threadblock_shape)[:2], epilogue_tile_mn))
if self.dag_ir.has_node("C"):
element_c = self.dag_ir.get_node_meta("C").element
else:
element_c = None
element_d = self.dag_ir.get_node_meta("D").element
if element_c == element_d:
reuse_smem_c = True
else:
reuse_smem_c = False
stages_c = max(epi_tiles, stages_d + 1) if reuse_smem_c else epi_tiles
# Record the epilogue tile
self.cta_tile_mnk = tuple(tile_description.threadblock_shape)
self.epilogue_tile_mn = epilogue_tile_mn
self.epi_tiles = epi_tiles
self.stages_c = stages_c
self.stages_d = stages_d
self.reuse_smem_c = reuse_smem_c
self.element_c = element_c
self.element_d = element_d
self.is_source_supported = element_c is not None
def sm90_or_sm100_epilogue_smem_size(self, tile_description):
# Get the Fusion Storage
nodes = self.dag_ir.nodes_topological_order()
self.smem_types = {}
for node in nodes:
meta = self.dag_ir.get_node_meta(node)
if not meta.disabled:
self.smem_types[node] = meta.underlying_impl.get_smem_size(
self.cta_tile_mnk, self.epilogue_tile_mn,
self.stages_c, self.stages_d, self.epi_tiles)
if node == "D":
continue
if isinstance(meta, TopoVisitorNode):
self.get_dag_smem_type(node)
else:
self.get_evt_smem_type(node)
thread_smem_size = self.smem_types[self.dag_ir.get_all_inputs("D")[0]][0]
# Get the Tensor Storage
tensors = []
if self.is_source_supported:
smem_C = DataTypeSize[self.element_c] * product(self.epilogue_tile_mn) * self.stages_c // 8
tensors.append((smem_C, 128))
else:
tensors.append((0, 1))
if self.reuse_smem_c:
tensors.append((0, 128))
else:
smem_D = DataTypeSize[self.element_d] * product(self.epilogue_tile_mn) * self.stages_d // 8
tensors.append((smem_D, 128))
tensors.append((thread_smem_size, 128))
tensor_smem_size = self.get_struct_size(tensors)
# Get pipeline storage size
# sizeof(uint64_t * stages_c * 2), alignment of uint64_t
# 2 is for FullBarrier and EmptyBarrier
pipeline_smem_size = (8 * self.stages_c * 2, 8)
# get SharedStorage size
smem_size = self.get_struct_size([tensor_smem_size, pipeline_smem_size])
return smem_size[0]
def sm90_epilogue_smem_size(self, tile_description):
"""
Compute the shared memory size of sm90 collective epilogue
"""
self.sm90_epilogue_tile(tile_description)
return self.sm90_or_sm100_epilogue_smem_size(tile_description)
#
# Sm100 epilogue specific
#
def sm100_epilogue_tile(self, tile_description):
cta_tile = (tile_description.blackwell_threadblock_shape[0], tile_description.blackwell_threadblock_shape[1])
mma_tile = cta_tile
if tile_description.is_2sm:
cta_tile = (cta_tile[0] // 2, cta_tile[1])
if tile_description.is_2sm and mma_tile[0] == 128:
tmem_warps = (2, 2)
else:
tmem_warps = (4, 1)
if self.dag_ir.has_node("C"):
element_c = self.dag_ir.get_node_meta("C").element
element_c_size = DataTypeSize[element_c]
else:
element_c = None
element_c_size = 0
element_d = self.dag_ir.get_node_meta("D").element
DisableSource = element_c is None or not self.dag_ir.has_node("C") or self.dag_ir.get_node_meta("C").element == DataType.void
CtaM = cta_tile[0]
CtaN = cta_tile[1]
WarpM = tmem_warps[0]
WarpN = tmem_warps[1]
MaxBits = max(element_c_size, DataTypeSize[element_d])
DpFull = 32
M = min(CtaM, DpFull * WarpM)
if DisableSource:
# Epilogues w/o residual load are less sensitive to smem allocation
# Target a fixed amount of compute per epilogue iteration
if MaxBits == 4:
# Make epilogue tile larger to reduce the epilogue iterations.
# 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same.
ComputeElts = 8192
Nperf = ComputeElts // M
else:
ComputeElts = 4096
Nperf = ComputeElts // M
else:
# Epilogues w/ residual load are more sensitive to smem allocation
# Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize
if MaxBits == 32:
Nperf = 16 if CtaM > 64 and CtaN <= 128 else 32
elif MaxBits == 16:
Nperf = 32 if CtaN <= 128 else 64
else:
Nperf = 64
def is_m_major(layout):
return flatten(layout.stride[0]) == 1
if DisableSource or is_m_major(self.dag_ir.get_node_meta("C").tensor.layout):
N_min_C = 8 * WarpN
elif element_c_size == 6:
N_min_C = 128 * WarpN
else:
N_min_C = (128 // element_c_size) * WarpN
if is_m_major(self.dag_ir.get_node_meta("D").tensor.layout):
N_min_D = 8 * WarpN
elif DataTypeSize[element_d] == 6:
N_min_D = 128 * WarpN
else:
N_min_D = (128 // DataTypeSize[element_d]) * WarpN
N = min(CtaN, max(Nperf, N_min_C, N_min_D))
tile_m = M
tile_n_size = N // WarpN * WarpN
epilogue_tile_mn = (tile_m, tile_n_size)
epi_tiles = product(shape_div(tuple(tile_description.threadblock_shape)[:2], epilogue_tile_mn))
stages_d = min(epi_tiles, 2)
reuse_smem_c = (element_c_size > 8)
if reuse_smem_c:
stages_c = max(min(epi_tiles, 4), stages_d + 1)
else:
stages_c = min(epi_tiles, 4)
# Record the epilogue tile
self.cta_tile_mnk = tuple(tile_description.threadblock_shape)
self.epilogue_tile_mn = epilogue_tile_mn
self.epi_tiles = epi_tiles
self.stages_c = stages_c
self.stages_d = stages_d
self.reuse_smem_c = reuse_smem_c
self.element_c = element_c
self.element_d = element_d
self.is_source_supported = not DisableSource
def sm100_epilogue_smem_size(self, tile_description):
"""
Compute the shared memory size of sm100 collective epilogue
"""
self.sm100_epilogue_tile(tile_description)
return self.sm90_or_sm100_epilogue_smem_size(tile_description)
def __call__(self, tile_description):
return getattr(self, f"sm{self.cc}_epilogue_smem_size")(tile_description)
#
# Helper functions
#
@staticmethod
def get_visitor_size(members: list, ebo: bool):
"""
Get the size of struct in bytes
"""
offset = 0
max_alignment = 1
if len(members) > 0:
# Get alignment
for _, alignment in members:
max_alignment = max(max_alignment, alignment)
for type_size, _ in members:
if type_size != 0:
offset = ((offset + max_alignment - 1) // max_alignment) * max_alignment
if type_size == 0 and not ebo:
offset += 1
else:
offset += type_size
offset = ((offset + max_alignment - 1) // max_alignment) * max_alignment
return (offset, max_alignment)
else:
# Struct size is at least 1
return (1, 1)
def get_struct_size(self, members: list):
"""
Get the size of struct in bytes
"""
return self.get_visitor_size(members, False)
def get_evt_smem_type(self, node):
# Sort the input nodes by edge weight
input_types = [self.smem_types[child] for child in self.dag_ir.get_all_inputs(node)]
input_types.append(self.smem_types[node])
if len(input_types) > 1:
ebo = len(input_types) > 4
self.smem_types[node] = self.get_visitor_size(input_types, ebo)
def get_dag_smem_type(self, node):
meta = self.dag_ir.get_node_meta(node)
subgraph = meta.subgraph
subgraph_nodes = subgraph.nodes_topological_order()
# Visit the unvisited nodes in subgraph
for n in subgraph_nodes:
m = subgraph.get_node_meta(n)
if m.disabled:
continue
else:
self.smem_types[n] = m.underlying_impl.get_smem_size(
self.cta_tile_mnk, self.epilogue_tile_mn,
self.stages_c, self.stages_d, self.epi_tiles)
input_types = [self.smem_types[child] for child in subgraph_nodes[:-1]]
if len(input_types) > 0:
ebo = len(input_types) > 4
self.smem_types[node] = self.get_visitor_size(input_types, ebo)

View File

@ -0,0 +1,46 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Utilities for passes
"""
# Map from the CC of the kernel to the EVT implementation that the CC targets
cc_map = {
80: 80,
86: 80,
89: 80,
90: 90,
100: 100,
101: 100,
103: 100,
}

View File

@ -0,0 +1,109 @@
#################################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from __future__ import annotations
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
import numpy as np
from cutlass_cppgen.backend.memory_manager import device_mem_alloc, todevice
from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor
class NumpyFrontend:
"""
Frontend node for numpy
"""
@staticmethod
def argument(np_tensor: "np.ndarray", is_output: "bool") -> cuda.CUdeviceptr:
"""Convert the input numpy tensor to CUDA device pointer
:param np_tensor: input numpy nd array
:param is_output: whether the tensor is output
:return: CUDA device pointer
"""
# copy the data to device
if is_output:
return device_mem_alloc(np_tensor.size * np_tensor.itemsize)
else:
return todevice(np_tensor)
class TorchFrontend:
"""
Frontend node for torch
"""
@staticmethod
def argument(torch_tensor: "torch.Tensor") -> cuda.CUdeviceptr:
"""Convert the input torch tensor to CUDA device pointer
:param torch_tensor: input torch tensor
:param is_output: whether the tensor is output
:return: CUDA device pointer
"""
# check the device of torch_tensor
if not torch_tensor.is_cuda:
torch_tensor = torch_tensor.to("cuda")
return cuda.CUdeviceptr(torch_tensor.data_ptr())
class CupyFrontend:
"""
Frontend node for cupy
"""
@staticmethod
def argument(cupy_ndarray: "cp.ndarray"):
return cuda.CUdeviceptr(int(cupy_ndarray.data.ptr))
class TensorFrontend:
"""
Universal Frontend for client-provide tensors
"""
@staticmethod
def argument(tensor, is_output=False):
if is_numpy_tensor(tensor):
return NumpyFrontend.argument(tensor, is_output)
elif is_torch_tensor(tensor):
return TorchFrontend.argument(tensor)
elif is_cupy_tensor(tensor):
return CupyFrontend.argument(tensor)
else:
raise NotImplementedError("Unknown Tensor Type")

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,509 @@
#################################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Common data types and string names/tags for them
"""
import enum
from cutlass_library import (
ComplexTransform,
DataType,
DataTypeSize,
EpilogueScheduleType,
KernelScheduleSuffixes,
KernelScheduleType,
MathOperation,
OpcodeClass,
TileSchedulerType
)
# The following block implements enum.auto() for Python 3.5 variants that don't include it such
# as the default 3.5.2 on Ubuntu 16.04.
#
# https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility
try:
from enum import auto as enum_auto
except ImportError:
__cutlass_library_auto_enum = 0
def enum_auto() -> int:
global __cutlass_library_auto_enum
i = __cutlass_library_auto_enum
__cutlass_library_auto_enum += 1
return i
class DataTypeSizeBytes:
"""
Static class to mimic the `DataTypeSize` dictionary, but with checks for whether the
data type key is less than a full byte or a non-integer number of bytes.
"""
@staticmethod
def __class_getitem__(datatype):
"""
Returns the number of bytes in size the data type is. Raises an exception if the data type
is either less than a full byte or a non-integer number of bytes in size.
:param datatype: data type to query
:return: number of bytes the data type occupies
:rtype: int
"""
bits = DataTypeSize[datatype]
if bits < 8:
raise Exception(
f"Data type {datatype} is less than one byte in size."
)
elif bits % 8 != 0:
raise Exception(
f"Data type datatype is not an integer number of bytes."
)
return bits // 8
class SchedulerMode(enum.Enum):
Device = enum_auto()
Host = enum_auto()
SchedulerModeTag = {
SchedulerMode.Device: "cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly",
SchedulerMode.Host: "cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute",
}
ShortSchedulerModeNames = {SchedulerMode.Device: "Device", SchedulerMode.Host: "Host"}
class FunctionalOp(enum.Enum):
AtomicAdd = enum_auto()
AtomicMaximum = enum_auto()
Divides = enum_auto()
Maximum = enum_auto()
Minimum = enum_auto()
Minus = enum_auto()
Multiplies = enum_auto()
MultiplyAdd = enum_auto()
Plus = enum_auto()
Exp = enum_auto()
FunctionalOpTag = {
FunctionalOp.AtomicAdd: "cutlass::atomic_add",
FunctionalOp.AtomicMaximum: "cutlass::atomic_maximum",
FunctionalOp.Divides: "cutlass::divides",
FunctionalOp.Maximum: "cutlass::maximum",
FunctionalOp.Minimum: "cutlass::minimum",
FunctionalOp.Minus: "cutlass::minus",
FunctionalOp.Multiplies: "cutlass::multiplies",
FunctionalOp.MultiplyAdd: "cutlass::multiply_add",
FunctionalOp.Plus: "cutlass::plus",
FunctionalOp.Exp: "cutlass::fast_exp_op",
}
class ActivationOp(enum.Enum):
DGelu = enum_auto()
Gelu = enum_auto()
GeluTaylor = enum_auto()
HardSwish = enum_auto()
Identity = enum_auto()
LeakyReLU = enum_auto()
ReLU = enum_auto()
Sigmoid = enum_auto()
SiLU = enum_auto()
Tanh = enum_auto()
ActivationOpTag = {
ActivationOp.DGelu: "cutlass::epilogue::thread::dGELU",
ActivationOp.Gelu: "cutlass::epilogue::thread::GELU",
ActivationOp.GeluTaylor: "cutlass::epilogue::thread::GELU_taylor",
ActivationOp.HardSwish: "cutlass::epilogue::thread::HardSwish",
ActivationOp.Identity: "cutlass::epilogue::thread::Identity",
ActivationOp.LeakyReLU: "cutlass::epilogue::thread::LeakyReLU",
ActivationOp.ReLU: "cutlass::epilogue::thread::ReLu",
ActivationOp.Sigmoid: "cutlass::epilogue::thread::Sigmoid",
ActivationOp.SiLU: "cutlass::epilogue::thread::SiLu",
ActivationOp.Tanh: "cutlass::epilogue::thread::Tanh",
}
def op_tag(op) -> str:
"""
Dispatches `op` to the appropriate *Tag dictionary depending on whether
`op` is an ActivationOp or FunctionalOp. This is useful for cases in which
either type can be used.
:param op: operation to emit a tag for
:type op: ActivationOp | FunctionalOp
:return: tag corresponding to op
:rtype: str
"""
if isinstance(op, ActivationOp):
return ActivationOpTag[op]
elif isinstance(op, FunctionalOp):
return FunctionalOpTag[op]
else:
raise Exception(f"Unexpected op type {op}. Must be one of ActivationOp or FunctionalOp.")
class FloatRoundStyle(enum.Enum):
ToNearest = enum_auto()
ToNearestSatfinite = enum_auto()
Indeterminate = enum_auto()
TowardZero = enum_auto()
TowardInfinity = enum_auto()
TowardNegInfinity = enum_auto()
HalfUlpTruncDntz = enum_auto()
HalfUlpTruncate = enum_auto()
FloatRoundStyleTag = {
FloatRoundStyle.ToNearest: "cutlass::FloatRoundStyle::round_to_nearest",
FloatRoundStyle.ToNearestSatfinite: "cutlass::FloatRoundStyle::round_to_nearest_satfinite",
FloatRoundStyle.Indeterminate: "cutlass::FloatRoundStyle::round_indeterminate",
FloatRoundStyle.TowardZero: "cutlass::FloatRoundStyle::round_toward_zero",
FloatRoundStyle.TowardInfinity: "cutlass::FloatRoundStyle::round_toward_infinity",
FloatRoundStyle.TowardNegInfinity: "cutlass::FloatRoundStyle::round_toward_neg_infinity",
FloatRoundStyle.HalfUlpTruncDntz: "cutlass::FloatRoundStyle::round_half_ulp_trunc_dntz",
FloatRoundStyle.HalfUlpTruncate: "cutlass::FloatRoundStyle::round_half_ulp_truncate",
}
class MathInstruction:
"""
Description of a the lowest-level matrix-multiply-accumulate operation to be used in a kernel
"""
def __init__(
self,
instruction_shape,
element_a,
element_b,
element_accumulator,
opcode_class=OpcodeClass.Simt,
math_operation=MathOperation.multiply_add,
):
"""
:param instruction_shape: size of the [M, N, K] dimensions of the instruction
:type instruction_shape: list or tuple
:param element_a: data type of operand A
:param element_b: data type of operand B
:param element_accumulator: data type used in accumulation
:param opcode_class: higher-level class of the instruction (e.g., SIMT or Tensor Core)
:type opcode_class: cutlass_library.library.OpcodeClass
:param math_operation: the type of low-level operation to be performed (e.g., multiply accumulate)
:type math_operation: MathOperation
"""
self.instruction_shape = instruction_shape
self.element_a = element_a
self.element_b = element_b
self.element_accumulator = element_accumulator
self.opcode_class = opcode_class
self.math_operation = math_operation
def to_blackwell_threadblock_shape(tile_description, cluster_shape, kernel_schedule):
blackwell_threadblock_shape = tile_description.threadblock_shape
is_2sm = False if kernel_schedule is None else ("2sm" in KernelScheduleSuffixes[kernel_schedule])
if cluster_shape[0] > 0:
blackwell_threadblock_shape = [
tile_description.threadblock_shape[0] // cluster_shape[0],
tile_description.threadblock_shape[1] // cluster_shape[1],
tile_description.threadblock_shape[2] // cluster_shape[2]
]
if is_2sm:
blackwell_threadblock_shape[0] *= 2
else:
blackwell_threadblock_shape = tile_description.math_instruction.instruction_shape
return blackwell_threadblock_shape, is_2sm
class TileDescription:
"""
Description of a tile of computation to be performed in the kernel, encompassing threadblock, cluster, and warp shapes,
stage count, and math instruction specification
"""
def __init__(
self,
threadblock_shape,
stages,
warp_count,
math_instruction,
cluster_shape=[1, 1, 1],
kernel_schedule: KernelScheduleType = None,
epilogue_schedule: EpilogueScheduleType = None,
tile_scheduler: TileSchedulerType = None
):
"""
:param threadblock_shape: shape of a threadblock tyle
:type threadblock_shape: list or tuple
:param stages: number of pipline stages in the operation. For SM90 kernels, this can be set to `None` and the maximum
number of stages that can be supported for an operation on a given architecture will be computed at a later time
:type stages: int or None
:param warp_count: number of warps in each [M, N, K] dimension of a threadblock tile
:type warp_count: list, tuple, or None
:param math_instruction: specification of the instruction type and shape to be performed and the types of its operands
:type math_instruction: MathInstruction
:param cluster_shape: number of threadblocks in the [X, Y, Z] dimensions of a threadblock cluster
:param kernel_schedule: type of kernel schedule to use (only available for SM90+)
:type kernel_schedule: cutlass_library.KernelScheduleType
:param epilogue_schedule: type of epilogue schedule to use (only available for SM90+)
:type epilogue_schedule: cutlass_library.EpilogueScheduleType
:param tile_scheduler: type of tile scheduler to use (only available for SM90+)
:type tile_scheduler: cutlass_library.TileSchedulerType
"""
if ((kernel_schedule is None and epilogue_schedule is not None) or
(kernel_schedule is not None and epilogue_schedule is None)):
raise Exception("Kernel and epilogue schedule must either both be Auto or neither be Auto.")
self.threadblock_shape = threadblock_shape
self.cluster_shape = cluster_shape
self.kernel_schedule = kernel_schedule
self.epilogue_schedule = epilogue_schedule
self.tile_scheduler = tile_scheduler
self.stages = stages
self.math_instruction = math_instruction
self.instruction_shape = math_instruction.instruction_shape
# Number of warps along x, y, z directions
self.warp_count = warp_count
self.blackwell_threadblock_shape, self.is_2sm = to_blackwell_threadblock_shape(self, self.cluster_shape, self.kernel_schedule)
def clone_and_update(self, td: dict):
attrs = {
"cluster_shape": None,
"threadblock_shape": None,
"warp_count": None,
"stages": None,
"instruction_shape": None,
"kernel_schedule": None,
"epilogue_schedule": None,
"tile_scheduler": None
}
for key in attrs.keys():
if key in td.keys():
attrs[key] = td[key]
else:
attrs[key] = getattr(self, key)
attrs["math_instruction"] = MathInstruction(
attrs["instruction_shape"],
self.math_instruction.element_a,
self.math_instruction.element_b,
self.math_instruction.element_accumulator,
self.math_instruction.opcode_class,
self.math_instruction.math_operation
)
# Remove the instruction shape
del attrs["instruction_shape"]
return TileDescription(**attrs)
@property
def num_threads(self):
"""
Returns the number of threads in the threadblock
:return: number of threads in the threadblock
:rtype: int or None (if warp count is None)
"""
if self.warp_count is not None:
threads = 32
for cnt in self.warp_count:
threads *= cnt
return threads
return None
def procedural_name(self):
"""
Returns a name identifying the tile description
:return: name identifying the tile description
:rtype: int
"""
emit_stages = 0 if self.stages is None else self.stages
name = "%dx%dx%d_%dx%d_%dx%d" % (
self.cluster_shape[0],
self.cluster_shape[1],
self.cluster_shape[2],
self.threadblock_shape[0],
self.threadblock_shape[1],
self.threadblock_shape[2],
emit_stages
)
return name
def procedural_name_2x(self):
"""
Returns a name identifying the tile description
:return: name identifying the tile description
:rtype: int
"""
return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages)
def __str__(self):
"""
Returns a string with containing each of the tile description's values
:return: contents of tile description
:rtype: str
"""
if self.kernel_schedule is not None:
kschedule = self.kernel_schedule
else:
kschedule = KernelScheduleType.ScheduleAuto
if self.epilogue_schedule is not None:
eschedule = self.epilogue_schedule
else:
eschedule = EpilogueScheduleType.ScheduleAuto
if self.tile_scheduler is not None:
tschedule = self.tile_scheduler.name
else:
tschedule = "None"
return f"""
{{
ClusterShape: {self.cluster_shape}
ThreadblockShape: {self.threadblock_shape}
WarpCount: {self.warp_count}
Stages: {self.stages if self.stages is not None else 'Auto'}
InstructionShape: {self.math_instruction.instruction_shape}
Kernel schedule: {kschedule.name}
Epilogue schedule: {kschedule.name}
TileScheduler: {tschedule}
}}"""
class TensorDescription:
def __init__(self, element, layout, alignment=1, complex_transform=ComplexTransform.none):
self.element = element
self.layout = layout
if element != DataType.void:
self.alignment = min(128 // DataTypeSize[self.element], alignment)
else:
self.alignment = alignment
self.complex_transform = complex_transform
def CalculateSmemUsagePerStage(operation):
"""
Returns the amount of shared memory in bytes consumed in a single stage of a kernel.
:param op: operation for which the maximum stages should be computed. If stages are
set via the `op.tile_description.stages` parameter, this setting is ignored
in the present calculation
:type op: cutlass_cppgen.backend.Operation
:return: number of bytes of shared memory consumed by a single stage
:rtype: int
"""
m, n, k = operation.tile_description.threadblock_shape
if operation.operation_kind == OperationKind.Gemm:
stage_barrier_bytes = 32
return (
(DataTypeSize[operation.A.element] * m * k // 8)
+ (DataTypeSize[operation.B.element] * k * n // 8)
+ stage_barrier_bytes
)
else:
raise Exception("Unsupported operation kind {}.".format(operation.operation_kind))
def CalculateSmemUsage(operation):
"""
Returns the amount of shared memory in bytes consumed by a kernel.
:param op: operation for which the maximum stages should be computed. If stages are
set via the `op.tile_description.stages` parameter, this setting is ignored
in the present calculation
:type op: cutlass_cppgen.backend.Operation
:return: int
"""
return operation.tile_description.stages * CalculateSmemUsagePerStage(operation)
class ApiVersion(enum.Enum):
"""
Differentiate between CUTLASS 2.x and 3.x API versions
"""
v2x = enum_auto()
v3x = enum_auto()
def api_version(arch, opclass, dtype):
"""
Returns whether the architecture, opcode class, and datatype in question require using CUTLASS 2.x
or 3.x for code emission.
:param arch: compute capability of device on which to run
:type arch: int
:param opclass: class of the operation being performed
:type opclass: cutlass_library.OpcodeClass
:param dtype: data type to be used in operation (assumes that ElementA and ElementB are the same)
:type dtype: cutlass_library.DataType
:return: API version to be used in code emission
:rtype: ApiVersion
"""
if (arch in [90, 100, 101, 103] and
opclass == OpcodeClass.TensorOp and
(dtype != DataType.f64)):
return ApiVersion.v3x
else:
return ApiVersion.v2x
class EmissionType(enum.Enum):
"""
Tags for whether to emit a kernel- or device-level operation
"""
Kernel = enum_auto()
Device = enum_auto()

View File

@ -0,0 +1,121 @@
#################################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import numpy as np
import cutlass_cppgen
from cutlass_cppgen.utils.datatypes import is_numpy_tensor
from cutlass_cppgen.utils.lazy_import import lazy_import
if cutlass_cppgen.use_rmm:
import rmm
else:
cudart = lazy_import("cuda.cudart")
class PoolMemoryManager:
def __init__(self, init_pool_size: int, max_pool_size: int) -> None:
self.pool = rmm.mr.PoolMemoryResource(
rmm.mr.CudaMemoryResource(),
initial_pool_size=init_pool_size,
maximum_pool_size=max_pool_size
)
self.mr = rmm.mr.TrackingResourceAdaptor(self.pool)
rmm.mr.set_current_device_resource(self.mr)
def pool_size(self):
return self.pool.pool_size()
class DevicePtrWrapper:
"""
Wrapper around a pointer to device memory to provide a uniform interface with the RMM DeviceBuffer
(at least in terms of the interface used by the CUTLASS Python interface)
"""
def __init__(self, dev_ptr):
self.dev_ptr = dev_ptr
@property
def ptr(self):
return self.dev_ptr
def _todevice(host_data):
"""
Helper for transferring host data to device memory
"""
if cutlass_cppgen.use_rmm:
return rmm.DeviceBuffer.to_device(host_data.tobytes())
else:
nbytes = len(host_data.tobytes())
dev_ptr_wrapper = device_mem_alloc(nbytes)
err, = cudart.cudaMemcpy(
dev_ptr_wrapper.ptr,
host_data.__array_interface__['data'][0],
nbytes,
cudart.cudaMemcpyKind.cudaMemcpyHostToDevice
)
if err != cudart.cudaError_t.cudaSuccess:
raise Exception(f"cudaMemcpy failed with error {err}")
return dev_ptr_wrapper
def todevice(host_data, dtype=np.float32):
"""
Pass the host_data to device memory
"""
if isinstance(host_data, list):
return _todevice(np.array(host_data, dtype=dtype))
elif is_numpy_tensor(host_data):
return _todevice(host_data)
def device_mem_alloc(size):
if cutlass_cppgen.use_rmm:
return rmm.DeviceBuffer(size=size)
else:
err, ptr = cudart.cudaMalloc(size)
if err != cudart.cudaError_t.cudaSuccess:
raise Exception(f"cudaMalloc failed with error {err}")
return DevicePtrWrapper(ptr)
def align_size(size, alignment=256):
return ((size + alignment - 1) // alignment) * alignment
def create_memory_pool(init_pool_size=0, max_pool_size=2 ** 34):
if cutlass_cppgen.use_rmm:
memory_pool = PoolMemoryManager(init_pool_size=init_pool_size, max_pool_size=max_pool_size)
return memory_pool
else:
return None

View File

@ -0,0 +1,140 @@
#################################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import ctypes
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
from cutlass_cppgen.backend.utils.device import device_cc
_supports_cluster_launch = None
def supports_cluster_launch():
from cuda import __version__
_version_splits = [int(x) for x in __version__.split("rc")[0].split(".post")[0].split(".")]
global _supports_cluster_launch
if _supports_cluster_launch is None:
major, minor = _version_splits[0], _version_splits[1]
_supports_cluster_launch = device_cc() in [90, 100, 101, 103] and (major > 11 or (major == 11 and minor >= 8))
return _supports_cluster_launch
class LaunchConfiguration:
def __init__(self, grid=[1, 1, 1], block=[1, 1, 1], smem=0):
self.grid = grid
self.block = block
self.shared_memory_capacity = smem
class ExecutableOperation:
def __init__(self, operation):
self.operation = operation
self.module = None
self.kernel = None
def name(self):
return self.operation.procedural_name()
def emit(self):
return ""
def can_implement(self, configuration, arguments):
raise NotImplementedError()
def get_host_workspace_size(self, arguments):
raise NotImplementedError()
def get_device_workspace_size(self, arguments):
raise NotImplementedError()
def plan(self, arguments):
raise NotImplementedError()
def initialize(self, host_workspace, device_workspace, launch_config, arguments, stream=None):
raise NotImplementedError()
def run_with_clusters(self, launch_config, kernel_params, stream=None):
if not stream:
stream = cuda.CUstream(0)
if hasattr(self.operation, "tile_description") and hasattr(self.operation.tile_description, "cluster_shape"):
attr = cuda.CUlaunchAttribute()
attr.value.clusterDim.x, attr.value.clusterDim.y, attr.value.clusterDim.z = self.operation.tile_description.cluster_shape
attr.id = cuda.CUstreamAttrID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
attrs = [attr]
# Allow for non-portable cluster sizes
err, = cuda.cuFuncSetAttribute(
self.kernel, cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1)
if err != cuda.CUresult.CUDA_SUCCESS:
return err
else:
attrs = []
config = cuda.CUlaunchConfig()
config.gridDimX, config.gridDimY, config.gridDimZ = launch_config.grid
config.blockDimX, config.blockDimY, config.blockDimZ = launch_config.block
config.blockDimZ = launch_config.block[2]
config.sharedMemBytes = launch_config.shared_memory_capacity
config.hStream = stream
config.attrs = attrs
config.numAttrs = len(attrs)
err, = cuda.cuLaunchKernelEx(
config, f=self.kernel, kernelParams=kernel_params, extra=0)
return err
def run_without_clusters(self, launch_config, kernel_params, stream=None):
if not stream:
stream = cuda.CUstream(0)
err, = cuda.cuLaunchKernel(
self.kernel,
launch_config.grid[0], launch_config.grid[1], launch_config.grid[2],
launch_config.block[0], launch_config.block[1], launch_config.block[2],
launch_config.shared_memory_capacity,
stream,
kernel_params,
0)
return err
def run(self, host_workspace, device_workspace, launch_config, stream=None):
if not stream:
stream = cuda.CUstream(0)
cArg = (ctypes.c_char * len(host_workspace)).from_buffer(host_workspace)
packed = (ctypes.c_void_p * 1)()
packed[0] = ctypes.addressof(cArg)
if supports_cluster_launch():
return self.run_with_clusters(launch_config, packed, stream)
else:
return self.run_without_clusters(launch_config, packed, stream)

View File

@ -0,0 +1,455 @@
################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
from __future__ import annotations
import ctypes
from typing import Union
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
cudart = lazy_import("cuda.cudart")
import numpy as np
from cutlass_library import (
DataTypeNames,
DataTypeSize,
DataTypeTag,
LayoutType,
SubstituteTemplate
)
import cutlass_cppgen
from cutlass_cppgen.backend.c_types import MatrixCoord_, TensorRef2D_, get_reduction_params
from cutlass_cppgen.backend.frontend import NumpyFrontend, TorchFrontend
from cutlass_cppgen.backend.library import TensorDescription
from cutlass_cppgen.backend.memory_manager import DevicePtrWrapper
from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration
from cutlass_cppgen.shape import MatrixCoord
from cutlass_cppgen.utils.datatypes import is_numpy_tensor, is_torch_tensor
class ReductionOperation:
pass
class ReductionArguments:
"""
Arguments of reduction
"""
def __init__(
self,
operation: ReductionOperation,
problem_size: "list[int]",
partitions: int,
workspace: cuda.CUdeviceptr,
destination: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]",
source: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]",
**kwargs,
) -> None:
# tensor_C can be interpreted as the bias with bias=True in keyword args
if "bias" in kwargs.keys():
self.bias = kwargs["bias"]
else:
# by default, tensor_C is not bias
self.bias = False
if "stream" in kwargs.keys():
self.stream = kwargs["stream"]
else:
self.stream = cuda.CUstream(0)
self.operation = operation
self.ptr_workspace = workspace
# number of split-k partitions
self.partitions = partitions
if is_numpy_tensor(destination):
self.host_D = destination
self.destination_buffer = NumpyFrontend.argument(destination, True)
self.source_buffer = NumpyFrontend.argument(source, False)
self.ptr_destination = cuda.CUdeviceptr(self.destination_buffer.ptr)
self.ptr_source = cuda.CUdeviceptr(self.source_buffer.ptr)
elif is_torch_tensor(destination):
self.ptr_destination = TorchFrontend.argument(destination)
self.ptr_source = TorchFrontend.argument(source)
elif isinstance(destination, cuda.CUdeviceptr):
self.ptr_destination = destination
self.ptr_source = source
else:
raise TypeError("unknown Type")
self.problem_size = MatrixCoord_(problem_size[0], problem_size[1])
self.partition_stride = (
problem_size[0] * problem_size[1] * DataTypeSize[operation.C.element] // 8
)
if "output_op" in kwargs.keys():
self.output_op = kwargs["output_op"]
else:
self.output_op = self.operation.epilogue_type(1.0, 0.0)
self.get_arguments()
@staticmethod
def get_tensor_ref(
extent: "tuple[int]",
device_ptr: cuda.CUdeviceptr,
layout: LayoutType,
):
if layout == LayoutType.RowMajor:
return TensorRef2D_(int(device_ptr), extent[1])
else:
raise ValueError(f"Unknown layout type {layout}")
def get_arguments(self):
ref_workspace = ReductionArguments.get_tensor_ref(
extent=[
self.problem_size.row,
self.problem_size.column,
],
device_ptr=self.ptr_workspace,
layout=LayoutType.RowMajor,
)
if self.bias:
ref_source = ReductionArguments.get_tensor_ref(
extent=[0, 0],
device_ptr=self.ptr_source,
layout=LayoutType.RowMajor,
)
else:
ref_source = ReductionArguments.get_tensor_ref(
extent=[
self.problem_size.row,
self.problem_size.column,
],
device_ptr=self.ptr_source,
layout=LayoutType.RowMajor,
)
ref_destination = ReductionArguments.get_tensor_ref(
extent=[
self.problem_size.row,
self.problem_size.column,
],
device_ptr=self.ptr_destination,
layout=LayoutType.RowMajor,
)
self.c_arguments = self.operation.argument_type(
self.problem_size,
self.partitions,
self.partition_stride,
ref_workspace,
ref_destination,
ref_source,
self.output_op,
)
params_ = self.operation.rt_module.get_args(ctypes.byref(self.c_arguments))
self.host_workspace = bytearray(params_.contents)
def sync(self):
(err,) = cudart.cudaDeviceSynchronize()
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"CUDA Error {str(err)}")
if hasattr(self, "host_D"):
(err,) = cuda.cuMemcpyDtoH(
self.host_D,
self.ptr_destination,
self.host_D.size * self.host_D.itemsize,
)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
self.free()
def free(self):
"""
Frees allocated device-side memory
"""
# Free any device memory allocated manually
if not cutlass_cppgen.use_rmm:
for attr in ["destination_buffer", "source_buffer"]:
if hasattr(self, attr):
buf = getattr(self, attr)
if isinstance(buf, DevicePtrWrapper):
err, = cudart.cudaFree(buf.ptr)
if err != cudart.cudaError_t.cudaSuccess:
raise RuntimeError(f"cudaFree failed with error {err}")
del buf
class ReductionRT(ExecutableOperation):
"""
ReductionRT manages the CUTLASS runtime components for reduction
"""
KernelTemplate = r"""
extern "C"
__global__ void
${operation_name}(${operation_name}${operation_suffix}::Params params) {
// Dynamic shared memory base pointer
extern __shared__ int SharedStorageBase[];
// Declare pointer to dynamic shared memory.
${operation_name}${operation_suffix}::SharedStorage *shared_storage =
reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase);
${operation_name}${operation_suffix} op;
op(params, *shared_storage);
}
"""
HostTemplate = r"""
extern "C" {
// Get the size of params in bytes
int ${operation_name}_get_param_size(){
return sizeof(${operation_name}${operation_suffix}::Params);
}
// Get the size of dynamic shared memory in bytes
int ${operation_name}_shared_memory_size() {
return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
}
// Get the params as byte array
char* ${operation_name}_get_params(${operation_name}${operation_suffix}::Params* params){
char *bytes = ((char*)(params));
char *output = new char[sizeof(${operation_name}${operation_suffix}::Params)];
for (unsigned int i = 0; i < sizeof(${operation_name}${operation_suffix}::Params); i ++)
output[i] = bytes[i];
return output;
}
}
"""
def __init__(self, operation: ReductionOperation):
super().__init__(operation)
self.operation: ReductionOperation = operation
self.emitter = EmitReductionInstance("_type")
self.elements_per_access = self.operation.count
(
self.argument_type,
self.epilogue_type,
) = get_reduction_params(operation.epilogue_functor)
self.argtype = [ctypes.POINTER(self.argument_type)]
def emit(self):
return self.emitter.emit(self.operation)
def plan(self, arguments: ReductionArguments):
block_shape = [
self.operation.shape.column // self.elements_per_access,
self.operation.shape.row,
1,
]
grid_shape = [
(arguments.problem_size.row + self.operation.shape.row - 1)
// self.operation.shape.row,
(arguments.problem_size.column + self.operation.shape.column - 1)
// self.operation.shape.column,
1,
]
return LaunchConfiguration(
grid_shape,
block_shape,
self.shared_memory_capacity,
)
def initialize(self):
(err,) = cuda.cuFuncSetAttribute(
self.kernel,
attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
value=self.shared_memory_capacity,
)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"CUDA Error: {err}")
class ReductionOperation:
"""
CUTLASS reduction Operation
"""
def __init__(
self,
shape: MatrixCoord,
C: TensorDescription,
element_accumulator,
element_workspace=None,
element_compute=None,
epilogue_functor=None,
count: int = 1,
partitions_per_stage: int = 4,
) -> None:
self.shape = shape
self.epilogue_functor = epilogue_functor
self.element_accumulator = element_accumulator
if element_workspace is None:
self.element_workspace = element_accumulator
else:
self.element_workspace = element_workspace
if element_compute is None:
self.element_compute = element_accumulator
else:
self.element_compute = element_compute
self.element_output = C.element
self.C: TensorDescription = C
# Reduce op processing size
self.count: int = count
# Number of partitions to reduce per stage
self.partitions_per_stage: int = partitions_per_stage
self.rt_module: ReductionRT = ReductionRT(self)
self.argument_type = self.rt_module.argument_type
self.epilogue_type = self.rt_module.epilogue_type
def extended_name(self):
extend_name = "${element_workspace}_${element_accumulator}_${element_compute}_${element_output}"
return SubstituteTemplate(
extend_name,
{
"element_workspace": DataTypeNames[self.element_workspace],
"element_accumulator": DataTypeNames[self.element_accumulator],
"element_compute": DataTypeNames[self.element_compute],
"element_output": DataTypeNames[self.element_output],
},
)
def configuration_name(self):
"""The full procedural name indicates architecture, extended name, tile size"""
configuration_name = "cutlass_reduce_split_k_${extended_name}_${threadblock}"
threadblock = "%dx%d" % (
self.shape.row,
self.shape.column,
)
return SubstituteTemplate(
configuration_name,
{
"extended_name": self.extended_name(),
"threadblock": threadblock,
},
)
def procedural_name(self):
"""The full procedural name indicates architeture, extended name, tile size"""
return self.configuration_name()
def run(self, arguments: ReductionArguments) -> cuda.CUresult:
"""
Configure and launch the cuda kernel with input arguments
"""
launch_config = self.rt_module.plan(arguments)
host_workspace = arguments.host_workspace
device_workspace = None
err = self.rt_module.run(
host_workspace,
device_workspace,
launch_config,
arguments.stream
)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"CUDA Error {str(err)}")
return err
class EmitReductionInstance:
def __init__(self, operation_suffix="") -> None:
self.operation_suffix = operation_suffix
self.includes = [
"cutlass/cutlass.h",
"cutlass/numeric_types.h",
"cutlass/arch/arch.h",
"cutlass/arch/mma.h",
"cutlass/layout/matrix.h",
"cutlass/gemm/device/gemm.h",
"cutlass/gemm/device/gemm_universal_adapter.h",
"cutlass/gemm/kernel/default_gemm_universal.h",
"cutlass/reduction/kernel/reduce_split_k.h",
"cutlass/reduction/thread/reduction_operators.h",
]
self.template = """
// Reduction kernel instance
using ${operation_name}_base =
typename cutlass::reduction::kernel::ReduceSplitK<
cutlass::MatrixShape<${shape_row}, ${shape_column}>,
${epilogue_functor},
cutlass::reduction::thread::ReduceAdd<
${element_accumulator},
${element_output},
${count}>,
${partition_per_stage}>;
struct ${operation_name}${operation_suffix}:
public ${operation_name}_base { };
"""
def emit(self, operation: ReductionOperation):
vector_length_bits = min(operation.C.alignment * DataTypeSize[operation.C.element], 128)
epilogue_vector_length = vector_length_bits // DataTypeSize[operation.C.element]
values = {
"operation_name": operation.configuration_name(),
"operation_suffix": self.operation_suffix,
"shape_row": str(operation.shape.row),
"shape_column": str(operation.shape.column),
"epilogue_functor": operation.epilogue_functor.emit(),
"element_output": DataTypeTag[operation.element_output],
"epilogue_vector_length": str(epilogue_vector_length),
"element_accumulator": DataTypeTag[operation.element_accumulator],
"element_compute": DataTypeTag[operation.element_compute],
"element_workspace": DataTypeTag[operation.element_workspace],
"count": str(operation.count),
"partition_per_stage": str(operation.partitions_per_stage),
}
return SubstituteTemplate(self.template, values)

View File

@ -0,0 +1,35 @@
################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
GemmOperation = "Union[GemmOperationUniversal, GemmOperationGrouped]"
Tensor = "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]"

View File

@ -0,0 +1,33 @@
################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
from cutlass_cppgen.backend.utils.device import check_cuda_errors, device_cc

View File

@ -0,0 +1,126 @@
#################################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Utility functions for interacting with the device
"""
from __future__ import annotations
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
cudart = lazy_import("cuda.cudart")
import cutlass_cppgen
from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor
def check_cuda_errors(result: list):
"""
Checks whether `result` contains a CUDA error raises the error as an exception, if so. Otherwise,
returns the result contained in the remaining fields of `result`.
:param result: the results of the `cudart` method, consisting of an error code and any method results
:type result: list
:return: non-error-code results from the `results` parameter
"""
# `result` is of the format : (cudaError_t, result...)
err = result[0]
if err.value:
raise RuntimeError("CUDA error: {}".format(cudart.cudaGetErrorName(err)))
if len(result) == 1:
return None
elif len(result) == 2:
return result[1]
else:
return result[1:]
def device_cc(device: int = -1) -> int:
"""
Returns the compute capability of the device with ID `device`.
:param device: ID of the device to query
:type device: int
:return: compute capability of the queried device (e.g., 80 for SM80)
:rtype: int
"""
if device == -1:
device = cutlass_cppgen.device_id()
deviceProp = check_cuda_errors(cudart.cudaGetDeviceProperties(device))
major = str(deviceProp.major)
minor = str(deviceProp.minor)
return int(major + minor)
def device_sm_count(device: int = -1):
if device == -1:
device = cutlass_cppgen.device_id()
err, device_sm_count = cuda.cuDeviceGetAttribute(
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device
)
if err != cuda.CUresult.CUDA_SUCCESS:
raise Exception(
"Failed to retireve SM count. "
f"cuDeviceGetAttribute() failed with error: {cuda.cuGetErrorString(err)[1]}"
)
return device_sm_count
def to_device_ptr(tensor) -> cuda.CUdeviceptr:
"""
Converts a tensor to a CUdeviceptr
:param tensor: tensor to convert
:type tensor: np.ndarray | torch.Tensor | cp.ndarray | int
:return: device pointer
:rtype: cuda.CUdeviceptr
"""
if is_numpy_tensor(tensor):
ptr = cuda.CUdeviceptr(tensor.__array_interface__["data"][0])
elif is_torch_tensor(tensor):
ptr = cuda.CUdeviceptr(tensor.data_ptr())
elif is_cupy_tensor(tensor):
ptr = cuda.CUdeviceptr(int(tensor.data.ptr))
elif isinstance(tensor, cuda.CUdeviceptr):
ptr = tensor
elif isinstance(tensor, int):
ptr = cuda.CUdeviceptr(tensor)
else:
raise NotImplementedError(tensor)
return ptr

View File

@ -0,0 +1,33 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_cppgen.emit.pytorch import pytorch

View File

@ -0,0 +1,267 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Common utilities for emitting CUTLASS kernels
"""
import cutlass_cppgen
# Strings used for printing information about the generation of emitted scripts
_AUTOGEN_STR = f"This file was automatically generated by the CUTLASS {cutlass_cppgen.__version__} Python interface (https://github.com/nvidia/cutlass/python)"
_CSTYLE_AUTOGEN_COMMENT = f"""// {_AUTOGEN_STR}
"""
_PYSTYLE_AUTOGEN_COMMENT = f"""# {_AUTOGEN_STR}
"""
_CUTLASS_KERNEL_ARGS_2x = """
typename DeviceKernel::Arguments arguments {
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K}, // problem size
1,
{alpha, beta},
A, B, C, D,
0, 0, 0, 0, // batch strides
DeviceKernel::LayoutA::packed({M, K}).stride(0), // lda
DeviceKernel::LayoutB::packed({K, N}).stride(0), // ldb
DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldc
DeviceKernel::LayoutC::packed({M, N}).stride(0) // ldd
};
"""
_CUTLASS_KERNEL_ARGS_2x_STREAM_K = """
typename DeviceKernel::Arguments arguments {
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K}, // problem size
1,
{alpha, beta},
A, B, C, D,
0, 0, 0, 0, // batch strides
DeviceKernel::LayoutA::packed({M, K}).stride(0), // lda
DeviceKernel::LayoutB::packed({K, N}).stride(0), // ldb
DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldc
DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldd
-1 // avail_sms
};
"""
_CUTLASS_KERNEL_RUN_GEMM_2x = """
using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;
cutlass::Status ${name}_kernel_run(int M, int N, int K,
const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C, DeviceKernel::ElementC* D,
ElementCompute alpha, ElementCompute beta) {
${args}
size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
DeviceKernel gemm_op;
cutlass::Status status = gemm_op.initialize(arguments,
workspace.get(),
nullptr); // CUDA stream
if (status != cutlass::Status::kSuccess) {
return status;
}
status = gemm_op();
return status;
}
"""
_CUTLASS_KERNEL_RUN_GEMM_3x = """
using StrideA = typename DeviceKernel::GemmKernel::StrideA;
using StrideB = typename DeviceKernel::GemmKernel::StrideB;
using StrideC = typename DeviceKernel::GemmKernel::StrideC;
using StrideD = typename DeviceKernel::GemmKernel::StrideD;
using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;
cutlass::Status ${name}_kernel_run(
int M, int N, int K, int L,
const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C, DeviceKernel::ElementC* D,
ElementCompute alpha, ElementCompute beta, const cutlass::KernelHardwareInfo& hw_info) {
typename DeviceKernel::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K, L}, // problem size
{
A, // ptrA
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)), // stride A
B, // ptrB
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)), // stride B
},
{
{alpha, beta},
C, // ptrC
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)), // stride C
D, // ptrD
cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)), // stride D
},
hw_info
};
size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
DeviceKernel gemm_op;
cutlass::Status status = gemm_op.run(arguments,
workspace.get(),
nullptr); // CUDA stream
return status;
}
"""
_CUTLASS_KERNEL_RUN_GROUPED_GEMM_2x = """
using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;
int threadblock_count = DeviceKernel::sufficient();
cutlass::Status ${name}_kernel_run(int problem_count, cutlass::gemm::GemmCoord* problem_sizes,
DeviceKernel::ElementA** A, DeviceKernel::ElementB** B, DeviceKernel::ElementC** C, DeviceKernel::ElementC** D,
int64_t* lda, int64_t* ldb, int64_t* ldc, int64_t* ldd,
ElementCompute alpha, ElementCompute beta) {
typename DeviceKernel::Arguments arguments {
problem_sizes,
problem_count,
threadblock_count,
{alpha, beta},
A, B, C, D,
lda, ldb, ldc, ldd
};
size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
DeviceKernel gemm_op;
cutlass::Status status = gemm_op.initialize(arguments,
workspace.get(),
nullptr); // CUDA stream
if (status != cutlass::Status::kSuccess) {
return status;
}
status = gemm_op();
return status;
}
"""
_CUTLASS_KERNEL_RUN_CONV2D_2x = """
using UnderlyingKernel = typename DeviceKernel::UnderlyingKernel;
namespace {
using TensorRefA = typename UnderlyingKernel::TensorRefA;
using TensorRefB = typename UnderlyingKernel::TensorRefB;
using TensorRefC = typename UnderlyingKernel::TensorRefC;
using ElementCompute = typename UnderlyingKernel::EpilogueOutputOp::ElementCompute;
}
template<typename TensorRef, typename Element>
TensorRef get_tensor_ref(cutlass::Tensor4DCoord tensor_coord, Element* ptr){
cutlass::layout::TensorNHWC layout = cutlass::layout::TensorNHWC::packed(tensor_coord);
TensorRef tensor_ref(ptr, layout);
return tensor_ref;
}
cutlass::Status ${name}_kernel_run(cutlass::conv::Conv2dProblemSize* problem_size,
UnderlyingKernel::ElementA* A, UnderlyingKernel::ElementB* B,
UnderlyingKernel::ElementC* C, UnderlyingKernel::ElementC* D,
ElementCompute alpha, ElementCompute beta, std::string split_k_mode,
cudaStream_t stream, int device_id=0) {
// create the tensor references
cutlass::Tensor4DCoord tensor_coord_A = cutlass::conv::implicit_gemm_tensor_a_extent(
cutlass::conv::Operator::k${conv_kind_name}, *problem_size
);
cutlass::Tensor4DCoord tensor_coord_B = cutlass::conv::implicit_gemm_tensor_b_extent(
cutlass::conv::Operator::k${conv_kind_name}, *problem_size
);
cutlass::Tensor4DCoord tensor_coord_C = cutlass::conv::implicit_gemm_tensor_c_extent(
cutlass::conv::Operator::k${conv_kind_name}, *problem_size
);
TensorRefA tensor_ref_A = get_tensor_ref<TensorRefA, UnderlyingKernel::ElementA>(tensor_coord_A, A);
TensorRefB tensor_ref_B = get_tensor_ref<TensorRefB, UnderlyingKernel::ElementB>(tensor_coord_B, B);
TensorRefC tensor_ref_C = get_tensor_ref<TensorRefC, UnderlyingKernel::ElementC>(tensor_coord_C, C);
TensorRefC tensor_ref_D = get_tensor_ref<TensorRefC, UnderlyingKernel::ElementC>(tensor_coord_C, D);
cutlass::conv::SplitKMode mode;
if (split_k_mode == "serial") {
mode = cutlass::conv::SplitKMode::kSerial;
} else if (split_k_mode == "parallel") {
mode = cutlass::conv::SplitKMode::kParallel;
} else {
throw std::runtime_error("Invalid split_k_mode: " + split_k_mode);
}
typename DeviceKernel::Arguments arguments{
*problem_size,
tensor_ref_A,
tensor_ref_B,
tensor_ref_C,
tensor_ref_D,
{alpha, beta},
mode
};
DeviceKernel implicit_gemm_op;
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
void* workspace_ptr = device_memory_allocation(workspace_size, device_id);
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
return status;
}
status = implicit_gemm_op.initialize(arguments, workspace_ptr, stream);
if (status != cutlass::Status::kSuccess) {
return status;
}
//
// Launch initialized CUTLASS kernel
//
status = implicit_gemm_op(stream);
return status;
}
"""

View File

@ -0,0 +1,936 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Utilities for generating source for building a PyTorch CUDA extension that using a CUTLASS kernel.
If specified, the extension can be JIT compiled via PyTorch's ``cpp_extension.load`` method.
Example usage with JIT compilation:
.. highlight:: python
.. code-block:: python
plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_library.LayoutType.RowMajor)
op = plan.construct()
mod = cutlass_cppgen.emit.pytorch(op, 'cutlass_gemm', 80, jit=True)
# Generate inputs for the GEMM
A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)]
# Run the module
D = mod.run(A, B, C)
Example usage without JIT compilation:
.. highlight:: python
.. code-block:: python
plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
op = plan.construct()
cutlass_cppgen.emit.pytorch(op, 'cutlass_gemm', 80, jit=False, sourcedir='output')
After this call, the directory ``output`` contains ``setup.py``,
``cutlass_gemm.cpp``, and ``cutlass_gemm_kernel.cu``. The module can be built from
within ``output`` by running: ``TORCH_CUDA_ARCH_LIST="8.0" python setup.py develop --user``.
The module can later be used in Python via:
.. highlight:: python
.. code-block:: python
import torch
import cutlass_gemm
# Generate inputs for the GEMM
A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)]
# Run the module
D = cutlass_gemm.run(A, B, C)
"""
import logging
import os
from cutlass_library import ConvKind, ConvKindNames, DataType, SubstituteTemplate
from cutlass_cppgen import CUTLASS_PATH, logger, swizzle
from cutlass_cppgen.backend.gemm_operation import GemmOperationGrouped, GemmOperationUniversal
from cutlass_cppgen.backend.conv2d_operation import Conv2dOperation
from cutlass_cppgen.backend.library import ApiVersion
from cutlass_cppgen.emit import common
from cutlass_cppgen.utils.datatypes import is_torch_available
if is_torch_available():
import torch
_PYTORCH_CUDA_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include "cutlass/cutlass.h"
#include "cutlass/util/device_memory.h"
// helper function allocating the memory
void* device_memory_allocation(size_t size, int device_id=0) {
if (size > 0) {
torch::Device device(torch::kCUDA, device_id);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
torch::TensorOptions options = torch::TensorOptions().dtype(torch::kI8).device(device);
at::Tensor device_tensor = torch::empty({(long)size,}, options);
return reinterpret_cast<void*>(device_tensor.data_ptr());
} else {
return nullptr;
}
}
${includes}
${declaration}
${impl}
"""
_PYTORCH_GEMM_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <pybind11/stl.h>
// CUDA forward declarations
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt, float alpha=1.f, float beta=0.f);
// C++ interface
at::Tensor ${name}(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt, float alpha=1.f, float beta=0.f) {
return ${name}_kernel(A, B, C, alpha, beta);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("run", py::overload_cast<const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>, float, float>(&${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f);
}
"""
_PYTORCH_GROUPED_GEMM_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <pybind11/stl.h>
// CUDA forward declarations
std::vector<at::Tensor> ${name}_kernel(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C=at::nullopt, float alpha=1.f, float beta=0.f);
// C++ interface
std::vector<at::Tensor> ${name}(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C=at::nullopt, float alpha=1.f, float beta=0.f) {
return ${name}_kernel(A, B, C, alpha, beta);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("run", py::overload_cast<const std::vector<at::Tensor>&, const std::vector<at::Tensor>&, at::optional<const std::vector<at::Tensor>>, float, float>(&${name}),
py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f);
}
"""
_PYTORCH_CONV2D_FPROP_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <pybind11/stl.h>
// CUDA forward declarations
at::Tensor ${name}_kernel(
const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
float alpha=1.f, float beta=0.f,
std::string split_k_mode="serial", int split_k_slices=1);
// C++ interface
at::Tensor ${name}(
const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
float alpha=1.f, float beta=0.f,
std::string split_k_mode="serial", int split_k_slices=1) {
return ${name}_kernel(A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("run",
py::overload_cast<
const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>,
std::tuple<int, int>, std::tuple<int, int>, std::tuple<int, int>, float, float, std::string, int>(
&${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr,
py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1),
py::arg("alpha") = 1.f, py::arg("beta") = 0.f,
py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1);
}
"""
_PYTORCH_CONV2D_GRAD_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <pybind11/stl.h>
// CUDA forward declarations
at::Tensor ${name}_kernel(
std::tuple<int, int, int, int> result_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
float alpha=1.f, float beta=0.f,
std::string split_k_mode="serial", int split_k_slices=1);
// C++ interface
at::Tensor ${name}(
std::tuple<int, int, int, int> result_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
float alpha=1.f, float beta=0.f,
std::string split_k_mode="serial", int split_k_slices=1) {
return ${name}_kernel(result_size, A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("run",
py::overload_cast<
std::tuple<int, int, int, int>, const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>,
std::tuple<int, int>, std::tuple<int, int>, std::tuple<int, int>, float, float, std::string, int>(
&${name}), py::arg("result_size"), py::arg("A"), py::arg("B"), py::arg("C") = nullptr,
py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1),
py::arg("alpha") = 1.f, py::arg("beta") = 0.f,
py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1);
}
"""
_PYTORCH_GEMM_INCLUDES = {
ApiVersion.v2x: """
#include "cutlass/gemm/device/gemm_universal.h"
""",
ApiVersion.v3x: """
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/util/packed_stride.hpp"
""",
}
_PYTORCH_GROUPED_GEMM_INCLUDES = """
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/gemm/device/gemm_grouped.h"
"""
_PYTORCH_CONV2D_INCLUDES = """
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
#include "cutlass/conv/kernel/default_conv2d_dgrad.h"
#include "cutlass/conv/kernel/default_conv2d_wgrad.h"
#include "cutlass/conv/device/implicit_gemm_convolution.h"
"""
_CUTLASS_TYPE_TO_TORCH_TYPE = {
DataType.f16: "torch::kF16",
DataType.f32: "torch::kF32",
DataType.f64: "torch::kF64",
DataType.s8: "torch::kI8",
DataType.s32: "torch::kI32",
DataType.bf16: "torch::kBFloat16",
}
_PYTORCH_GEMM_IMPL_TEMPLATE_2x = (
common._CUTLASS_KERNEL_RUN_GEMM_2x
+ """
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C, float alpha, float beta) {
int M = A.size(0);
int N = B.size(1);
int K = A.size(1);
typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ?
nullptr :
reinterpret_cast<typename DeviceKernel::ElementC*>(C->contiguous().data_ptr());
at::Tensor D = B.new_empty({M, N}, ${torch_type_C});
cutlass::Status status = ${name}_kernel_run(M, N, K,
reinterpret_cast<typename DeviceKernel::ElementA*>(A.contiguous().data_ptr()),
reinterpret_cast<typename DeviceKernel::ElementB*>(B.contiguous().data_ptr()),
ptrC,
reinterpret_cast<typename DeviceKernel::ElementC*>(D.contiguous().data_ptr()),
ElementCompute(alpha), ElementCompute(beta));
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
return D;
}
"""
)
_PYTORCH_GEMM_IMPL_TEMPLATE_3x = (
common._CUTLASS_KERNEL_RUN_GEMM_3x
+ """
bool hw_info_queried = false;
cutlass::KernelHardwareInfo hw_info;
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C, float alpha, float beta) {
int M = A.size(0);
int N = B.size(1);
int K = A.size(1);
int L = 1;
// Query hardware info if we haven't already
if (!hw_info_queried) {
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
}
typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ?
nullptr :
reinterpret_cast<typename DeviceKernel::ElementC*>(C->contiguous().data_ptr());
at::Tensor D = B.new_empty({M, N}, ${torch_type_C});
cutlass::Status status = ${name}_kernel_run(M, N, K, L,
reinterpret_cast<typename DeviceKernel::ElementA*>(A.contiguous().data_ptr()),
reinterpret_cast<typename DeviceKernel::ElementB*>(B.contiguous().data_ptr()),
ptrC,
reinterpret_cast<typename DeviceKernel::ElementC*>(D.contiguous().data_ptr()),
ElementCompute(alpha), ElementCompute(beta),
hw_info);
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
return D;
}
"""
)
_PYTORCH_GROUPED_GEMM_IMPL_TEMPLATE = (
common._CUTLASS_KERNEL_RUN_GROUPED_GEMM_2x
+ """
std::vector<at::Tensor> ${name}_kernel(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C, float alpha, float beta) {
size_t num = A.size();
// To avoid performing many small cudaMallocs and host-to-device copies,
// we serialize the grouped GEMM arguments on the host, allocate one
// large chunk of device memory, and perform a single cudaMemcpy to
// copy the host data to the device. Allocation overheads could be
// avoided by using a memory pool.
// Calculate the total size of the data to be copied from host to device
size_t total_size = sizeof(cutlass::gemm::GemmCoord) +
sizeof(DeviceKernel::ElementA*) +
sizeof(DeviceKernel::ElementB*) +
sizeof(DeviceKernel::ElementC*) +
sizeof(DeviceKernel::ElementC*) +
sizeof(int64_t) +
sizeof(int64_t) +
sizeof(int64_t);
total_size *= num;
// num * sizeof(cutlass::gemm::GemmCoord) may leave one at a non-multiple
// of sizeof(DeviceKernel::ElementA*) (which will be 64 on a 64-bit system).
// To ensure that we don't end up having misaligned loads in the kernel,
// we pad to the nearest multiple of 8.
//
// Note that, even on a 32-bit system (for which sizeof(X*) will not equal
// sizeof(int64_t)), only padding between the list of GemmCoords and the
// list of ptr_As is sufficient because the set of four equal-length lists of pointers
// (A*, B*, C*, D*) will ensure that the first list of int64_ts will always
// start on a multiple of 8.
int64_t padding = 8 - (total_size % 8);
total_size += padding;
uint8_t* host_data = new uint8_t[total_size];
cutlass::DeviceAllocation<uint8_t> device_data(total_size);
uint8_t* start = host_data;
cutlass::gemm::GemmCoord* problem_sizes_host = reinterpret_cast<cutlass::gemm::GemmCoord*>(start);
// Apply the padding after the list of GemmCoords
start += num * sizeof(cutlass::gemm::GemmCoord) + padding;
int64_t ptr_A_offset = start - host_data;
DeviceKernel::ElementA** ptr_A_host = reinterpret_cast<DeviceKernel::ElementA**>(start);
start += num * sizeof(DeviceKernel::ElementA*);
int64_t ptr_B_offset = start - host_data;
DeviceKernel::ElementB** ptr_B_host = reinterpret_cast<DeviceKernel::ElementB**>(start);
start += num * sizeof(DeviceKernel::ElementB*);
int64_t ptr_C_offset = start - host_data;
DeviceKernel::ElementC** ptr_C_host = reinterpret_cast<DeviceKernel::ElementC**>(start);
start += num * sizeof(DeviceKernel::ElementC*);
int64_t ptr_D_offset = start - host_data;
DeviceKernel::ElementC** ptr_D_host = reinterpret_cast<DeviceKernel::ElementC**>(start);
start += num * sizeof(DeviceKernel::ElementC*);
int64_t lda_offset = start - host_data;
int64_t* lda_host = reinterpret_cast<int64_t*>(start);
start += num * sizeof(int64_t);
int64_t ldb_offset = start - host_data;
int64_t* ldb_host = reinterpret_cast<int64_t*>(start);
start += num * sizeof(int64_t);
int64_t ldc_offset = start - host_data;
int64_t* ldc_host = reinterpret_cast<int64_t*>(start);
start += num * sizeof(int64_t);
std::vector<at::Tensor> D(num);
bool need_C = (C != at::nullopt) && (beta != 0.f);
for (size_t i = 0; i < num; ++i) {
int M = A[i].size(0);
int N = B[i].size(1);
int K = A[i].size(1);
*(problem_sizes_host + i) = {M, N, K};
*(ptr_A_host + i) = reinterpret_cast<typename DeviceKernel::ElementA*>(A[i].contiguous().data_ptr());
*(ptr_B_host + i) = reinterpret_cast<typename DeviceKernel::ElementB*>(B[i].contiguous().data_ptr());
if (need_C) {
*(ptr_C_host + i) = reinterpret_cast<typename DeviceKernel::ElementC*>(C->at(i).contiguous().data_ptr());
}
else {
*(ptr_C_host + i) = nullptr;
}
D[i] = B[i].new_empty({M, N}, ${torch_type_C});
*(ptr_D_host + i) = reinterpret_cast<typename DeviceKernel::ElementC*>(D[i].contiguous().data_ptr());
*(lda_host + i) = DeviceKernel::LayoutA::packed({M, K}).stride(0);
*(ldb_host + i) = DeviceKernel::LayoutB::packed({K, N}).stride(0);
*(ldc_host + i) = DeviceKernel::LayoutC::packed({M, N}).stride(0);
}
device_data.copy_from_host(host_data);
cutlass::Status status = ${name}_kernel_run(
num,
reinterpret_cast<cutlass::gemm::GemmCoord*>(device_data.get()),
reinterpret_cast<DeviceKernel::ElementA**>(device_data.get() + ptr_A_offset),
reinterpret_cast<DeviceKernel::ElementB**>(device_data.get() + ptr_B_offset),
reinterpret_cast<DeviceKernel::ElementC**>(device_data.get() + ptr_C_offset),
reinterpret_cast<DeviceKernel::ElementC**>(device_data.get() + ptr_D_offset),
reinterpret_cast<int64_t*>(device_data.get() + lda_offset),
reinterpret_cast<int64_t*>(device_data.get() + ldb_offset),
reinterpret_cast<int64_t*>(device_data.get() + ldc_offset),
reinterpret_cast<int64_t*>(device_data.get() + ldc_offset),
ElementCompute(alpha), ElementCompute(beta));
delete[] host_data;
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
return D;
}
"""
)
_PYTORCH_CONV2D_IMPL_TEMPLATE_2x = """
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
cutlass::Status status = ${name}_kernel_run(
&problem_size,
reinterpret_cast<typename UnderlyingKernel::ElementA*>(A.data_ptr()),
reinterpret_cast<typename UnderlyingKernel::ElementB*>(B.data_ptr()),
ptrC,
reinterpret_cast<typename UnderlyingKernel::ElementC*>(D.data_ptr()),
alpha, beta,
split_k_mode, stream, B.device().index());
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
return D;
}
"""
_PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x = (
common._CUTLASS_KERNEL_RUN_CONV2D_2x
+ """
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
float alpha=1.f, float beta=0.f, std::string split_k_mode="serial", int split_k_slices=1) {
int N, H, W, C_, K, R, S, P, Q;
N = A.size(0);
C_ = A.size(1);
H = A.size(2);
W = A.size(3);
K = B.size(0);
R = B.size(2);
S = B.size(3);
cutlass::conv::Conv2dProblemSize problem_size(
cutlass::Tensor4DCoord(N, H, W, C_),
cutlass::Tensor4DCoord(K, R, S, C_),
cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)),
cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)),
cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)),
cutlass::conv::Mode::kCrossCorrelation,
split_k_slices
);
P = problem_size.P;
Q = problem_size.Q;
typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
nullptr :
reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
at::Tensor D = torch::zeros({N, K, P, Q}, options);
""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
)
_PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x = (
common._CUTLASS_KERNEL_RUN_CONV2D_2x
+ """
at::Tensor ${name}_kernel(std::tuple<int, int, int, int> input_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1}, float alpha=1.f, float beta=0.f,
std::string split_k_mode="serial", int split_k_slices=1) {
int N, H, W, C_, K, R, S;
N = std::get<0>(input_size);
C_ = std::get<1>(input_size);
H = std::get<2>(input_size);
W = std::get<3>(input_size);
K = B.size(0);
R = B.size(2);
S = B.size(3);
cutlass::conv::Conv2dProblemSize problem_size(
cutlass::Tensor4DCoord(N, H, W, C_),
cutlass::Tensor4DCoord(K, R, S, C_),
cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)),
cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)),
cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)),
cutlass::conv::Mode::kCrossCorrelation,
split_k_slices
);
typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
nullptr :
reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
at::Tensor D = torch::empty({N, C_, H, W}, options);
""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
)
_PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x = (
common._CUTLASS_KERNEL_RUN_CONV2D_2x
+ """
at::Tensor ${name}_kernel(std::tuple<int, int, int, int> weight_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1}, float alpha=1.f, float beta=0.f,
std::string split_k_mode="serial", int split_k_slices=1) {
int N, H, W, C_, K, R, S;
K = std::get<0>(weight_size);
C_ = std::get<1>(weight_size);
R = std::get<2>(weight_size);
S = std::get<3>(weight_size);
N = B.size(0);
H = B.size(2);
W = B.size(3);
cutlass::conv::Conv2dProblemSize problem_size(
cutlass::Tensor4DCoord(N, H, W, C_),
cutlass::Tensor4DCoord(K, R, S, C_),
cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)),
cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)),
cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)),
cutlass::conv::Mode::kCrossCorrelation,
split_k_slices
);
typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
nullptr :
reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
at::Tensor D = torch::empty({K, C_, R, S}, options);
""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
)
_PYTORCH_SETUP_PY = common._PYSTYLE_AUTOGEN_COMMENT + """
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='${name}',
ext_modules=[
CUDAExtension('${name}', [
'${name}.cpp',
'${name}_kernel.cu',
],
include_dirs=['${cutlass_path}/include', '${cutlass_path}/tools/util/include'],
extra_compile_args={
'cxx': ['-std=c++17'],
'nvcc': ['-std=c++17', ${extra_compile_args}],
},
libraries=['cuda']
),
],
cmdclass={
'build_ext': BuildExtension
})
"""
def _generate_setup(name: str, sourcedir: str, extra_compile_args: str=""):
"""
Generates a setup.py file for the extension
:param name: name of the module to generate
:type name: str
:param sourcedir: directory to which generated source files should be written
:type sourcedir: str
:param extra_compile_args: additional arguments to pass to setup.py
:type extra_args: str
"""
setup_py_file = os.path.join(sourcedir, "setup.py")
setup_source = SubstituteTemplate(
_PYTORCH_SETUP_PY, {"name": name, "cutlass_path": CUTLASS_PATH, "extra_compile_args": extra_compile_args}
)
with open(setup_py_file, "w") as outfile:
outfile.write(setup_source)
class _ArchListSetter:
"""
Utility context manager for temporarily setting the value of the ``TORCH_CUDA_ARCH_LIST``
environment variable when building a PyTorch CUDA module.
``TORCH_CUDA_ARCH_LIST`` is a space-delmited list of compute capabilites for which a PyTorch
CUDA module should be compiled.
For example, ``TORCH_CUDA_ARCH_LIST="7.0 8.0"`` would result in the inclusion of
``-gencode=arch=compute_70,code=sm_70`` and ``-gencode=arch=compute_80,code=sm_80`` in the
compilation of the module.
This utility wraps the building of a PyTorch CUDA module with a setting of this environment
variable according to the current compute capability being targetted.
Example usage:
.. highlight:: python
.. code-block:: python
# Temporarily set TORCH_CUDA_ARCH_LIST="8.0"
with _ArchListSetter(80):
# Perform JIT compilation and loading of the module
mod = torch.utils.cpp_extension.load(...)
:param cc: compute capability
:type cc: int
"""
_TORCH_CUDA_ARCH_LIST = "TORCH_CUDA_ARCH_LIST"
def __init__(self, cc: int):
self.cc_str = ".".join(list(str(cc)))
def __enter__(self):
"""
Saves the old value of TORCH_CUDA_ARCH_LIST and reset it to the new value based on ``cc``
"""
self.old_arch_list = os.getenv(_ArchListSetter._TORCH_CUDA_ARCH_LIST)
os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] = self.cc_str
return self
def __exit__(self, exc_type, exc_val, traceback):
"""
Restores the old value of TORCH_CUDA_ARCH_LIST
"""
if self.old_arch_list is None:
del os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST]
else:
os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] = self.old_arch_list
def _jit(name: str, cc: int, cpp_file: str, cuda_file: str):
"""
JIT compiles and loads a PyTorch CUDA extension.
:param name: name of the module to generate
:type name: str
:param cc: compute capability of the device the module should target
:type cc: int
:param cpp_file: path to file containing extension's C++ interface
:type cpp_file: str
:param cuda_file: path to file containing extension's CUDA interface
:type cuda_file: str
:return: loaded PyTorch module
"""
from torch.utils.cpp_extension import load
extra_cuda_cflags = ["-std=c++17"]
if cc in [90, 100, 101, 103]:
# PyTorch does not currently add the sm_90a target when compute capability
# 9.0 is set within TORCH_CUDA_ARCH_LIST. Thus, we manually add the sm_90a target.
extra_cuda_cflags.append(f"-gencode=arch=compute_{cc}a,code=sm_{cc}a")
with _ArchListSetter(cc):
jitmodule = load(
name,
[cpp_file, cuda_file],
extra_cuda_cflags=extra_cuda_cflags,
extra_include_paths=[
os.path.join(CUTLASS_PATH, "include"),
os.path.join(CUTLASS_PATH, "tools/util/include"),
],
extra_ldflags=["-lcuda"],
verbose=(logger.level == logging.DEBUG)
)
return jitmodule
def _pytorch_gemm(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
"""
Generates source for building a PyTorch CUDA module that leverages the CUTLASS GEMM
specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
compiled, loaded, and returned.
:param op: operation to emit in the module
:param name: name of the module to generate
:type name: str
:param cc: compute capability of the device the module should target
:type cc: int
:param jit: whether the module should be just-in-time compiled
:type jit: bool
:param sourcedir: directory to which generated source files should be written
:type sourcedir: str
:return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise
"""
if sourcedir != "" and not os.path.isdir(sourcedir):
os.makedirs(sourcedir)
cuda_file = os.path.join(sourcedir, name + "_kernel.cu")
extra_kw = {}
if op.api == ApiVersion.v3x:
impl_template = _PYTORCH_GEMM_IMPL_TEMPLATE_3x
else:
impl_template = _PYTORCH_GEMM_IMPL_TEMPLATE_2x
if op.swizzling_functor == swizzle.ThreadblockSwizzleStreamK:
extra_kw["args"] = common._CUTLASS_KERNEL_ARGS_2x_STREAM_K
else:
extra_kw["args"] = common._CUTLASS_KERNEL_ARGS_2x
impl_template = (
_PYTORCH_GEMM_IMPL_TEMPLATE_3x
if op.api == ApiVersion.v3x
else _PYTORCH_GEMM_IMPL_TEMPLATE_2x
)
cuda_impl = SubstituteTemplate(impl_template, {"name": name, **extra_kw})
cuda_source = SubstituteTemplate(
_PYTORCH_CUDA_TEMPLATE,
{
"includes": _PYTORCH_GEMM_INCLUDES[op.api],
"declaration": op.rt_module.emit(),
"procedural_name": op.procedural_name(),
"impl": cuda_impl,
"torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element],
},
)
with open(cuda_file, "w") as outfile:
outfile.write(cuda_source)
cpp_file = os.path.join(sourcedir, name + ".cpp")
cpp_source = SubstituteTemplate(
_PYTORCH_GEMM_CPP_TEMPLATE,
{"name": name, "description": f"CUTLASS {op.procedural_name()} GEMM"},
)
with open(cpp_file, "w") as outfile:
outfile.write(cpp_source)
extra_compile_args = ""
if cc in [90, 100, 101, 103]:
extra_compile_args = f"'--generate-code=arch=compute_{cc}a,code=[sm_{cc}a]'"
_generate_setup(name, sourcedir, extra_compile_args)
if jit:
return _jit(name, cc, cpp_file, cuda_file)
return None
def _pytorch_grouped_gemm(
op, name: str, cc: int, jit: bool = False, sourcedir: str = ""
):
"""
Generates source for building a PyTorch CUDA module that leverages the CUTLASS grouped GEMM
specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
compiled, loaded, and returned.
:param op: operation to emit in the module
:param name: name of the module to generate
:type name: str
:param cc: compute capability of the device the module should target
:type cc: int
:param jit: whether the module should be just-in-time compiled
:type jit: bool
:param sourcedir: directory to which generated source files should be written
:type sourcedir: str
:return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise
"""
if op.api != ApiVersion.v2x:
raise Exception("Grouped GEMM is currently only supported for CUTLASS 2.x")
if sourcedir != "" and not os.path.isdir(sourcedir):
os.makedirs(sourcedir)
cuda_file = os.path.join(sourcedir, name + "_kernel.cu")
cuda_impl = SubstituteTemplate(_PYTORCH_GROUPED_GEMM_IMPL_TEMPLATE, {"name": name})
cuda_source = SubstituteTemplate(
_PYTORCH_CUDA_TEMPLATE,
{
"includes": _PYTORCH_GROUPED_GEMM_INCLUDES,
"declaration": op.rt_module.emit(),
"procedural_name": op.procedural_name(),
"impl": cuda_impl,
"torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element],
},
)
with open(cuda_file, "w") as outfile:
outfile.write(cuda_source)
cpp_file = os.path.join(sourcedir, name + ".cpp")
cpp_source = SubstituteTemplate(
_PYTORCH_GROUPED_GEMM_CPP_TEMPLATE,
{"name": name, "description": f"CUTLASS {op.procedural_name()} grouped GEMM"},
)
with open(cpp_file, "w") as outfile:
outfile.write(cpp_source)
_generate_setup(name, sourcedir)
if jit:
return _jit(name, cc, cpp_file, cuda_file)
return None
def _pytorch_conv2d(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
"""
Generates source for building a PyTorch CUDA module that leverages the CUTLASS Conv2d
specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
compiled, loaded, and returned.
:param op: operation to emit in the module
:param name: name of the module to generate
:type name: str
:param cc: compute capability of the device the module should target
:type cc: int
:param jit: whether the module should be just-in-time compiled
:type jit: bool
:param sourcedir: directory to which generated source files should be written
:type sourcedir: str
Note that the when conv kind is `dgrad` or `wgrad`, the size of the input `(N, C, H, W)` or
weight `(K, C, R, S)` should be provided. This is because there are multiple valid solutions
for H/W/R/S given the same P/Q.
:return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise
"""
if sourcedir != "" and not os.path.isdir(sourcedir):
os.makedirs(sourcedir)
cuda_file = os.path.join(sourcedir, name + "_kernel.cu")
extra_kw = {}
if op.conv_kind == ConvKind.Fprop:
impl_template = _PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x
cpp_template = _PYTORCH_CONV2D_FPROP_CPP_TEMPLATE
elif op.conv_kind == ConvKind.Dgrad:
impl_template = _PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x
cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE
elif op.conv_kind == ConvKind.Wgrad:
impl_template = _PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x
cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE
extra_kw["conv_kind_name"] = ConvKindNames[op.conv_kind].capitalize()
extra_kw["torch_type_C"] = _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element]
cuda_impl = SubstituteTemplate(impl_template, {"name": name, **extra_kw})
cuda_source = SubstituteTemplate(
_PYTORCH_CUDA_TEMPLATE,
{
"includes": _PYTORCH_CONV2D_INCLUDES,
"declaration": op.rt_module.emit(),
"procedural_name": op.procedural_name(),
"impl": cuda_impl,
"torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element],
},
)
with open(cuda_file, "w") as outfile:
outfile.write(cuda_source)
cpp_file = os.path.join(sourcedir, name + ".cpp")
cpp_source = SubstituteTemplate(
cpp_template,
{"name": name, "description": f"CUTLASS {op.procedural_name()} Conv2d"},
)
with open(cpp_file, "w") as outfile:
outfile.write(cpp_source)
_generate_setup(name, sourcedir)
if jit:
return _jit(name, cc, cpp_file, cuda_file)
return None
def pytorch(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
"""
Generates source for building a PyTorch CUDA module that leverages the CUTLASS kernel
specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
compiled, loaded, and returned.
The result of this method is files within ``sourcedir`` that can be used for building
a PyTorch module.
:param op: operation to emit in the module
:param name: name of the module to generate
:type name: str
:param cc: compute capability of the device the module should target
:type cc: int
:param jit: whether the module should be just-in-time compiled
:type jit: bool
:param sourcedir: directory to which generated source files should be written
:type sourcedir: str
:return: loaded PyTorch module (if ``jit=True``) or None
"""
device_op = op.device_op()
if isinstance(op, GemmOperationUniversal):
return _pytorch_gemm(device_op, name, cc, jit, sourcedir)
elif isinstance(op, GemmOperationGrouped):
return _pytorch_grouped_gemm(device_op, name, cc, jit, sourcedir)
elif isinstance(op, Conv2dOperation):
return _pytorch_conv2d(device_op, name, cc, jit, sourcedir)
else:
raise Exception(
f"Operation type {type(op)} is not currently supported for PyTorch emission."
)

View File

@ -0,0 +1,56 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_cppgen.epilogue.epilogue import (
get_activations,
get_activation_epilogue,
gelu,
hardswish,
identity,
leaky_relu,
relu,
sigmoid,
silu,
tanh,
trace
)
from cutlass_cppgen.epilogue.evt_ops import (
max,
multiply_add,
sum,
permute,
reshape,
maximum,
minimum,
exp
)

View File

@ -0,0 +1,176 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Registry of elementwise epilogues
Elementwise epilogues can be added to many CUTLASS kernels in the CUTLAS Python interface via
code like the following for GEMM:
.. highlight:: python
.. code-block:: python
plan = cutlass_cppgen.op.Gemm(element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
plan.activation = cutlass_cppgen.epilogue.relu
"""
from cutlass_cppgen.backend import epilogue, device_cc
gelu = epilogue.gelu
hardswish = epilogue.hardswish
identity = epilogue.identity
leaky_relu = epilogue.leaky_relu
relu = epilogue.relu
sigmoid = epilogue.sigmoid
silu = epilogue.silu
tanh = epilogue.tanh
_activations = [gelu, hardswish, identity, leaky_relu, relu, sigmoid, silu, tanh]
def get_activations() -> list:
"""
Returns a list of available activation functions
:return: list of available activation functions
:rtype: list
"""
return _activations
def get_activation_epilogue(
activation,
element_output,
elements_per_access,
element_accumulator,
element_compute,
):
"""
Return an epilogue corresponding to the activation function, data types, and alignment
used in the kernel
:param activation: elementwise activation function to use
:param element_output: data type of the output
:param elements_per_access: alignment of operand C of the kernel
:type elements_per_access: int
:param element_accumulator: data type of the accumulated output C
:param element_compute: data type in which compute operations should be performed
:return: epilogue functor
"""
if activation not in _activations:
raise Exception(
f"Unsupported activation type {activation}. Available activations are: {_activations}"
)
if activation == identity:
return epilogue.LinearCombination(
element_output, elements_per_access, element_accumulator, element_compute
)
else:
return epilogue.LinearCombinationGeneric(
activation,
element_output,
elements_per_access,
element_accumulator,
element_compute,
)
"""
Frontend for EVT that generates epilogue functor through tracing the input function
"""
from cutlass_cppgen.backend.evt.frontend import PythonASTFrontend
def trace(fn, example_tensors, **kwargs):
"""
Trace `fn(**example_tensors)` and generates epilogue visitor
:param fn or str: Python callable or string of the epilogue function
:param example_tensors: example inputs for fn
:type example_tensors: dict
.. hightlight:: python
.. code-block:: python
import cutlass_cppgen.backend.evt
# Define epilogue function as Python callable
def example_fn(accum, C, alpha, beta, gamma):
D = ((accum + C) * alpha - gamma) / beta
return D
# Define the example tensors
example_inputs = {
"accum": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda"),
"C": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda"),
"alpha": 1.5,
"beta": 0.5,
"gamma": 2.5,
"D": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda")
}
# Generate the epilogue functor
epilogue_visitor = cutlass_cppgen.epilogue.trace(example_fn, example_inputs)
"""
if callable(fn):
class EpilogueFunctor(PythonASTFrontend):
def __init__(self, cc=None, **kwargs):
if not cc:
cc = device_cc()
super().__init__(cc, **kwargs)
pass
setattr(EpilogueFunctor, "__call__", staticmethod(fn))
epilogue_functor = EpilogueFunctor(**kwargs)
epilogue_functor.trace(example_tensors)
return epilogue_functor
elif isinstance(fn, str):
class EpilogueFunctor(PythonASTFrontend):
def __init__(self, cc=None, **kwargs):
self.source = textwrap.dedent(fn)
if not cc:
cc = device_cc()
super().__init__(cc, **kwargs)
def parse(self, example_inputs) -> None:
self.example_inputs = example_inputs
self.ast = ast.parse(self.source)
self.visit(self.ast)
epilogue_functor = EpilogueFunctor(**kwargs)
epilogue_functor.trace(example_tensors)
return epilogue_functor
else:
raise NotImplementedError("Expect a callable Python function")

View File

@ -0,0 +1,98 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Collection of builtin functions used for host reference in EVT
"""
import numpy as np
from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_available, is_torch_tensor
if is_torch_available():
import torch
def multiply_add(x, y, z):
return x * y + z
def sum(x, dim):
if is_numpy_tensor(x):
return x.sum(axis=tuple(dim))
elif is_torch_tensor(x):
return torch.sum(x, dim)
def max(x, dim):
if is_numpy_tensor(x):
return x.max(axis=tuple(dim))
elif is_torch_tensor(x):
return torch.amax(x, dim)
def maximum(x, y):
if is_numpy_tensor(x):
return np.maximum(x, y)
elif is_torch_tensor(x):
return torch.maximum(x, torch.tensor(y))
def minimum(x, y):
if is_numpy_tensor(x):
return np.minimum(x, y)
elif is_torch_tensor(x):
return torch.minimum(x, torch.tensor(y))
def exp(x):
if is_numpy_tensor(x):
return np.exp(x)
elif is_torch_tensor(x):
return torch.exp(x)
##############################################################################
# Layout manipulate nodes
##############################################################################
def permute(x, indices: tuple):
if is_numpy_tensor(x):
return np.transpose(x, axes=indices)
elif is_torch_tensor(x):
return x.permute(*indices)
def reshape(x, new_shape: tuple):
if is_numpy_tensor(x):
return np.reshape(x, newshape=new_shape)
elif is_torch_tensor(x):
return x.view(new_shape)

View File

@ -0,0 +1,569 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Classes containing valid operations for a given compute capability and data types.
"""
from itertools import combinations_with_replacement
import logging
import cutlass_library
from cutlass_library.library import ConvKind, IteratorAlgorithm, StrideSupport, GroupMode
import cutlass_cppgen
from cutlass_cppgen.utils.check import valid_stage_count
from cutlass_cppgen.utils.datatypes import td_from_profiler_td, td_from_profiler_op
_generator_ccs = [50, 60, 61, 70, 75, 80, 90, 100]
class KernelsForDataType:
"""
Container class for keeping track of kernels that correspond to a particular combination
of data types for operands A, B, and accumulator
"""
def __init__(self, datatype_comb: tuple, layout_comb: tuple):
self.datatype_comb = datatype_comb
self.layout_comb = layout_comb
self.math_operations = set()
# Dictionary mapping from alignment (int) to a list of kernels that fit the alignment
# constraint for the data type combination
self.kernels_by_alignment = {}
def add(self, operation):
"""
Add an operation to the list of supported kernels
"""
alignment_key = f"{operation.A.alignment} {operation.B.alignment} {operation.C.alignment}"
if alignment_key not in self.kernels_by_alignment:
self.kernels_by_alignment[alignment_key] = []
self.kernels_by_alignment[alignment_key].append(operation)
self.math_operations.add(operation.tile_description.math_instruction.math_operation)
def alignments(self, operand: str):
"""
Returns an unsorted list of alignments supported by this data type combination
:param operand: identifier of operand in question (e.g., A, B, C)
:type operand: str
:return: unsorted list of alignments supported by this data type combination
:rtype: list
"""
operand_idx = self._operand_idx(operand)
return [int(key.split(" ")[operand_idx]) for key in self.kernels_by_alignment.keys()]
@property
def all_operations(self):
"""
Returns a list of all operations supported by this data type combination
:return: list of all operations supported by this data type combination
:rtype: list
"""
ops = []
for _, alignment_ops in self.kernels_by_alignment.items():
ops.extend(alignment_ops)
return ops
def default_operation(self, math_operation: cutlass_cppgen.MathOperation):
key = sorted(list(self.kernels_by_alignment.keys()))[0]
kernels = self.kernels_by_alignment[key]
if math_operation is not None:
kernels = [x for x in kernels if x.tile_description.math_instruction.math_operation == math_operation]
return kernels[0]
def operations(self, alignment_A: int, alignment_B: int, alignment_C: int, math_operation: cutlass_cppgen.MathOperation):
"""
Returns operations satisfying the alignment constraints
:param alignment_A: alignment constraint of operations to return
:type alignment_A: int
:param alignment_B: alignment constraint of operations to return
:type alignment_B: int
:param alignment_C: alignment constraint of operations to return
:type alignment_C: int
:param math_operation: math operation to consider
:type math_operation: cutlass_cppgen.MathOperation
:return: list of operations
:rtype: list
"""
key = f"{alignment_A} {alignment_B} {alignment_C}"
if key not in self.kernels_by_alignment:
og_key = key
# Reconcile A, B, and C alignments by trying to align to the minimum
min_alignment = min(alignment_A, alignment_B, alignment_C)
key = f"{min_alignment} {min_alignment} {min_alignment}"
if key not in self.kernels_by_alignment:
# Finally, go through all available alignment combinations and find
# one for which all values are less than those passed in.
key = None
alignments = sorted([tuple(int(x) for x in k.split(" ")) for k in self.kernels_by_alignment.keys()], reverse=True)
for align_A, align_B, align_C in alignments:
if alignment_A % align_A == 0 and alignment_B % align_B == 0 and alignment_C % align_C == 0:
key = f"{align_A} {align_B} {align_C}"
break
if key is None:
raise Exception(
f"No operations of alignment {og_key} found for data type and layout "
f"combination {self.datatype_comb} {self.layout_comb}. Compatible alignments "
f"are {self.kernels_by_alignment.keys()}"
)
ops = self.kernels_by_alignment[key]
if math_operation is not None:
ops = [op for op in ops if op.tile_description.math_instruction.math_operation == math_operation]
return ops
def _operand_idx(self, key: str) -> int:
operand_list = ["A", "B", "C"]
if key not in operand_list:
raise Exception(f"Unexpected operand {operand}")
return operand_list.index(key)
def find_alignment(self, shape: tuple, layout: cutlass_cppgen.LayoutType, operand=str) -> int:
"""
Returns the most preferable alignment for a given shape and layout
:param shape: extent of each dimension of the tensor
:type shape: tuple
:param layout: layout of the tensor
:type layout: cutlass_cppgen.LayoutType
:param operand: descriptor of the operand in question
:type operand: str
:return: maximum alignment supported by the data type combination and tensor size
:rtype: int
"""
operand_idx = self._operand_idx(operand)
# Determine the leading dimension of the shape
if layout == cutlass_cppgen.LayoutType.ColumnMajor:
ld = shape[-2]
elif layout == cutlass_cppgen.LayoutType.RowMajor:
ld = shape[-1]
elif layout == cutlass_cppgen.LayoutType.TensorNHWC:
ld = shape[-1]
else:
raise Exception(f"Unexpected or unsupported layout {layout}")
for alignments in sorted(list(self.kernels_by_alignment.keys()), reverse=True):
alignment = int(alignments.split(" ")[operand_idx])
if ld % alignment == 0:
return alignment
# Default to alignment of 1 if no others match
return 1
def sort(self):
"""
Sorts each list of kernels in `kernels_by_alignment` in descending order of threadblock shape
"""
key = lambda op: (
op.tile_description.threadblock_shape[0]
* op.tile_description.threadblock_shape[1]
* op.tile_description.threadblock_shape[2]
)
for alignment in self.kernels_by_alignment.keys():
self.kernels_by_alignment[alignment].sort(key=key, reverse=True)
def supports_math_operation(self, math_operation: cutlass_cppgen.MathOperation) -> bool:
"""
Returns whether `math_operation` is supported by at least one operation.
:param math_operation: math operation to consider
:type math_operation: cutlass_cppgen.MathOperation
:return: whether math_operation is supported by at least one operation
:rtype: bool
"""
return math_operation is None or math_operation in self.math_operations
class ArchOptions:
"""
Structure for keeping track of kernels available on a given compute capability
:param target_cc: compute capability of the device on which kernels will be run
:type target_cc: int
:param kernel_cc: compute capability of the kernels to generate
:type kernel_cc: int
:param operation_kind: type of operation to register
:type operation_kind: cutlass_library.OperationKind
:param gemm_kinds: types of GEMM operations that can be included
:type gemm_kinds: list
:param allowed_math_operations: types of primitive math operations allowed
:type allowed_math_operations: list
"""
def __init__(
self,
target_cc: int,
kernel_cc: int,
operation_kind: cutlass_library.OperationKind,
gemm_kinds: list,
allowed_math_operations: list = [
cutlass_library.MathOperation.multiply_add,
cutlass_library.MathOperation.multiply_add_saturate,
cutlass_library.MathOperation.multiply_add_mixed_input_upcast,
cutlass_library.MathOperation.multiply_add_fast_f32
]
):
self.cc = kernel_cc
# Dictionary with following structure:
# Key: OpcodeClass
# Value: Dictionary with the following structure:
# Key: tuple of ((DataType, DataType, DataType), (LayoutType, LayoutType, LayoutType),
# representing ((element_a, element_b, element_accumulator), (layout_a, layout_b))
# Value: KernelsForDataType
self.operations_by_opclass = {}
self.op_class = None
self.allowed_math_operations = allowed_math_operations
if target_cc == 100 and kernel_cc == 90 or target_cc == 90 and kernel_cc == 100:
return
# Identify the method within CUTLASS generator script that generates kernel
# descriptions for the target CC
generate_function_name = "GenerateSM" + str(kernel_cc)
if not hasattr(cutlass_library.generator, generate_function_name):
cutlass_cppgen.logger.warning(f"No generator found for architecture {kernel_cc}")
return
generate_function = getattr(cutlass_library.generator, generate_function_name)
# Initialize a default manifest and populate it with valid kernel descriptions
# for the target CC
args = [
"--kernels=all",
f"--log-level={logging.getLevelName(cutlass_cppgen.logger.level)}"
]
manifest_args = cutlass_library.generator.define_parser().parse_args(args)
manifest = cutlass_library.manifest.Manifest(manifest_args)
generate_function(manifest, cutlass_cppgen._nvcc_version)
if operation_kind not in manifest.operations:
# No kernels generated for this architecture, this could be because the CUDA
# toolkit is insufficient to support operations in this CC
cutlass_cppgen.logger.warning(f"No operations of type {operation_kind} found for CC {kernel_cc}")
return
# Only one CC should be returned, given the setup above of calling only the generation scripts
# for a given CC
if len(manifest.operations[operation_kind].keys()) != 1 or kernel_cc not in manifest.operations[operation_kind]:
raise Exception(f"Error finding kernels for SM{kernel_cc}. Check that your CUDA toolkit version "
"is sufficient for the architecture in question.")
# Iterate through the available operations for this operation kind and
# find available opclasses and data types
for name, op_list in manifest.operations[operation_kind][kernel_cc].items():
for op in op_list:
if operation_kind == cutlass_library.OperationKind.Gemm:
if op.gemm_kind not in gemm_kinds:
continue
mi = op.tile_description.math_instruction
if mi.math_operation not in self.allowed_math_operations:
continue
# Prune operations that don't fit in shared memory
td = td_from_profiler_op(op)
if not valid_stage_count(target_cc, kernel_cc, td, verbose=False)[0]:
continue
if mi.opcode_class not in self.operations_by_opclass:
self.operations_by_opclass[mi.opcode_class] = {}
datatype_comb = (mi.element_a, mi.element_b, mi.element_accumulator)
layout_comb = (op.A.layout, op.B.layout)
# Register TF32 kernels as F32 to enable F32 -> TF32 conversion + TF32 Tensor Core operations
if datatype_comb == (cutlass_library.DataType.tf32, cutlass_library.DataType.tf32, cutlass_library.DataType.f32):
# TF32 kernels only supported on SM80 and beyond
if self.cc < 80:
continue
elif self.cc == 90 or self.cc == 100:
if (op.A.element != cutlass_library.DataType.f32
or op.B.element != cutlass_library.DataType.f32
or op.C.element != cutlass_library.DataType.f32):
continue
datatype_comb = (cutlass_library.DataType.f32, cutlass_library.DataType.f32, cutlass_library.DataType.f32)
opclass_dict = self.operations_by_opclass[mi.opcode_class]
key = (datatype_comb, layout_comb)
if key not in opclass_dict:
opclass_dict[key] = KernelsForDataType(datatype_comb, layout_comb)
opclass_dict[key].add(op)
# Set the default opclass to TensorOp, if available. Otherwise default to SIMT
if cutlass_library.OpcodeClass.TensorOp in self.operations_by_opclass:
self.op_class = cutlass_library.OpcodeClass.TensorOp
else:
self.op_class = cutlass_library.OpcodeClass.Simt
# The profiler's generator may generate only a limited set of combinations of operands for SIMT kernels.
# Here, we generate additional versions via a generic TileDescription.
if cutlass_library.OpcodeClass.Simt not in self.operations_by_opclass:
self.operations_by_opclass[cutlass_library.OpcodeClass.Simt] = {}
if operation_kind == cutlass_library.OperationKind.Gemm:
types = [
(cutlass_library.DataType.s8, cutlass_library.DataType.s8, cutlass_library.DataType.s8),
(cutlass_library.DataType.s8, cutlass_library.DataType.s8, cutlass_library.DataType.s32),
(cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f16),
(cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f32),
(cutlass_library.DataType.f32, cutlass_library.DataType.f32, cutlass_library.DataType.f32),
(cutlass_library.DataType.f64, cutlass_library.DataType.f64, cutlass_library.DataType.f64),
]
# Add FP8 A/B/C
fp8_types = [cutlass_library.DataType.e4m3, cutlass_library.DataType.e5m2]
for type_comb in combinations_with_replacement(fp8_types, 3):
types.append(type_comb)
# Add FP8 A/B with FP32 C
for type_comb in combinations_with_replacement(fp8_types, 2):
types.append(type_comb + (cutlass_cppgen.DataType.f32,))
layouts = [
(cutlass_library.LayoutType.RowMajor, cutlass_library.LayoutType.RowMajor),
(cutlass_library.LayoutType.RowMajor, cutlass_library.LayoutType.ColumnMajor),
(cutlass_library.LayoutType.ColumnMajor, cutlass_library.LayoutType.RowMajor),
(cutlass_library.LayoutType.ColumnMajor, cutlass_library.LayoutType.ColumnMajor),
]
elif operation_kind == cutlass_library.OperationKind.Conv2d:
types = [
(cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f16),
(cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f32),
(cutlass_library.DataType.f32, cutlass_library.DataType.f32, cutlass_library.DataType.f32),
(cutlass_library.DataType.f64, cutlass_library.DataType.f64, cutlass_library.DataType.f64),
]
layouts = [
(cutlass_library.LayoutType.TensorNHWC, cutlass_library.LayoutType.TensorNHWC),
]
else:
raise NotImplementedError(f"Operation kind {operation_kind} is currently unsupported.")
alignment = 1
epilogue_functor = cutlass_library.EpilogueFunctor.LinearCombination
swizzling_functor = cutlass_library.SwizzlingFunctor.Identity8
for type_comb in types:
for layout_comb in layouts:
comb = (type_comb, layout_comb)
if comb in self.operations_by_opclass[cutlass_library.OpcodeClass.Simt]:
continue
A = cutlass_library.TensorDescription(type_comb[0], layout_comb[0], alignment)
B = cutlass_library.TensorDescription(type_comb[1], layout_comb[1], alignment)
C = cutlass_library.TensorDescription(type_comb[2], cutlass_library.LayoutType.ColumnMajor, alignment)
math_inst = cutlass_library.MathInstruction(
[1, 1, 1],
type_comb[0],
type_comb[1],
type_comb[2],
cutlass_library.OpcodeClass.Simt,
cutlass_library.MathOperation.multiply_add
)
td = cutlass_library.TileDescription(
[128, 128, 8], 2, [4, 2, 1], math_inst, 50, 1024)
# Prune operations that don't fit in shared memory
if not valid_stage_count(target_cc, kernel_cc, td_from_profiler_td(td), verbose=False)[0]:
continue
new_kernels = KernelsForDataType(type_comb, layout_comb)
if operation_kind == cutlass_library.OperationKind.Gemm:
new_operation = cutlass_library.manifest.GemmOperation(
cutlass_library.GemmKind.Universal, td.minimum_compute_capability,
td, A, B, C, type_comb[2], epilogue_functor, swizzling_functor)
new_kernels.add(new_operation)
elif operation_kind == cutlass_library.OperationKind.Conv2d:
for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]:
new_operation = cutlass_library.manifest.Conv2dOperation(
conv_kind, IteratorAlgorithm.Analytic, td.minimum_compute_capability, td,
A, B, C, type_comb[2], StrideSupport.Strided, epilogue_functor, swizzling_functor,
group_mode=GroupMode.SingleGroup
)
new_kernels.add(new_operation)
self.operations_by_opclass[cutlass_library.OpcodeClass.Simt][comb] = new_kernels
# Sort all operations
for oc in self.operations_by_opclass.keys():
for comb in self.operations_by_opclass[oc].keys():
self.operations_by_opclass[oc][comb].sort()
def opclass_supports_combination(
self, op_class: cutlass_library.OpcodeClass, datatype_comb: tuple, layout_comb: tuple, math_operation: cutlass_library.MathOperation
) -> bool:
"""
Returns whether the provided operation class supports the provided data type and layout combination
:param op_class: operation class to consider
:type op_class: cutlass_library.OpcodeClass
:param datatype_comb: tuple of data types for (element_A, element_B, element_accumulator)
:type datatype_comb: tuple[cutlass_library.DataType]
:param layout_comb: tuple of data types for (layout_A, layout_B)
:type layout_comb: tuple[cutlass_library.LayoutType]
:param math_operation: math operation to consider or None if any can be considered
:type math_operation: cutlass_cppgen.MathOperation
:return: set of operation classes that support the provided data type and layout combination
:rtype: set
"""
if op_class not in self.operations_by_opclass:
raise Exception(f"Unexpected or unsupported operation class {op_class}")
if operations := self.operations_by_opclass[op_class].get((datatype_comb, layout_comb)):
if math_operation is not None:
return operations.supports_math_operation(math_operation)
else:
return True
return False
def supporting_opclasses(
self,
element_a: cutlass_library.DataType,
element_b: cutlass_library.DataType,
element_accumulator: cutlass_library.DataType,
layout_a: cutlass_library.LayoutType,
layout_b: cutlass_library.LayoutType,
math_operation: cutlass_library.MathOperation,
) -> set:
"""
Returns a set of operation classes that support the provided data type combination
:param element_a: data type of operand A
:type element_a: cutlass_library.DataType
:param element_b: data type of operand B
:type element_b: cutlass_library.DataType
:param element_accumulator: data type of accumulator
:type element_accumulator: cutlass_library.DataType
:param layout_a: layout of operand A
:type layout_a: cutlass_library.LayoutType
:param layout_b: layout of operand B
:type layout_b: cutlass_library.LayoutType
:param math_operation: math operation to consider
:type math_operation: cutlass_cppgen.MathOperation
:return: set of operation classes that support the provided data type combination
:rtype: set
"""
supporting_op_classes = set()
datatype_comb = (element_a, element_b, element_accumulator)
layout_comb = (layout_a, layout_b)
for op_class in self.operations_by_opclass.keys():
if self.opclass_supports_combination(op_class, datatype_comb, layout_comb, math_operation):
supporting_op_classes.add(op_class)
return supporting_op_classes
def operations(
self,
op_class: cutlass_library.OpcodeClass,
element_a: cutlass_library.DataType,
element_b: cutlass_library.DataType,
element_accumulator: cutlass_library.DataType,
layout_a: cutlass_library.LayoutType,
layout_b: cutlass_library.LayoutType,
math_operation: cutlass_library.MathOperation,
) -> KernelsForDataType:
"""
Returns whether the provided operation class supports the provided data type combination
:param op_class: operation class to consider
:type op_class: cutlass_library.OpcodeClass
:param element_a: data type of operand A
:type element_a: cutlass_library.DataType
:param element_b: data type of operand B
:type element_b: cutlass_library.DataType
:param element_accumulator: data type of accumulator
:type element_accumulator: cutlass_library.DataType
:param layout_a: layout of operand A
:type layout_a: cutlass_library.LayoutType
:param layout_b: layout of operand B
:type layout_b: cutlass_library.LayoutType
:param math_operation: math operation to consider
:type math_operation: cutlass_cppgen.MathOperation
:return: container of kernels by alignment supported by the provided combination of parameters
:rtype: KernelsForDataType
"""
datatype_comb = (element_a, element_b, element_accumulator)
layout_comb = (layout_a, layout_b)
if not self.opclass_supports_combination(op_class, datatype_comb, layout_comb, math_operation):
raise Exception(
f"Data type layout combination {datatype_comb}, {layout_comb} "
f"is not supported by opcode class {op_class} on CC {self.cc}."
)
return self.operations_by_opclass[op_class][(datatype_comb, layout_comb)]
class OptionRegistry:
"""
Container of all architecture-specific options
:param target_cc: compute capability of the device on which operations will be run
:type target_cc: int
"""
def __init__(self, target_cc: int):
self.registry = {}
if target_cc > 100 and (target_cc not in [101, 103, 120, 121]):
raise Exception(f"Unsupported compute capability {target_cc}. The CUTLASS Python interface only supports compute capabilities up to the Blackwell architecture.")
gemm_kinds = [cutlass_library.GemmKind.Universal, cutlass_library.GemmKind.Universal3x]
operation_kinds = [cutlass_library.OperationKind.Gemm, cutlass_library.OperationKind.Conv2d]
# Construct options for each CC
for kernel_cc in _generator_ccs:
self.registry[kernel_cc] = {}
for opkind in operation_kinds:
self.registry[kernel_cc][opkind] = ArchOptions(target_cc, kernel_cc, opkind, gemm_kinds)
def options_for_cc(self, cc: int, op_kind=cutlass_library.OperationKind.Gemm) -> ArchOptions:
return self.registry.get(cc, None)[op_kind]

View File

@ -0,0 +1,36 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_cppgen.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad
from cutlass_cppgen.op.gemm import Gemm
from cutlass_cppgen.op.gemm_grouped import GroupedGemm
from cutlass_cppgen.op.op import OperationBase

View File

@ -0,0 +1,997 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Ease-of-use interface for constructing, compiling, and running CONVs
The ``Conv2d`` interface is meant to allow one to easily instantiate, compile, and run
CONV2D operations in CUTLASS via Python, without specifying many configuration parameters.
Under the hood, the interface will select sensible default parameters for the many template
parameters for CUTLASS CONVs.
Note: optimal performance is not to be expected from this interface. To achieve optimal
performance, one should specify and tune each configuration parameter.
The simplest example of using this interface is the following:
.. highlight:: python
.. code-block:: python
# A, B, C, and D are torch/numpy/cupy tensor objects
plan = cutlass_cppgen.op.Conv(A, B, C, D)
plan.run(stride=(1, 1), padding=(0, 0), dilation=(1, 1))
One can also use the interface by specifying data types of operands at construction
and using different tensor objects with these data types at runtime:
.. highlight:: python
.. code-block:: python
# The following is shorthand for:
# cutlass_cppgen.op.Conv2d(kind="fprop",
# element_A=torch.float32, element_B=torch.float32,
# element_C=torch.float32, element_D=torch.float32,
# element_accumulator=torch.float32)
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=torch.float32)
A0 = torch.rand((128, 256), dtype=torch.float32, device='cuda')
B0 = torch.rand((256, 64), dtype=torch.float32, device='cuda')
C0 = torch.zeros((128, 64), dtype=torch.float32, device='cuda')
D0 = torch.zeros((128, 64), dtype=torch.float32, device.'cuda')
plan.run(A0, B0, C0, D0, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
A = torch.rand((32, 128), dtype=torch.float32, device='cuda')
B = torch.rand((128, 256), dtype=torch.float32, device='cuda')
C = torch.zeros((32, 256), dtype=torch.float32, device='cuda')
D = torch.zeros((32, 256), dtype=torch.float32, device.'cuda')
plan.run(A1, B1, C1, D1, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
The interface additionally enables one to decouple the compilation of the underlying CUTLASS
kernel from its execution:
.. highlight:: python
.. code-block:: python
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32)
# Do other work...
plan.run(A0, B0, C0, D0, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
# Do other work...
plan.run(A1, B1, C1, D1, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
Elementwise activation functions are easily fused to the GEMM via the interface:
.. highlight:: python
.. code-block:: python
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32)
plan.activation = cutlass_cppgen.epilogue.relu
Operations can also be run asynchronously:
.. highlight:: python
.. code-block:: python
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32)
args = plan.run()
# Do other work...
args.sync()
"""
from __future__ import annotations
from typing import Optional
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
cudart = lazy_import("cuda.cudart")
from cutlass_library import (
ConvKind,
ConvMode,
DataTypeSize,
IteratorAlgorithm,
OperationKind,
SplitKMode,
StrideSupport,
)
import cutlass_cppgen
from cutlass_cppgen import epilogue
from cutlass_cppgen.backend import compiler
from cutlass_cppgen.backend.conv2d_operation import Conv2dArguments, Conv2dOperation
from cutlass_cppgen.backend.reduction_operation import ReductionOperation, ReductionArguments
from cutlass_cppgen.backend.library import TensorDescription, TileDescription
from cutlass_cppgen.op.op import OperationBase
from cutlass_cppgen.shape import Conv2DProblemSize, MatrixCoord
from cutlass_cppgen.utils import check, datatypes
class Conv2d(OperationBase):
"""
Constructs a ``Conv2d`` object.
The convolution kind (fprop, wgrad, degrad), the data types of operands A, B, and C,
along with the data type of output D and that used for accumulation, are bound to the ``Conv``
object throughout its lifetime -- these are not to be changed after a ``Conv2d`` has been constructed.
The constructor has optional parameters for flexibly setting these parameters. The following
constructors are equivalent:
.. highlight:: python
.. code-block:: python
# Use F32 for A, B, C, D, and accumulation in fprop
# Use the generic ``element`` parameter to concisely set all data types for operands to the same values.
Conv2d(kind="fprop", element=cutlass_cppgen.DataType.f32)
# Explicitly specify the data types to use for A, B, C, and D.
Conv2d(kind="fprop", element_A=cutlass_cppgen.DataType.f32, element_B=cutlass_cppgen.DataType.f32,
element_C=cutlass_cppgen.DataType.f32, element_D=cutlass_cppgen.DataType.f32)
# Set the data types and elements from existing tensors. Note that one can use different tensors when
# executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must
# have the same data type as those passed in here).
# A, B, C, and D are torch.Tensor objects of type torch.float32 under the channel-last layout
Conv2d(kind="fprop", A=A, B=B, C=C, D=D)
# Explicitly specify the data type for only some of A, B, C, and D. Unspecified data types will inherit
# those passed in via the generic ``element``
Conv2d(kind="fprop", element_A=cutlass_cppgen.DataType.f32, element_accumulator=cutlass_cppgen.DataType.f32,
element=cutlass_cppgen.DataType.f32)
The order of precedence for the setting of the data type for a given operand/output is as follows:
1) If the tensor type is specified (e.g., ``A``), use the data type inferred from this tensor
2) Otherwise, if the data type (e.g., ``element_A``) is specified, use those
3) Otherwise, use the generic values (e.g., ``element``)
:param kind: the convolution kind (i.e. fprop, wgrad, and dgrad)
:type kind: str
:param A: tensor representing data type of operand A
:param B: tensor representing data type of operand B
:param C: tensor representing data type of operand C
:param D: tensor representing data type of operand D
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
:param beta: scalar parameter beta from GEMM operation that scales operand C
:param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
:type element: cutlass_cppgen.DataType
:param element_A: data type to be used for operand A
:type element_A: cutlass_cppgen.DataType
:param element_B: data type to be used for operand B
:type element_B: cutlass_cppgen.DataType
:param element_C: data type to be used for operand C
:type element_C: cutlass_cppgen.DataType
:param element_D: data type to be used for operand D
:type element_D: cutlass_cppgen.DataType
:param element_accumulator: data type to be used in accumulation of the product of operands A and B
:type element_accumulator: cutlass_cppgen.DataType
:param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
:type cc: int
:param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
:type kernel_cc: int
"""
def __init__(
self, kind="fprop",
A=None, B=None, C=None, D=None, alpha=1.0, beta=0.0,
element=None,
element_A=None, element_B=None, element_C=None, element_D=None,
element_accumulator=None,
cc: int = None, kernel_cc: int = None
):
super().__init__(cc=cc, kernel_cc=kernel_cc, operation_kind=OperationKind.Conv2d)
# Verify the kernel cc
if self.current_cc in [90, 100, 101, 103]:
# The Conv2d kernel on Hopper (SM90) is currently unsupported
# Revert to use SM80-tagged kernels
cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
self.specified_kernel_cc = 80
self._reset_options(80)
# The arch is used in testing
self.arch = self.current_cc
self.name = "conv2d" + kind
# The convolution kind. (concept: cutlass_library.library.ConvKind)
self.conv_kind = datatypes.getattr_enum(ConvKind, kind)
# The element types (concept: cutlass library types) of A, B, C, and D
elements = []
layouts = []
# Complete the data types based on user-provided arguments
for elt, tens, name in zip([element_A, element_B, element_C, element_D],
[A, B, C, D],
["A", "B", "C", "D"]):
if elt is not None and tens is not None:
raise Exception(f'Must not specify both element_{name} and tensor {name}')
if elt is None and tens is None and element is None:
raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.')
elt_to_set = None
lay_to_set = None
if tens is not None:
elt_to_set, _ = datatypes.get_datatype_and_layout(tens)
else:
elt_to_set = elt if elt is not None else element
assert elt_to_set is not None
# Currently we only support layout TensorNHWC
lay_to_set = cutlass_cppgen.LayoutType.TensorNHWC
elements.append(datatypes.library_type(elt_to_set))
layouts.append(lay_to_set)
self._element_a, self._element_b, self._element_c, self._element_d = elements
self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts
self.A, self.B, self.C, self.D, self.alpha, self.beta = A, B, C, D, alpha, beta
if element_accumulator is None:
self._element_accumulator = self._element_c
else:
self._element_accumulator = datatypes.library_type(element_accumulator)
# Default inputs if none is supplied in run()
self.A = A
self.B = B
self.C = C
self.D = D
self.alpha = alpha
self.beta = beta
# We only specify the stride of the swizzling functor here
# The actual swizzling functor is determined in run based on conv_kind and stride
self._swizzling_stride = 1
# Arguments that will be set to default value in _reset_operations
# The default tile_description and op_class are fetched from manifest of cutlass library
self._tile_description = None
self.op_class = None
# The default identity epilogue will be created
self.epilogue_functor = None
self._reset_operations()
# Arguments that will be determined online based on arguments of "run"
# based on stride, input/output channels, alignment, and conv_kind
self._iterator_algorithm = None
self._stride_support = None
def _reset_operations(self, reset_epilogue: bool = True):
# Set the default op class
datatype_comb = (self._element_a, self._element_b, self._element_accumulator)
layout_comb = (self._layout_a, self._layout_b)
self.possible_op_classes = self.options.supporting_opclasses(
self._element_a, self._element_b, self._element_accumulator,
self._layout_a, self._layout_b, self._math_operation
)
if cutlass_cppgen.OpcodeClass.TensorOp in self.possible_op_classes:
self.opclass = cutlass_cppgen.OpcodeClass.TensorOp
elif cutlass_cppgen.OpcodeClass.Simt in self.possible_op_classes:
self.opclass = cutlass_cppgen.OpcodeClass.Simt
else:
if self._math_operation is not None:
math_op_str = f' and math operation {self._math_operation}'
else:
math_op_str = ''
raise Exception(f'No kernel configuration found for supported data type and layout '
f'combination {datatype_comb}x{layout_comb}{math_op_str}')
if reset_epilogue:
self._reset_epilogue_functor_activation(epilogue.identity)
self.alignment_pref_A = min(
128 // DataTypeSize[self._element_a], max(self.possible_operations.alignments("A")))
self.alignment_pref_B = min(
128 // DataTypeSize[self._element_b], max(self.possible_operations.alignments("B")))
self.alignment_pref_C = min(
128 // DataTypeSize[self._element_c], max(self.possible_operations.alignments("C")))
#
# Tile description Related
#
@property
def tile_description(self) -> TileDescription:
"""
Returns the tile description
"""
return self._tile_description
@tile_description.setter
def tile_description(
self, td=None):
"""
Set the tile description
:param td: tile description
:type td: cutlass_cppgen.backend.TileDescription, or a dict with keys
{
"threadblock_shape": [int, int, int],
"warp_count": [int, int, int],
"stages": int,
"instruction_shape": [int, int, int] (optional),
"cluster_shape": [int, int, int] (optional)
}
"""
if td is None:
return
if isinstance(td, dict):
if self._tile_description is None:
op = self.possible_operations.default_operation(self._math_operation)
self._tile_description = datatypes.td_from_profiler_op(op)
if "cluster_shape" in td.keys():
if td["cluster_shape"] != [1, 1, 1]:
cutlass_cppgen.logger.warning("Conv2d currently only support 'cluster_shape'=[1, 1, 1]'.")
td["cluster_shape"] = [1, 1, 1]
td = self._tile_description.clone_and_update(td)
valid, msg = self._valid_tile_description(td)
if valid:
self._tile_description = td
else:
raise Exception(msg)
def _valid_tile_description(self, td: TileDescription) -> tuple:
"""
Checks whether the provided tile description is valid for the given compute capability. At present,
this checks the following:
- Does the tile description use a number of stages supported by the compute capability in question?
- Does the tile size requested fit within shared memory?
- Are cluster dimensions outside the valid range requested for a given architecture (e.g.,
more non-unit cluster dimensions for pre-SM90 architectures)?
- Is the kernel schedule being used supported on the architecture in question?
:param td: tile description to validate
:type td: cutlass_cppgen.backend.TileDescription
:return: tuple in which the first element is a bool indicating that the tile description is valid
and the second element is a string providing an optional error message.
:rtype: tuple
"""
valid, msg = check.valid_stage_count(self.cc, self.current_cc, td)
if not valid:
return (valid, msg)
valid, msg = check.valid_cluster_shape(self.current_cc, td.cluster_shape)
if not valid:
return (valid, msg)
return valid, msg
def tile_descriptions(self) -> list:
"""
Returns a list of valid tile descriptions for the operations
:returns: list of valid tile descriptions for the operations
:rtype: list
"""
descriptions = []
description_str = []
for op in self.possible_operations.all_operations:
td = datatypes.td_from_profiler_op(op)
if self._math_operation is not None:
if td.math_instruction.math_operation != self._math_operation:
continue
if str(td) not in description_str:
description_str.append(str(td))
descriptions.append(td)
return descriptions
#
# Swizzling functor Related
#
@property
def swizzling_stride(self):
"""
Returns the stride of swizzling currently being used by the Conv2d
:return: swizzing stride
"""
return self._swizzling_stride
@swizzling_stride.setter
def swizzling_stride(self, stride: int):
"""
Sets the swizzling functor to the type specified by `swizzling_functor`
"""
if not isinstance(stride, int):
raise Exception(f"Expect integer (1, 2, 4, 8), got {stride}")
self._swizzling_stride = stride
def _propose_swizzling_functor(self, stride):
"""
Automatically propose the swizzling functor based on the stride
"""
if self.conv_kind == ConvKind.Dgrad:
if stride[0] != 1 or stride[1] != 1:
return getattr(cutlass_cppgen.swizzle, f"StridedDgradIdentitySwizzle{self._swizzling_stride}")
return getattr(cutlass_cppgen.swizzle, f"IdentitySwizzle{self._swizzling_stride}")
#
# Iterator Algorithm Related
#
@property
def iterator_algorithm(self) -> IteratorAlgorithm:
"""
Returns the iterator algorithm
"""
return self._iterator_algorithm
@iterator_algorithm.setter
def iterator_algorithm(self, alg: str):
"""
Sets the iterator algorithm
:param alg: The iterator algorithm
:type td: string, options: "analytic", "optimized", "few_channels", and "fixed_channels"
"""
iterator_alg = datatypes.getattr_enum(IteratorAlgorithm, alg)
# Check if the iterator algorithm is valid
if iterator_alg in [IteratorAlgorithm.FewChannels, IteratorAlgorithm.FixedChannels] and self.conv_kind != ConvKind.Fprop:
raise Exception(f"{self.conv_kind} does not support iterator algorithm {alg}.")
self._iterator_algorithm = iterator_alg
def _propose_iterator_algorithm(self, problem_size, alignment_a, alignment_b) -> IteratorAlgorithm:
"""
Propose a valid iterator algorithm based on problem size and alignment
"""
if self.conv_kind == ConvKind.Fprop:
# Check whether the fixed channel is applicable
if problem_size.C == alignment_a:
return IteratorAlgorithm.FixedChannels
elif (problem_size.C % alignment_a == 0 and
problem_size.R <= 32 and problem_size.S <= 32):
return IteratorAlgorithm.Optimized
else:
return IteratorAlgorithm.Analytic
elif self.conv_kind == ConvKind.Dgrad:
if (problem_size.K % alignment_a == 0 and
problem_size.R <= 32 and problem_size.S <= 32 and
problem_size.C % alignment_b == 0):
return IteratorAlgorithm.Optimized
else:
return IteratorAlgorithm.Analytic
elif self.conv_kind == ConvKind.Wgrad:
if (problem_size.K % alignment_a == 0 and
problem_size.C % alignment_b == 0):
return IteratorAlgorithm.Optimized
else:
return IteratorAlgorithm.Analytic
def _validate_iterator_algorithm(self, iterator_algorithm, problem_size, alignment_a, alignment_b) -> bool:
"""
Validate whether the user provide iterator algorithm works for the given problem size
"""
if self.conv_kind == ConvKind.Fprop:
if iterator_algorithm == IteratorAlgorithm.FixedChannels:
return problem_size.C == alignment_a
elif iterator_algorithm == IteratorAlgorithm.Optimized:
return (problem_size.C % alignment_a == 0 and
problem_size.R <= 32 and problem_size.S <= 32)
elif iterator_algorithm == IteratorAlgorithm.FewChannels:
return problem_size.C % alignment_a == 0
elif self.conv_kind == ConvKind.Dgrad:
if iterator_algorithm == IteratorAlgorithm.Optimized:
return (problem_size.K % alignment_a == 0 and
problem_size.R <= 32 and problem_size.S <= 32 and
problem_size.C % alignment_b == 0)
elif self.conv_kind == ConvKind.Wgrad:
if iterator_algorithm == IteratorAlgorithm.Optimized:
return (problem_size.K % alignment_a == 0 and
problem_size.C % alignment_b == 0)
return True
#
# Stride Support Related
#
def _propose_stride_support(self, stride):
if self.conv_kind == ConvKind.Dgrad:
if stride[0] == 1 and stride[1] == 1:
return StrideSupport.Unity
return StrideSupport.Strided
#
# Construct and Compilation
#
def construct(
self, tile_description: TileDescription = None,
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
iterator_algorithm: IteratorAlgorithm = None,
stride_support = None, swizzling_functor: cutlass_cppgen.swizzle = None,
epilogue_functor=None) -> cutlass_cppgen.backend.Conv2dOperation:
"""
Constructs a ``cutlass_cppgen.backend.Conv2dOperation`` based on the input parameters and current
kernel specification of the ``Conv2d`` object.
:param tile_description: tile description specifying shapes and operand types to use in the kernel
:type tile_description: cutlass_cppgen.backend.TileDescription
:param alignment_A: alignment of operand A
:type alignment_A: int
:param alignment_B: alignment of operand B
:type alignment_B: int
:param alignment_C: alignment of operand C
:type alignment_C: int
:param iterator_algorithm: the iterator algorithm used
:type iterator_algorithm: cutlass_library.library.IteratorAlgorithm
:param stride_support: the stride support of dgrad
:type stride_support: cutlass_library.library.StrideSupport
:param swizzling_functor: the swizzling functor
:type swizzling_functor: cutlass_cppgen.swizzle
:param epilogue_functor: the epilogue functor
:return: operation that was constructed
:rtype: cutlass_cppgen.backend.Conv2dOperation
"""
# Get alignment
alignment_A = check.alignment_or_default(alignment_A, self.alignment_pref_A)
alignment_B = check.alignment_or_default(alignment_B, self.alignment_pref_B)
alignment_C = check.alignment_or_default(alignment_C, self.alignment_pref_C)
tensor_A = TensorDescription(self._element_a, self._layout_b, alignment_A)
tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
if tile_description is None:
if self.tile_description is not None:
tile_description = self.tile_description
else:
op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0]
tile_description = datatypes.td_from_profiler_op(op)
else:
valid, err_str = self._valid_tile_description(tile_description)
if not valid:
raise Exception(f"Invalid tile description. {err_str}")
self.tile_description = tile_description
if iterator_algorithm is None:
# If the iterator algorithm is already set
if self.iterator_algorithm is not None:
iterator_algorithm = self.iterator_algorithm
else:
# Otherwise, we conservatively use the analytic iterator for correctness
iterator_algorithm = IteratorAlgorithm.Analytic
if stride_support is None:
# If the stride support is already set
if self._stride_support is not None:
stride_support = self._stride_support
else:
# Otherwise, we assume strided
stride_support = StrideSupport.Strided
if swizzling_functor is None:
# If the swizzling functor is already set
swizzling_functor = self._propose_swizzling_functor(stride=(2, 2))
if epilogue_functor is None:
if self.epilogue_functor is not None:
epilogue_functor = self.epilogue_functor
else:
epilogue_functor = self._create_epilogue_functor_activation(self._activation)
# Reset the alignment of the epilogue functor
epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, epilogue_functor)
operation = Conv2dOperation(
conv_kind=self.conv_kind,
iterator_algorithm=iterator_algorithm,
arch=self.current_cc,
tile_description=tile_description,
A=tensor_A, B=tensor_B, C=tensor_C,
stride_support=stride_support,
epilogue_functor=epilogue_functor,
swizzling_functor=swizzling_functor,
)
return operation
def compile(self, tile_description: TileDescription = None,
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
iterator_algorithm: IteratorAlgorithm = None,
stride_support = None, swizzling_functor: cutlass_cppgen.swizzle = None,
epilogue_functor = None, print_module: bool = False) -> cutlass_cppgen.backend.Conv2dOperation:
"""
Emits and compiles the kernel currently specified. If ``tile_description`` and any
of the ``alignment`` parameters are set, the kernel will be chosen using this
tile description and alignments. Otherwise, a default tile description and alignment
will be used.
::param tile_description: tile description specifying shapes and operand types to use in the kernel
:type tile_description: cutlass_cppgen.backend.TileDescription
:param alignment_A: alignment of operand A
:type alignment_A: int
:param alignment_B: alignment of operand B
:type alignment_B: int
:param alignment_C: alignment of operand C
:type alignment_C: int
:param iterator_algorithm: the iterator algorithm used
:type iterator_algorithm: cutlass_library.library.IteratorAlgorithm
:param stride_support: the stride support of dgrad
:type stride_support: cutlass_library.library.StrideSupport
:param swizzling_functor: the swizzling functor
:type swizzling_functor: cutlass_cppgen.swizzle
:param epilogue_functor: the epilogue functor
:return: operation that was compiled
:rtype: cutlass_cppgen.backend.Conv2dOperation
"""
self.operation = self.construct(
tile_description, alignment_A, alignment_B, alignment_C,
iterator_algorithm, stride_support, swizzling_functor, epilogue_functor)
if print_module:
print(self.operation.rt_module.emit())
compiler.add_module([self.operation,])
return self.operation
#
# Run Related
#
def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name):
"""
Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception
is raised if it does not.
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
:type tensor: numpy/cupy/torch array/tensor object
:param ref_dtype: data type for the tensor that this object was initialized to
:param name: identifier of the tensor to verify. Used in raising exceptions
:type name: str
"""
dtype, _ = datatypes.get_datatype_and_layout(tensor)
if dtype != ref_type:
raise Exception(f'Tensor {name} with type and layout {dtype} '
f'does not match the expected type of {ref_type}.')
def _get_and_verify_conv_problem_size(self, A, B, C, stride, padding, dilation):
if self.conv_kind == ConvKind.Fprop:
input = A
weight = B
output = C
output_tensor = "C"
elif self.conv_kind == ConvKind.Dgrad:
output = A
weight = B
input = C
output_tensor = "A"
elif self.conv_kind == ConvKind.Wgrad:
output = A
input = B
weight = C
output_tensor = "A"
else:
raise Exception(f"Convolution kind {self.conv_kind} is not supported")
N_, H_, W_, C_ = datatypes.get_tensor_shape(input, op="CONV")
K_, R_, S_, _ = datatypes.get_tensor_shape(weight, op="CONV")
_, P_, Q_, _ = datatypes.get_tensor_shape(output, op="CONV")
problem_size = Conv2DProblemSize(
N_, H_, W_, C_,
K_, R_, S_, C_,
padding[0], padding[1],
stride[0], stride[1],
dilation[0], dilation[1],
ConvMode.CrossCorrelation,
1, 1
)
if P_ != problem_size.P or Q_ != problem_size.Q:
raise Exception(
f"Tensor {output_tensor} size should be ({N_}, {problem_size.P}, {problem_size.Q}, {K_}), got ({N_}, {P_}, {Q_}, {K_})")
return problem_size
def run(self, A=None, B=None, C=None, D=None,
stride=(1, 1), padding=(0, 0), dilation=(1, 1),
alpha=None, beta=None,
split_k=("serial", 1), sync: bool = True,
print_module: bool = False,
stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
"""
Runs the kernel currently specified. If it has not already been, the kernel is emitted and
compiled. Tensors holding operands and outputs of the kernel are sourced either from the
``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta``
parameters provided in the call, or from those
passed in on the construction of this object -- one of the two must be specified.
By default, this call returns only once the kernel has completed. To launch the kernel
and immediately return, set ``sync=False``. In this case, it is the responsibility of the
caller to syncrhonize the results of the kernel before attempting to access outputs
by calling ``sync()`` on the arguments returned from this call.
:param A: tensor representing data type and layout of operand A
:param B: tensor representing data type and layout of operand B
:param C: tensor representing data type and layout of operand C
:param D: tensor representing data type and layout of operand D
:param stride: (stride_h, stride_w) describing the convolution stride. Default: (1, 1)
:param padding: (pad_h, pad_w) describing the convolution padding. Default: (0, 0)
:param dilation: (dilation_h, dilation_w) describing the dilation of convolution. Default: (1, 1)
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
:param beta: scalar parameter beta from GEMM operation that scales operand C
:param split_k: a tuple (split_k_mode, split_k_slices)
:param sync: whether the call should wait for the kernel to complete before returning
:type sync: bool
:param print_module: whether to print the emitted C++ code
:type print_module: bool
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
:type stream: :class:`cuda.cuda.CUstream`
:return: arguments passed in to the kernel
:rtype: cutlass_cppgen.backend.Conv2dArguments
"""
if not stream:
stream = cuda.CUstream(0)
super().run_setup()
A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A")
B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B")
C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C")
D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D")
alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
# handle the case when there is no C
if C is None:
if beta != 0:
raise Exception(f"With beta {beta} != 0, C has to be provided.")
else:
C = D
# Construct problem size based on input
# It also verifies whether the A, B, C, D, stride, padding, and dilation are matching
problem_size = self._get_and_verify_conv_problem_size(A, B, C, stride, padding, dilation)
# Propose stride support based on input
stride_support = self._propose_stride_support(stride)
# Propose swizzling functor
swizzling_functor = self._propose_swizzling_functor(stride)
shape_a = datatypes.get_tensor_shape(A, op="CONV")
shape_b = datatypes.get_tensor_shape(B, op="CONV")
shape_c = datatypes.get_tensor_shape(C, op="CONV")
# Get the alignment
alignment_a = self.possible_operations.find_alignment(shape_a, self._layout_a, operand="A")
alignment_b = self.possible_operations.find_alignment(shape_b, self._layout_b, operand="B")
alignment_c = self.possible_operations.find_alignment(shape_c, self._layout_c, operand="C")
alignment_a = check.update_alignment(alignment_a, self.alignment_pref_A)
alignment_b = check.update_alignment(alignment_b, self.alignment_pref_B)
alignment_c = check.update_alignment(alignment_c, self.alignment_pref_C)
# Propose iterator algorithm based on input
if self._iterator_algorithm is None:
# Propose a default iterator algorithm based on the problem size
iterator_algorithm = self._propose_iterator_algorithm(problem_size, alignment_a, alignment_b)
else:
if (self._validate_iterator_algorithm(self._iterator_algorithm, problem_size, alignment_a, alignment_b)):
iterator_algorithm = self._iterator_algorithm
else:
raise Exception(f"Iterator algorithm {self._iterator_algorithm} is invalid for current problem.")
epilogue_args = [alpha, beta]
if hasattr(self, "_activation_args"):
if isinstance(self._activation_args, list):
epilogue_args += self._activation_args
else:
epilogue_args.append(self._activation_args)
if split_k[0] == "parallel" and split_k[1] > 1:
epilogue_functor = self._create_epilogue_functor_activation(epilogue.identity)
else:
epilogue_functor = self.epilogue_functor
# The alignment is determined by the iterator function (I believe)
self.compile(tile_description=self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
alignment_C=alignment_c, iterator_algorithm=iterator_algorithm, stride_support=stride_support,
swizzling_functor=swizzling_functor, epilogue_functor=epilogue_functor, print_module=print_module)
# Create reduction operation for parallel split-k
if split_k[0] == "parallel" and split_k[1] > 1:
epilogue_functor_reduction = self._reset_epilogue_functor_alignment(alignment_c, self.epilogue_functor)
self.reduction_operation = ReductionOperation(
shape=MatrixCoord(4, 32 * alignment_c), C=self.operation.C,
element_accumulator=self._element_accumulator,
element_compute=self._element_accumulator,
epilogue_functor=epilogue_functor_reduction,
count=alignment_c
)
if print_module:
print(self.reduction_operation.rt_module.emit())
compiler.add_module([self.reduction_operation,])
arguments = Conv2dArguments(
operation=self.operation, problem_size=problem_size,
A=A, B=B, C=C, D=D,
output_op=self.operation.epilogue_type(*epilogue_args),
split_k_mode=datatypes.getattr_enum(SplitKMode, split_k[0]),
split_k_slices=split_k[1],
stream=stream
)
self.operation.run(arguments)
if split_k[0] == "parallel" and split_k[1] > 1:
implicit_gemm_size = arguments.problem_size.implicit_gemm_size(self.conv_kind)
reduction_arguments = ReductionArguments(
self.reduction_operation,
problem_size=[implicit_gemm_size.m, implicit_gemm_size.n],
partitions=split_k[1],
workspace=arguments.ptr_D,
destination=D,
source=C,
output_op=self.reduction_operation.epilogue_type(*epilogue_args),
stream=stream
)
self.reduction_operation.run(reduction_arguments)
if sync:
if split_k[0] == "parallel" and split_k[1] > 1:
reduction_arguments.sync()
# Free memory allocated by args because we are not
# calling `arguments.sync()` in this case (which will free memory)
arguments.free()
else:
arguments.sync()
return arguments
#
# Helper functions
#
@staticmethod
def output_size(input_size, weight_size, padding, stride, dilation):
problem_size = Conv2DProblemSize(
*input_size,
*weight_size,
padding[0], padding[1],
stride[0], stride[1],
dilation[0], dilation[1],
ConvMode.CrossCorrelation,
1, 1
)
return (problem_size.N, problem_size.P, problem_size.Q, problem_size.K)
#
# Easy to use interfaces for fprop, wgrad, and dgrad
#
class Conv2dFprop(Conv2d):
def __init__(
self,
input=None, weight=None, C=None, output=None, alpha=1, beta=0,
element=None,
element_input=None, element_weight=None, element_C=None, element_output=None,
element_accumulator=None,
cc: int = None, kernel_cc: int = None):
A, B, D = input, weight, output
element_A, element_B, element_D = element_input, element_weight, element_output
super().__init__(
"fprop", A, B, C, D, alpha, beta, element,
element_A, element_B, element_C, element_D,
element_accumulator, cc, kernel_cc)
def run(
self, input=None, weight=None, C=None, output=None, alpha=None, beta=None,
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
sync: bool = True, print_module: bool = False,
stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
if not stream:
stream = cuda.CUstream(0)
A, B, D = input, weight, output
return super().run(
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)
class Conv2dDgrad(Conv2d):
def __init__(
self,
grad_output=None, weight=None, C=None, grad_input=None, alpha=1, beta=0,
element=None,
element_grad_output=None, element_weight=None, element_C=None, element_grad_input=None,
element_accumulator=None,
cc: int = None, kernel_cc: int = None):
A, B, D = grad_output, weight, grad_input
element_A, element_B, element_D = element_grad_output, element_weight, element_grad_input
super().__init__(
"dgrad", A, B, C, D, alpha, beta, element,
element_A, element_B, element_C, element_D,
element_accumulator, cc, kernel_cc)
def run(self, grad_output=None, weight=None, C=None, grad_input=None, alpha=None, beta=None,
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
sync: bool = True, print_module: bool = False,
stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
#
if not stream:
stream = cuda.CUstream(0)
A, B, D = grad_output, weight, grad_input
return super().run(
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)
class Conv2dWgrad(Conv2d):
def __init__(
self,
grad_output=None, input=None, C=None, grad_weight=None, alpha=1, beta=0,
element=None,
element_grad_output=None, element_input=None, element_C=None, element_grad_weight=None,
element_accumulator=None,
cc: int = None, kernel_cc: int = None):
A, B, D = grad_output, input, grad_weight
element_A, element_B, element_D = element_grad_output, element_input, element_grad_weight
super().__init__(
"wgrad", A, B, C, D, alpha, beta, element,
element_A, element_B, element_C, element_D,
element_accumulator, cc, kernel_cc)
def run(self, grad_output=None, input=None, C=None, grad_weight=None, alpha=None, beta=None,
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
sync: bool = True, print_module: bool = False,
stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
if not stream:
stream = cuda.CUstream(0)
A, B, D = grad_output, input, grad_weight
return super().run(
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)

View File

@ -0,0 +1,725 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Ease-of-use interface for constructing, compiling, and running GEMMs.
The ``Gemm`` interface is meant to allow one to easily instantiate, compile, and run
GEMM operations in CUTLASS via Python, without specifying many configuration parameters.
Under the hood, the interface will select sensible default parameters for the many template
parameters for CUTLASS GEMMs.
Note: optimal performance is not to be expected from this interface. To achieve optimal
performance, one should specify and tune each configuration parameter.
The simplest example of using this interface is the following:
.. highlight:: python
.. code-block:: python
# A, B, C, and D are torch/numpy/cupy tensor objects
plan = cutlass_cppgen.op.Gemm(A, B, C, D)
plan.run()
One can also use the interface by specifying data types of operands at construction
and using different tensor objects with these data types at runtime:
.. highlight:: python
.. code-block:: python
# The following is shorthand for:
# cutlass_cppgen.op.Gemm(element_A=torch.float32, element_B=torch.float32,
# element_C=torch.float32, element_D=torch.float32,
# element_accumulator=torch.float32,
# layout=cutlass_cppgen.LayoutType.RowMajor)
plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
A0 = torch.rand((128, 256), device='cuda')
B0 = torch.rand((256, 64), device='cuda')
C0 = torch.zeros((128, 64), device='cuda')
D0 = torch.zeros((128, 64), device.'cuda')
plan.run(A0, B0, C0, D0)
A = torch.rand((32, 128), device='cuda')
B = torch.rand((128, 256), device='cuda')
C = torch.zeros((32, 256), device='cuda')
D = torch.zeros((32, 256), device.'cuda')
plan.run(A1, B1, C1, D1)
The interface additionally enables one to decouple the compilation of the underlying CUTLASS
kernel from its execution:
.. highlight:: python
.. code-block:: python
plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
plan.compile()
# Do other work...
plan.run(A0, B0, C0, D0)
# Do other work...
plan.run(A1, B1, C1, D1)
Elementwise activation functions are easily fused to the GEMM via the interface:
.. highlight:: python
.. code-block:: python
plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
plan.activation = cutlass_cppgen.epilogue.relu
Operations can also be run asynchronously:
.. highlight:: python
.. code-block:: python
plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
args = plan.run()
# Do other work...
args.sync()
"""
from __future__ import annotations
from typing import Optional
from math import prod
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
from cutlass_library import (
DataType,
DataTypeSize,
GemmUniversalMode,
KernelScheduleSuffixes,
)
import cutlass_cppgen
from cutlass_cppgen import epilogue, swizzle
from cutlass_cppgen.backend import compiler
from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor
from cutlass_cppgen.backend.gemm_operation import GemmArguments, GemmOperationUniversal
from cutlass_cppgen.backend.library import TensorDescription, TileDescription
from cutlass_cppgen.op.op import OperationBase
from cutlass_cppgen.shape import GemmCoord
from cutlass_cppgen.utils import check, datatypes
class Gemm(OperationBase):
"""
Constructs a ``Gemm`` object.
The data types and layouts of operands A, B, and C, along with the data type of output D
and that used for accumulation, are bound to the ``Gemm`` object throughout its lifetime --
these are not to be changed after a ``Gemm`` has been constructed.
The constructor has optional parameters for flexibly setting these parameters. The following
constructors are equivalent:
.. highlight:: python
.. code-block:: python
# Use F32 for A, B, C, D, and accumulation. All operands are row major.
# Use the generic ``element`` and ``layout`` parameters to concisely set all data types and layouts
# for operands to the same values.
Gemm(element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
# Explicitly specify the data types to use for A, B, C, and D. Use the generic ``layout``.
Gemm(element_A=cutlass_cppgen.DataType.f32, element_B=cutlass_cppgen.DataType.f32, element_C=cutlass_cppgen.DataType.f32,
element_D=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
# Set the data types and elements from existing tensors. Note that one can use different tensors when
# executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must
# have the same data type and layout as those passed in here).
# A, B, C, and D are row-major torch.Tensor objects of type torch.float32
Gemm(A=A, B=B, C=C, D=D)
# Use the generic ``element`` and explicitly specify the layouts to use for A, B, and C (layout of D is
# the same as that for D, at present)
Gemm(element=cutlass_cppgen.DataType.f32, layout_A=cutlass_cppgen.LayoutType.RowMajor,
layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor)
# Explicitly specify the data type and layout for only some of A, B, C, and D. Unspecified data types
# and layouts will inherit those passed in via the generic ``element`` and ``layout``
Gemm(element_A=cutlass_cppgen.DataType.f32, layout_B=cutlass_cppgen.LayoutType.RowMajor,
element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
The order of precedence for the setting of the data type and layout for a given operand/output is as follows:
1) If the tensor type is specified (e.g., ``A``), use the data type and layout inferred from this tensor
2) Otherwise, if the data type/layout (e.g., ``element_A``, ``layout_A``) is specified, use those
3) Otherwise, use the generic values (e.g., ``element``, ``layout``)
:param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
:type cc: int
:param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
:type kernel_cc: int
:param A: tensor representing data type and layout of operand A
:param B: tensor representing data type and layout of operand B
:param C: tensor representing data type and layout of operand C
:param D: tensor representing data type and layout of operand D
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
:param beta: scalar parameter beta from GEMM operation that scales operand C
:param element_accumulator: data type to be used in accumulation of the product of operands A and B
:type element_accumulator: cutlass_cppgen.DataType
:param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
:type element: cutlass_cppgen.DataType
:param layout: generic layout type to be used for operands A, B, C, and D
:type layout: cutlass_cppgen.LayoutType
:param element_A: data type to be used for operand A
:type element_A: cutlass_cppgen.DataType
:param element_B: data type to be used for operand B
:type element_B: cutlass_cppgen.DataType
:param element_C: data type to be used for operand C
:type element_C: cutlass_cppgen.DataType
:param element_D: data type to be used for operand D
:type element_D: cutlass_cppgen.DataType
:param layout_A: layout of operand A
:type layout_A: cutlass_cppgen.LayoutType
:param layout_B: layout of operand B
:type layout_B: cutlass_cppgen.LayoutType
:param layout_C: layout of operand C
:type layout_C: cutlass_cppgen.LayoutType
:param layout_D: layout of operand D
:type layout_D: cutlass_cppgen.LayoutType
"""
def __init__(
self, A=None, B=None, C=None, D=None,
alpha=1.0, beta=0.0, element_accumulator=None,
element=None, layout=None,
element_A=None, element_B=None, element_C=None, element_D=None,
layout_A=None, layout_B=None, layout_C=None,
cc: int = None, kernel_cc: int = None
):
super().__init__(cc=cc, kernel_cc=kernel_cc)
self.name = "gemm"
self.compiled = False
elements = []
layouts = []
# Check that at least one of the following is set for each tensor (illustrated assuming tensor A):
# ``A``, ``element_A``, ``element`` and ``A``, ``layout_A``, ``layout``
for elt, lay, tens, name in zip([element_A, element_B, element_C, element_D],
[layout_A, layout_B, layout_C, layout_C],
[A, B, C, D],
["A", "B", "C", "D"]):
if elt is not None and tens is not None:
raise Exception(f'Must not specify both element_{name} and tensor {name}')
if lay is not None and tens is not None:
raise Exception(f'Must not specify both layout_{name} and tensor {name}')
if elt is None and tens is None and element is None:
raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.')
if lay is None and tens is None and layout is None:
raise Exception(f'Must specify one of layout_{name}, tensor {name}, or generic layout.')
elt_to_set = None
lay_to_set = None
if tens is not None:
elt_to_set, lay_to_set = datatypes.get_datatype_and_layout(tens)
else:
elt_to_set = elt if elt is not None else element
lay_to_set = lay if lay is not None else layout
elements.append(datatypes.library_type(elt_to_set))
layouts.append(lay_to_set)
self._element_a, self._element_b, self._element_c, self._element_d = elements
self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts
if element_accumulator is None:
self._element_accumulator = self._element_c
else:
self._element_accumulator = datatypes.library_type(element_accumulator)
self.A = A
self.B = B
self.C = C
self.D = D
self.alpha = alpha
self.beta = beta
self.epilogue_functor = None
self.op_class = None
self._tile_description = None
self._reset_operations()
self._swizzling_functor = cutlass_cppgen.swizzle.IdentitySwizzle1
def _reset_operations(self, reset_epilogue: bool = True):
# Set the default op class
datatype_comb = (self._element_a, self._element_b, self._element_accumulator)
layout_comb = (self._layout_a, self._layout_b)
self.possible_op_classes = self.options.supporting_opclasses(
self._element_a, self._element_b, self._element_accumulator,
self._layout_a, self._layout_b, self._math_operation)
if cutlass_cppgen.OpcodeClass.TensorOp in self.possible_op_classes:
self.opclass = cutlass_cppgen.OpcodeClass.TensorOp
elif cutlass_cppgen.OpcodeClass.Simt in self.possible_op_classes:
self.opclass = cutlass_cppgen.OpcodeClass.Simt
else:
if self._math_operation is not None:
math_op_str = f' and math operation {self._math_operation}'
else:
math_op_str = ''
raise Exception(f'No kernel configuration found for supported data type and layout '
f'combination {datatype_comb}x{layout_comb}{math_op_str}')
if reset_epilogue:
self._reset_epilogue_functor_activation(cutlass_cppgen.epilogue.identity)
@property
def swizzling_functor(self):
"""
Returns the type of the swizzling functor currently being used by the GEMM
:return: swizzing functor type
"""
return self._swizzling_functor
@swizzling_functor.setter
def swizzling_functor(self, swizzling_functor):
"""
Sets the swizzling functor to the type specified by `swizzling_functor`
"""
if swizzling_functor == cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK:
if self.op_class == cutlass_cppgen.OpcodeClass.Simt:
raise Exception('ThreadblockSwizzleStreamK is currently only supported with opcode class TensorOp')
if self.current_cc in [90, 100, 101, 103]:
raise Exception('ThreadblockSwizzleStreamK is currently unsupported on SM90+')
self._swizzling_functor = swizzling_functor
#
# Tile description Related
#
@property
def tile_description(self) -> TileDescription:
"""
Returns the tile description
"""
return self._tile_description
@tile_description.setter
def tile_description(
self, td=None):
"""
Set the tile description
:param td: tile description
:type td: cutlass_cppgen.backend.TileDescription, or a dict with keys
{
"threadblock_shape": [int, int, int],
"warp_count": [int, int, int],
"stages": int,
"instruction_shape": [int, int, int] (optional),
"cluster_shape": [int, int, int] (optional)
}
"""
if td is None:
return
if isinstance(td, dict):
if self._tile_description is None:
op = self.possible_operations.default_operation(self._math_operation)
self._tile_description = datatypes.td_from_profiler_op(op)
td = self._tile_description.clone_and_update(td)
valid, msg = self._valid_tile_description(td)
if valid:
self._tile_description = td
else:
raise Exception(msg)
def _valid_tile_description(self, td: TileDescription) -> tuple:
"""
Checks whether the provided tile description is valid for the given compute capability. At present,
this checks the following:
- Does the tile description use a number of stages supported by the compute capability in question?
- Does the tile size requested fit within shared memory?
- Are cluster dimensions outside the valid range requested for a given architecture (e.g.,
more non-unit cluster dimensions for pre-SM90 architectures)?
- Is the kernel schedule being used supported on the architecture in question?
:param td: tile description to validate
:type td: cutlass_cppgen.backend.TileDescription
:return: tuple in which the first element is a bool indicating that the tile description is valid
and the second element is a string providing an optional error message.
:rtype: tuple
"""
valid, msg = check.valid_stage_count(self.cc, self.current_cc, td, self._element_c, self._element_d)
if not valid:
return (valid, msg)
valid, msg = check.valid_cluster_shape(self.current_cc, td.cluster_shape)
if not valid:
return (valid, msg)
valid, msg = check.valid_schedule(self.current_cc, td.kernel_schedule, td.epilogue_schedule, td.tile_scheduler)
if self.cc in [100, 101, 103] and td.kernel_schedule is not None and td.is_2sm and td.cluster_shape[0] % 2 != 0:
valid = False
msg = "Cluster shape must be divisible by 2 for 2SM kernels on SM100, SM101, and SM103"
return valid, msg
def tile_descriptions(self) -> list:
"""
Returns a list of valid tile descriptions for the operations
:returns: list of valid tile descriptions for the operations
:rtype: list
"""
tds = [datatypes.td_from_profiler_op(op) for op in self.possible_operations.all_operations]
if self._math_operation is not None:
tds = [td for td in tds if td.math_instruction.math_operation == self._math_operation]
return tds
def construct(
self, tile_description: TileDescription = None,
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None) -> GemmOperationUniversal:
"""
Constructs a ``cutlass_cppgen.backend.GemmUniversalOperation`` based on the input parameters and current
kernel specification of the ``Gemm`` object.
:param tile_description: tile description specifying shapes and operand types to use in the kernel
:type tile_description: cutlass_cppgen.backend.TileDescription
:param alignment_A: alignment of operand A
:type alignment_A: int
:param alignment_B: alignment of operand B
:type alignment_B: int
:param alignment_C: alignment of operand C
:type alignment_C: int
:return: operation that was constructed
:rtype: cutlass_cppgen.backend.GemmOperationUniversal
"""
alignment_pref_A = min(128 // DataTypeSize[self._element_a], max(self.possible_operations.alignments("A")))
alignment_pref_B = min(128 // DataTypeSize[self._element_b], max(self.possible_operations.alignments("B")))
alignment_A = check.alignment_or_default(alignment_A, alignment_pref_A)
alignment_B = check.alignment_or_default(alignment_B, alignment_pref_B)
tensor_A = TensorDescription(self._element_a, self._layout_a, alignment_A)
tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
if alignment_C is None:
alignment_C = max(self.possible_operations.alignments("C"))
if self._element_c != DataType.void:
alignment_C = min(128 // DataTypeSize[self._element_c], alignment_C)
if tile_description is None:
if self._tile_description is None:
op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0]
tile_description = datatypes.td_from_profiler_op(op)
# The selected op may have lower alignment than that determined above, so we must
# reset alignment here.
alignment_C = op.C.alignment
else:
tile_description = self._tile_description
else:
valid, err_str = self._valid_tile_description(tile_description)
if not valid:
raise Exception(f"Invalid tile description. {err_str}")
self._tile_description = tile_description
tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor)
operation = GemmOperationUniversal(
arch=self.current_cc,
tile_description=tile_description,
A=tensor_A, B=tensor_B, C=tensor_C,
epilogue_functor=self.epilogue_functor,
swizzling_functor=self._swizzling_functor,
)
return operation
def compile(self, tile_description: TileDescription = None,
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
print_module: bool = False) -> cutlass_cppgen.backend.GemmOperationUniversal:
"""
Emits and compiles the kernel currently specified. If ``tile_description`` and any
of the ``alignment`` parameters are set, the kernel will be chosen using this
tile description and alignments. Otherwise, a default tile description and alignment
will be used.
:param tile_description: tile description specifying shapes and operand types to use in the kernel
:type tile_description: cutlass_cppgen.backend.TileDescription
:param alignment_A: alignment of operand A
:type alignment_A: int
:param alignment_B: alignment of operand B
:type alignment_B: int
:param alignment_C: alignment of operand C
:type alignment_C: int
:param print_module: whether to print the emitted C++ code
:type print_module: bool
:return: operation that was compiled
:rtype: cutlass_cppgen.backend.GemmOperationUniversal
"""
self.operation = self.construct(tile_description, alignment_A, alignment_B, alignment_C)
if print_module:
print(self.operation.rt_module.emit())
compiler.add_module([self.operation,])
return self.operation
def _verify_rank(self, tensor):
"""
Verifies that ``tensor`` has rank greater than 1
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
:type tensor: numpy/cupy/torch array/tensor object
"""
if len(tensor.shape) < 2:
raise Exception(f"Tensors must be of rank greater than 1. Received tensor of shape: {tensor.shape}")
def _get_batch_count(self, A, B, C, D) -> int:
"""
Returns the batch count specified by the tensors A, B, C, and D and verifies that these
tensors match in batch size. Presence of a batch dimension is detected by one of the
tensors being rank 3. If a batch dimension is present, it must be present in one of
operands A, B, or C (but need not be in all), and must be present in D.
:param A: tensor A
:type A: numpy/cupy/torch array/tensor object
:param B: tensor B
:type B: numpy/cupy/torch array/tensor object
:param C: tensor C
:type C: numpy/cupy/torch array/tensor object
:param D: tensor D
:type D: numpy/cupy/torch array/tensor object
:return: tuple of batch count dimensions
:rtype: tuple
"""
A_batch = prod(A.shape[:-2]) if len(A.shape) > 2 else 1
B_batch = prod(B.shape[:-2]) if len(B.shape) > 2 else 1
if 1 not in [A_batch, B_batch]:
if A_batch != B_batch:
raise Exception(f"Get invalid batch counts: A={A_batch}, B={B_batch}")
return max(A_batch, B_batch)
def _get_batch_stride(self, tensor) -> int:
"""
Returns the batch stride of ``tensor``. If ``tensor`` is only rank-2, batch stride is 0.
:param tensor: tensor object to process
:type tensor: numpy/cupy/torch array/tensor object
:return: stride between each matrix in the batch
:rtype: int
"""
if tensor is not None and len(tensor.shape) > 2:
return tensor.shape[-2] * tensor.shape[-1]
else:
return 0
def _get_problem_args(self, A, B, C, D) -> tuple:
"""
Returns the problem size and GEMM universal mode to use for the
given operands.
:param A: tensor A
:type A: numpy/cupy/torch array/tensor object
:param B: tensor B
:type B: numpy/cupy/torch array/tensor object
:param C: tensor C
:type C: numpy/cupy/torch array/tensor object
:param D: tensor D
:type D: numpy/cupy/torch array/tensor object
:return: tuple containing the problem size (cutlass_cppgen.shape.GemmCoord), the GEMM mode (cutlass_cppgen.GemmUniversalMode), and the batch count (int)
:rtype: tuple
"""
M, K = A.shape[-2:]
N = B.shape[-1]
mode = GemmUniversalMode.Gemm
batch_count = self._get_batch_count(A, B, C, D)
returned_batch_count = batch_count
# If we are running a batched GEMM in which there is a nonzero batch stride
# only for A, then we can fold the batched dimension of A into the M dimension
# (i.e., (b, m, k) x (k, n) -> (m*b, k) x (k, n)). This works only if both A
# and C are row major. A similar operation can be performed if only B has a nonzero
# batch dimension
if batch_count > 1:
A_row = self._layout_a == cutlass_cppgen.LayoutType.RowMajor
B_row = self._layout_b == cutlass_cppgen.LayoutType.RowMajor
C_row = self._layout_c == cutlass_cppgen.LayoutType.RowMajor
# Consider a Tensor to be batched if its rank is > 2 and
# the product of the modes beyond rank 2 equals our pre-determined batch size.
batched = lambda x : x is None or (len(x.shape) > 2 and prod(x.shape[:-2]) == batch_count)
if batched(A) and not batched(B) and (C is None or batched(C)) and A_row and C_row:
M *= batch_count
returned_batch_count = 1
elif not batched(A) and batched(B) and (C is None or batched(C)) and not B_row and not C_row:
N *= batch_count
returned_batch_count = 1
else:
mode = GemmUniversalMode.Batched
return GemmCoord(M, N, K), mode, returned_batch_count
def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name):
"""
Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception
is raised if it does not.
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
:type tensor: numpy/cupy/torch array/tensor object
:param ref_dtype: data type for the tensor that this object was initialized to
:param ref_layout: layout for the tensor that this object was initialized to
:param name: identifier of the tensor to verify. Used in raising exceptions
:type name: str
"""
dtype, layout = datatypes.get_datatype_and_layout(tensor)
if dtype != ref_type or layout != ref_layout:
try:
# Attempt to transpose the tensor to fit the desired layout
tensor = tensor.transpose(-1, -2)
except:
raise Exception(f'Tensor {name} with type and layout ({dtype}, {layout}) '
f'does not match the expected type and '
f'layout of ({ref_type}, {ref_layout}) and transpose failed.')
def run(self, A=None, B=None, C=None, D=None,
alpha=None, beta=None, sync: bool = True, print_module: bool = False, visitor_args: dict = None,
stream: Optional[cuda.CUstream] = None) -> GemmArguments:
"""
Runs the kernel currently specified. If it has not already been, the kernel is emitted and
compiled. Tensors holding operands and outputs of the kernel are sourced either from the
``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta``
parameters provided in this call, or from those
passed in on the construction of this object -- one of the two must be specified.
By default, this call returns only once the kernel has completed. To launch the kernel
and immediately return, set ``sync=False``. In this case, it is the responsibility of the
caller to syncrhonize the results of the kernel before attempting to access outputs
by calling ``sync()`` on the arguments returned from this call.
:param A: tensor representing data type and layout of operand A
:param B: tensor representing data type and layout of operand B
:param C: tensor representing data type and layout of operand C
:param D: tensor representing data type and layout of operand D
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
:param beta: scalar parameter beta from GEMM operation that scales operand C
:param sync: whether the call should wait for the kernel to complete before returning
:type sync: bool
:param print_module: whether to print the emitted C++ code
:type print_module: bool
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
:type stream: :class:`cuda.cuda.CUstream`
:return: arguments passed in to the kernel
:rtype: cutlass_cppgen.backend.GemmArguments
"""
if not stream:
stream = cuda.CUstream(0)
super().run_setup()
A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A")
B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B")
C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C")
D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D")
alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
is_void_c = self._element_c == DataType.void
self._verify_rank(A)
self._verify_rank(B)
if not is_void_c:
self._verify_rank(C)
self._verify_rank(D)
alignment_a = self.possible_operations.find_alignment(A.shape, self._layout_a, operand="A")
alignment_b = self.possible_operations.find_alignment(B.shape, self._layout_b, operand="B")
# Set C alignment based on D.shape so as to correctly get an alignment with void-C
# kernels, for which `C` is None.
alignment_c = self.possible_operations.find_alignment(D.shape, self._layout_c, operand="C")
self.compile(self._tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
alignment_C=alignment_c, print_module=print_module)
problem_size, mode, batch_count = self._get_problem_args(A, B, C, D)
if mode == GemmUniversalMode.Gemm or batch_count == 1:
kwargs = {'split_k_slices': 1}
else:
kwargs = {
'batch': batch_count,
'batch_strides': {
'A': self._get_batch_stride(A),
'B': self._get_batch_stride(B),
'C': self._get_batch_stride(C),
'D': self._get_batch_stride(D)
}
}
kwargs['stream'] = stream
if isinstance(self.epilogue_functor, EpilogueFunctorVisitor):
output_op = self.operation.epilogue_type(visitor_args)
else:
output_op = self.operation.epilogue_type(alpha, beta)
arguments = GemmArguments(
operation=self.operation, problem_size=problem_size,
A=A, B=B, C=C, D=D,
output_op=output_op,
gemm_mode=mode,
**kwargs
)
self.operation.run(arguments)
if sync:
arguments.sync()
return arguments

View File

@ -0,0 +1,269 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Ease-of-use interface for constructing, compiling, and running GEMMs.
The ``GroupedGemm`` interface is meant to allow one to easily instantiate, compile, and run
grouped GEMM operations in CUTLASS via Python, without specifying many configuration parameters.
Under the hood, the interface will select sensible default parameters for the many template
parameters for CUTLASS grouped GEMMs.
Note: optimal performance is not to be expected from this interface. To achieve optimal
performance, one should specify and tune each configuration parameter.
The simplest example of using this interface is the following:
.. highlight:: python
.. code-block:: python
# As, Bs, Cs, and Ds are torch/numpy/cupy tensor objects
plan = cutlass_cppgen.op.GroupedGemm(element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor)
plan.run([A0, A1], [B0, B1], [C0, C1], [D0, D1])
"""
from __future__ import annotations
from typing import Optional
from cutlass_library import DataTypeSize
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
from cutlass_cppgen.backend.gemm_operation import (
GemmGroupedArguments,
GemmOperationGrouped,
)
from cutlass_cppgen.backend.library import (
SchedulerMode,
TensorDescription,
TileDescription,
)
from cutlass_cppgen.op.gemm import Gemm
from cutlass_cppgen.shape import GemmCoord
from cutlass_cppgen.utils import check, datatypes
class GroupedGemm(Gemm):
"""
Constructs a ``GroupedGemm`` object.
The data types and layouts of operands A, B, and C, along with the data type of output D
and that used for accumulation, are bound to the ``GroupedGemm`` object throughout its lifetime --
these are not to be changed after a ``GroupedGemm`` has been constructed.
The constructor has optional parameters for flexibly setting these parameters. Please see the constructor
for ``Gemm`` for examples of these.
:param cc: compute capability of device to generate kernels for
:type cc: int
:param A: tensor representing data type and layout of operands A
:param B: tensor representing data type and layout of operands B
:param C: tensor representing data type and layout of operands C
:param D: tensor representing data type and layout of operands D
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
:param beta: scalar parameter beta from GEMM operation that scales operand C
:param element_accumulator: data type to be used in accumulation of the product of operands A and B
:type element_accumulator: cutlass_cppgen.DataType
:param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
:type element: cutlass_cppgen.DataType
:param layout: generic layout type to be used for operands A, B, C, and D
:type layout: cutlass_cppgen.LayoutType
:param element_A: data type to be used for operand A
:type element_A: cutlass_cppgen.DataType
:param element_B: data type to be used for operand B
:type element_B: cutlass_cppgen.DataType
:param element_C: data type to be used for operand C
:type element_C: cutlass_cppgen.DataType
:param element_D: data type to be used for operand D
:type element_D: cutlass_cppgen.DataType
:type layout_A: layout of operand A
:param layout_A: cutlass_cppgen.LayoutType
:type layout_B: layout of operand B
:param layout_B: cutlass_cppgen.LayoutType
:type layout_C: layout of operand C
:param layout_C: cutlass_cppgen.LayoutType
:type layout_D: layout of operand D
:param layout_D: cutlass_cppgen.LayoutType
"""
def __init__(
self, A=None, B=None, C=None, D=None,
alpha=1.0, beta=0.0, element_accumulator=None,
element=None, layout=None,
element_A=None, element_B=None, element_C=None, element_D=None,
layout_A=None, layout_B=None, layout_C=None,
cc: int = None,
):
super().__init__(
A=A, B=B, C=C, D=D,
alpha=alpha, beta=beta,
element_accumulator=element_accumulator,
element=element, layout=layout,
element_A=element_A, element_B=element_B,
element_C=element_C, element_D=element_D,
layout_A=layout_A, layout_B=layout_B, layout_C=layout_C,
cc=cc
)
# Grouped GEMM specializations for SM90 are currently unavailable. Revert to using SM80
if self.current_cc in [90, 100, 101, 103]:
self._reset_options(80)
self._reset_operations(reset_epilogue=False)
self.name = "grouped_gemm"
@Gemm.swizzling_functor.setter
def swizzling_functor(self, swizzling_functor):
"""
Sets the swizzling functor to the type specified by `swizzling_functor`
"""
raise Exception('Grouped GEMM does not currently support different swizzling functors')
def construct(self, tile_description: TileDescription = None,
alignment_A: int = None,
alignment_B: int = None,
alignment_C: int = None) -> GemmOperationGrouped:
"""
Constructs a ``cutlass_cppgen.backend.GemmOperationGrouped`` based on the input parameters and current
kernel specification of the ``Gemm`` object.
:param tile_description: tile description specifying shapes and operand types to use in the kernel
:type tile_description: cutlass_cppgen.backend.TileDescription
:param alignment_A: alignment of operand A
:type alignment_A: int
:param alignment_B: alignment of operand B
:type alignment_B: int
:param alignment_C: alignment of operand C
:type alignment_C: int
:return: operation that was constructed
:rtype: cutlass_cppgen.backend.GemmOperationGrouped
"""
alignment_A = check.alignment_or_default(alignment_A, max(self.possible_operations.alignments("A")))
alignment_B = check.alignment_or_default(alignment_B, max(self.possible_operations.alignments("B")))
alignment_C = check.alignment_or_default(alignment_C, max(self.possible_operations.alignments("C")))
self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor)
tensor_A = TensorDescription(self._element_a, self._layout_b, alignment_A)
tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
if tile_description is None:
op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0]
tile_description = datatypes.td_from_profiler_op(op)
else:
valid, err_str = self._valid_tile_description(tile_description)
if not valid:
raise Exception(f"Invalid tile description. {err_str}")
self.tile_description = tile_description
operation = GemmOperationGrouped(
arch=self.current_cc,
tile_description=tile_description,
A=tensor_A, B=tensor_B, C=tensor_C,
epilogue_functor=self.epilogue_functor,
swizzling_functor=self._swizzling_functor,
precompute_mode=SchedulerMode.Device)
return operation
def run(self, A, B, C, D,
alpha=None, beta=None, sync: bool = True,
print_module: bool = False,
stream: Optional[cuda.CUstream] = None) -> GemmGroupedArguments:
"""
Runs the kernel currently specified.
By default, this call returns only once the kernel has completed. To launch the kernel
and immediately return, set ``sync=False``. In this case, it is the responsibility of the
caller to syncrhonize the results of the kernel before attempting to access outputs
by calling ``sync()`` on the arguments returned from this call.
:param A: list of tensors representing data type and layout of operand A
:type A: list
:param B: list of tensors representing data type and layout of operand B
:type B: list
:param C: list of tensors representing data type and layout of operand C
:type C: list
:param D: list of tensors representing data type and layout of operand D
:type D: list
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
:param beta: scalar parameter beta from GEMM operation that scales operand C
:param sync: whether the call should wait for the kernel to complete before returning
:type sync: bool
:param print_module: whether to print the emitted C++ code
:type print_module: bool
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
:type stream: :class:`cuda.cuda.CUstream`
:return: arguments passed in to the kernel
:rtype: cutlass_cppgen.backend.GemmGroupedArguments
"""
if not stream:
stream = cuda.CUstream(0)
super().run_setup()
if len(A) != len(B) or len(A) != len(C) or len(A) != len(D):
raise Exception("Lengths of A, B, C, and D lists must be equal")
problem_sizes = []
As, Bs, Cs, Ds = ([None] * len(A) for _ in range(4))
for i in range(len(A)):
As[i] = self._verify_tensor(A[i], self.A, self._element_a, self._layout_a, "A")
Bs[i] = self._verify_tensor(B[i], self.B, self._element_b, self._layout_b, "B")
Cs[i] = self._verify_tensor(C[i], self.C, self._element_c, self._layout_c, "C")
Ds[i] = self._verify_tensor(D[i], self.D, self._element_d, self._layout_d, "D")
problem_sizes.append(GemmCoord(A[i].shape[0], B[i].shape[1], A[i].shape[1]))
alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
alignment_a = min((self.possible_operations.find_alignment(A.shape, self._layout_a, operand="A") for A in As))
alignment_b = min((self.possible_operations.find_alignment(B.shape, self._layout_b, operand="B") for B in Bs))
alignment_c = min((self.possible_operations.find_alignment(C.shape, self._layout_c, operand="C") for C in Cs))
self.compile(self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
alignment_C=alignment_c, print_module=print_module)
arguments = GemmGroupedArguments(
operation=self.operation,
problem_sizes=problem_sizes,
A=As, B=Bs, C=Cs, D=Ds,
output_op=self.operation.epilogue_type(alpha, beta),
stream=stream
)
self.operation.run(arguments)
if sync:
arguments.sync()
return arguments

View File

@ -0,0 +1,431 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d)
"""
from bisect import bisect_left
from cutlass_library import (
DataType,
DataTypeSize,
MathOperation,
OperationKind,
SharedMemPerCC
)
import cutlass_cppgen
from cutlass_cppgen import get_option_registry
from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor
from cutlass_cppgen.backend.evt.passes.util import cc_map
from cutlass_cppgen.backend.utils.device import device_cc
from cutlass_cppgen.epilogue import get_activations, get_activation_epilogue, identity
from cutlass_cppgen.library_defaults import KernelsForDataType, _generator_ccs
from cutlass_cppgen.swizzle import get_swizzling_functors
from cutlass_cppgen.utils import datatypes, check
class OperationBase:
"""
Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d)
"""
def __init__(self, cc: int = None, kernel_cc: int = None, operation_kind = OperationKind.Gemm):
"""
:param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
:type cc: int
:param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
:type kernel_cc: int
:param operation_kind: class of operation that will be performed (e.g., GEMM, Conv)
:type operation_kind: cutlass_library.OperationKind
"""
self.operation_kind = operation_kind
self.cc = cc if cc is not None else device_cc()
self.specified_kernel_cc = kernel_cc is not None
self.current_cc = kernel_cc if kernel_cc is not None else self._find_closest_cc(self.cc)
self.tile_description = None
self._math_operation = None
self.options = get_option_registry().options_for_cc(self.current_cc, operation_kind)
if self.options is None:
raise Exception(f"Invalid or unsupported compute capability: {self.current_cc}")
# Default activation function: identity
self._activation = identity
def _find_closest_cc(self, cc: int) -> int:
"""
Returns the closest CC in _generator_ccs less than or equal to `cc`
:param cc: compute capability to query
:type cc: int
:returns: closest CC in _generator_ccs less than or equal to `cc`
:rtype: int
"""
if cc in _generator_ccs:
return cc
# Find closest CC lower than this CC
idx = bisect_left(_generator_ccs, cc)
if idx == 0:
raise Exception(f'No valid CC to fall back to for {cc}')
return _generator_ccs[idx-1]
def activations(self) -> list:
"""
Returns possible activation functions that can be used
:return: list of activation functions that can be used
:rtype: list
"""
return get_activations()
def swizzling_functors(self) -> list:
"""
Returns possible swizzling functions that can be used
:return: list of swizzling functions that can be used
:rtype: list
"""
return get_swizzling_functors()
def _reset_options(self, cc: int):
"""
Resets the kernel options based on cc
:param cc: compute capability to reset to
:type cc: int
"""
if cc != self.current_cc:
if cc not in _generator_ccs:
raise Exception(f'Invalid CC for CUTLASS kernels: {cc}.')
self.current_cc = cc
self.options = get_option_registry().options_for_cc(self.current_cc, self.operation_kind)
def _verify_scalar(self, scalar, ref_scalar, ref_dtype, name):
"""
Verifies the following properties:
1) Either ``scalar`` or ``ref_scakar`` must be set (i.e., not ``None``)
2) If ``scalar`` is not ``None``, its datatype must match matches the current version
set by the plan (i.e., those in ``ref_dtype``)
If either of these properties does not hold, an exception is raised. If these properties hold and
``scalar`` is not ``None``, ``scalar`` is returned. Otherwise, ``ref_scalar`` is returned.
:param scalar: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
:type scalar: numpy/cupy/torch scalar
:param ref_scalar: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in
:type ref_scalar: numpy/cupy/torch scalar
:param ref_dtype: data type for the scalar that this object was initialized to
:param name: identifier of the scalar to verify. Used in raising exceptions
:type name: str
:return: valid scalar to use
:rtype: numpy/cupy/torch scalar
"""
if scalar is None:
if ref_scalar is None:
raise Exception(f"Scalar {name} must be set.")
return ref_scalar
if hasattr(scalar, "dtype"):
dtype = datatypes.library_type(scalar.dtype)
if dtype != ref_dtype:
raise Exception(
f"Tensor {name} with type {dtype} does not match expected type {ref_dtype}."
)
return scalar
def _verify_tensor(self, tensor, ref_tensor, ref_dtype, ref_layout, name):
"""
Verifies the following properties:
If ref_dtype is not void:
1) Either ``tensor`` or ``ref_tensor`` must be set (i.e., not ``None``)
2) If ``tensor`` is not ``None``, its datatype and layout must match matches the current versions
set by the plan (i.e., those in ``ref_dtype`` and ``ref_layout``)
If ref_dtype is void:
Neither ``tensor`` nor ``ref_tensor`` are set
If either of these properties does not hold, an exception is raised. If these properties hold and
``tensor`` is not ``None``, ``tensor`` is returned. Otherwise, ``ref_tensor`` is returned.
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
:type tensor: numpy/cupy/torch array/tensor object
:param ref_tensor: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in
:type ref_tensor: numpy/cupy/torch array/tensor object
:param ref_dtype: data type for the tensor that this object was initialized to
:param ref_layout: layout for the tensor that this object was initialized to
:param name: identifier of the tensor to verify. Used in raising exceptions
:type name: str
:return: valid tensor object to use
:rtype: numpy/cupy/torch array/tensor object
"""
if ref_dtype == DataType.void:
if tensor is not None or ref_tensor is not None:
raise Exception("Operands with element DataType.void must not be provided a tensor")
return None
if tensor is None:
if ref_tensor is None:
raise Exception(f"Tensor {name} must be set.")
return ref_tensor
self._verify_type_and_layout(tensor, ref_dtype, ref_layout, name)
return tensor
@property
def opclass(self) -> cutlass_cppgen.OpcodeClass:
"""
Returns the opcode class currently in use
:return: opcode class currently in use
:rtype: cutlass_cppgen.OpcodeClass
"""
return self.op_class
@opclass.setter
def opclass(self, oc: cutlass_cppgen.OpcodeClass):
if isinstance(oc, str):
oc = datatypes.getattr_enum(cutlass_cppgen.OpcodeClass, oc)
if oc in self.possible_op_classes:
self.op_class = oc
else:
raise Exception(
f'Unsupported operation class {oc} for CC {self.cc} and data type combination '
f'({self._element_a}, {self._element_b}, {self._element_accumulator}) and '
f'layout combination ({self._layout_a}, {self._layout_b}).')
# Changing the op class also changes the possible operations available. Reset these.
self.possible_operations = self.options.operations(
self.op_class, self._element_a, self._element_b,
self._element_accumulator, self._layout_a, self._layout_b, self._math_operation)
# Changing the op class changes the elements per access in the epilogue. Reset this.
if self.epilogue_functor is not None:
self.epilogue_functor = self._reset_epilogue_functor_alignment(self._elements_per_access(), self.epilogue_functor)
@property
def math_operation(self) -> cutlass_cppgen.MathOperation:
"""
Returns the math operation currently in use
:return: math operation currently in use
:rtype: cutlass_cppgen.MathOperation
"""
return self._math_operation
@math_operation.setter
def math_operation(self, mo: cutlass_cppgen.MathOperation):
if isinstance(mo, str):
mo = datatypes.getattr_enum(cutlass_cppgen.MathOperation, mo)
if not self.specified_kernel_cc:
if self.current_cc in [90, 100, 101, 103]:
# CUTLASS 3.0 kernels do not use different math operations. If one is specified, we
# revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels.
cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
self._reset_options(80)
self._reset_operations(reset_epilogue=False)
elif self.current_cc in [90, 100, 101, 103]:
raise Exception("CUTLASS 3.0 kernels do not use different math operations. "
"To use 2.x kernels with a specific math operation, do not set the `kernel_cc`"
"parameter when constructing the plan.")
self._math_operation = mo
self._reset_operations()
def _elements_per_access(self):
if self.op_class == cutlass_cppgen.OpcodeClass.Simt:
return 1
elif self._element_c != DataType.void:
return 128 // DataTypeSize[self._element_c]
else:
return 128 // max(self.possible_operations.alignments("C"))
def _create_epilogue_functor_activation(self, activation):
"""
Returns the epilogue functor with given activation function
"""
if self.epilogue_functor is None:
elements_per_access = self._elements_per_access()
else:
elements_per_access = self.epilogue_functor.epilogue_vector_length
if not self.specified_kernel_cc:
if self.current_cc in [90, 100, 101, 103] and activation != identity:
# CUTLASS 3.0 kernels in Python currently only support identity activation. If one requests a non-identity activation,
# revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels.
cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
if self._element_c != self._element_d:
raise Exception("CUTLASS 2.x kernels require element C to be the same as element D")
self._reset_options(80)
self._reset_operations(reset_epilogue=False)
elif (self.cc in [90, 100, 101, 103] and self.current_cc not in [90, 100, 101, 103] and activation == identity and self._math_operation is None):
# SM80 fallback kernels are currently used. Since an identity activation is requested,
# we can switch back to using SM90 kernels.
self._reset_options(self.cc)
self._reset_operations(reset_epilogue=False)
else:
if self.current_cc in [90, 100, 101, 103] and activation != identity:
raise Exception("Epilogues with elementwise fusion are not currently supported "
"in the Python interface for 3.x kernels. To use 2.x kernels "
"with fused elementwise epilogues, do not set the `kernel_cc` "
"parameter when constructing the plan.")
return get_activation_epilogue(
activation,
self._element_d,
elements_per_access,
self._element_accumulator,
self._element_accumulator,
)
def _reset_epilogue_functor_activation(self, activation):
"""
Set the epilogue functor based on the provided activation function
"""
self.epilogue_functor = self._create_epilogue_functor_activation(activation)
def _reset_epilogue_functor_alignment(self, alignment, epilogue_functor):
"""
Reset the alignment of the current epilogue functor based on alignment C
"""
if isinstance(epilogue_functor, EpilogueFunctorVisitor):
return epilogue_functor
if epilogue_functor is None or not hasattr(epilogue_functor, 'activation_functor'):
# Identity epilogue does not have 'activation_functor'
activation = identity
else:
activation = epilogue_functor.activation_functor
epilogue_functor = get_activation_epilogue(
activation,
self._element_d,
alignment,
self._element_accumulator,
self._element_accumulator,
)
return epilogue_functor
@property
def activation(self):
"""
Returns the type of the current activation function used
"""
if hasattr(self.epilogue_functor, "activation_functor"):
return self.epilogue_functor.activation_functor
else:
return identity
@activation.setter
def activation(self, act):
"""
Sets the type of the activation function to use
Activation can come with a set of arguments
:param act: type of activation function to use
:type act: str or tuple. e.g. "relu", ("leaky_relu", 0.01)
"""
if isinstance(act, tuple):
if isinstance(act[0], str):
act_fn = getattr(cutlass_cppgen.backend.epilogue, act[0])
else:
act_fn = act[0]
self._reset_epilogue_functor_activation(act_fn)
self._activation_args = act[1]
self._activation = act[0]
else:
if isinstance(act, str):
act = getattr(cutlass_cppgen.backend.epilogue, act)
self._reset_epilogue_functor_activation(act)
self._activation = act
@property
def epilogue_visitor(self):
"""
Return the epilogue functor
"""
return self.epilogue_functor
@epilogue_visitor.setter
def epilogue_visitor(self, visitor):
"""
Create the epilogue visitor
"""
self.epilogue_functor = EpilogueFunctorVisitor(cc_map[self.cc], visitor)
# The epilogue_functor may consume too much shared memory
# Reset the possible operations
if self.cc not in [90, 100, 101, 103]:
# The shared memory is only a concern for sm90+ epilogue
# In sm80, the epilogue and mainloop share the shared memory
return
datatype_comb = self.possible_operations.datatype_comb
layout_comb = self.possible_operations.layout_comb
new_possible_operations = KernelsForDataType(datatype_comb, layout_comb)
for operation in self.possible_operations.all_operations:
td = datatypes.td_from_profiler_op(operation)
# Filter invalid epilogue schedules
if cc_map[self.cc] == 90 and td.epilogue_schedule not in [
cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized,
cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative]:
continue
epilogue_smem_bytes = self.epilogue_functor.get_smem_size(td)
# Verify the maximum number of mainloop stages
mainloop_smem_per_stage = check.calculate_smem_usage_per_stage(td, OperationKind.Gemm)
smem_capacity_bytes = SharedMemPerCC[self.cc] << 10
mainloop_stages = (smem_capacity_bytes - epilogue_smem_bytes) // mainloop_smem_per_stage
if mainloop_stages < 2:
# Mainloop stages must >= 2
continue
new_possible_operations.add(operation)
if len(new_possible_operations.all_operations) == 0:
raise RuntimeError(
"The epilogue consumes too much shared memory. "
"No valid tile description is found in the generator.")
self.possible_operations = new_possible_operations
def run_setup(self):
"""
Steps that must be taken before caling `plan.run()`
"""
# Initialize the memory pool if, if not already done
cutlass_cppgen.get_memory_pool()

View File

@ -0,0 +1,184 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Utilities for expressing shapes
"""
from cutlass_library import (
ConvMode,
ConvKind,
LayoutType
)
from cutlass_cppgen.backend.c_types import (
Conv2DProblemSize_,
GemmCoord_,
GemmCoordBatched_
)
class MatrixCoord:
def __init__(self, row, col):
self._row = row
self._col = col
@property
def row(self):
return self._row
@property
def column(self):
return self._col
def leading_dimension(self, layout: LayoutType) -> int:
"""
Returns the leading dimension for a matrix with layout ``layout`` and shape provided by the MatrixCoord.
:param layout: layout of matrix
:type layout: cutlass_library.LayoutType
:returns: leading dimension
:rtype: int
"""
if layout == LayoutType.RowMajor:
return self._col
elif layout == LayoutType.ColumnMajor:
return self._row
else:
raise Exception(f'Unsupported layout for leading dimension calculation: {layout}')
class GemmCoord:
def __init__(self, m: int, n: int, k: int):
self._m = m
self._n = n
self._k = k
@property
def m(self) -> int:
return self._m
@property
def n(self) -> int:
return self._n
@property
def k(self) -> int:
return self._k
@property
def mk(self) -> MatrixCoord:
return MatrixCoord(self._m, self._k)
@property
def mn(self) -> MatrixCoord:
return MatrixCoord(self._m, self._n)
@property
def kn(self) -> MatrixCoord:
return MatrixCoord(self._k, self._n)
@property
def ctype(self) -> GemmCoord_:
return GemmCoord_(self._m, self._n, self._k)
def batched_ctype(self, batch_count: int) -> GemmCoordBatched_:
return GemmCoordBatched_(self._m, self._n, self._k, batch_count)
class Conv2DProblemSize:
def __init__(
self, n: int, h: int, w: int, c: int,
k: int, r: int, s: int, c_: int,
pad_h: int, pad_w: int, stride_h: int, stride_w: int,
dilation_h: int, dilation_w: int, mode: ConvMode=ConvMode.CrossCorrelation,
split_k_slices: int=1, groups: int=1):
self.N = n
self.H = h
self.W = w
self.C = c
self.K = k
self.R = r
self.S = s
self.pad_h = pad_h
self.pad_w = pad_w
self.stride_h = stride_h
self.stride_w = stride_w
self.dilation_h = dilation_h
self.dilation_w = dilation_w
self.mode = int(mode)
self.split_k_slices = split_k_slices
self.groups = groups
self.P = ((h + pad_h * 2 - r * dilation_h) // stride_h) + 1
self.Q = ((w + pad_w * 2 - s * dilation_w) // stride_w) + 1
@property
def ctype(self) -> Conv2DProblemSize_:
return Conv2DProblemSize_(self)
def implicit_gemm_size(self, kind: ConvKind):
if kind == ConvKind.Fprop:
return GemmCoord(
self.N * self.P * self.Q,
self.K,
self.R * self.S * self.C // self.groups
)
elif kind == ConvKind.Dgrad:
return GemmCoord(
self.N * self.H * self.W,
self.C,
self.R * self.S * self.K
)
elif kind == ConvKind.Wgrad:
return GemmCoord(
self.K,
self.R * self.S * self.C,
self.N * self.P * self.Q
)
@staticmethod
def from_sizes(input_size, weight_size):
K, R, S, _ = weight_size
pad_h = R // 2
pad_w = S // 2
stride_h = 1
stride_w = 1
dilation_h = 1
dilation_w = 1
return Conv2DProblemSize(
*input_size,
*weight_size,
pad_h, pad_w,
stride_h, stride_w,
dilation_h, dilation_w
)

View File

@ -0,0 +1,65 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Registry of swizzling functions
"""
from cutlass_library import SwizzlingFunctor
IdentitySwizzle1 = SwizzlingFunctor.Identity1
IdentitySwizzle2 = SwizzlingFunctor.Identity2
IdentitySwizzle4 = SwizzlingFunctor.Identity4
IdentitySwizzle8 = SwizzlingFunctor.Identity8
HorizontalSwizzle = SwizzlingFunctor.Horizontal
ThreadblockSwizzleStreamK = SwizzlingFunctor.StreamK
StridedDgradIdentitySwizzle1 = SwizzlingFunctor.StridedDgradIdentity1
StridedDgradIdentitySwizzle4 = SwizzlingFunctor.StridedDgradIdentity4
StridedDgradHorizontalSwizzle = SwizzlingFunctor.StridedDgradHorizontal
_swizzling_functors = [
IdentitySwizzle1,
IdentitySwizzle2,
IdentitySwizzle4,
IdentitySwizzle8,
HorizontalSwizzle,
ThreadblockSwizzleStreamK,
StridedDgradIdentitySwizzle1,
StridedDgradIdentitySwizzle4,
StridedDgradHorizontalSwizzle,
]
def get_swizzling_functors():
return _swizzling_functors

View File

@ -0,0 +1,41 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_cppgen.utils.check import (
alignment_or_default,
calculate_smem_usage,
calculate_smem_usage_per_stage,
valid_cluster_shape,
valid_schedule,
valid_stage_count,
update_alignment,
)

View File

@ -0,0 +1,262 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Utility functions for checking constraints on kernels and calculating kernel attributes
"""
import ctypes
from cutlass_library import DataTypeSize, KernelScheduleSuffixes, OperationKind, SharedMemPerCC
import cutlass_cppgen
from cutlass_cppgen.backend.library import TileDescription
def calculate_smem_usage_per_stage(td: TileDescription, operation_kind: OperationKind) -> int:
"""
Returns the amount of shared memory in bytes consumed in a single stage of a kernel.
:param td: tile description to compute shared memory of
:type td: TileDescription
:param operation_kind: identifier for the type of operation being performed
:type operation_kind: cutlass_library.OperationKind
:return: number of bytes of shared memory consumed by a single stage
:rtype: int
"""
m, n, k = td.blackwell_threadblock_shape
if td.is_2sm:
m //= 2
if operation_kind == OperationKind.Gemm:
stage_barrier_bytes = 32
return (
(DataTypeSize[td.math_instruction.element_a] * m * k // 8)
+ (DataTypeSize[td.math_instruction.element_b] * k * n // 8)
+ stage_barrier_bytes
)
else:
raise Exception(f"No available shared memory calculation for operation kind {operation.operation_kind}")
def calculate_smem_usage(operation) -> int:
"""
Returns the amount of shared memory in bytes consumed by a kernel.
:return: number of bytes of shared memory consumed by the operation
:return: int
"""
_per_stage = calculate_smem_usage_per_stage(operation.tile_description, operation.operation_kind)
return _per_stage * operation.tile_description.stages
def valid_stage_count(
cc: int,
kernel_cc: int,
td: TileDescription,
element_C: cutlass_cppgen.DataType = None,
element_D: cutlass_cppgen.DataType = None,
verbose: bool = True) -> tuple:
"""
Checks whether a device with `cc` supports the number of stages within `tile_description`, both
based on raw limits on the number of stages and based on shared memory capacity
:param cc: compute capability of device in question
:type cc: int
:param kernel_cc: compute capability that the kernel targets (corresponding to the arch::SMxy tag in CUTLASS)
:type kernel_cc: int
:param td: tile description to check
:type td: TileDescription
:param element_C: data type of operand C
:type element_C: cutlass_cppgen.DataType
:param element_D: data type of operand D
:type element_D: cutlass_cppgen.DataType
:param verbose: whether to log warnings
:type verbose: bool
:return: tuple with the first element indicating whether the provided tile description is
valid for the provided device and the second element being an error message
:rtype: tuple
"""
if kernel_cc in [90, 100, 101, 103]:
if (td.stages is None or td.stages == 0):
# Stage count of None or 0 for SM90 indicates that the CollectiveBuilder automatically
# determines the stage count to use. Thus, all settings are valid in these scenarios.
return (True, "")
elif verbose:
cutlass_cppgen.logger.warning(
"Setting an explicit stage count for SM90 kernels currently may "
"result in compilation errors if the combination of tile shape, "
"stage count, and shared memory requirement of the epilogue exceeds "
"the available shared memory per SM.")
if td.stages <= 0:
return (False, f"Stage counts must be positive integers. Tile description has stage count of {td.stages}.")
if cc < 80 and td.stages != 2:
return (False, f"Tile description has stage count of {td.stages}, "
f"but only 2 stages are supported on SM{cc}.")
# The calculation below does not consider shared memory used by the epilogue and, thus,
# only catches cases in which the mainloop exceeds the device's shared memory capacity.
# This is not a concern for CUTLASS 2.x kernels, for which the shared memory of the
# mainloop and epilogue is shared.
smem_per_stage = calculate_smem_usage_per_stage(td, OperationKind.Gemm)
smem_usage_mainloop = (smem_per_stage * td.stages)
smem_arch = SharedMemPerCC[cc] << 10
if smem_usage_mainloop > smem_arch:
return ( False,
"Configuration uses too much shared memory. Consider reducing stage count or tile shape.\n"
f"Details:\n"
f"Mainloop uses {smem_per_stage} bytes of shared memory per stage, and "
f"{td.stages} stages for a total of {smem_usage_mainloop} bytes.\n"
f"The maxmium amount of shared memory that can be used per block on CC {cc} is {smem_arch}.")
return (True, "")
def valid_cluster_shape(cc: int, cluster_shape: list) -> tuple:
"""
Checks whether a device with `cc` supports a thread block cluster of shape `cluster_shape`.
:param cc: compute capability of device in question
:type cc: int
:param cluster_shape: dimensions of thread block cluster shape to check
:type cluster_shape: list
:return: tuple with the first element indicating whether the provided cluster shape is
valid for the provided device and the second element being an error message
:rtype: tuple
"""
if cc < 90 or cc in [120, 121]:
if cluster_shape != [1, 1, 1]:
return (False,
f"Cluster shape for pre-SM90 architectures and SM 120 and 121 must be [1, 1, 1]. Received cluster shape of "
f"{cluster_shape} for SM{cc}.")
else:
return (True, "")
if len(cluster_shape) != 3:
return (False,
f"Cluster shapes must be rank-3. Received {cluster_shape} (rank {len(cluster_shape)}")
if cluster_shape[2] != 1:
return (False,
"CUTLASS kernels currently require the third dimension of cluster shape to be 1. "
f"Received cluster shape of {cluster_shape}.")
return (True, "")
def valid_schedule(
cc: int,
kernel_schedule: cutlass_cppgen.KernelScheduleType,
epilogue_schedule: cutlass_cppgen.EpilogueScheduleType,
tile_scheduler: cutlass_cppgen.TileSchedulerType) -> tuple:
"""
Checks that the kernel and epilogue schedules passed in are a valid combination for
a device of compute capability ``cc``.
:param cc: compute capability of device in question
:type cc: int
:param kernel_schedule: kernel schedule type
:type kernel_schedule: cutlass_cppgen.KernelScheduleType
:param epilogue_schedule: epilogue schedule type
:type epilogue_schedule: cutlass_cppgen.EpilogueScheduleType
:param tile_scheduler: tile scheduler type
:type tile_scheduler: cutlass_cppgen.TileSchedulerType
:return: tuple with the first element indicating whether the provided schedules are
valid for the provided device and the second element being an error message
:rtype: tuple
"""
kernel_auto = (kernel_schedule == cutlass_cppgen.KernelScheduleType.ScheduleAuto)
epilogue_auto = (epilogue_schedule == cutlass_cppgen.EpilogueScheduleType.ScheduleAuto)
tile_scheduler_default = (tile_scheduler == cutlass_cppgen.TileSchedulerType.Default)
if (cc < 90 or cc in [120, 121]) and not (kernel_auto and epilogue_auto and tile_scheduler_default):
return (False, "Non-default schedules are only supported on SM90 and beyond (excluding SM120 and SM121)")
if cc == 90 and ((kernel_auto and not epilogue_auto) or (not kernel_auto and epilogue_auto)):
return (False, "Kernel and epilogue schedules must either both be auto or neither be auto")
if not tile_scheduler_default:
cooperative_kernels = [cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative,
cutlass_cppgen.KernelScheduleType.CpAsyncWarpSpecializedCooperative]
if cc == 90 and (tile_scheduler == cutlass_cppgen.TileSchedulerType.StreamK) and (kernel_schedule not in cooperative_kernels):
return (False, "Stream-K tile scheduler is currently only supported with the cooperative kernel schedule")
return (True, "")
def alignment_or_default(alignment_provided: int, default_alignment: int) -> int:
"""
Returns `alignment_provided` if it is set, otherwise `default_alignment` and checks
that `alignment_provided` does not exceed `default_alignment`.
:param alignment_provided: alignment preference specified. Can be None.
:type alignment_provided: int
:param default_alignment: alignment to use if `alignment_provided` is None
:type default_alignment: int
:return: alignment to use
:rtype: int
"""
if alignment_provided is not None:
if alignment_provided > default_alignment:
raise Exception(f"Alignment {alignment_provided} exceeds the maximum supported of {default_alignment}.")
return alignment_provided
return default_alignment
def update_alignment(alignment_provided:int, default_alignment: int) -> int:
"""
Returns `alignment_provided` if it is set, otherwise `default_alignment` and checks
that `alignment_provided` does not exceed `default_alignment`.
:param alignment_provided: alignment preference specified. Can be None.
:type alignment_provided: int
:param default_alignment: alignment to use if `alignment_provided` is None
:type default_alignment: int
:return: alignment to use
:rtype: int
"""
if alignment_provided is not None:
if alignment_provided > default_alignment:
if alignment_provided % default_alignment == 0:
return default_alignment
raise Exception(f"Alignment {alignment_provided} exceeds the maximum supported of {default_alignment}.")
return alignment_provided
return default_alignment

View File

@ -0,0 +1,362 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Utility functions for converting between frontend datatypes and CUTLASS datatypes
"""
import cutlass_cppgen
from cutlass_library import (
DataTypeSize,
MathOperation,
MathInstruction
)
from cutlass_cppgen.backend.library import (
TileDescription,
)
bfloat16_available = None
cupy_available = None
numpy_available = None
torch_available = None
_library_to_cupy_dict = None
_library_to_numpy_dict = None
_library_to_torch_dict = None
_torch_to_library_dict = None
def is_numpy_available():
global numpy_available, _library_to_numpy_dict
if numpy_available is None:
try:
import numpy as np
numpy_available = True
_library_to_numpy_dict = {
cutlass_cppgen.DataType.f16: np.float16,
cutlass_cppgen.DataType.f32: np.float32,
cutlass_cppgen.DataType.f64: np.float64,
cutlass_cppgen.DataType.s8: np.int8,
cutlass_cppgen.DataType.s32: np.int32,
}
except ImportError:
numpy_available = False
_library_to_numpy_dict = {}
return numpy_available
def is_numpy_tensor(inp) -> bool:
if is_numpy_available():
import numpy as np
return isinstance(inp, np.ndarray)
return False
def numpy_library_type(inp) -> cutlass_cppgen.DataType:
if is_numpy_available():
import numpy as np
if inp == np.float16:
return cutlass_cppgen.DataType.f16
elif inp == np.float32:
return cutlass_cppgen.DataType.f32
elif inp == np.float64:
return cutlass_cppgen.DataType.f64
elif inp == np.int8:
return cutlass_cppgen.DataType.s8
elif inp == np.int32:
return cutlass_cppgen.DataType.s32
return None
def numpy_type(inp):
return _library_to_numpy_dict.get(inp, None)
def is_cupy_available():
global cupy_available
if cupy_available is None:
try:
import cupy as cp
cupy_available = True
_library_to_cupy_dict = {
cutlass_cppgen.DataType.f16: cp.float16,
cutlass_cppgen.DataType.f32: cp.float32,
cutlass_cppgen.DataType.f64: cp.float64,
cutlass_cppgen.DataType.s8: cp.int8,
cutlass_cppgen.DataType.s32: cp.int32,
}
except ImportError:
cupy_available = False
_library_to_cupy_dict = {}
return cupy_available
def is_cupy_tensor(inp) -> bool:
if is_cupy_available():
import cupy as cp
return isinstance(inp, cp.ndarray)
return False
def cupy_library_type(inp) -> cutlass_cppgen.DataType:
if is_cupy_available():
import cupy as cp
if inp == cp.float16:
return cutlass_cppgen.DataType.f16
elif inp == cp.float32:
return cutlass_cppgen.DataType.f32
elif inp == cp.float64:
return cutlass_cppgen.DataType.f64
return None
def cupy_type(inp):
return _library_to_cupy_dict.get(inp, None)
def is_torch_available():
global torch_available, _library_to_torch_dict, _torch_to_library_dict
if torch_available is None:
try:
import torch
torch_available = True
_torch_to_library_dict = {
torch.half: cutlass_cppgen.DataType.f16,
torch.float16: cutlass_cppgen.DataType.f16,
torch.bfloat16: cutlass_cppgen.DataType.bf16,
torch.float: cutlass_cppgen.DataType.f32,
torch.float32: cutlass_cppgen.DataType.f32,
torch.double: cutlass_cppgen.DataType.f64,
torch.float64: cutlass_cppgen.DataType.f64,
torch.int8: cutlass_cppgen.DataType.s8,
torch.int32: cutlass_cppgen.DataType.s32,
torch.uint8: cutlass_cppgen.DataType.u8,
}
_library_to_torch_dict = {
cutlass_cppgen.DataType.f16: torch.half,
cutlass_cppgen.DataType.f16: torch.float16,
cutlass_cppgen.DataType.bf16: torch.bfloat16,
cutlass_cppgen.DataType.f32: torch.float,
cutlass_cppgen.DataType.f32: torch.float32,
cutlass_cppgen.DataType.f64: torch.double,
cutlass_cppgen.DataType.f64: torch.float64,
cutlass_cppgen.DataType.s8: torch.int8,
cutlass_cppgen.DataType.s32: torch.int32,
cutlass_cppgen.DataType.u8: torch.uint8,
}
def possibly_add_type(torch_type_name, cutlass_type):
# Only try adding the type if the version of torch being used supports it
if hasattr(torch, torch_type_name):
torch_type = getattr(torch, torch_type_name)
_torch_to_library_dict[torch_type] = cutlass_type
_library_to_torch_dict[cutlass_type] = torch_type
possibly_add_type("float8_e4m3fn", cutlass_cppgen.DataType.e4m3)
possibly_add_type("float8_e5m2", cutlass_cppgen.DataType.e5m2)
except ImportError:
torch_available = False
_torch_to_library_dict = {}
_library_to_torch_dict = {}
return torch_available
def is_torch_tensor(inp) -> bool:
if is_torch_available():
import torch
return isinstance(inp, torch.Tensor)
return False
def torch_library_type(inp) -> cutlass_cppgen.DataType:
return _torch_to_library_dict.get(inp, None)
def torch_type(inp):
return _library_to_torch_dict.get(inp, None)
def is_bfloat16_available():
global bfloat16_available
if bfloat16_available is None:
try:
import bfloat16
bfloat16_available = True
except ImportError:
bfloat16_available = False
return bfloat16_available
def bfloat16_library_type(inp) -> cutlass_cppgen.DataType:
if is_bfloat16_available():
import bfloat16
if inp == bfloat16.bfloat16:
return cutlass_cppgen.DataType.bf16
def bfloat16_type(inp):
if is_bfloat16_available():
import bfloat16
if inp == cutlass_cppgen.DataType.bf16:
return bfloat16.bfloat16
def library_type(inp):
if inp in DataTypeSize:
return inp
for cvt_fn in [
bfloat16_library_type,
cupy_library_type,
numpy_library_type,
torch_library_type,
]:
out = cvt_fn(inp)
if out is not None:
return out
raise Exception(f"No available conversion from type {inp} to a library type.")
def _tensor_from_numpy(np_tensor):
dtype = library_type(np_tensor.dtype)
if np_tensor.flags.c_contiguous:
layout = cutlass_cppgen.LayoutType.RowMajor
elif np_tensor.flags.f_contiguous:
layout = cutlass_cppgen.LayoutType.ColumnMajor
return (dtype, layout)
def _tensor_from_torch(pt_tensor):
dtype = library_type(pt_tensor.dtype)
return (dtype, cutlass_cppgen.LayoutType.RowMajor)
def get_datatype_and_layout(tensor):
if (is_numpy_tensor(tensor) or is_cupy_tensor(tensor)):
return _tensor_from_numpy(tensor)
elif is_torch_tensor(tensor):
return _tensor_from_torch(tensor)
elif isinstance(tensor, float) or isinstance(tensor, int):
return (cutlass_cppgen.DataType.f32, cutlass_cppgen.LayoutType.RowMajor)
else:
raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.")
def get_tensor_shape(tensor, op="GEMM"):
if (is_numpy_tensor(tensor) or is_cupy_tensor(tensor)):
return tensor.shape
elif is_torch_tensor(tensor):
size = tensor.size()
if op == "CONV":
# PyTorch Tensors have shape NCHW
return (size[0], size[2], size[3], size[1])
else:
return tuple(tensor.size())
elif isinstance(tensor, float) or isinstance(tensor, int):
return (1,)
else:
raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.")
_math_operation_value_map = {x.value: x for x in MathOperation}
def backend_math_operation(math_op: MathOperation):
if math_op.value not in _math_operation_value_map.keys():
raise Exception(f"Unable to convert math operation of type {math_op} to backend math operation.")
return _math_operation_value_map[math_op.value]
def construct_backend_td(td: cutlass_cppgen.TileDescription,
kernel_schedule: cutlass_cppgen.KernelScheduleType,
epilogue_schedule: cutlass_cppgen.EpilogueScheduleType,
tile_scheduler: cutlass_cppgen.TileSchedulerType) -> TileDescription:
mi = td.math_instruction
backend_mi = MathInstruction(
mi.instruction_shape,
mi.element_a,
mi.element_b,
mi.element_accumulator,
mi.opcode_class,
backend_math_operation(mi.math_operation)
)
cluster_shape = td.cluster_shape if hasattr(td, "cluster_shape") else [1, 1, 1]
return TileDescription(td.threadblock_shape, td.stages, td.warp_count,
backend_mi, cluster_shape, kernel_schedule, epilogue_schedule, tile_scheduler)
def td_from_profiler_op(op) -> TileDescription:
"""
Converts the profiler's TileDescription in ``op`` into the backend TileDescription
:param op: profiler Operation
:returns: backend TileDescription
:rtype: cutlass_cppgen.backend.TileDescription
"""
kschedule = op.kernel_schedule if hasattr(op, 'kernel_schedule') else None
eschedule = op.epilogue_schedule if hasattr(op, 'epilogue_schedule') else None
tschedule = op.tile_scheduler if hasattr(op, 'tile_scheduler') else None
return construct_backend_td(op.tile_description, kschedule, eschedule, tschedule)
def td_from_profiler_td(td: TileDescription) -> TileDescription:
"""
Converts the profiler's TileDescription into the backend TileDescription
:param td: profiler TileDescription
:type td: cutlass_cppgen.TileDescription
:returns: backend TileDescription
:rtype: cutlass_cppgen.backend.TileDescription
"""
return construct_backend_td(td, kernel_schedule=None, epilogue_schedule=None, tile_scheduler=None)
def to_camel_case(snake_str):
return "".join(x.capitalize() for x in snake_str.lower().split("_"))
def getattr_enum(obj, attr_name):
# The attr_name is under the snake_case
camel_attr = to_camel_case(attr_name)
if hasattr(obj, camel_attr):
return getattr(obj, camel_attr)
else:
raise Exception(f"Invalid option: {attr_name}")

View File

@ -0,0 +1,41 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import importlib
from typing import Any
def lazy_import(mod_name: str) -> Any:
class Lazy:
def __getattr__(self, name:str) -> Any:
module = importlib.import_module(mod_name)
return getattr(module, name)
return Lazy()

View File

@ -0,0 +1,196 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Profiler based on the cuda events
"""
import re
import subprocess
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
cudart = lazy_import("cuda.cudart")
import numpy as np
from cutlass_cppgen import CUTLASS_PATH
from cutlass_cppgen.backend.library import DataTypeSize
from cutlass_cppgen.op.op import OperationBase
from cutlass_cppgen.shape import GemmCoord
from cutlass_cppgen.utils.datatypes import is_numpy_tensor
class GpuTimer:
def __init__(self) -> None:
self.events = [
cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1],
cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1],
]
def start(self, stream=None):
if not stream:
stream = cuda.CUstream(0)
(err,) = cuda.cuEventRecord(self.events[0], stream)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"CUDA Error {str(err)}")
def stop(self, stream=None):
if not stream:
stream = cuda.CUstream(0)
(err,) = cuda.cuEventRecord(self.events[1], stream)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"CUDA Error {str(err)}")
pass
def stop_and_wait(self, stream=None):
if not stream:
stream = cuda.CUstream(0)
self.stop(stream)
if stream:
(err,) = cuda.cuStreamSynchronize(stream)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"CUDA Error {str(err)}")
else:
(err,) = cudart.cudaDeviceSynchronize()
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"CUDA Error {str(err)}")
def duration(self, iterations=1):
err, duration = cuda.cuEventElapsedTime(self.events[0], self.events[1])
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"CUDA Error {str(err)}")
return duration / float(iterations)
class CUDAEventProfiler:
def __init__(self, op: OperationBase, warmup_iterations: int=500, iterations: int=500, *args, **kwargs) -> None:
self.arguments = op.run(*args, **kwargs)
self.operation = op.operation
self.warmup_iterations = warmup_iterations
self.iterations = iterations
self.timer = GpuTimer()
#
# Cutlass Python Interface Profiler
#
def __call__(self):
for _ in range(self.warmup_iterations):
self.operation.run(self.arguments)
self.timer.start()
for _ in range(self.iterations):
self.operation.run(self.arguments)
self.timer.stop_and_wait()
runtime = self.timer.duration(self.iterations)
return runtime
#
# CUTLASS Profiler
#
def run_cutlass_profiler(self):
alpha = 1.0
beta = 1.0
profiler_path = CUTLASS_PATH + "/build/tools/profiler/cutlass_profiler"
kernel_name = self.operation.procedural_name()
verification_providers = "device"
provider = "cutlass"
problem_size = self.arguments.problem_size
if "cutlass3x" in kernel_name:
# cutlass3x generator only have column-major output
layout_name = self.operation.layout_name_3x()
if layout_name[-1] == "t":
new_layout_name = "".join(["n" for l in layout_name if l == "t" or "t"])
problem_size = GemmCoord(problem_size.n, problem_size.m, problem_size.k)
kernel_name = kernel_name.replace(layout_name, new_layout_name)
batch_count = self.arguments.batch_count
cmd = f"{profiler_path} --kernels={kernel_name} --verification-providers={verification_providers} " \
f"--providers={provider} --m={problem_size.m()} --n={problem_size.n()} --k={problem_size.k()} " \
f"--batch_count={batch_count} --alpha={alpha} --beta={beta} "\
f"--warmup-iterations={self.warmup_iterations} --profiling-iterations={self.iterations}"
result = subprocess.getoutput(cmd)
m = re.search(r"Runtime:\s+(?P<runtime>\d+.\d+)", result)
runtime = float(m.group("runtime"))
m = re.search(r"Bytes:\s+(?P<bytes>\d+)", result)
bytes = int(m.group("bytes"))
m = re.search(r"FLOPs:\s+(?P<flops>\d+)", result)
flops = int(m.group("flops"))
# check if the problem size matches
assert bytes == self.bytes(problem_size, batch_count, beta)
assert flops == self.flops(problem_size, batch_count, beta)
return runtime
def bytes(self, problem_size, batch_count=1, beta=0.0):
m = problem_size.m()
n = problem_size.n()
k = problem_size.k()
bytes = (
(DataTypeSize[self.operation.A.element] * m // 8) * k
+ (DataTypeSize[self.operation.B.element] * n // 8) * k
+ (DataTypeSize[self.operation.C.element] * m // 8) * n
)
if beta != 0:
bytes += (DataTypeSize[self.operation.C.element] * m // 8) * n
bytes *= batch_count
return bytes
def flops(self, problem_size, batch_count=1, beta=0.0):
m = problem_size.m()
n = problem_size.n()
k = problem_size.k()
flops_ = (m * n * k) * 2 * batch_count
if beta != 0:
flops_ += m * n * batch_count * 2
return flops_