Rename python/cutlass to python/cutlass_cppgen (#2652)

This commit is contained in:
Jack Kosaian
2025-09-18 13:26:57 -05:00
committed by GitHub
parent 74825181f2
commit b234a8c024
71 changed files with 1 additions and 1 deletions

View File

@ -0,0 +1,48 @@
#################################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_cppgen.backend.arguments import *
from cutlass_cppgen.backend.c_types import *
from cutlass_cppgen.backend.compiler import ArtifactManager
from cutlass_cppgen.backend.conv2d_operation import *
from cutlass_cppgen.backend.epilogue import *
from cutlass_cppgen.backend.frontend import *
from cutlass_cppgen.backend.gemm_operation import *
from cutlass_cppgen.backend.library import *
from cutlass_cppgen.backend.memory_manager import PoolMemoryManager, create_memory_pool
from cutlass_cppgen.backend.operation import *
from cutlass_cppgen.backend.reduction_operation import *
from cutlass_cppgen.backend.type_hint import *
from cutlass_cppgen.backend.utils import *
from cutlass_cppgen.backend.utils.device import device_cc
compiler = ArtifactManager()

View File

@ -0,0 +1,136 @@
#################################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from math import prod
from typing import Union
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
cudart = lazy_import("cuda.cudart")
import numpy as np
import cutlass_cppgen
from cutlass_cppgen.backend.frontend import CupyFrontend, NumpyFrontend, TorchFrontend
from cutlass_cppgen.backend.memory_manager import DevicePtrWrapper
from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor
class ArgumentBase:
"""
Base class for operation arguments
"""
def __init__(
self,
A: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]",
B: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]",
C: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]",
D: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]",
**kwargs,
) -> None:
# tensor_C can be interpreted as the bias with bias=True in keyword args
self.bias = kwargs.get("bias", False)
self.stream = kwargs.get("stream", cuda.CUstream(0))
# RMM buffers used to track tensor lifetime
self.buffers = {}
# Host tensor to copy the computed result back
self.host_tensors = {}
self.ptr_A = self.tensor_to_ptr(A, "A")
self.ptr_B = self.tensor_to_ptr(B, "B")
self.ptr_C = self.tensor_to_ptr(C, "C")
self.ptr_D = self.tensor_to_ptr(D, "D", is_output=True)
if C is not None:
if not isinstance(C, cuda.CUdeviceptr):
self.tensor_c_numel = prod(C.shape)
def tensor_to_ptr(self, tensor, name, is_output=False):
"""
Convert and remember the input tensor to cuda.CUdeviceptr used by cuda python
For numpy.ndarray, it also remembers the host buffer for synchronization
"""
if tensor is None:
return cuda.CUdeviceptr(0)
if is_numpy_tensor(tensor):
if is_output:
assert name
self.buffers[name] = NumpyFrontend.argument(tensor, is_output)
if is_output:
self.host_tensors[name] = tensor
return self.buffers[name].ptr
elif is_torch_tensor(tensor):
return TorchFrontend.argument(tensor)
elif isinstance(tensor, cuda.CUdeviceptr):
return tensor
elif is_cupy_tensor(tensor):
return CupyFrontend.argument(tensor)
else:
raise TypeError("Unsupported Frontend. Only support numpy and torch")
def sync(self, stream_sync=True):
if stream_sync:
(err,) = cudart.cudaDeviceSynchronize()
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
for key in self.host_tensors.keys():
host_tensor = self.host_tensors[key]
(err,) = cuda.cuMemcpyDtoH(
host_tensor,
self.buffers[key].ptr,
host_tensor.size * host_tensor.itemsize,
)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
self.free()
def free(self):
"""
Frees allocated device-side memory
"""
# Free any device memory allocated manually
if not cutlass_cppgen.use_rmm:
for name, buf in self.buffers.items():
if isinstance(buf, DevicePtrWrapper):
err, = cudart.cudaFree(buf.ptr)
if err != cudart.cudaError_t.cudaSuccess:
raise RuntimeError(f"cudaFree failed with error {err}")
if hasattr(self, "workspace_buffer") and isinstance(self.workspace_buffer, DevicePtrWrapper):
err, = cudart.cudaFree(self.workspace_buffer.ptr)
if err != cudart.cudaError_t.cudaSuccess:
raise RuntimeError(f"cudaFree failed with error {err}")
del self.workspace_buffer

View File

@ -0,0 +1,625 @@
#################################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import ctypes
from cutlass_library import (
DataType,
KernelScheduleType,
TileSchedulerType
)
from cutlass_cppgen.backend.library import DataTypeSizeBytes
class GemmCoord_(ctypes.Structure):
_fields_ = [
("m", ctypes.c_int),
("n", ctypes.c_int),
("k", ctypes.c_int)
]
def __init__(self, m, n, k) -> None:
self.m = m
self.n = n
self.k = k
class GemmCoordBatched_(ctypes.Structure):
"""
Wrapper around a GemmCoord that also contains batch count. This is used for encoding
batched GEMM inputs to CUTLASS 3 GEMMs.
"""
_fields_ = [
("m", ctypes.c_int),
("n", ctypes.c_int),
("k", ctypes.c_int),
("batch_count", ctypes.c_int)
]
def __init__(self, gemm_coord, batch_count) -> None:
self.m = gemm_coord.m
self.n = gemm_coord.n
self.k = gemm_coord.k
self.batch_count = batch_count
class MatrixCoord_(ctypes.Structure):
_fields_ = [
("row", ctypes.c_int),
("column", ctypes.c_int)
]
class dim3_(ctypes.Structure):
_fields_ = [
("x", ctypes.c_int),
("y", ctypes.c_int),
("z", ctypes.c_int)
]
class StrideBatched_(ctypes.Structure):
"""
CUTLASS 3.0 strides for operands contain one static dimension and two variable dimensions. The
variable dimensions represent the stride along non-unit-stride dimension of the row/column major
layout, and the batch stride. This structure encodes the two variable dimensions.
"""
_fields_ = [
("major_stride", ctypes.c_int64),
("batch_stride", ctypes.c_int64)
]
class GenericMainloopArguments3x_(ctypes.Structure):
"""
Structure representing the superset of possible mainloop arguments.
This structure should not be passed to kernels directly, but, rather,
be used as an input to one of the more specific schedule arguments, which
will each select those arguments relevant to the particular schedule.
"""
_fields_ = [
("ptr_A", ctypes.c_void_p),
("stride_A", StrideBatched_),
("ptr_B", ctypes.c_void_p),
("stride_B", StrideBatched_),
("mma_promotion_interval", ctypes.c_int)
]
class _PersistentTileSchedulerArguments(ctypes.Structure):
_fields_ = [
("max_swizzle_size", ctypes.c_int),
("raster_order_option", ctypes.c_int),
]
class _PersistentTileSchedulerStreamKArguments(ctypes.Structure):
_fields_ = [
("splits", ctypes.c_int),
("max_swizzle_size", ctypes.c_int),
("raster_order_option", ctypes.c_int),
("reduction_mode", ctypes.c_int),
("decomposition_mode", ctypes.c_int),
]
def get_tile_scheduler_arguments_3x(
tile_scheduler: TileSchedulerType,
splits: int = 1):
max_swizzle_size = 1
raster_order_option = 0 # Heuristic
if tile_scheduler in [TileSchedulerType.Default, TileSchedulerType.Persistent]:
return _PersistentTileSchedulerArguments(
max_swizzle_size,
raster_order_option,
)
elif tile_scheduler == TileSchedulerType.StreamK:
reduction_mode = 0 # Deterministic
decomposition_mode = 0 # Heuristic
return _PersistentTileSchedulerStreamKArguments(
splits,
max_swizzle_size,
raster_order_option,
reduction_mode,
decomposition_mode,
)
def get_mainloop_arguments_3x(
kernel_schedule: KernelScheduleType,
element_A,
element_B,
alignment_A: int,
alignment_B: int) -> ctypes.Structure:
"""
Returns the ctypes structure to be used for the 3.x kernel's mainloop parameters.
:param kernel_schedule: type of kernel schedule to be used in the mainloop
:type kernel_schedule: cutlass_library.KernelScheduleType
:param element_A: data type of operand A
:param element_B: data type of operand B
:param alignment_A: alignment of operand A
:type alignment_A: int
:param alignment_B: alignment of operand B
:type alignment_B: int
:returns: ctypes structure to be used for the 3.x kernel's mainloop parameters
:rtype: ctypes.Structure
"""
class _MainloopArgumentsTma(ctypes.Structure):
_fields_ = [
("ptr_A", ctypes.c_void_p),
("stride_A", StrideBatched_),
("ptr_B", ctypes.c_void_p),
("stride_B", StrideBatched_),
("mma_promotion_interval", ctypes.c_int)
]
@staticmethod
def from_generic_mainloop_args(args: GenericMainloopArguments3x_):
return _MainloopArgumentsTma(
args.ptr_A, args.stride_A, args.ptr_B, args.stride_B,
args.mma_promotion_interval
)
class _MainloopArgumentsMultistage(ctypes.Structure):
_fields_ = [
("ptr_A", ctypes.c_void_p),
("stride_A", StrideBatched_),
("ptr_B", ctypes.c_void_p),
("stride_B", StrideBatched_),
]
@staticmethod
def from_generic_mainloop_args(args: GenericMainloopArguments3x_):
return _MainloopArgumentsMultistage(
args.ptr_A, args.stride_A, args.ptr_B, args.stride_B,
)
# Currently all 3.x kernels (CpAsync and Tma) have the same argument structure.
# Should that become not the case, this is the place to return custom ctypes
# structures based on selected kernel schedule.
return _MainloopArgumentsTma
def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor, scheduler_args, default_epilogue):
if not default_epilogue and hasattr(epilogue_functor, "epilogue_type_evt"):
_EpilogueOutputOpParams = epilogue_functor.epilogue_type_evt
else:
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
if hasattr(epilogue_functor, "visitor"):
class _EpilogueArguments(ctypes.Structure):
_fields_ = [
("epilogue", _EpilogueOutputOpParams),
("arg_C", epilogue_functor.arg_c_type),
("arg_D", epilogue_functor.arg_d_type)
]
def __init__(self, output_op, ptr_c, stride_c, ptr_d, stride_d) -> None:
self.epilogue = output_op
self.arg_C = epilogue_functor.arg_c_type(ptr_c)
self.arg_D = epilogue_functor.arg_d_type(ptr_d)
else:
class _EpilogueArguments(ctypes.Structure):
_fields_ = [
("epilogue", _EpilogueOutputOpParams),
("ptr_C", ctypes.c_void_p),
("stride_C", StrideBatched_),
("ptr_D", ctypes.c_void_p),
("stride_D", StrideBatched_),
]
class _HardwareInfo(ctypes.Structure):
_fields_ = [
("device_id", ctypes.c_int),
("sm_count", ctypes.c_int),
("max_active_clusters", ctypes.c_int),
("cluster_shape", dim3_),
("cluster_shape_fallback", dim3_),
]
class _GemmArguments(ctypes.Structure):
_fields_ = [
("mode", ctypes.c_int),
("problem_size", GemmCoordBatched_),
("mainloop", mainloop_arguments),
("epilogue", _EpilogueArguments),
("hw_info", _HardwareInfo),
("scheduler", type(scheduler_args)),
]
return _GemmArguments, _EpilogueArguments, _EpilogueOutputOpParams, _HardwareInfo
def get_gemm_arguments(epilogue_functor):
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
class _GemmArguments(ctypes.Structure):
_fields_ = [
# Arguments from UniversalArgumentsBase
("mode", ctypes.c_int),
("problem_size", GemmCoord_),
("batch_count", ctypes.c_int),
("batch_stride_D", ctypes.c_longlong),
# Remaining arguments
("epilogue", _EpilogueOutputOpParams),
("ptr_A", ctypes.c_void_p),
("ptr_B", ctypes.c_void_p),
("ptr_C", ctypes.c_void_p),
("ptr_D", ctypes.c_void_p),
("batch_stride_A", ctypes.c_longlong),
("batch_stride_B", ctypes.c_longlong),
("batch_stride_C", ctypes.c_longlong),
("stride_a", ctypes.c_longlong),
("stride_b", ctypes.c_longlong),
("stride_c", ctypes.c_longlong),
("stride_d", ctypes.c_longlong),
("lda", ctypes.c_longlong),
("ldb", ctypes.c_longlong),
("ldc", ctypes.c_longlong),
("ldd", ctypes.c_longlong),
("ptr_gather_A_indices", ctypes.c_void_p),
("ptr_gather_B_indices", ctypes.c_void_p),
("ptr_scatter_D_indices", ctypes.c_void_p)
]
return _GemmArguments, _EpilogueOutputOpParams
def get_gemm_arguments_streamk(epilogue_functor):
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
class _GemmArguments(ctypes.Structure):
_fields_ = [
("mode", ctypes.c_int),
("problem_size", GemmCoord_),
("batch_count", ctypes.c_int),
("epilogue", _EpilogueOutputOpParams),
("ptr_A", ctypes.c_void_p),
("ptr_B", ctypes.c_void_p),
("ptr_C", ctypes.c_void_p),
("ptr_D", ctypes.c_void_p),
("batch_stride_A", ctypes.c_longlong),
("batch_stride_B", ctypes.c_longlong),
("batch_stride_C", ctypes.c_longlong),
("batch_stride_D", ctypes.c_longlong),
("stride_a", ctypes.c_longlong),
("stride_b", ctypes.c_longlong),
("stride_c", ctypes.c_longlong),
("stride_d", ctypes.c_longlong),
("lda", ctypes.c_longlong),
("ldb", ctypes.c_longlong),
("ldc", ctypes.c_longlong),
("ldd", ctypes.c_longlong),
("avail_sms", ctypes.c_int)
]
return _GemmArguments, _EpilogueOutputOpParams
###########################################################################################
# GEMM Grouped
###########################################################################################
def get_gemm_grouped_arguments(epilogue_functor):
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
class _GEMMGroupedArguments(ctypes.Structure):
_fields_ = [
("problem_sizes", ctypes.c_void_p),
("problem_count", ctypes.c_int),
("threadblock_count", ctypes.c_int),
("output_op", _EpilogueOutputOpParams),
("ptr_A", ctypes.c_void_p),
("ptr_B", ctypes.c_void_p),
("ptr_C", ctypes.c_void_p),
("ptr_D", ctypes.c_void_p),
("lda", ctypes.c_void_p),
("ldb", ctypes.c_void_p),
("ldc", ctypes.c_void_p),
("ldd", ctypes.c_void_p),
("host_problem_sizes", ctypes.c_void_p)
]
return _GEMMGroupedArguments, _EpilogueOutputOpParams
############################################################################################
# Convolution2D
############################################################################################
class Conv2DProblemSize_(ctypes.Structure):
_fields_ = [
("N", ctypes.c_int),
("H", ctypes.c_int),
("W", ctypes.c_int),
("C", ctypes.c_int),
("P", ctypes.c_int),
("Q", ctypes.c_int),
("K", ctypes.c_int),
("R", ctypes.c_int),
("S", ctypes.c_int),
("pad_h", ctypes.c_int),
("pad_w", ctypes.c_int),
("stride_h", ctypes.c_int),
("stride_w", ctypes.c_int),
("dilation_h", ctypes.c_int),
("dilation_w", ctypes.c_int),
("mode", ctypes.c_int), # kCrossCorrelation: 0, kConvolution: 1
("split_k_slices", ctypes.c_int),
("groups", ctypes.c_int)
]
def __init__(self, problem_size) -> None:
for field_name, _ in self._fields_:
setattr(self, field_name, getattr(problem_size, field_name))
class Layout4D(ctypes.Structure):
_fields_ = [("stride", ctypes.c_int * 3)]
def __init__(self, tensor_ref):
stride = tensor_ref.stride()
setattr(self, "stride", (stride.at(0), stride.at(1), stride.at(2)))
class TensorRef_(ctypes.Structure):
_fields_ = [
("ptr", ctypes.c_void_p),
("layout", Layout4D)
]
def __init__(self, tensor_ref):
setattr(self, "ptr", tensor_ref.data())
setattr(self, "layout", Layout4D(tensor_ref.layout()))
class TensorRef2D_(ctypes.Structure):
_fields_ = [
("ptr", ctypes.c_void_p),
("stride", ctypes.c_int)
]
def get_conv2d_arguments(epilogue_functor):
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
class _Conv2dArguments(ctypes.Structure):
_fields_ = [
("conv_kind", ctypes.c_int),
("problem_size", Conv2DProblemSize_),
("ptr_A", ctypes.c_void_p),
("ptr_B", ctypes.c_void_p),
("ptr_C", ctypes.c_void_p),
("ptr_D", ctypes.c_void_p),
("tensor_C_numel", ctypes.c_int),
("output_op", _EpilogueOutputOpParams),
("split_k_mode", ctypes.c_int)
]
return _Conv2dArguments, _EpilogueOutputOpParams
############################################################################################
# Reduction
############################################################################################
def get_reduction_params(epilogue_functor):
_EpilogueOutputParams = epilogue_functor.epilogue_type
class _ReductionParams(ctypes.Structure):
_fields_ = [
("problem_size", MatrixCoord_),
("partitions", ctypes.c_int),
("partition_stride", ctypes.c_longlong),
("workspace", TensorRef2D_),
("destination", TensorRef2D_),
("source", TensorRef2D_),
("output_op", _EpilogueOutputParams),
]
return _ReductionParams, _EpilogueOutputParams
###########################################################################################
# Epilogue Visitor Type Factory
###########################################################################################
class Empty(ctypes.Structure):
_fields_ = []
def __init__(self, *arg) -> None:
pass
class EmptyByte(ctypes.Structure):
_fields_ = [
("byte", ctypes.c_byte)
]
def __init__(self, *arg) -> None:
pass
class EBO:
def __init__(self, index: int, type) -> None:
self.index = index
self.type = type
def __eq__(self, other) -> bool:
if isinstance(other, EBO):
return self.index == other.index and self.type == other.type
return False
def __hash__(self) -> int:
return hash((self.index, self.type))
def __ne__(self, other):
return not self.__eq__(other)
def __str__(self) -> str:
return f"<{self.index}, {self.type}>"
def tuple_factory_(input_tuple, dtype, constants=[0,1]):
"""
The factory function generating cute::Tuple with input tuple
:param input_tuple: the input tuple
:type input_tuple: tuple
:param dtype: the data type for non-constant values
:type dtype: str, "int32_t", "int", "int64_t"
:param constant: the values that will be treated as constants
:type constant: list[int]
:return: ctype structure representing the cute::Tuple
:return: the empty base classes of the tuple
"""
# The empty base classes of the current tuple
empty_bases = []
# The first non empty base class
first_non_empty_base = None
# The ctype fields of the current tuple
ctype_fields = []
for idx, entry in enumerate(input_tuple):
# For nested tuples
if isinstance(entry, tuple):
sub_tuple_ctype, sub_empty_bases = tuple_factory_(entry, dtype, constants)
if ctypes.sizeof(sub_tuple_ctype) == 0:
# The empty tuple base class is also an empty EBO
empty_bases.append(EBO(idx, entry))
else:
if first_non_empty_base is None:
first_non_empty_base = sub_empty_bases
ctype_fields.append((f"entry_{idx}", sub_tuple_ctype))
else:
if entry in constants:
empty_bases.append(EBO(idx, entry))
ctype_fields.append((f"entry_{idx}", Empty))
else:
ctype_fields.append((f"entry_{idx}", dtype))
if first_non_empty_base is None:
first_non_empty_base = []
# Create the ctype tuple
class TupleType(ctypes.Structure):
_fields_ = ctype_fields
def __init__(self, args) -> None:
fields = self._fields_
assert len(fields) == len(args)
for field, arg in zip(fields, args):
name = field[0]
field_type = field[1]
setattr(self, name, field_type(arg))
return TupleType, empty_bases
def tuple_factory(input_tuple, dtype: str, constants=[0,1]):
"""
The factory function generating cute::Tuple with input tuple
:param input_tuple: the input tuple
:type input_tuple: tuple
:param dtype: the data type for non-constant values
:type dtype: str, "int32_t", "int", "int64_t"
:param constant: the values that will be treated as constants
:type constant: list[int]
:return: ctype structure representing the cute::Tuple
:return: the empty base classes of the tuple
"""
# Step 1: convert the dtype
if dtype == "int64_t":
dtype = ctypes.c_longlong
elif dtype in ["int", "int32_t"]:
dtype = ctypes.c_int32
else:
raise NotImplementedError(f"Type {dtype} is not supported")
tuple_type, _ = tuple_factory_(input_tuple, dtype, constants)
if ctypes.sizeof(tuple_type) == 0:
return EmptyByte
return tuple_type
def visitor_factory(node_types, node_names):
"""
Creates the argument type of epilogue visitor type
:param node_types: list of argument types under ctypes
:param node_names: list of argument names under str
:return: tuple type in ctypes.Structure
"""
ctypes_field = []
# Struct is used when number of nodes < 4
# Because the Sm90VisitorImplBase has specification up to 4 nodes
# in `include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp`
if len(node_types) <= 4:
for idx, node_type in enumerate(node_types):
if ctypes.sizeof(node_type) == 0:
# Special case for empty struct
# 1 byte placeholder is used for correct alignment
ctypes_field.append((node_names[idx], ctypes.c_byte))
else:
ctypes_field.append((node_names[idx], node_type))
class VisitorType(ctypes.Structure):
_fields_ = ctypes_field
def __init__(self, kwargs) -> None:
for field in self._fields_:
fname, ftype = field
if ftype != ctypes.c_byte:
setattr(self, fname, ftype(kwargs))
# For cases with more than 4 nodes, tuple is used
else:
for idx, node_type in enumerate(node_types):
ctypes_field.append((node_names[idx], node_type))
class VisitorType(ctypes.Structure):
_fields_ = ctypes_field
def __init__(self, kwargs) -> None:
for field in self._fields_:
fname, ftype = field
setattr(self, fname, ftype(kwargs))
return VisitorType

View File

@ -0,0 +1,462 @@
#################################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import ctypes
import json
import os
import sqlite3
import subprocess
import tempfile
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
cudart = lazy_import("cuda.cudart")
nvrtc = lazy_import("cuda.nvrtc")
from cutlass_library import SubstituteTemplate
import cutlass_cppgen
from cutlass_cppgen import CACHE_FILE, CUTLASS_PATH, cuda_install_path, logger
from cutlass_cppgen.backend.gemm_operation import GemmOperationUniversal
from cutlass_cppgen.backend.library import ApiVersion
from cutlass_cppgen.backend.utils.device import device_cc
IncludeTemplate = r"""#include "${include}"
"""
def compile_with_nvcc(cmd, source, error_file):
succeed = True
try:
subprocess.check_output(cmd, stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as e:
error_message = e.output.decode()
with open(error_file, "w") as error_out:
error_log = "Compilation error for the following kernel: \n"
error_log += source
error_log += "\nError Message:\n"
error_log += error_message
error_out.write(error_log)
succeed = False
if not succeed:
# Print the error log to stdout if log level is set to warning or higher
# verbosity. Otherwise, simply point to the error log file.
logger.warning(error_log)
raise Exception(f"Invalid Kernel. See '{error_file}' for details.")
class CompilationOptions:
"""
Compilation options.
"""
def __init__(self, flags, arch, include_paths=[]):
self.includes = []
self.include_paths = include_paths
self.flags = flags
self.arch = arch
def get_str(self):
opts = []
for flag in self.flags:
opts.append(flag)
for incl in self.include_paths:
opts.append(f"--include-path={incl}")
arch_flag = f"-arch=sm_{self.arch}"
if self.arch in [90, 100, 101, 103, 120, 121] and int(cutlass_cppgen.nvcc_version().split('.')[0]) >= 12:
arch_flag += "a"
opts.append(arch_flag)
return " ".join(opts)
def get(self):
options = []
for flag in self.flags:
options.append(bytes(str.encode(flag)))
for incl in self.include_paths:
options.append(bytes(str.encode(f" --include-path={incl}")))
arch_flag = f" -arch=sm_{self.arch}"
if self.arch in [90, 100, 101, 103, 120, 121]:
arch_flag += "a"
options.append(bytes(str.encode(arch_flag)))
return options
def convertToBinaryData(filename):
with open(filename, "rb") as file:
blobData = file.read()
return blobData
def CDLLBin(host_binary):
tempfile.tempdir = "./"
temp_so = tempfile.NamedTemporaryFile(prefix="host_func", suffix=".so", delete=True)
with open(temp_so.name, "wb") as file:
file.write(host_binary)
host_lib = ctypes.CDLL(temp_so.name)
return host_lib
class ArtifactManager:
"""
Artifact manager
"""
def __init__(self) -> None:
connection = sqlite3.connect(CACHE_FILE)
cursor = connection.cursor()
# Create the table if it does not already exist
sqlite_create_table_query = """
CREATE TABLE IF NOT EXISTS compiled_operations(op_key TEXT NOT NULL UNIQUE,
cubin BLOB NOT NULL,
hostbin BLOB NOT NULL,
op_name TEXT NOT NULL,
op_attrs TEXT NOT NULL)
"""
cursor.execute(sqlite_create_table_query)
connection.commit()
cursor.close()
self._nvrtc_compile_options = ["-std=c++17", "-default-device"]
self._nvcc_compile_options = [
"-std=c++17",
"--expt-relaxed-constexpr",
"-Xcudafe --diag_suppress=esa_on_defaulted_function_ignored",
]
self.nvcc()
self.compiled_cache_device = {}
self.compiled_cache_host = {}
def nvrtc(self):
self.backend = "nvrtc"
self.default_compile_options = self._nvrtc_compile_options
def nvcc(self):
self.backend = "nvcc"
self.default_compile_options = self._nvcc_compile_options
def insert_operation(self, op_key, cubin, hostfile, op_name, op_attrs):
connection = sqlite3.connect(CACHE_FILE)
cursor = connection.cursor()
sqlite_insert_blob_query = """ INSERT OR IGNORE INTO compiled_operations (op_key, cubin, hostbin, op_name, op_attrs) VALUES (?, ?, ?, ?, ?)"""
hostbin = convertToBinaryData(hostfile)
data_tuple = (op_key, cubin, hostbin, op_name, json.dumps(op_attrs))
cursor.execute(sqlite_insert_blob_query, data_tuple)
connection.commit()
cursor.close()
def load_operation(self, op_key, extra_funcs):
connection = sqlite3.connect(CACHE_FILE)
cursor = connection.cursor()
sqlite_fetch_blob_query = """SELECT * from compiled_operations where op_key = ?"""
cursor.execute(sqlite_fetch_blob_query, (op_key,))
record = cursor.fetchall()
if len(record) == 0:
return False
for row in record:
key, cubin_image, host_binary, operation_name, op_attr = row
op_attr = json.loads(op_attr)
err, module = cuda.cuModuleLoadData(cubin_image)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("Cuda Error: {}".format(err))
err, kernel = cuda.cuModuleGetFunction(module, bytes(str.encode(operation_name)))
self.compiled_cache_device[key] = kernel
compiled_host_fns = {}
host_lib = CDLLBin(host_binary)
func_name = operation_name + "_get_params"
func = getattr(host_lib, func_name)
func.restype = ctypes.POINTER(ctypes.c_char * op_attr[0])
compiled_host_fns["get_args"] = func
func_name = operation_name + "_shared_memory_size"
func = getattr(host_lib, func_name)
compiled_host_fns["shared_memory_capacity"] = func()
for attr in op_attr:
if isinstance(attr, str):
func_name = operation_name + "_" + attr
func = getattr(host_lib, func_name)
# Set the return type of the function
if attr in extra_funcs and extra_funcs[attr] != None:
func.restype = extra_funcs[attr]
compiled_host_fns[attr] = func
self.compiled_cache_host[key] = compiled_host_fns
return True
def emit_compile_(self, operation_list, compilation_options, host_compilation_options):
"""
Compile a list of kernels and store them into database
"""
source_buffer_device = ""
source_buffer_host = ""
# 1. include
includes = []
for operation in operation_list:
for incl in operation.emitter.includes:
if incl not in includes:
includes.append(incl)
includes_host = ["builtin_types.h", "device_launch_parameters.h", "cstddef"] + includes
for incl in includes:
source_buffer_device += SubstituteTemplate(
IncludeTemplate,
{"include": incl},
)
for incl in includes_host:
source_buffer_host += SubstituteTemplate(
IncludeTemplate,
{"include": incl},
)
# 2. Operations
for operation in operation_list:
source_buffer_device += operation.emit()
source_buffer_host += operation.emit()
values = {
"operation_name": operation.name(),
"operation_suffix": operation.emitter.operation_suffix,
}
source_buffer_device += SubstituteTemplate(
operation.KernelTemplate,
values,
)
source_buffer_host += SubstituteTemplate(operation.HostTemplate, values)
if self.backend == "nvrtc":
# 3. compile
err, program = nvrtc.nvrtcCreateProgram(
str.encode(source_buffer_device),
bytes(str.encode("module.cu")),
0, [], [])
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError("NVRTC Error: {}".format(err))
# Compile program
options = compilation_options.get()
err, = nvrtc.nvrtcCompileProgram(program, len(options), options)
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
error_string = "NVRTC Error: {}\n".format(err)
# Get log from compilation
err, logSize = nvrtc.nvrtcGetProgramLogSize(program)
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError("NVRTC Error: {}".format(err))
log = b" " * logSize
err, = nvrtc.nvrtcGetProgramLog(program, log)
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError("NVRTC Error: {}".format(err))
raise RuntimeError(error_string + log.decode() + source_buffer_device)
# Get data from compilation
err, dataSize = nvrtc.nvrtcGetCUBINSize(program)
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError("NVRTC Error: {}".format(err))
cubin_image = b" " * dataSize
(err,) = nvrtc.nvrtcGetCUBIN(program, cubin_image)
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError("NVRTC Error: {}".format(err))
else: # with nvcc backend
# emit code
tempfile.tempdir = "./"
temp_cu = tempfile.NamedTemporaryFile(
prefix="kernel", suffix=".cu", delete=True)
temp_cubin = tempfile.NamedTemporaryFile(
prefix="kernel", suffix=".cubin", delete=True)
with open(temp_cu.name, "w") as file:
file.write(source_buffer_device)
# compile with nvcc
cmd_template = "${cuda_install_path}/bin/nvcc ${options} -cubin ${srcfile} -o ${tarfile}"
values = {
"cuda_install_path": cuda_install_path(),
"options": compilation_options.get_str(),
"srcfile": temp_cu.name,
"tarfile": temp_cubin.name,
}
cmd = SubstituteTemplate(cmd_template, values)
compile_with_nvcc(cmd.split(" "), source_buffer_device, "./cutlass_python_compilation_device_error.txt")
# load the cubin image
with open(temp_cubin.name, "rb") as file:
cubin_image = file.read()
tempfile.tempdir = "./"
temp_src = tempfile.NamedTemporaryFile(
prefix="host_src", suffix=".cu", delete=True)
# Write the host source
with open(temp_src.name, "w") as outfile:
outfile.write(source_buffer_host)
temp_dst = tempfile.NamedTemporaryFile(
prefix="host_func", suffix=".so", delete=True)
# Set up host compilation arguments
cmd = []
cmd.append(f"{cuda_install_path()}/bin/nvcc")
cmd.extend(["-x", "cu", "-Xcompiler=-fpermissive", "-Xcompiler=-w", "-Xcompiler=-fPIC"])
cmd.extend(host_compilation_options.get_str().split(" "))
cmd.extend(["-shared", "-o", temp_dst.name, temp_src.name, "-lcudart", "-lcuda"])
# Comile and load the library
compile_with_nvcc( cmd, source_buffer_host, error_file="./cutlass_python_compilation_host_error.txt")
host_lib = ctypes.CDLL(temp_dst.name)
return cubin_image, host_lib, temp_dst
def add_module(self, operations, compile_options=None, bypass_cache=False):
"""
Insert a new compiled device module
"""
include_paths = [
cuda_install_path() + "/include",
CUTLASS_PATH + "/include",
CUTLASS_PATH + "/tools/util/include",
CUTLASS_PATH + "/python/cutlass/cpp/include",
]
cutlass_cppgen.initialize_cuda_context()
arch = device_cc()
host_compile_options = CompilationOptions(
self._nvcc_compile_options, arch, include_paths)
if compile_options is None:
compile_options = CompilationOptions(
self.default_compile_options, arch, include_paths)
# save the cubin
operation_key = []
operation_list = []
for operation in operations:
# step 1: get kernel string as key
key = operation.rt_module.emit() + operation.procedural_name() + self.backend
# step 1: check if the operation is in cache
compiled_kernel = self.compiled_cache_device.get(key)
if compiled_kernel is None and not bypass_cache:
hit = self.load_operation(key, getattr( operation.rt_module, "extra_funcs", {}))
if hit:
compiled_kernel = self.compiled_cache_device.get(key)
assert compiled_kernel is not None
if compiled_kernel is not None:
operation.rt_module.kernel = compiled_kernel
compiled_host_fns = self.compiled_cache_host.get(key)
assert compiled_host_fns is not None
for key in compiled_host_fns.keys():
setattr(operation.rt_module, key, compiled_host_fns[key])
operation.rt_module.initialize()
else:
operation_list.append(operation.rt_module)
operation_key.append(key)
if len(operation_list) > 0:
cubin_image, host_lib, host_file = self.emit_compile_(
operation_list, compile_options, host_compile_options)
err, module = cuda.cuModuleLoadData(cubin_image)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("Cuda Error: {}".format(err))
operation_name = []
operation_attr = []
for operation, key in zip(operation_list, operation_key):
# get device kernels
err, operation.kernel = cuda.cuModuleGetFunction(
module,
bytes(str.encode(operation.name()))
)
operation_name.append(operation.name())
self.compiled_cache_device[key] = operation.kernel
# get host functions
compiled_host_fns = {}
op_attr = []
# get param size
func_name = operation.name() + "_get_param_size"
func = getattr(host_lib, func_name)
param_size = func()
func_name = operation.name() + "_get_params"
func = getattr(host_lib, func_name)
func.argtype = operation.argtype
func.restype = ctypes.POINTER(ctypes.c_char * param_size)
setattr(operation, "get_args", func)
compiled_host_fns["get_args"] = func
# set shared memory size
func_name = operation.name() + "_shared_memory_size"
func = getattr(host_lib, func_name)
setattr(operation, "shared_memory_capacity", func())
compiled_host_fns["shared_memory_capacity"] = func()
# set the maximum dynamic shared size
operation.initialize()
# get extra functions
op_attr.append(param_size)
if hasattr(operation, "extra_funcs"):
for suffix, ret_type in operation.extra_funcs.items():
func_name = operation.name() + "_" + suffix
func = getattr(host_lib, func_name)
if ret_type is not None:
func.restype = ret_type
setattr(operation, suffix, func)
compiled_host_fns[suffix] = func
op_attr.append(suffix)
operation_attr.append(op_attr)
self.compiled_cache_host[key] = compiled_host_fns
for (key, operation_name, operation_attr,) in zip(operation_key, operation_name, operation_attr):
self.insert_operation(
key, cubin_image, host_file.name, operation_name, operation_attr)

View File

@ -0,0 +1,700 @@
#################################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from __future__ import annotations
import ctypes
from typing import Union
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
from cutlass_library import SubstituteTemplate
import numpy as np
from cutlass_library import (
ConvKindNames,
ConvKindTag,
DataTypeNames,
DataTypeSize,
DataTypeTag,
IteratorAlgorithmNames,
IteratorAlgorithmTag,
LayoutTag,
LayoutType,
MathOperation,
MathOperationTag,
OpcodeClass,
OpcodeClassNames,
OpcodeClassTag,
OperationKind,
ShortDataTypeNames,
ShortLayoutTypeNames,
SplitKMode,
StrideSupport,
StrideSupportTag,
SwizzlingFunctor,
SwizzlingFunctorTag,
get_complex_from_real,
)
from cutlass_cppgen.backend.arguments import ArgumentBase
from cutlass_cppgen.backend.c_types import dim3_, get_conv2d_arguments
from cutlass_cppgen.backend.library import (
EmissionType,
TensorDescription,
TileDescription,
)
from cutlass_cppgen.backend.memory_manager import device_mem_alloc
from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration
from cutlass_cppgen.backend.utils.device import to_device_ptr
from cutlass_cppgen.shape import GemmCoord
class Conv2dArguments(ArgumentBase):
"""
Argument wrapper for Conv2d. It encodes problem information and
user-provide tensors into the kernel's argument.
:param operation: the Conv2d operation to take the argument
:type operation: :class:`cutlass_cppgen.backend.Conv2dOperation`
:param problem_size: the Conv2d problem size
:type problem_size: :class:`cutlass_cppgen.shape.Conv2dProblemSize`
:param A: tensor A
:type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
:param B: tensor B
:type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
:param C: tensor C
:type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
:param D: tensor D
:type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
:param split_k_mode: conv2d split K mode, defaults to cutlass_library.library.SplitKMode.Serial
:type split_k_mode: cutlass_library.library.SplitKMode, optional
:param output_op: output operator, optional
:type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
:type stream: :class:`cuda.cuda.CUstream`
"""
def __init__(self, operation, problem_size, A, B, C, D,
split_k_mode=SplitKMode.Serial, **kwargs, ) -> None:
self.operation = operation
self.conv_kind = operation.conv_kind
self.layout_A = operation.A.layout
self.layout_B = operation.B.layout
self.layout_C = operation.C.layout
self.element_A = operation.A.element
self.element_B = operation.B.element
self.element_C = operation.C.element
if self.layout_C == LayoutType.TensorNC32HW32:
raise Exception("Layout type TensorNC32HW32 is not currently supported")
super().__init__(A, B, C, D, **kwargs)
if "split_k_slices" in kwargs.keys() and kwargs["split_k_slices"] > 1:
self.split_k_mode = split_k_mode
self.split_k_slices = kwargs["split_k_slices"]
else:
self.split_k_mode = SplitKMode.Serial
self.split_k_slices = 1
if "output_op" in kwargs.keys() and self.split_k_mode != SplitKMode.Parallel:
self.output_op = kwargs["output_op"]
else:
self.output_op = self.operation.epilogue_type(1.0, 0.0)
self.problem_size = problem_size
self.problem_size.split_k_slices = self.split_k_slices
self.initialize()
def get_arguments(self):
tc_numel = -1
if hasattr(self, "tensor_c_numel"):
tc_numel = self.tensor_c_numel
self.c_arguments = self.operation.argument_type(
int(self.conv_kind),
self.problem_size.ctype,
int(to_device_ptr(self.ptr_A)),
int(to_device_ptr(self.ptr_B)),
int(to_device_ptr(self.ptr_C)),
int(to_device_ptr(self.ptr_D)),
tc_numel,
self.output_op,
int(self.split_k_mode)
)
def initialize(self):
self.launch_config = self.operation.rt_module.plan(self)
self.get_arguments()
# Allocate and initialize device workspace
device_workspace_size = self.operation.rt_module.get_workspace_size(self.c_arguments)
if device_workspace_size > 0:
self.workspace_buffer = device_mem_alloc(device_workspace_size)
workspace_ptr = self.workspace_buffer.ptr
err, = cuda.cuMemsetD32(
workspace_ptr, 0, device_workspace_size // 4)
else:
workspace_ptr = None
self.semaphore = 0
if workspace_ptr is not None and self.split_k_mode == SplitKMode.Parallel:
self.ptr_D = workspace_ptr
# Reset arguments now that ptr_D has been updated
self.get_arguments()
elif workspace_ptr is not None and self.split_k_mode == SplitKMode.Serial:
self.semaphore = workspace_ptr
params_ = self.operation.rt_module.get_args(
self.c_arguments, ctypes.c_void_p(int(self.semaphore)))
self.host_workspace = bytearray(params_.contents)
self.device_workspace = None
def sync(self):
"""
Synchronize the arguments. If the input tensor is in host,
copy it from device to host.
"""
return super().sync()
class Conv2dRT(ExecutableOperation):
"""
Conv2dRT manages the CUTLASS runtime components
"""
KernelTemplate = r"""
extern "C"
__global__ void
${operation_name}(${operation_name}${operation_suffix}::Params params) {
// Dynamic shared memory base pointer
extern __shared__ int SharedStorageBase[];
// Declare pointer to dynamic shared memory.
${operation_name}${operation_suffix}::SharedStorage *shared_storage =
reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase);
${operation_name}${operation_suffix} op;
op(params, *shared_storage);
}
"""
HostTemplate = r"""
extern "C" {
// Get the size of params in bytes
int ${operation_name}_get_param_size(){
return sizeof(${operation_name}${operation_suffix}::Params);
}
// Get the size of dynamic shared memory in bytes
int ${operation_name}_shared_memory_size() {
return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
}
using ElementA = typename ${operation_name}_base::ElementA;
using ElementB = typename ${operation_name}_base::ElementB;
using ElementC = typename ${operation_name}_base::ElementC;
using LayoutA = typename ${operation_name}_base::LayoutA;
using LayoutB = typename ${operation_name}_base::LayoutB;
using LayoutC = typename ${operation_name}_base::LayoutC;
using EpilogueOutputOp = typename ${operation_name}_base::EpilogueOutputOp;
struct ${operation_name}_TemporaryArgs {
int conv_kind;
cutlass::conv::Conv2dProblemSize problem_size;
ElementA* ptr_A;
ElementB* ptr_B;
ElementC* ptr_C;
ElementC* ptr_D;
int tensor_c_numel;
typename EpilogueOutputOp::Params epilogue_params;
int split_k_mode;
};
typename ${operation_name}${operation_suffix}::Arguments
construct_arguments(${operation_name}_TemporaryArgs args) {
cutlass::conv::Operator conv_operator = static_cast<cutlass::conv::Operator>(args.conv_kind);
auto tc_A = cutlass::conv::implicit_gemm_tensor_a_extent(conv_operator, args.problem_size);
auto tc_B = cutlass::conv::implicit_gemm_tensor_b_extent(conv_operator, args.problem_size);
auto tc_C = cutlass::conv::implicit_gemm_tensor_c_extent(conv_operator, args.problem_size);
auto tc_D = cutlass::conv::implicit_gemm_tensor_c_extent(conv_operator, args.problem_size);
auto size_C = tc_C.at(0) * tc_C.at(1) * tc_C.at(2) * tc_C.at(3);
if (args.tensor_c_numel >= 0 && args.tensor_c_numel == tc_C.at(3) && args.tensor_c_numel < size_C) {
// C is interpreted as bias
tc_C = {0, 0, 0, 0};
}
cutlass::TensorRef<ElementA, LayoutA> tref_A(args.ptr_A, LayoutA::packed(tc_A));
cutlass::TensorRef<ElementB, LayoutA> tref_B(args.ptr_B, LayoutB::packed(tc_B));
cutlass::TensorRef<ElementC, LayoutA> tref_C(args.ptr_C, LayoutC::packed(tc_C));
cutlass::TensorRef<ElementC, LayoutA> tref_D(args.ptr_D, LayoutC::packed(tc_D));
return {
args.problem_size,
tref_A,
tref_B,
tref_C,
tref_D,
args.epilogue_params,
static_cast<cutlass::conv::SplitKMode>(args.split_k_mode)
};
}
// Get the params as byte array
char* ${operation_name}_get_params(${operation_name}_TemporaryArgs args, int *semaphore=nullptr) {
auto arguments = construct_arguments(args);
typename ${operation_name}${operation_suffix}::Params* params;
params = new ${operation_name}${operation_suffix}::Params(arguments, semaphore);
char *bytes = ((char*)(params));
char *output = new char[sizeof(${operation_name}${operation_suffix}::Params)];
for (unsigned int i = 0; i < sizeof(${operation_name}${operation_suffix}::Params); i ++)
output[i] = bytes[i];
return output;
}
dim3 ${operation_name}_get_grid_shape(
int conv_kind,
cutlass::conv::Conv2dProblemSize problem_size,
cutlass::gemm::GemmCoord tile_size,
int split_k_slices
) {
using Swizzle = typename ${operation_name}_base::ThreadblockSwizzle;
auto tiled_shape = Swizzle::get_tiled_shape(
static_cast<cutlass::conv::Operator>(conv_kind),
problem_size,
tile_size,
split_k_slices);
return Swizzle::get_grid_shape(tiled_shape);
}
size_t ${operation_name}_get_workspace_size(${operation_name}_TemporaryArgs args) {
auto arguments = construct_arguments(args);
// Temporarily define device::-level Conv2d so that we can call get_workspace_size
using DeviceConv = cutlass::conv::device::ImplicitGemmConvolution<${operation_name}_base>;
return DeviceConv::get_workspace_size(arguments);
}
}
"""
def __init__(self, operation: "Conv2dOperation"):
super().__init__(operation)
self.extra_funcs = {
"get_grid_shape": dim3_,
"get_workspace_size": ctypes.c_uint64
}
self.argument_type, self.epilogue_type = get_conv2d_arguments(operation.epilogue_functor)
self.argtype = [ctypes.POINTER(self.argument_type), ctypes.c_void_p]
self.conv_kind = operation.conv_kind
self.operation: Conv2dOperation = operation
self.emitter = EmitConv2dInstance("_type")
self.threads = operation.tile_description.num_threads
self.swizzle_functor = operation.swizzling_functor
def emit(self):
return self.emitter.emit(self.operation)
def plan(self, arguments: Conv2dArguments):
tile_size = GemmCoord(
self.operation.tile_description.threadblock_shape[0],
self.operation.tile_description.threadblock_shape[1],
self.operation.tile_description.threadblock_shape[2],
)
grid = self.get_grid_shape(
int(self.conv_kind),
arguments.problem_size.ctype,
tile_size.ctype,
arguments.split_k_slices
)
return LaunchConfiguration(
[grid.x, grid.y, grid.z], [self.threads, 1, 1],
self.shared_memory_capacity)
def initialize(self):
err, = cuda.cuFuncSetAttribute(
self.kernel,
attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
value=self.shared_memory_capacity)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"CUDA Error: {err}")
class Conv2dOperation:
"""
CUTLASS Conv2d operation description.
:param conv_kind: convolution operator
:type conv_kind: :class:`cutlass_library.library.ConvKind`
:param iterator_algorithm: Selects among several implementation
variants trading off performance with simplicity
:type iterator_algorithm: :class:`cutlass_library.library.IteratorAlgorithm`
:param arch: GPU compute capability (sm_xx)
:type arch: int
:param tile_description: tile description
:type tile_description: :class:`cutlass_cppgen.backend.TileDescription`
:param A: tensor A description
:type A: :class:`cutlass_cppgen.backend.TensorDescription`
:param B: tensor B description
:type B: :class:`cutlass_cppgen.backend.TensorDescription`
:param C: tensor C description
:type C: :class:`cutlass_cppgen.backend.TensorDescription`
:param D: tensor D description
:type D: :class:`cutlass_cppgen.backend.TensorDescription`
:param element_epilogue: element type for computation in epilogue \
:type element_epilogue: cutlass_library.library.DataType
:param stride_support: distinguish among partial specializations that \
accelerate certain problems where convolution stride is unit \
:type stride_support: :class:`cutlass_library.library.StrideSupport`
:param epilogue_functor: convolution epilogue functor
:type epilogue_functor: :class:`EpilogueFunctor`
:param swizzling_functor: threadblock swizzling functor
"""
def __init__(
self,
conv_kind,
iterator_algorithm,
arch: int,
tile_description: TileDescription,
A: TensorDescription,
B: TensorDescription,
C: TensorDescription,
stride_support,
epilogue_functor,
swizzling_functor=SwizzlingFunctor.Identity1,
emission_type=EmissionType.Kernel,
**kwargs
):
self.operation_kind: OperationKind = OperationKind.Conv2d
self.arch: int = arch
self.tile_description: TileDescription = tile_description
self.conv_kind = conv_kind
self.A: TensorDescription = A
self.B: TensorDescription = B
self.C: TensorDescription = C
self.epilogue_functor = epilogue_functor
self.iterator_algorithm = iterator_algorithm
self.stride_support = stride_support
self.swizzling_functor = swizzling_functor
self.emission_type = emission_type
self.rt_module: Conv2dRT = Conv2dRT(self)
self.argument_type = self.rt_module.argument_type
self.epilogue_type = self.rt_module.epilogue_type
def run(self, arguments: Conv2dArguments) -> cuda.CUresult:
"""
Launch the cuda kernel with input arguments
:param arguments: conv2d arguments
:type arguments: :class:`cutlass_cppgen.backend.Conv2dArguments`
"""
# launch the kernel
err = self.rt_module.run(
arguments.host_workspace,
arguments.device_workspace,
arguments.launch_config,
arguments.stream
)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"CUDA Error {err}")
return err
#
# Get function name
#
def procedural_name(self):
"""The full procedural name indicates architecture, extended name, tile size, and layout."""
return self.configuration_name()
def configuration_name(self):
"""The full procedural name indicates architecture, extended name, tile size, and layout."""
opcode_class_name = OpcodeClassNames[
self.tile_description.math_instruction.opcode_class
]
threadblock = "%dx%d_%dx%d" % (
self.tile_description.threadblock_shape[0],
self.tile_description.threadblock_shape[1],
self.tile_description.threadblock_shape[2],
self.tile_description.stages,
)
if self.stride_support == StrideSupport.Unity:
configuration_name = "cutlass_sm${arch}_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_align${alignment}"
else:
configuration_name = "cutlass_sm${arch}_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}"
return SubstituteTemplate(
configuration_name,
{
"arch": str(self.arch),
"opcode_class": opcode_class_name,
"extended_name": self.extended_name(),
"threadblock": threadblock,
"layout": self.layout_name(),
"alignment": "%d" % self.A.alignment
},
)
def extended_name(self):
"""Append data types if they differ from compute type."""
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
self.A.element != self.tile_description.math_instruction.element_accumulator:
extended_name = "${element_c}_${core_name}_${element_a}"
elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
self.A.element != self.tile_description.math_instruction.element_accumulator:
extended_name = "${core_name}_${element_a}"
else:
extended_name = "${core_name}"
extended_name = SubstituteTemplate(extended_name, {
"element_a": DataTypeNames[self.A.element],
"element_c": DataTypeNames[self.C.element],
"core_name": self.core_name(),
})
return extended_name
def layout_name(self):
return "%s" % (ShortLayoutTypeNames[self.A.layout])
def core_name(self):
"""The basic operation kind is prefixed with a letter indicating the accumulation type."""
intermediate_type = ""
if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp:
inst_shape = "%dx%dx%d" % tuple(
self.tile_description.math_instruction.instruction_shape)
if self.tile_description.math_instruction.element_a != self.A.element and \
self.tile_description.math_instruction.element_a != self.accumulator_type():
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
else:
inst_shape = ""
return "%s%s%s%s_%s" % (
ShortDataTypeNames[self.accumulator_type()],
inst_shape,
intermediate_type,
ConvKindNames[self.conv_kind],
IteratorAlgorithmNames[self.iterator_algorithm]
)
def is_complex(self):
complex_operators = [
MathOperation.multiply_add_complex,
MathOperation.multiply_add_complex_gaussian,
]
return self.tile_description.math_instruction.math_operation in complex_operators
def accumulator_type(self):
accum = self.tile_description.math_instruction.element_accumulator
if self.is_complex():
return get_complex_from_real(accum)
return accum
def device_op(self):
"""
Returns a new Conv2dOperation object that is constructed with emission type
``EmissionType.Device``.
:return: operation ready for device-level code emission
:rtype: Conv2dOperation
"""
return Conv2dOperation(
self.conv_kind, self.iterator_algorithm, self.arch, self.tile_description,
self.A, self.B, self.C, self.stride_support, self.epilogue_functor, self.swizzling_functor,
emission_type=EmissionType.Device)
###################################################################################################
#
# Emits single instances of a CUTLASS device-wide operator
#
###################################################################################################
class EmitConv2dInstance:
def __init__(self, operation_suffix=""):
self.operation_suffix = operation_suffix
self.includes = [
"cutlass/cutlass.h",
"cutlass/conv/kernel/default_conv2d_fprop.h",
"cutlass/conv/kernel/default_conv2d_dgrad.h",
"cutlass/conv/kernel/default_conv2d_wgrad.h",
"cutlass/conv/device/implicit_gemm_convolution.h"
]
self.template = """
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
using ${operation_name}_base =
typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
${element_a},
${layout_a},
${element_b},
${layout_b},
${element_c},
${layout_c},
${element_accumulator},
${opcode_class},
${arch},
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
${epilogue_functor},
${swizzling_functor},
${stages},
${math_operator},
${iterator_algorithm},
${stride_support},
${align_a},
${align_b}
>::Kernel;
struct ${operation_name}${operation_suffix}:
public ${operation_name}_base { };
"""
self.template_device = """
// Conv2d operation ${operation_name}
using Conv2d${conv_kind_name}Kernel = typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
${element_a},
${layout_a},
${element_b},
${layout_b},
${element_c},
${layout_c},
${element_accumulator},
${opcode_class},
${arch},
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
${epilogue_functor},
${swizzling_functor},
${stages},
${math_operator},
${iterator_algorithm},
${stride_support},
${align_a},
${align_b}
>::Kernel;
using DeviceKernel =
typename cutlass::conv::device::ImplicitGemmConvolution<Conv2d${conv_kind_name}Kernel>;
"""
def emit(self, operation):
warp_shape = [int(operation.tile_description.threadblock_shape[idx] /
operation.tile_description.warp_count[idx]) for idx in range(3)]
epilogue_vector_length = int(min(
operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
values = {
"operation_name": operation.procedural_name(),
"operation_suffix": self.operation_suffix,
"conv_kind": ConvKindTag[operation.conv_kind],
"conv_kind_name": ConvKindNames[operation.conv_kind].capitalize(),
"element_a": DataTypeTag[operation.A.element],
"layout_a": LayoutTag[operation.A.layout],
"element_b": DataTypeTag[operation.B.element],
"layout_b": LayoutTag[operation.B.layout],
"element_c": DataTypeTag[operation.C.element],
"layout_c": LayoutTag[operation.C.layout],
"element_accumulator": DataTypeTag[operation.accumulator_type()],
"opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
"arch": "cutlass::arch::Sm%d" % operation.arch,
"threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
"threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
"threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
"warp_shape_m": str(warp_shape[0]),
"warp_shape_n": str(warp_shape[1]),
"warp_shape_k": str(warp_shape[2]),
"instruction_shape_m": str(operation.tile_description.math_instruction.instruction_shape[0]),
"instruction_shape_n": str(operation.tile_description.math_instruction.instruction_shape[1]),
"instruction_shape_k": str(operation.tile_description.math_instruction.instruction_shape[2]),
"epilogue_vector_length": str(epilogue_vector_length),
"epilogue_functor": operation.epilogue_functor.emit(),
"swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
"stages": str(operation.tile_description.stages),
"iterator_algorithm": IteratorAlgorithmTag[operation.iterator_algorithm],
"iterator_algorithm_name": IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(),
"stride_support": StrideSupportTag[operation.stride_support],
"math_operator": "cutlass::arch::OpMultiplyAddComplex" if operation.is_complex() else MathOperationTag[operation.tile_description.math_instruction.math_operation],
"align_a": str(operation.A.alignment),
"align_b": str(operation.B.alignment),
}
if operation.emission_type == EmissionType.Kernel:
conv2d_template = self.template
else:
conv2d_template = self.template_device
return SubstituteTemplate(conv2d_template, values)

View File

@ -0,0 +1,541 @@
#################################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import ctypes
from cutlass_library import SubstituteTemplate
import numpy as np
from cutlass_library import DataType, DataTypeTag
from cutlass_cppgen.backend.c_types import MatrixCoord_, tuple_factory
from cutlass_cppgen.backend.frontend import NumpyFrontend
from cutlass_cppgen.backend.library import ActivationOp, ActivationOpTag
from cutlass_cppgen.utils.datatypes import is_numpy_tensor, is_torch_available, is_torch_tensor
dtype2ctype = {
DataType.f16: ctypes.c_uint16,
DataType.bf16: ctypes.c_uint16,
DataType.f32: ctypes.c_float,
DataType.f64: ctypes.c_double,
DataType.s8: ctypes.c_int8,
DataType.s32: ctypes.c_int32
}
if is_torch_available():
import torch
import torch.nn.functional as F
def get_scalar(value):
"""
Returns a scalar value from a container (e.g., np.ndarray)
"""
if is_numpy_tensor(value):
if value.size != 1:
raise Exception("Scalars used in epilogue must be of size 1")
return value.reshape(-1)[0]
elif is_torch_tensor(value):
if value.size != 1:
raise Exception("Scalars used in epilogue must be of size 1")
return value.reshape(-1)[0]
else:
return value
def to_ctype_value(value, dtype):
"""
Converts ``value`` to the corresponding storage needed for the ctype that
will store ``value``.
"""
scalar = get_scalar(value)
if dtype == DataType.f16:
# Convert f16 value into an integer
return int.from_bytes(np.float16(scalar).tobytes(), "little")
else:
return scalar
#################################################################################################
#
# Epilogue Functors
#
#################################################################################################
class EpilogueFunctorBase:
"""
Base class for thread-level epilogue functors
"""
def __init__(self) -> None:
pass
def emit(self, tag, template_argument):
template = """${tag}<${arguments}>"""
arguments = ""
for idx, arg in enumerate(template_argument):
arguments += arg
if idx < len(template_argument) - 1:
arguments += ", "
values = {
"tag": tag,
"arguments": arguments,
}
return SubstituteTemplate(template, values)
class LinearCombination(EpilogueFunctorBase):
"""
Apply a linear combination operator to an array of elements
D = alpha * accumulator + beta * source
:param element_output: data type used to load and store tensors
:param epilogue_vector_length: number of elements computed per operation.
Usually it is 128/sizeof_bits_v<ElementOutput_>, but we use 64 and 32 sometimes
when there are not enough data to store
:param element_accumulator: Accumulator data type
:param element_epilogue: data type used to compute linear combination
"""
tag = "cutlass::epilogue::thread::LinearCombination"
def __init__(
self, element_output, epilogue_vector_length,
element_accumulator=None, element_epilogue=None) -> None:
super().__init__()
if element_accumulator is None:
element_accumulator = element_output
if element_epilogue is None:
element_epilogue = element_output
self.element_output = element_output
self.element_accumulator = element_accumulator
self.element_epilogue = element_epilogue
self.epilogue_vector_length = epilogue_vector_length
self.template_arguments = [
DataTypeTag[element_output],
str(epilogue_vector_length),
DataTypeTag[element_accumulator],
DataTypeTag[element_epilogue],
]
c_element_epilogue = dtype2ctype[self.element_epilogue]
element_epilogue = self.element_epilogue
class _EpilogueOutputOpParamsEVT(ctypes.Structure):
"""
Epilogue params when using the default linear combination of EVT, which
does not currently use {alpha,beta}_ptr_array
"""
stride_type = tuple_factory((0,0,1), "int64_t", [0])
_fields_ = [
("alpha", c_element_epilogue),
("beta", c_element_epilogue),
("alpha_ptr", ctypes.c_void_p),
("beta_ptr", ctypes.c_void_p),
("dalpha", stride_type),
("dbeta", stride_type),
]
def __init__(self, alpha, beta, *args) -> None:
self.alpha = to_ctype_value(alpha, element_epilogue)
self.beta = to_ctype_value(beta, element_epilogue)
class _EpilogueOutputOpParams(ctypes.Structure):
_fields_ = [
("alpha", c_element_epilogue),
("beta", c_element_epilogue),
("alpha_ptr", ctypes.c_void_p),
("beta_ptr", ctypes.c_void_p),
("alpha_ptr_array", ctypes.c_void_p),
("beta_ptr_array", ctypes.c_void_p),
]
def __init__(self, alpha, beta, *args) -> None:
self.alpha = to_ctype_value(alpha, element_epilogue)
self.beta = to_ctype_value(beta, element_epilogue)
def to_evt_params(self) -> _EpilogueOutputOpParamsEVT:
return _EpilogueOutputOpParamsEVT(self.alpha, self.beta)
self.epilogue_type = _EpilogueOutputOpParams
self.epilogue_type_evt = _EpilogueOutputOpParamsEVT
def emit(self):
return super().emit(self.tag, self.template_arguments)
class LinearCombinationClamp(LinearCombination):
"""
Applies a linear combination operator to an array of elements then clamps
the output before converting to the output element type.
D = alpha * accumulator + beta * source + uniform
:param element_output: data type used to load and store tensors
:param epilogue_vector_length: number of elements computed per operation.
Usually it is 128/sizeof_bits_v<ElementOutput_>, but we use 64 and 32 sometimes
when there are not enough data to store
:param element_accumulator: Accumulator data type
:param element_epilogue: data type used to compute linear combination
"""
tag = "cutlass::epilogue::thread::LinearCombinationClamp"
def __init__(
self, element_output, epilogue_vector_length,
element_accumulator=None, element_epilogue=None) -> None:
# Base constructor
super().__init__(
element_output,
epilogue_vector_length,
element_accumulator,
element_epilogue,
)
c_element_epilogue = dtype2ctype[self.element_epilogue]
element_epilogue = self.element_epilogue
class _EpilogueOutputOpParams(ctypes.Structure):
_fields_ = [
("alpha", c_element_epilogue),
("beta", c_element_epilogue),
("alpha_ptr", ctypes.c_void_p),
("beta_ptr", ctypes.c_void_p),
]
def __init__(self, alpha, beta, *args) -> None:
self.alpha = to_ctype_value(alpha, element_epilogue)
self.beta = to_ctype_value(beta, element_epilogue)
self.epilogue_type = _EpilogueOutputOpParams
class FastLinearCombinationClamp(EpilogueFunctorBase):
"""
Applies a linear combination operator to an array of elements then clamps
the output before converting to the output element type.
D = alpha * accumulator + beta * source
Note: The below method only when problem_size_K <= 256 for signed int8 gemm
or problem_size_K <= 128 for unsigned int8 gemm. The default approach is
above.
:param element_output: data type used to load and store tensors
:param epilogue_vector_length: number of elements computed per operation.
Usually it is 128/sizeof_bits_v<ElementOutput_>, but we use 64 and 32 sometimes
when there are not enough data to store
"""
tag = "cutlass::epilogue::thread::FastLinearCombinationClamp"
def __init__(self, element_output, epilogue_vector_length, *args) -> None:
super().__init__()
self.template_arguments = [
DataTypeTag[element_output], str(epilogue_vector_length)
]
self.element_accumulator = DataType.s32
self.element_epilogue = DataType.f32
# get epilogue output op
c_element_epilogue = dtype2ctype[self.element_epilogue]
element_epilogue = self.element_epilogue
class _EpilogueOutputOpParams(ctypes.Structure):
_fields_ = [
("alpha", c_element_epilogue),
("beta", c_element_epilogue),
("alpha_ptr", ctypes.c_void_p),
("beta_ptr", ctypes.c_void_p),
]
def __init__(self, alpha, beta, *args) -> None:
self.alpha = to_ctype_value(alpha, element_epilogue)
self.beta = to_ctype_value(beta, element_epilogue)
self.epilogue_type = _EpilogueOutputOpParams
def emit(self):
return super().emit(self.tag, self.template_arguments)
class LinearCombinationGeneric(LinearCombination):
"""
Applies a linear combination operator followed by an activation function
to an array of elements.
D = activation(alpha * accumulator + beta * source)
:param activation_functor: input activation functor
:param element_output: data type used to load and store tensors
:param epilogue_vector_length: number of elements computed per operation.
Usually it is 128/sizeof_bits_v<ElementOutput_>, but we use 64 and 32 sometimes
when there are not enough data to store
:param element_accumulator: Accumulator data type
:param element_epilogue: data type used to compute linear combination
"""
tag = "cutlass::epilogue::thread::LinearCombinationGeneric"
def __init__(
self, activation_functor,
element_output, epilogue_vector_length,
element_accumulator=None, element_epilogue=None) -> None:
super().__init__(
element_output,
epilogue_vector_length,
element_accumulator,
element_epilogue,
)
self.template_arguments = [
activation_functor.emit()] + self.template_arguments
self.activation_functor = activation_functor
self.element_epilogue = element_epilogue
# get epilogue output op
self.epilogue_type = self.activation_functor.epilogue_output_op(self.element_epilogue)
class ActivationFunctor:
"""
Base class for frequently used activation functions
"""
@staticmethod
def numpy(x: np.ndarray):
raise NotImplementedError()
@classmethod
def emit(cls):
return ActivationOpTag[cls.binding_type]
@staticmethod
def epilogue_output_op(element_epilogue):
c_element_epilogue = dtype2ctype[element_epilogue]
class _EpilogueOutputOpParams(ctypes.Structure):
_fields_ = [
("alpha", c_element_epilogue),
("beta", c_element_epilogue),
("alpha_ptr", ctypes.c_void_p),
("beta_ptr", ctypes.c_void_p),
]
def __init__(self, alpha, beta, *args) -> None:
self.alpha = to_ctype_value(alpha, element_epilogue)
self.beta = to_ctype_value(beta, element_epilogue)
return _EpilogueOutputOpParams
class ActivationMeta(type):
@classmethod
def __call__(cls, x, *args):
if is_numpy_tensor(x):
return cls.numpy(x, *args)
elif is_torch_tensor(x):
return cls.torch(x, *args)
else:
raise NotImplementedError("Unsupported tensor type")
@classmethod
def numpy(cls, *args):
raise NotImplementedError(f"Numpy reference for {cls.__name__[:-4]} is not implemented.")
@classmethod
def torch(cls, *args):
raise NotImplementedError(f"PyTorch reference for {cls.__name__[:-4]} is not implemented.")
##############################################################################
# identity operator
class identityMeta(ActivationMeta):
@classmethod
def numpy(cls, x):
return x
@classmethod
def torch(cls, x):
return x
class identity(ActivationFunctor, metaclass=identityMeta):
binding_type = ActivationOp.Identity
##############################################################################
# ReLu operator
class reluMeta(ActivationMeta):
@classmethod
def numpy(cls, x):
return np.where(x > 0, x, 0)
@classmethod
def torch(cls, x):
return F.relu(x)
class relu(ActivationFunctor, metaclass=reluMeta):
binding_type = ActivationOp.ReLU
##############################################################################
# Leaky ReLu operator
class leakyReLUMeta(ActivationMeta):
@classmethod
def numpy(cls, x, leaky_alpha):
return np.maximum(x, 0) + np.minimum(x, 0) * leaky_alpha
@classmethod
def torch(cls, x, leaky_alpha):
return F.leaky_relu(x, leaky_alpha)
class leaky_relu(ActivationFunctor, metaclass=leakyReLUMeta):
binding_type = ActivationOp.LeakyReLU
@staticmethod
def epilogue_output_op(element_epilogue):
c_element_epilogue = dtype2ctype[element_epilogue]
class _EpilogueOutputOpParams(ctypes.Structure):
_fields_ = [
("alpha", c_element_epilogue),
("beta", c_element_epilogue),
("alpha_ptr", ctypes.c_void_p),
("beta_ptr", ctypes.c_void_p),
("leaky_alpha", c_element_epilogue)
]
def __init__(self, alpha, beta, leaky_alpha=0.2, *args) -> None:
self.alpha = to_ctype_value(alpha, element_epilogue)
self.beta = to_ctype_value(beta, element_epilogue)
self.alpha_ptr = 0
self.beta_ptr = 0
self.leaky_alpha = to_ctype_value(leaky_alpha, element_epilogue)
return _EpilogueOutputOpParams
##############################################################################
# Tanh operator
class tanhMeta(ActivationMeta):
@classmethod
def numpy(cls, x):
return np.tanh(x)
@classmethod
def torch(cls, x):
return torch.tanh(x)
class tanh(ActivationFunctor, metaclass=tanhMeta):
binding_type = ActivationOp.Tanh
##############################################################################
# Sigmoid operator
class sigmoidMeta(ActivationMeta):
@classmethod
def numpy(cls, x):
return 1.0 / (1.0 + np.exp(-x))
@classmethod
def torch(cls, x):
return F.sigmoid(x)
class sigmoid(ActivationFunctor, metaclass=sigmoidMeta):
binding_type = ActivationOp.Sigmoid
##############################################################################
# SiLu operator
class siluMeta(ActivationMeta):
@classmethod
def numpy(cls, x):
return x * sigmoidMeta.numpy()
@classmethod
def silu(cls, x):
return F.silu(x)
class silu(ActivationFunctor, metaclass=siluMeta):
binding_type = ActivationOp.SiLU
##############################################################################
# Hardswish operator
class hardswishMeta(ActivationMeta):
@classmethod
def numpy(cls, x):
relu6 = np.minimum(np.maximum(x + 3.0, 0), 6.0)
return x * relu6 / 6.0
@classmethod
def torch(cls, x):
return F.hardswish(x)
class hardswish(ActivationFunctor, metaclass=hardswishMeta):
binding_type = ActivationOp.HardSwish
##############################################################################
# GELU operator
class geluMeta(ActivationMeta):
@classmethod
def numpy(cls, x):
from scipy.special import erf
return 0.5 * x * (1 + erf(x / np.sqrt(2.0)))
@classmethod
def torch(cls, x):
return F.gelu(x)
class gelu(ActivationFunctor, metaclass=geluMeta):
binding_type = ActivationOp.Gelu

View File

@ -0,0 +1,34 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_cppgen.backend.evt.epilogue import EpilogueFunctorVisitor
from cutlass_cppgen.backend.evt.frontend import PythonASTFrontend

View File

@ -0,0 +1,38 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_cppgen.backend.evt.backend.sm80_emitter import Sm80Emitter
import cutlass_cppgen.backend.evt.backend.sm80_nodes as sm80_nodes
from cutlass_cppgen.backend.evt.backend.sm90_emitter import Sm90Emitter
import cutlass_cppgen.backend.evt.backend.sm90_nodes as sm90_nodes
from cutlass_cppgen.backend.evt.backend.sm100_emitter import Sm100Emitter
import cutlass_cppgen.backend.evt.backend.sm100_nodes as sm100_nodes

View File

@ -0,0 +1,159 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Base class for Epilogue Visitor Emitter
"""
from cutlass_library import DataTypeTag
from cutlass_cppgen.backend.evt.ir import TopoVisitorNode, DAGIR
class FusionCallbacks:
def __init__(self, dag_ir: DAGIR, cc: int, emit_CD=True) -> None:
"""
Emit the EVT fusion callbacks
:param dag_ir: the DAG IR holding the epilogue visitor
:param cc: compute capability
:param emit_CD: whether to emit nodes C & D as a part of the fusion callbacks
For Sm90, set emit_CD=False, as Tensor C & D are hardcoded in the collective API
so that their shared memory can be explicitly reused
For Sm89, set emit_CD=True as they are treated as normal AuxLoad & AuxStore nodes.
"""
self.dag_ir = dag_ir
self.emit_CD = emit_CD
self.cc = cc
self.evt_cc = 90 if cc >= 90 else cc
if self.cc < 90:
self.namespace = "threadblock"
else:
self.namespace = "fusion"
#
# Helper functions
#
def get_visitor_name(self, node: str):
"""
Get the visitor name
"""
meta = self.dag_ir.get_node_meta(node)
if not isinstance(meta, TopoVisitorNode) and self.dag_ir.in_degree(node) > 0:
return f"EVT{meta.name_camel}"
else:
return meta.name_camel
def emit(self):
node_metas = self.dag_ir.node_metas_topological_order()
epilogue_str = ""
# Step 1: emit individual node type decl
# emit the EVT & DAG connector
for meta in node_metas:
if not meta.disabled:
epilogue_str += self.emit_node(meta)
if not self.emit_CD and meta.name == "D":
continue
if isinstance(meta, TopoVisitorNode):
epilogue_str += self.emit_dag(meta)
else:
epilogue_str += self.emit_evt(meta)
# Step 2: post-processing & get callback name
if not self.emit_CD:
if not self.dag_ir.has_node("C"):
epilogue_str += "using ElementC = void;\nusing StrideC = StrideD;\n"
output_node = self.dag_ir.get_all_inputs("D")[0]
# The callback is the src of node D
callback_name = self.get_visitor_name(output_node)
else:
# The callback is the last node in the topological order
callback_name = self.get_visitor_name(node_metas[-1].name)
return epilogue_str, callback_name
def emit_evt(self, node):
if self.dag_ir.in_degree(node.name) == 0:
return ""
evt_tmp = f"""
using EVT{node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.evt_cc}EVT<
{node.name_camel},
"""
sorted_children = self.dag_ir.get_all_inputs(node.name)
evt_node_strs = [f" {self.get_visitor_name(child_name)}" for child_name in sorted_children]
evt_tmp += ",\n".join(evt_node_strs) + ">;\n"
return evt_tmp
def emit_dag(self, node):
subgraph = node.subgraph
subgraph_nodes = subgraph.nodes_topological_order()
# Emit the Edge Tuple
edge_tuples = "cute::tuple<\n"
for n in subgraph_nodes[:-1]:
in_edges = subgraph.in_edges(n)
edge_weights = [subgraph.get_edge_weight(edge[0], edge[1]) for edge in in_edges]
sorted_children = [edge[0] for _, edge in sorted(zip(edge_weights, in_edges))]
edge_tuple = " cute::seq<"
edge_str = [str(subgraph_nodes.index(child)) for child in sorted_children]
edge_tuple += ", ".join(edge_str) + ">,\n"
edge_tuples += edge_tuple
edge_tuples += " >"
# Emit the node list
dag_nodes = ""
dag_node_strs = []
for n in subgraph_nodes[:-1]:
n_meta = subgraph.get_node_meta(n)
if n_meta.disabled:
dag_node_strs.append(f" {self.get_visitor_name(n)}")
else:
dag_node_strs.append(f" {n_meta.name_camel}")
dag_nodes = ",\n".join(dag_node_strs)
return f"""
using {node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.evt_cc}TopologicalVisitor<
{DataTypeTag[node.subgraph.element_compute]},
{edge_tuples},
{dag_nodes}
>;
"""
def emit_node(self, node):
if isinstance(node, TopoVisitorNode):
emission = ""
for node in node.subgraph.node_metas_topological_order():
if not node.disabled:
emission += self.emit_node(node)
return emission
else:
return node.underlying_impl.type_decl

View File

@ -0,0 +1,116 @@
#################################################################################################
#
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Emitter for Sm100 Epilogue Visitor
"""
from cutlass_library import DataType, DataTypeTag, EpilogueScheduleTag, OpcodeClassTag
from cutlass_cppgen.backend.library import to_blackwell_threadblock_shape
from cutlass_cppgen.backend import GemmOperationUniversal
from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks
from cutlass_cppgen.backend.evt.ir.node import TupleEmitter
class Sm100CollectiveEpilogue:
def __init__(self, tile_description,
kernel_schedule,
epilogue_schedule,
element_accumulator,
element_d,
fusion_callbacks) -> None:
self.cta_tile_mnk, _ = to_blackwell_threadblock_shape(tile_description, tile_description.cluster_shape, kernel_schedule)
self.element_accumulator = element_accumulator
if fusion_callbacks.dag_ir.has_node("C"):
self.element_c = fusion_callbacks.dag_ir.get_node_meta("C").element
else:
self.element_c = DataType.void
self.element_d = element_d
self.schedule = epilogue_schedule
self.fusion_callbacks = fusion_callbacks
self.opclass = tile_description.math_instruction.opcode_class
@property
def CtaTileMNK(self) -> str:
"""
The threadblock shape
"""
return f"cute::Shape<_{self.cta_tile_mnk[0]}, _{self.cta_tile_mnk[1]}, _{self.cta_tile_mnk[2]}>"
@property
def EpilogueTileType(self) -> str:
"""
The epilogue tile type
"""
return "cutlass::epilogue::collective::EpilogueTileAuto"
@property
def Schedule(self) -> str:
return EpilogueScheduleTag[self.schedule]
def emit(self):
tuple_emitter = TupleEmitter("int64_t")
stride_D_str = self.fusion_callbacks.dag_ir.get_node_meta("D").underlying_impl.stride_mnl
stride_C_str = stride_D_str
if self.fusion_callbacks.dag_ir.has_node("C"):
stride_C_str = self.fusion_callbacks.dag_ir.get_node_meta("C").underlying_impl.stride_mnl
callback_decl, callback_name = self.fusion_callbacks.emit()
return callback_name, f"""
using EpilogueDescriptor = cutlass::epilogue::collective::detail::Sm100EpilogueDescriptor<
{OpcodeClassTag[self.opclass]},
{self.CtaTileMNK}, {self.EpilogueTileType},
{DataTypeTag[self.element_accumulator]}, {DataTypeTag[self.element_c]}, {DataTypeTag[self.element_d]},
{self.Schedule}, {stride_C_str}, {stride_D_str},
false /* IsPerColScaleSupported */,
false /* IsBlockScaleSupported */
>;
{callback_decl}
"""
class Sm100Emitter:
def __init__(self, operation: GemmOperationUniversal, graph) -> None:
fusion_callbacks = FusionCallbacks(graph, cc=100, emit_CD=False)
self.collective_epilogue = Sm100CollectiveEpilogue(
tile_description=operation.tile_description,
kernel_schedule=operation.tile_description.kernel_schedule,
epilogue_schedule=operation.tile_description.epilogue_schedule,
element_accumulator=operation.tile_description.math_instruction.element_accumulator,
element_d=fusion_callbacks.dag_ir.get_node_meta("D").element,
fusion_callbacks=fusion_callbacks
)
def emit(self):
return self.collective_epilogue.emit()

View File

@ -0,0 +1,134 @@
#################################################################################################
#
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from pycute import product
from cutlass_library import DataTypeSize, DataTypeTag
from cutlass_cppgen.backend.evt.ir import AuxLoadImpl, AuxStoreImpl
import cutlass_cppgen.backend.evt.backend.sm90_nodes as sm90_nodes
from cutlass_cppgen.backend.library import FloatRoundStyleTag
Sm100AccumulatorImpl = sm90_nodes.Sm90AccumulatorImpl
Sm100LoadSrcImpl = sm90_nodes.Sm90LoadSrcImpl
Sm100ScalarBroadcastImpl = sm90_nodes.Sm90ScalarBroadcastImpl
Sm100RowBroadcastImpl = sm90_nodes.Sm90RowBroadcastImpl
Sm100ColumnBroadcastImpl = sm90_nodes.Sm90ColumnBroadcastImpl
Sm100ComputeImpl = sm90_nodes.Sm90ComputeImpl
Sm100StoreDImpl = sm90_nodes.Sm90StoreDImpl
Sm100ColumnReductionImpl = sm90_nodes.Sm90ColumnReductionImpl
Sm100RowReductionImpl = sm90_nodes.Sm90RowReductionImpl
Sm100ScalarReductionImpl = sm90_nodes.Sm90ScalarReductionImpl
class Sm100AuxLoadImpl(AuxLoadImpl):
@property
def descriptor(self) -> str:
"""
Descriptor for Aux Load
"""
return f"{self.name_camel}Descriptor"
def decl_descriptor(self) -> str:
"""
Declare the descriptor type
"""
return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::Sm100AuxLoadDescriptor<EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}>;\n"
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = self.decl_descriptor()
self._type_decl += f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxLoad<
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
{self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, typename {self.descriptor}::CopyOpS2R
>;
"""
return self._type_decl
def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
"""
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
"""
return (DataTypeSize[self.element] * stages_c * product(epilogue_tile_mn) // 8, 128)
class Sm100AuxStoreImpl(AuxStoreImpl):
@property
def descriptor(self) -> str:
"""
Descriptor for Aux Load
"""
return f"{self.name_camel}Descriptor"
def decl_descriptor(self) -> str:
"""
Declare the descriptor type
"""
return f"""
using {self.descriptor} = cutlass::epilogue::collective::detail::Sm100AuxStoreDescriptor<
EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}
>;
"""
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = self.decl_descriptor()
self._type_decl += f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxStore<
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
{FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom,
typename {self.descriptor}::CopyOpR2S
>;
"""
return self._type_decl
def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
"""
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
"""
return (DataTypeSize[self.element] * stages_d * product(epilogue_tile_mn) // 8, 128)

View File

@ -0,0 +1,47 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Emitter for Sm80 Epilogue Visitor
"""
from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks
from cutlass_cppgen.backend import GemmOperationUniversal
class Sm80Emitter:
def __init__(self, operation: GemmOperationUniversal, graph) -> None:
self.fusion_callbacks = FusionCallbacks(graph, cc=80)
def emit(self):
callback_decl, callback_name = self.fusion_callbacks.emit()
return callback_name, callback_decl

View File

@ -0,0 +1,258 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_library import DataTypeSize, DataTypeTag
from cutlass_cppgen.backend.evt.ir import (
# Load Node
AccumulatorImpl,
AuxLoadImpl,
ColumnBroadcastImpl,
LoadNode,
LoadSrcImpl,
RowBroadcastImpl,
ScalarBroadcastImpl,
# Compute Node
ComputeImpl,
# Store Node
AuxStoreImpl,
ColumnReductionImpl,
RowReductionImpl,
ScalarReductionImpl
)
from cutlass_cppgen.backend.library import (
FloatRoundStyleTag,
FunctionalOp,
op_tag,
)
class Sm80AccumulatorImpl(AccumulatorImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""\nusing {self.name_camel} = cutlass::epilogue::threadblock::VisitorAccFetch;\n"""
return self._type_decl
class Sm80AuxLoadImpl(AuxLoadImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorAuxLoad<
OutputTileThreadMap, {DataTypeTag[self.element]}, {self.stride_mnl}
>;
"""
return self._type_decl
class Sm80LoadSrcImpl(Sm80AuxLoadImpl):
pass
class Sm80ScalarBroadcastImpl(ScalarBroadcastImpl):
def __init__(self, node: LoadNode) -> None:
super().__init__(node)
self.broadcast_count = 1
self.reduction_fn = FunctionalOp.Multiplies
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorScalarBroadcast<
{DataTypeTag[self.element]}, {self.stride_mnl}, {self.broadcast_count}, {op_tag(self.reduction_fn)}
>;
"""
return self._type_decl
class Sm80RowBroadcastImpl(RowBroadcastImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, {DataTypeTag[self.element]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm80ColumnBroadcastImpl(ColumnBroadcastImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorColBroadcast<
OutputTileThreadMap, {DataTypeTag[self.element]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm80ComputeImpl(ComputeImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorCompute<
{op_tag(self.fn)}, {DataTypeTag[self.element_output]}, {DataTypeTag[self.element_compute]},
{FloatRoundStyleTag[self.round_style]}
>;
"""
return self._type_decl
class Sm80AuxStoreImpl(AuxStoreImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, {DataTypeTag[self.element]}, {FloatRoundStyleTag[self.round_style]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm80StoreDImpl(Sm80AuxStoreImpl):
pass
class Sm80ColumnReductionImpl(ColumnReductionImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorColReduction<
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
OutputTileThreadMap, {DataTypeTag[self.element]},
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm80RowReductionImpl(RowReductionImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorRowReduction<
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
OutputTileThreadMap, {DataTypeTag[self.element]},
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm80ScalarReductionImpl(ScalarReductionImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorScalarReduction<
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
OutputTileThreadMap, {DataTypeTag[self.element]},
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
{self.stride_mnl}
>;
"""
return self._type_decl

View File

@ -0,0 +1,98 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Emitter for Sm90 Epilogue Visitor
"""
from cutlass_library import DataTypeTag, EpilogueScheduleTag
from cutlass_cppgen.backend import GemmOperationUniversal
from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks
class CollectiveEpilogue:
def __init__(self, tile_description,
schedule,
element_c,
element_d,
fusion_callbacks) -> None:
self.cta_tile_mnk = tile_description.threadblock_shape
self.element_c = element_c
self.element_d = element_d
self.schedule = schedule
self.fusion_callbacks = fusion_callbacks
@property
def CtaTileMNK(self) -> str:
"""
The threadblock shape
"""
return f"cute::Shape<_{self.cta_tile_mnk[0]}, _{self.cta_tile_mnk[1]}, _{self.cta_tile_mnk[2]}>"
@property
def EpilogueTileType(self) -> str:
"""
The epilogue tile type
"""
return "cutlass::epilogue::collective::EpilogueTileAuto"
@property
def Schedule(self) -> str:
return EpilogueScheduleTag[self.schedule]
def emit(self):
callback_decl, callback_name = self.fusion_callbacks.emit()
return callback_name, f"""
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor<
{self.CtaTileMNK}, {self.EpilogueTileType},
{DataTypeTag[self.element_c]}, {DataTypeTag[self.element_d]},
{self.Schedule}
>;
{callback_decl}
"""
class Sm90Emitter:
def __init__(self, operation: GemmOperationUniversal, graph) -> None:
fusion_callbacks = FusionCallbacks(graph, cc=90, emit_CD=False)
self.collective_epilogue = CollectiveEpilogue(
tile_description=operation.tile_description,
schedule=operation.tile_description.epilogue_schedule,
element_c=operation.C.element,
element_d=operation.C.element,
fusion_callbacks=fusion_callbacks
)
def emit(self):
return self.collective_epilogue.emit()

View File

@ -0,0 +1,329 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from pycute import product
from cutlass_library import DataTypeSize, DataTypeTag
from cutlass_cppgen.backend.evt.ir import (
# Load Node
AccumulatorImpl,
AuxLoadImpl,
ColumnBroadcastImpl,
LoadNode,
LoadSrcImpl,
RowBroadcastImpl,
ScalarBroadcastImpl,
# Compute Node
ComputeImpl,
ComputeNode,
# Store Node
AuxStoreImpl,
ColumnReductionImpl,
RowReductionImpl,
ScalarReductionImpl,
StoreNode,
StoreDImpl,
)
from cutlass_cppgen.backend.library import (
FloatRoundStyleTag,
FunctionalOp,
op_tag,
)
class Sm90AccumulatorImpl(AccumulatorImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""\nusing {self.name_camel} = cutlass::epilogue::fusion::Sm90AccFetch;\n"""
return self._type_decl
class Sm90LoadSrcImpl(LoadSrcImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using ElementC = {DataTypeTag[self.element]};
using StrideC = {self.stride_mnl};
using {self.name_camel} = cutlass::epilogue::fusion::Sm90SrcFetch<{DataTypeTag[self.element]}>;
"""
return self._type_decl
class Sm90AuxLoadImpl(AuxLoadImpl):
@property
def descriptor(self) -> str:
"""
Descriptor for Aux Load
"""
return f"{self.name_camel}Descriptor"
def decl_descriptor(self) -> str:
"""
Declare the descriptor type
"""
return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::AuxLoadDescriptor<EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}>;\n"
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = self.decl_descriptor()
self._type_decl += f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxLoad<
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
{self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, typename {self.descriptor}::CopyOpS2R
>;
"""
return self._type_decl
def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
"""
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
"""
return (DataTypeSize[self.element] * stages_c * product(epilogue_tile_mn) // 8, 128)
class Sm90ScalarBroadcastImpl(ScalarBroadcastImpl):
def __init__(self, node: LoadNode) -> None:
super().__init__(node)
self.broadcast_count = 1
self.reduction_fn = FunctionalOp.Multiplies
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ScalarBroadcast<
{DataTypeTag[self.element]}, {self.stride_mnl}, {self.broadcast_count}, {op_tag(self.reduction_fn)}
>;
"""
return self._type_decl
class Sm90RowBroadcastImpl(RowBroadcastImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, {DataTypeTag[self.element_output]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm90ColumnBroadcastImpl(ColumnBroadcastImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ColBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, {DataTypeTag[self.element_output]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm90ComputeImpl(ComputeImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90Compute<
{op_tag(self.fn)}, {DataTypeTag[self.element_output]}, {DataTypeTag[self.element_compute]},
{FloatRoundStyleTag[self.round_style]}
>;
"""
return self._type_decl
class Sm90AuxStoreImpl(AuxStoreImpl):
@property
def descriptor(self) -> str:
"""
Descriptor for Aux Load
"""
return f"{self.name_camel}Descriptor"
def decl_descriptor(self) -> str:
"""
Declare the descriptor type
"""
return f"""
using {self.descriptor} = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}
>;
"""
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = self.decl_descriptor()
self._type_decl += f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxStore<
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
{FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom,
typename {self.descriptor}::CopyOpR2S
>;
"""
return self._type_decl
def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
"""
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
"""
return (DataTypeSize[self.element] * stages_d * product(epilogue_tile_mn) // 8, 128)
class Sm90StoreDImpl(StoreDImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
return f"""
using ElementD = {DataTypeTag[self.element]};
using StrideD = {self.stride_mnl};
"""
class Sm90ColumnReductionImpl(ColumnReductionImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ColReduction<
{op_tag(self.reg_reduce_fn)}, {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, 0,
typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]},
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm90RowReductionImpl(RowReductionImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowReduction<
{op_tag(self.reg_reduce_fn)}, {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, 0 /* Stages */,
typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]},
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm90ScalarReductionImpl(ScalarReductionImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ScalarReduction<
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
{DataTypeTag[self.element]}, {DataTypeTag[self.element_compute]},
{FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}
>;
"""
return self._type_decl

View File

@ -0,0 +1,168 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Epilogue Visitor interface for compiling, and running visitor-based epilogue.
"""
import ctypes
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
from cutlass_library import DataType
import numpy as np
from cutlass_cppgen.backend.epilogue import EpilogueFunctorBase
import cutlass_cppgen.backend.evt.backend
from cutlass_cppgen.backend.frontend import TensorFrontend
from cutlass_cppgen.utils.datatypes import is_numpy_tensor
from cutlass_cppgen.backend.evt.passes.util import cc_map
class EpilogueFunctorVisitor(EpilogueFunctorBase):
"""
Apply an epilogue functor described by the epilogue EVT
:param cc: compute capability
:param visitor_frontend: user-provide visitor frontend
"""
def __init__(self, cc: int, visitor, element_compute=DataType.f32) -> None:
# Type of Emitter based on CC
self.emit_cls = getattr(cutlass_cppgen.backend.evt.backend, f"Sm{cc_map[cc]}Emitter")
# Visitor Types
self.visitor = visitor
self.graph = visitor.dag_ir
# Data types
self.element_epilogue = element_compute # element compute
self.element_output = self.graph.get_node_meta('D').underlying_impl.element
# Epilogue Thread Type
epilogue_thread_type = self.visitor.epilogue_thread_type
if cc_map[cc] in [90, 100]:
self.arg_c_type = self.visitor.arg_c_type
self.arg_d_type = self.visitor.arg_d_type
output_names = self.visitor.return_names
reduction_names = self.visitor.reduction_names
# Epilogue stages specialized for sm80 kernel
if cc == 80:
if hasattr(self.visitor, "epilogue_stages"):
self.epilogue_stages = self.visitor.epilogue_stages
assert self.epilogue_stages <= 2, "Only supports Stages <=2 in SM80 Epilogue"
# Epilogue Argument Type
class _Arguments(ctypes.Structure):
"""
Concepts:
class _EpilogueArguments(ctypes.Structure):
_fields_ = [
("epilogue", _Arguments), <- this class
("ptr_C", ctypes.c_void_p),
("stride_C", StrideBatched_),
("ptr_D", ctypes.c_void_p),
("stride_D", StrideBatched_)
]
"""
_fields_ = [
("output_op", epilogue_thread_type)
]
def __init__(self, kwargs: dict) -> None:
# The user-input kwargs is a dict of (name: tensors)
# We first convert all of them to device pointers
ptr_kwargs = {}
for key in kwargs.keys():
is_output = key in output_names and key not in reduction_names
ptr_kwargs[key] = self.get_tensor_ptr(key, kwargs, is_output)
# Initialize the thread arguments
self.output_op = epilogue_thread_type(ptr_kwargs)
def get_tensor_ptr(self, tensor_name, kwargs, is_output=False):
"""
Helper function for extracting device pointer
"""
# Skip the special tensors
if cc in [90, 100]:
if tensor_name in ["C", "D"]:
return 0
if tensor_name not in kwargs.keys():
raise ValueError(f"Tensor {tensor_name} is not provided.")
tensor = kwargs[tensor_name]
# For float scalar constant, directly return the value
if isinstance(tensor, float):
return tensor
# The tensor frontend returns a device buffer for np.ndarray
# and device ptr for other frontends
buffer_or_ptr = TensorFrontend.argument(tensor, is_output)
if is_numpy_tensor(tensor):
# Remember the host tensor for later synchronization
setattr(self, f"{tensor_name}_buffer", buffer_or_ptr)
setattr(self, f"{tensor_name}_host", tensor)
return int(buffer_or_ptr.ptr)
else:
return int(buffer_or_ptr)
def sync(self):
"""
Synchronize the results from device to host
"""
for name in output_names:
if hasattr(self, f"{name}_host"):
host_tensor = getattr(self, f"{name}_host")
tensor_ptr = getattr(self, f"{name}_buffer").ptr
(err,) = cuda.cuMemcpyDtoH(
host_tensor,
tensor_ptr,
host_tensor.size * host_tensor.itemsize,
)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
self.epilogue_type = _Arguments
def emit(self, operation):
"""
Emit the C++ code
"""
emitter = self.emit_cls(operation, self.graph)
return emitter.emit()
def get_smem_size(self, tile_description):
"""
Get the shared memory size in bytes
"""
return self.visitor.get_smem_size(tile_description)

View File

@ -0,0 +1,33 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_cppgen.backend.evt.frontend.python_ast import PythonASTFrontend

View File

@ -0,0 +1,272 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Base class for Python EVT Frontend
"""
from typing import Union
from cutlass_library import DataType
from cutlass_cppgen.backend.evt.ir import (
ComputeNode,
DAGIR,
LayoutNode,
LoadNode,
StoreNode,
)
from cutlass_cppgen.backend.evt.passes import (
EVTGraphDrawer,
EVTPassManager,
GetSmemSize,
PassDAG2Tree,
PassGetArgumentType,
PassGetImpl,
PassFixElementD,
PassLayoutManipulateElimination,
PassPreprocessRed,
PassShapeTypePropagation,
)
from cutlass_cppgen.backend.evt.passes.util import cc_map
from cutlass_cppgen.backend.utils import device_cc
from cutlass_cppgen.epilogue.evt_ops import permute, reshape
from cutlass_cppgen.utils.datatypes import library_type
class EVTFrontendBase:
layout_fns = {
"permute": permute,
"reshape": reshape
}
def __init__(self, cc, element_compute=DataType.f32, additional_passes=[], **kwargs) -> None:
self.cc = cc
self.element_compute = library_type(element_compute)
self.dag_ir = DAGIR(self.cc, self.element_compute)
self.compute_cnt = 0
self.layout_cnt = 0
self.imm_cnt = 0
self.pass_manager = EVTPassManager(
self.dag_ir,
[
PassPreprocessRed,
PassGetArgumentType,
PassShapeTypePropagation,
PassLayoutManipulateElimination,
PassGetImpl,
PassDAG2Tree,
PassFixElementD
] + additional_passes)
if self.cc == 80:
self._epilogue_stages = 1
else:
self._epilogue_stages = None
@property
def epilogue_stages(self):
return self._epilogue_stages
@epilogue_stages.setter
def epilogue_stages(self, stages):
self._epilogue_stages = stages
def parse(self, *args, **kwargs):
raise NotImplementedError(f"The 'parse' function must be overloaded in frontend class")
def trace(self, *args, **kwargs):
# Parse the input
self.parse(*args, **kwargs)
# Verify the DAG IR to ensure that "D" is the output node with out_degree = 0
if (self.cc >= 90):
if (self.dag_ir.out_degree("D") != 0):
raise RuntimeError(
f"On SM90 or higher, D is expected to be a output node with 0 users to "
f"enable smem reuse between C and D, but got {self.dag_ir.out_degree('D')}")
# Run the passes
self.pass_manager()
# Set the epilogue type
self.epilogue_thread_type = self.dag_ir.epilogue_thread_type
if cc_map[self.cc] in [90, 100]:
self.arg_c_type = self.dag_ir.arg_c_type
self.arg_d_type = self.dag_ir.arg_d_type
self.reduction_names = self.dag_ir.reduction_names
#
# Helper functions for DAG IR manipulation
#
def add_node(self, node):
self.dag_ir.add_node(node)
def add_edge(self, src, tgt, weight=0):
self.dag_ir.add_edge(src, tgt, weight=weight)
def set_tensor(self, node_name, example):
"""
Add an example tensor to node {node_name} in the DAG IR
"""
meta = self.dag_ir.get_node_meta(node_name)
meta.tensor = {"tensor": example}
def set_store_tensor(self, node_name, example):
"""
Add an example tensor to node {node_name} in the DAG IR
"""
meta = self.dag_ir.get_node_meta(node_name)
meta.store_tensor = {"tensor": example}
def mark_output(self, node_name):
"""
Mark a store node as output
"""
meta = self.dag_ir.get_node_meta(node_name)
if not isinstance(meta, StoreNode):
raise ValueError(
f"Only StoreNodes can be marked as output. "
f"Got {type(meta).__name__}: {node_name}")
meta.is_output = True
# Add node with specific type
def add_load_node(self, name, example):
"""
Add a Load node to DAG IR
:param name: name of the loaded variable
:type name: str
:param example: example input
:type example: np.ndarray|torch.Tensor|cupy.ndarray|float
"""
if name is None:
raise ValueError(f"Name is not provided.")
if example is None:
raise ValueError(f"Example input for {name} is not provided.")
load_node = LoadNode(name)
load_node.tensor = {"tensor": example}
# Special logics for accumulator
if name == "accum":
if load_node.tensor.rank == 2:
new_shape = tuple([1, ] + list(load_node.tensor.shape))
load_node.tensor.broadcast(new_shape)
elif load_node.tensor.rank < 2 or load_node.tensor.rank > 3:
raise ValueError(f"Expect example inputs for 'accum' be a rank-2 or rank-3 tensor. Got {load_node.tensor.shape}.")
self.add_node(load_node)
def add_imm(self, value: Union[float,int]):
"""
Add an immediate scalar value to DAG IR
:param value: the value of the immediate scalar
:type value: float
"""
try:
value = float(value)
except:
raise ValueError(f"{type(value).__name__} cannot be converted to float.")
name = f"imm_{value}_k{self.imm_cnt}".replace('.', '_')
self.imm_cnt += 1
load_node = LoadNode(name)
load_node.tensor = {"tensor": value, "is_constant": True}
self.add_node(load_node)
return name
def add_compute_node(self, op, name=None):
"""
Add a compute node.
:param op: the computation op
:param name: the node name (optional)
:type name: str
:return: the name of the compute node
"""
if name is None:
name = f"compute_{self.compute_cnt}"
self.compute_cnt += 1
compute_node = ComputeNode(
name=name, fn=op,
element_output=self.element_compute,
element_compute=self.element_compute)
self.add_node(compute_node)
return compute_node.name
def add_layout_node(self, op, kwargs, name=None):
"""
Add a layout node.
:param op: the layout op
:type op: evt_ops
:param name: the node name (optional)
:type name: str
:return: the name of the layout node
"""
if name is None:
name = f"layout_{self.layout_cnt}"
self.layout_cnt += 1
layout_node = LayoutNode(name=name, fn=op, kwargs=kwargs)
self.add_node(layout_node)
return layout_node.name
def add_store_node(self, name):
store_node = StoreNode(name)
self.add_node(store_node)
#
# Visualization The DAG IR
#
def visualize(self, name="dag_ir"):
"""
Visualize the dag ir with svg file
:param name: the name of the graph
"""
drawer = EVTGraphDrawer(self.dag_ir, name)
try:
for name, graph in drawer.get_dot_graph():
graph.write_svg(f"./{name}.svg")
except:
raise RuntimeError(
"'dot' is not found in path. GraphDrawer is disabled. "
"Please install it with 'sudo apt-get install graphviz'."
)
#
# Get shared memory size
#
def get_smem_size(self, tile_description):
"""
Get the shared memory size of the epilogue
"""
smem_size = GetSmemSize(self.dag_ir)(tile_description)
return smem_size

View File

@ -0,0 +1,194 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Python AST frontend that parses input into DAG IR
"""
import ast
import inspect
import textwrap
from cutlass_library import DataType
import cutlass_cppgen
from cutlass_cppgen.backend.evt.frontend.frontend_base import EVTFrontendBase
from cutlass_cppgen.backend.epilogue import identity, relu, tanh, sigmoid, silu, hardswish, gelu
from cutlass_cppgen.backend.library import FunctionalOp
class PythonASTFrontend(EVTFrontendBase, ast.NodeVisitor):
def __init__(self, cc, element_compute=DataType.f32, **kwargs):
super().__init__(cc, element_compute, **kwargs)
# Flags
# If this state is True, visit_Constant returns values without creating imm node
self.no_imm = False
self.visiting_return = False
def parse(self, example_inputs):
self.example_inputs = example_inputs
self.source = textwrap.dedent(inspect.getsource(self.__call__))
self.ast = ast.parse(self.source)
self.visit(self.ast)
#
# Helper functions
#
@staticmethod
def ast_op_to_bindings(op):
mapping = {
ast.Add: FunctionalOp.Plus,
ast.Sub: FunctionalOp.Minus,
ast.Mult: FunctionalOp.Multiplies,
ast.Div: FunctionalOp.Divides,
"maximum": FunctionalOp.Maximum,
"minimum": FunctionalOp.Minimum,
"identity": identity.binding_type,
"relu": relu.binding_type,
"tanh": tanh.binding_type,
"sigmoid": sigmoid.binding_type,
"silu": silu.binding_type,
"hardswish": hardswish.binding_type,
"gelu": gelu.binding_type,
"multiply_add": FunctionalOp.MultiplyAdd,
"sum": (FunctionalOp.Plus, FunctionalOp.AtomicAdd),
"max": (FunctionalOp.Maximum, FunctionalOp.AtomicMaximum),
"exp": FunctionalOp.Exp
}
return mapping[op]
#
# Visiting different node types
#
def visit_FunctionDef(self, node: ast.FunctionDef):
# Visit args and register load nodes
for arg in node.args.args:
self.visit(arg)
for expr in node.body:
self.visit(expr)
def visit_arg(self, node: ast.arg):
# Name of the argument
name = node.arg
try:
example_tensor = self.example_inputs[name]
except:
raise RuntimeError(f"Example input for {name} is not provided.")
self.add_load_node(name, example_tensor)
def visit_Name(self, node: ast.Name):
return node.id
def visit_Constant(self, node: ast.Constant):
if self.no_imm:
return node.value
else:
name = self.add_imm(node.value)
return name
def visit_Tuple(self, node: ast.Tuple):
results = []
for elt in node.elts:
results.append(self.visit(elt))
return tuple(results)
def visit_keyword(self, node: ast.keyword):
return {node.arg: self.visit(node.value)}
def visit_BinOp(self, node: ast.BinOp):
if self.visiting_return:
raise SyntaxError("Return value cannot be an expression")
lhs = self.visit(node.left)
rhs = self.visit(node.right)
op = self.ast_op_to_bindings(type(node.op))
name = self.add_compute_node(op)
# Add edges
# The edge weights are used to sort the input args
self.add_edge(lhs, name, weight=0)
self.add_edge(rhs, name, weight=1)
return name
def visit_Assign(self, node: ast.BinOp):
target = self.visit(node.targets[0])
value = self.visit(node.value)
# Create the assign node
self.add_store_node(target)
# Add edges
self.add_edge(value, target)
return target
def visit_Call(self, node: ast.Call):
if self.visiting_return:
raise SyntaxError("Return value cannot be an expression")
func = self.visit(node.func)
args = [self.visit(arg) for arg in node.args]
if func in self.layout_fns.keys():
# Parse kwargs
# By default, visiting imm automatically creates a load node
# However, in function call, keyword args are used to set
# specific function attributes such as indices for permute
# So no_imm is set to True temporarily
self.no_imm = True
kwargs = {}
for kw in node.keywords:
kwargs.update(self.visit(kw))
self.no_imm = False
op = self.layout_fns[func]
name = self.add_layout_node(op, kwargs)
else:
op = self.ast_op_to_bindings(func)
name = self.add_compute_node(op)
# Add edges
for idx, arg in enumerate(args):
self.add_edge(arg, name, weight=idx)
return name
def visit_Return(self, node: ast.Return):
self.visiting_return = True
results = self.visit(node.value)
self.visiting_return = False
self.return_names = results
if not isinstance(results, tuple):
results = (results,)
for rst in results:
try:
example_tensor = self.example_inputs[rst]
except:
raise RuntimeError(f"Example input for {rst} is not provided.")
self.set_store_tensor(rst, example_tensor)
self.mark_output(rst)

View File

@ -0,0 +1,53 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_cppgen.backend.evt.ir.compute_nodes import ComputeNode, ComputeImpl
from cutlass_cppgen.backend.evt.ir.dag_ir import DAGIR
from cutlass_cppgen.backend.evt.ir.layout_nodes import LayoutNode
from cutlass_cppgen.backend.evt.ir.load_nodes import (
LoadNode,
AccumulatorImpl,
LoadSrcImpl,
AuxLoadImpl,
RowBroadcastImpl,
ColumnBroadcastImpl,
ScalarBroadcastImpl
)
from cutlass_cppgen.backend.evt.ir.node import TopoVisitorNode, NoOpImpl
from cutlass_cppgen.backend.evt.ir.store_nodes import (
StoreNode,
StoreDImpl,
AuxStoreImpl,
ColumnReductionImpl,
RowReductionImpl,
ScalarReductionImpl
)

View File

@ -0,0 +1,91 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Python registration for compute nodes in EVT
"""
from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase
from cutlass_cppgen.backend.library import FloatRoundStyle
class ComputeImplBase(ImplBase):
"""
Base class for compute implementation
"""
def __init__(self, node) -> None:
super().__init__(node)
class ComputeImpl(ComputeImplBase):
"""
Implementation for Compute Node
"""
def __init__(self, node) -> None:
super().__init__(node)
self.fn = node.fn
self.element_output = node.element_output
self.element_compute = node.element_compute
self.round_style = node.round_style
@staticmethod
def match(node, problem_size: tuple):
return True
class ComputeNode(NodeBase):
"""
Compute Node in DAG IR
"""
possible_impls = [
ComputeImpl
]
def __init__(
self, name: str, fn, element_output,
element_compute,
round_style=FloatRoundStyle.ToNearest) -> None:
super().__init__(name)
self.op = "compute"
self.fn = fn
self.element_compute = element_compute
self.round_style = round_style
def type_propagation(self, *args, **kwargs):
"""
Load node loads tensor under type `tensor.element` and returns an array of type `tensor.element`.
"""
self.element = self.element_compute
# In general, the compute nodes have element_output = element_compute
# In certain cases like producer of D it is overwritten by other passes
if not hasattr(self, "element_output"):
self.element_output = self.element

View File

@ -0,0 +1,254 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
DAG IR used by Python EVT
"""
import networkx as nx
from cutlass_library import DataType
from cutlass_cppgen.backend.evt.ir.compute_nodes import ComputeNode
from cutlass_cppgen.backend.evt.ir.node import NodeBase
from cutlass_cppgen.backend.library import ActivationOp
from cutlass_cppgen.backend.utils import device_cc
class DAGIR:
"""
``DAGIR`` is the main data structure used in the EVT Intermediate Representation.
It consists of a series of ``Node`` s, each representing epilogue visitor nodes.
In the DAGIR, ``node`` is an string of its name. ``node_meta`` is the underlying class of the node
"""
def __init__(self, cc, element_compute=DataType.f32) -> None:
# The EVT DAGIR is managed through the nextworkX Digraph class
self._graph = nx.DiGraph()
self.element_compute = element_compute
self.reduction_names = []
self.cc = cc
self.identity_counter = 0
#
# IR manipulator
#
def add_node(self, meta: NodeBase):
"""
Add a node to dag ir
"""
if self.has_node(meta.name):
raise SyntaxError(f"Variable '{meta.name}' cannot be defined twice.")
self._graph.add_node(meta.name, meta=meta)
def add_edge(self, src: str, dst: str, weight: int=0):
"""
Add an edge src -> dst to dag ir with weight
"""
if not self.has_node(src):
raise SyntaxError(f"Variable '{src}' is undefined.")
if not self.has_node(dst):
raise SyntaxError(f"Variable '{dst}' is undefined.")
if self._graph.has_edge(src, dst):
# The DiGraph doesn't support multiple edges between two nodes
# We insert an identity node in such case as a workaround
identity_name = f"autogen_identity_{self.identity_counter}"
self.identity_counter += 1
compute_node = ComputeNode(
name=identity_name, fn=ActivationOp.Identity,
element_output=self.element_compute,
element_compute=self.element_compute)
self.add_node(compute_node)
self.add_edge(src, identity_name, 0)
self.add_edge(identity_name, dst, weight)
else:
self._graph.add_edge(src, dst, weight=weight)
def remove_node(self, node: str):
"""
Remove node from dag ir
"""
self._graph.remove_node(node)
def remove_edge(self, src: str, dst: str):
"""
Remove edge src -> dst
"""
self._graph.remove_edge(src, dst)
#
# Helper functions for getting attrs
#
def has_node(self, node: str) -> bool:
"""
Check if the node is in the graph
"""
return self._graph.has_node(node)
def in_degree(self, node: str):
"""
Get the input degree of node
"""
return self._graph.in_degree(node)
def in_edges(self, node: str):
"""
Get the input edges of node
"""
return [edge for edge in self._graph.in_edges(node)]
def out_degree(self, node: str):
"""
Get the output degree of node
"""
return self._graph.out_degree(node)
def out_edges(self, node: str):
"""
Get the output edges of node
"""
return [edge for edge in self._graph.out_edges(node)]
def get_node_meta(self, node: str):
"""
Get the meta data of the node
"""
return self._graph.nodes[node]["meta"]
def get_edge_weight(self, src, dst):
"""
Get the edge weight of edge src->dst
"""
return self._graph.get_edge_data(src, dst)["weight"]
#
# High-level helper functions
#
def all_reachable_nodes(self, node: str):
"""
Get all the nodes reachable from the current node (exclude)
"""
return list(nx.dfs_preorder_nodes(self._graph, source=node))
def get_users(self, node: str):
"""
Get all users of the current node
"""
return [edge[1] for edge in self.out_edges(node)]
def get_all_inputs(self, node: str):
"""
Get all the input nodes sorted by edge weight
"""
in_edges = self.in_edges(node)
edge_weights = [self.get_edge_weight(*edge) for edge in in_edges]
return [edge[0] for _, edge in sorted(zip(edge_weights, in_edges))]
def get_all_inputs_meta(self, node: str):
"""
Get all the input node metas sorted by edge weight
"""
return [self.get_node_meta(input_node) for input_node in self.get_all_inputs(node)]
def replace_all_uses_with(self, node1, node2):
"""
Replace all uses of node1 with node2
"""
for edge in self.out_edges(node1):
weight = self.get_edge_weight(*edge)
user = edge[1]
self.add_edge(node2, user, weight)
self.remove_edge(node1, user)
self.remove_node(node1)
#
# Node accessor
#
def nodes_topological_order(self):
"""
Get the nodes in the unique lexicographical topological order
It generates a unique ordering of nodes by first sorting topologically
and then additionally by sorting lexicographically.
Although topological_sort alone also works, this generates a unique key
for each epilogue visitor pattern and ensures the compilation cache can be reused.
:return: list[str]
"""
return list(nx.lexicographical_topological_sort(self._graph))
def node_metas_topological_order(self):
"""
Get the node metas in topological order
:return: list[NodeBase]
"""
return [self.get_node_meta(node) for node in self.nodes_topological_order()]
@property
def nodes(self):
"""
Get all nodes
:return: list[str]
"""
return list(self._graph.nodes)
@property
def nodes_meta(self):
"""
Get all node metas
:return: list[NodeBase]
"""
return [data[1]['meta'] for data in self._graph.nodes.data()]
@property
def edges(self):
"""
Get all edges
:return: list[(str, str)]
"""
return list(self._graph.edges)
#
# Path
#
def has_path(self, src: str, target: str) -> bool:
"""
Return True is a path exists from src to target
"""
return nx.has_path(self._graph, src, target)

View File

@ -0,0 +1,324 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Layout algebras
"""
from pycute import Layout, composition, make_layout, flatten, product
def _infer_split(old_shape, new_shape):
old_shape = _tuple_to_list(old_shape)
new_shape = _tuple_to_list(new_shape)
if len(old_shape) == 0 and len(new_shape) == 0:
return []
if len(old_shape) == 0:
if product(tuple(new_shape)) != 1:
raise ValueError("Invalid reshape size")
else:
return new_shape
if len(new_shape) == 0:
if product(tuple(old_shape)) != 1:
raise ValueError("Invalid reshape size")
else:
return old_shape
# This is done recursively by only process the last dimension at each time
old_dim = old_shape[-1]
new_dim = new_shape[-1]
# Exact match
if old_dim == new_dim:
return _infer_split(old_shape[:-1], new_shape[:-1]) + [new_dim,]
# Needs split
if old_dim > new_dim and old_dim % new_dim == 0:
residual = old_dim // new_dim
return _infer_split(old_shape[:-1] + [residual,], new_shape[:-1]) + [new_dim,]
# Needs merge
if old_dim < new_dim and new_dim % old_dim == 0:
residual = new_dim // old_dim
return _infer_split(old_shape[:-1], new_shape[:-1] + [residual,]) + [old_dim,]
raise NotImplementedError(f"Unsupported split: {old_shape} -> {new_shape}")
def _infer_merge(flatten_shape, shape):
flatten_shape = _tuple_to_list(flatten_shape)
shape = _tuple_to_list(shape)
idx_flat = 0
merged_shape = []
for dim in shape:
# Exact match
if dim == flatten_shape[idx_flat]:
merged_shape.append(dim)
idx_flat += 1
# Need group
elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0:
residual = dim
group = []
while(residual > 1):
group.append(flatten_shape[idx_flat])
residual = residual // flatten_shape[idx_flat]
idx_flat += 1
merged_shape.append(group)
else:
raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}")
return merged_shape
def _list_to_tuple(nested_list):
if isinstance(nested_list, list) or isinstance(nested_list, tuple):
return tuple(_list_to_tuple(item) for item in nested_list)
return nested_list
def _tuple_to_list(nested_tuple):
if isinstance(nested_tuple, list) or isinstance(nested_tuple, tuple):
return list(_tuple_to_list(item) for item in nested_tuple)
return nested_tuple
def _reverse_tuple(nested_tuple: tuple):
if isinstance(nested_tuple, tuple):
return tuple([_reverse_tuple(item) for item in nested_tuple][::-1])
return nested_tuple
def _get_first_lhs_nonzero_stride(stride_list, idx):
for i in reversed(range(idx)):
if stride_list[i] != 0:
return i
else:
return None
def _get_first_rhs_nonzero_stride(stride_list, idx):
for i in range(idx+1, len(stride_list)):
if stride_list[i] != 0:
return i
else:
return None
def reshape(layout, new_shape):
"""
General reshape of input layout.
It takes two steps:
1. split the dimensions of the old layout
2. merge the splitted dimensions according to the new shape
"""
#
# Step 1: Split the dimensions of the old layout
#
# 1.1 Flat old and new shape
old_flatten_shape = list(flatten(layout.shape))
new_flatten_shape = list(flatten(new_shape))
# 1.2 Infer the flatten splitted shape
splitted_flatten_shape = _infer_split(old_flatten_shape, new_flatten_shape)
# 1.3 Unflat the splitted shape based on the old shape
splited_shape = _infer_merge(splitted_flatten_shape, old_flatten_shape)
# 1.4 Infer the type of each split
# If the split type is in row-major (R), the dimension list is reversed because
# the cute::composition only support column-major split
split_type = [] # the type of each split (ColumnMajor or RowMajor)
permuted_splitted_shape = []
old_flatten_stride = list(flatten(layout.stride))
for idx, dim in enumerate(splited_shape):
if not isinstance(dim, list):
permuted_splitted_shape.append(dim)
split_type.append("C")
else:
lhs_stride = _get_first_lhs_nonzero_stride(old_flatten_stride, idx)
rhs_stride = _get_first_rhs_nonzero_stride(old_flatten_stride, idx)
# Special case for single tuple
# Use column-major by default
if lhs_stride is None and rhs_stride is None:
permuted_splitted_shape.append(dim)
split_type.append("C")
else:
if lhs_stride is not None and rhs_stride is not None:
# We consider shape[idx]:stride[idx]
# Case 1: stride[idx - 1] <= stride[idx] <= stride[idx + 1]: column major
if lhs_stride <= old_flatten_stride[idx] and old_flatten_stride[idx] <= rhs_stride:
permuted_splitted_shape.append(dim)
split_type.append("C")
# Case 2: stride[idx - 1] > stride[idx] > stride[idx + 1]: row major
elif lhs_stride > old_flatten_stride[idx] and old_flatten_stride[idx] > rhs_stride:
permuted_splitted_shape.append([d for d in reversed(dim)])
split_type.append("R")
# Case 3: stride[idx - 1] <= stride[idx] > stride[idx + 1]: concave
elif lhs_stride <= old_flatten_stride[idx] and old_flatten_stride[idx] > rhs_stride:
if lhs_stride >= rhs_stride:
permuted_splitted_shape.append(dim)
split_type.append("C")
else:
permuted_splitted_shape.append([d for d in reversed(dim)])
split_type.append("R")
# Case 4: stride[idx - 1] > stride[idx] <= stride[idx + 1]: concave
elif lhs_stride > old_flatten_stride[idx] and old_flatten_stride[idx] <= rhs_stride:
if lhs_stride >= rhs_stride:
permuted_splitted_shape.append(dim)
split_type.append("C")
else:
permuted_splitted_shape.append([d for d in reversed(dim)])
split_type.append("R")
else:
raise NotImplementedError()
elif lhs_stride is None:
# Case 1: dim's stride < dim+1's stride, expand in column major
if old_flatten_stride[idx] > rhs_stride:
permuted_splitted_shape.append([d for d in reversed(dim)])
split_type.append("R")
else:
permuted_splitted_shape.append(dim)
split_type.append("C")
else:
# Case 1: dim's stride > dim-1's stride
if old_flatten_stride[idx] < lhs_stride:
permuted_splitted_shape.append([d for d in reversed(dim)])
split_type.append("R")
else:
permuted_splitted_shape.append(dim)
split_type.append("C")
# 1.4 Generate the splitted layout
permuted_splitted_layout = composition(layout, Layout(_list_to_tuple(permuted_splitted_shape)))
# 1.5 Reverse the permutation in 1.4 before merge
splitted_shape = []
splitted_stride = []
for shape_dim, stride_dim, type in zip(
permuted_splitted_layout.shape,
permuted_splitted_layout.stride,
split_type):
if type == "C":
splitted_shape.append(shape_dim)
splitted_stride.append(stride_dim)
else:
splitted_shape.append(tuple([d for d in reversed(shape_dim)]))
splitted_stride.append(tuple([d for d in reversed(stride_dim)]))
splitted_layout = Layout(tuple(splitted_shape), tuple(splitted_stride))
#
# Step 2: Merge the splitted dimensions according to the new shape
#
# 2.1 Merge layout
merged_layout = composition(splitted_layout, Layout(new_shape))
# 2.2 Cleaning up
output_layout = composition(merged_layout, Layout(new_shape))
return output_layout
def permutation(layout, permutation):
"""
Permute the layout
"""
new_shape = tuple([layout.shape[idx] for idx in permutation])
new_stride = tuple([layout.stride[idx] for idx in permutation])
return Layout(new_shape, new_stride)
def _broadcast(layout, new_shape):
if len(layout) == 1 and isinstance(new_shape, int):
old_dim = layout.shape
old_stride = layout.stride
new_dim = new_shape
if old_dim == new_dim:
return Layout(old_dim, old_stride)
elif old_dim == 1:
return Layout(new_dim, 0)
else:
raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {new_dim}")
# Align the dimensions
old_shape = layout.shape
if isinstance(old_shape, int):
old_shape = (old_shape,)
sub_layouts = [layout,]
else:
sub_layouts = [sub_layout for sub_layout in layout]
rhs_broadcast_layouts = [Layout(1, 0)] * (len(new_shape) - len(old_shape))
# Get the broadcasted layout
broadcast_layouts = []
try:
layout = make_layout(*sub_layouts, *rhs_broadcast_layouts)
broadcast_layouts = []
for idx, sub_layout in enumerate(layout):
broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx]))
except NotImplementedError:
layout = make_layout(*rhs_broadcast_layouts, *sub_layouts)
for idx, sub_layout in enumerate(layout):
broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx]))
return make_layout(*broadcast_layouts)
def broadcast(layout, new_shape):
"""
Broadcast the new layout based on the input shape
The broadcasted shape equals to the new shape
The stride of broadcasted dimensions are 0
"""
return _broadcast(layout, new_shape)
def debroadcast(layout, dims):
"""
Squeeze the 0-stride
"""
for dim in dims:
if layout.stride[dim] != 0:
raise ValueError(f"Dim{dim} cannot be debroadcasted as it has stride {layout.stride[dim]}")
new_shape = tuple([s for idx, s in enumerate(layout.shape) if idx not in dims])
new_stride = tuple([s for idx, s in enumerate(layout.stride) if idx not in dims])
return Layout(new_shape, new_stride)
def canonicalization_(shapes, strides):
if isinstance(shapes, tuple):
c_shapes = []
c_strides = []
for shape, stride in zip(shapes, strides):
c_shape, c_stride = canonicalization_(shape, stride)
c_shapes.append(c_shape)
c_strides.append(c_stride)
return tuple(c_shapes), tuple(c_strides)
else:
if shapes == 1:
return 1, 0
else:
return shapes, strides
def canonicalization(layout):
"""
Canonicalize the input layout
1. set the stride of shape "1" to 0
"""
new_shape, new_stride = canonicalization_(layout.shape, layout.stride)
return Layout(new_shape, new_stride)

View File

@ -0,0 +1,336 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Layout manipulation nodes and implementations
The layout Nodes change the layout of intermediate nodes in epilogue visitor graph
"""
from copy import deepcopy
from cutlass_library import LayoutType
from pycute import product, flatten
import cutlass_cppgen
from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _tuple_to_list
from cutlass_cppgen.backend.evt.ir.node import NodeBase
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
class PermutationImpl:
"""
Detailed implementation and helper functions for permutation
"""
def __init__(self, node) -> None:
assert "indices" in node.kwargs.keys()
self.indices = list(node.kwargs["indices"])
self.inverse_indices = self.get_inverse_indices(self.indices)
def get_inverse_impl(self):
inverse_impl = deepcopy(self)
inverse_impl.indices = self.inverse_indices
inverse_impl.inverse_indices = self.indices
return inverse_impl
def update(self, shape):
num_dim = len(shape)
indices = self.indices
num_old_dim = len(indices)
# Add offset
for i, idx in enumerate(indices):
indices[i] = idx + num_dim - num_old_dim
# Add broadcast dims
for i in range(num_dim - num_old_dim):
indices = [i,] + indices
self.indices = indices
self.inverse_indices = self.get_inverse_indices(self.indices)
def get_inverse_indices(self, indices):
"""
Get the indices for inverse permutation
"""
num_dim = len(indices)
inverse_indices = [0] * num_dim
for i in range(num_dim):
inverse_indices[indices[i]] = i
return inverse_indices
def shape_propagation(self, input_node_meta):
input_shape = input_node_meta.tensor.shape
output_shape = tuple([input_shape[idx] for idx in self.indices])
return output_shape
def broadcast(self, shape, node_meta: NodeBase):
"""
Broadcast the inputs based on current shape
"""
self.update(shape)
inverse_shape = tuple([shape[idx] for idx in self.inverse_indices])
node_meta.tensor.broadcast(inverse_shape)
def apply_to_user(self, usr_meta: NodeBase):
"""
Propagate the permutation to the users of the current nodes
"""
usr_meta.tensor.permute(self.inverse_indices)
if hasattr(usr_meta, "store_tensor"):
if usr_meta.store_tensor is not None:
usr_meta.store_tensor.permute(self.inverse_indices)
def apply_to_input(self, input_meta: NodeBase):
"""
Propagate the permutation to inputs of the current nodes
"""
input_meta.tensor.permute(self.indices)
if hasattr(input_meta, "store_tensor"):
if input_meta.store_tensor is not None:
input_meta.store_tensor.permute(self.indices)
class ReshapeImpl:
"""
Detailed implementation and helper functions for reshape
"""
def __init__(self, node) -> None:
self.node = node
assert "new_shape" in node.kwargs.keys()
self.output_shape = _list_to_tuple(node.kwargs["new_shape"])
def get_inverse_impl(self):
inverse_impl = deepcopy(self)
inverse_impl.output_shape = self.input_shape
inverse_impl.input_shape = self.output_shape
return inverse_impl
def shape_propagation(self, input_node_meta):
self.input_shape = input_node_meta.tensor.shape
return _list_to_tuple(self.output_shape)
def broadcast(self, shape, node_meta: NodeBase):
"""
Broadcast the inputs based on current shape.
"""
# Step 1: infer split
flatten_split_shape = self.infer_split(flatten(self.input_shape), flatten(self.output_shape))
split_input_shape = self.infer_merge(flatten_split_shape, self.input_shape)
split_output_shape = self.infer_merge(flatten_split_shape, self.output_shape)
# broadcast shape -> split_output_shape -> flatten_split_shape
if len(shape) - len(split_output_shape) > 0:
for _ in range(len(shape) - len(split_output_shape)):
split_output_shape = [1,] + split_output_shape
flatten_split_shape = [1,] + flatten_split_shape
split_input_shape = [1,] + split_input_shape
broadcast_factor = []
for dim, old_dim in zip(shape, split_output_shape):
if not isinstance(dim, list):
dim = [dim,]
if not isinstance(old_dim, list):
old_dim = [old_dim,]
if product(tuple(dim)) == product(tuple(old_dim)):
broadcast_factor += [1] * len(old_dim)
elif product(tuple(old_dim)) == 1:
assert len(dim) == 1
broadcast_factor.append(dim[0])
else:
raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {dim}")
# flatten_split_shape -> split_input_shape
factor_idx = 0
broadcast_split_input_shape = []
for dim in split_input_shape:
if isinstance(dim, list):
new_dim = []
for d in dim:
new_dim.append(d * broadcast_factor[factor_idx])
factor_idx += 1
broadcast_split_input_shape.append(new_dim)
else:
broadcast_split_input_shape.append(dim * broadcast_factor[factor_idx])
factor_idx += 1
broadcast_split_input_shape = _list_to_tuple(broadcast_split_input_shape)
node_meta.tensor.reshape(_list_to_tuple(split_input_shape))
node_meta.tensor.broadcast(broadcast_split_input_shape)
# Last reshape op to clean up
broadcast_input_shape = tuple([product(dim) for dim in broadcast_split_input_shape])
node_meta.tensor.reshape(broadcast_input_shape)
# Update the input shape and output shape
self.input_shape = _list_to_tuple(node_meta.tensor.shape)
self.output_shape = _list_to_tuple(shape)
def apply_to_user(self, user_meta: NodeBase):
"""
Propagate the reshape to user nodes
"""
user_meta.tensor.reshape(tuple(self.input_shape))
if hasattr(user_meta, "store_tensor"):
if user_meta.store_tensor is not None:
user_meta.store_tensor.reshape(tuple(self.input_shape))
def apply_to_input(self, input_meta: NodeBase):
"""
Propagate the reshape to input nodes
"""
input_meta.tensor.reshape(tuple(self.output_shape))
if hasattr(input_meta, "store_tensor"):
if input_meta.store_tensor is not None:
input_meta.store_tensor.reshape(tuple(self.output_shape))
#
# Helper functions
#
def infer_split(self, input_shape, output_shape):
"""
Infer the flatten splitted shape that can be merged to both input_shape and output_shape
"""
input_shape = _tuple_to_list(input_shape)
output_shape = _tuple_to_list(output_shape)
if len(input_shape) == 0 and len(output_shape) == 0:
return []
if len(input_shape) == 0:
if product(tuple(output_shape)) != 1:
raise ValueError("Invalid reshape size")
else:
return output_shape
if len(output_shape) == 0:
if product(tuple(input_shape)) != 1:
raise ValueError("Invalid reshape size")
else:
return input_shape
# This is done recursively by only process the last dimension at each time
old_dim = input_shape[-1]
new_dim = output_shape[-1]
# Exact match
if old_dim == new_dim:
return self.infer_split(input_shape[:-1], output_shape[:-1]) + [new_dim,]
# Needs split
if old_dim > new_dim and old_dim % new_dim == 0:
residual = old_dim // new_dim
return self.infer_split(input_shape[:-1] + [residual,], output_shape[:-1]) + [new_dim,]
# Needs merge
if old_dim < new_dim and new_dim % old_dim == 0:
residual = new_dim // old_dim
return self.infer_split(input_shape[:-1], output_shape[:-1] + [residual,]) + [old_dim,]
raise NotImplementedError(f"Unsupported split: {input_shape} -> {output_shape}")
def infer_merge(self, flatten_shape, shape):
flatten_shape = _tuple_to_list(flatten_shape)
shape = _tuple_to_list(shape)
idx_flat = len(flatten_shape) - 1
merged_shape = []
for dim in reversed(shape):
# Exact match
if dim == flatten_shape[idx_flat]:
merged_shape.append(dim)
idx_flat -= 1
# need group
elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0:
residual = dim
group = []
while(residual > 1):
group.append(flatten_shape[idx_flat])
residual = residual // flatten_shape[idx_flat]
idx_flat -= 1
merged_shape.append(group[::-1])
else:
raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}")
return merged_shape[::-1]
class LayoutNode(NodeBase):
"""
Layout manipulation nodes
"""
fn_to_impl = {
"permute": PermutationImpl,
"reshape": ReshapeImpl
}
def __init__(self, name: str, fn, kwargs: dict) -> None:
super().__init__(name)
self.op = "layout"
self.fn = fn
self.kwargs = kwargs
self.underlying_impl = self.fn_to_impl[self.fn.__name__](self)
def get_inverse_node(self):
inverse_node = deepcopy(self)
inverse_node.underlying_impl = self.underlying_impl.get_inverse_impl()
return inverse_node
def shape_propagation(self, input_node_metas):
if self._tensor is not None:
return
assert len(input_node_metas) == 1, "Layout node can only have one input node"
output_shape = self.underlying_impl.shape_propagation(input_node_metas[0])
self._tensor = Tensor(
element=self.element_output,
shape=output_shape, layout_tag=LayoutType.RowMajor
)
return super().shape_propagation(input_node_metas)
def type_propagation(self, input_node_metas: 'list[NodeBase]'):
"""
The store nodes has element_output = element_input
"""
assert len(input_node_metas) == 1, "Layout node can only have one input node"
self.element_output = input_node_metas[0].element_output
def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
"""
Propagate the broadcast in the reversed topological order
"""
if self.tensor is None:
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
shape = self.tensor.shape
for child in input_node_metas:
self.underlying_impl.broadcast(shape, child)
def apply_to_user(self, usr_meta: NodeBase):
"""
Propagate the permutation to user nodes
"""
self.underlying_impl.apply_to_user(usr_meta)
def apply_to_input(self, input_meta: NodeBase):
"""
Propagate the permutation to input nodes
"""
self.underlying_impl.apply_to_input(input_meta)

View File

@ -0,0 +1,294 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Load nodes and implementations
"""
import ctypes
from cutlass_cppgen.backend.c_types import tuple_factory
from cutlass_cppgen.backend.epilogue import dtype2ctype, to_ctype_value
from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase
class LoadImplBase(ImplBase):
"""
Base class for load node implementations
"""
reserved_names = ["accum", "C"]
def __init__(self, node) -> None:
super().__init__(node)
self.element = node.element
self.element_output = node.element_output
self.stride = node.tensor.stride
class AccumulatorImpl(LoadImplBase):
"""
Accumulator node implementation
"""
@staticmethod
def match(node, problem_size: tuple):
return node.name == "accum" and node.tensor.shape == problem_size
class LoadSrcImpl(LoadImplBase):
"""
Load C implementation
"""
@property
def name_camel(self) -> str:
return "TensorC"
@property
def argument_type_c(self):
stride_mnl = self.get_stride_mnl()
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
class _Argument(ctypes.Structure):
_fields_ = [
("ptr_C", ctypes.c_void_p),
("stride_C", tuple_type)
]
def __init__(self, ptr) -> None:
self.ptr_C = ptr
self.stride_C = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
return node.name == "C" and node.tensor.shape == problem_size
class AuxLoadImpl(LoadImplBase):
"""
Load arbitrary tensor
"""
@property
def argument_type(self):
stride_mnl = self.get_stride_mnl()
name = self.name
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
element_type = self.element
class _Argument(ctypes.Structure):
_fields_ = [
("ptr_aux", ctypes.c_void_p),
("null_default", dtype2ctype[element_type]),
("dAux", tuple_type)
]
def __init__(self, kwargs) -> None:
ptr = kwargs[name]
self.ptr_aux = ptr
self.null_default = to_ctype_value(0, element_type)
self.dAux = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
if node.name in LoadImplBase.reserved_names:
return False
strideMN = node.tensor.stride[-2:]
if (strideMN[0] == 1 and strideMN[1] != 0 or
strideMN[0] != 0 and strideMN[1] == 1 ):
return True
else:
return False
class RowBroadcastImpl(LoadImplBase):
"""
Broadcast a row vector
"""
def __init__(self, node) -> None:
super().__init__(node)
self.stride_dtype = "int"
@property
def argument_type(self):
stride_mnl = self.get_stride_mnl()
name = self.name
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
element_type = self.element
class _Argument(ctypes.Structure):
_fields_ = [
("ptr_row", ctypes.c_void_p),
("null_default", dtype2ctype[element_type]),
("dRow", tuple_type)
]
def __init__(self, kwargs) -> None:
ptr = kwargs[name]
self.ptr_row = ptr
self.null_default = to_ctype_value(0, element_type)
self.dRow = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
if node.name in LoadImplBase.reserved_names:
return False
strideMN = node.tensor.stride[-2:]
if strideMN == (0, 1):
return True
else:
return False
class ColumnBroadcastImpl(LoadImplBase):
"""
Broadcast a column vector
"""
def __init__(self, node) -> None:
super().__init__(node)
self.stride_dtype = "int"
@property
def argument_type(self):
stride_mnl = self.get_stride_mnl()
name = self.name
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
element_type = self.element
class _Argument(ctypes.Structure):
_fields_ = [
("ptr_col", ctypes.c_void_p),
("null_default", dtype2ctype[element_type]),
("dCol", tuple_type)
]
def __init__(self, kwargs) -> None:
ptr = kwargs[name]
self.ptr_col = int(ptr)
self.null_default = to_ctype_value(0, element_type)
self.dCol = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
if node.name in LoadImplBase.reserved_names:
return False
strideMN = node.tensor.stride[-2:]
if strideMN == (1, 0):
return True
else:
return False
class ScalarBroadcastImpl(LoadImplBase):
"""
Broadcast a scalar
"""
def __init__(self, node) -> None:
super().__init__(node)
self.stride_dtype = "int"
@property
def argument_type(self):
stride_mnl = self.get_stride_mnl()
name = self.name
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
element_type = self.element
if self.tensor.is_constant:
value = self.tensor.value
class _Argument(ctypes.Structure):
_fields_ = [
("scalars", dtype2ctype[element_type]),
("scalar_ptrs", ctypes.c_void_p),
("dScalar", tuple_type)
]
def __init__(self, kwargs) -> None:
self.scalars = to_ctype_value(value, element_type)
self.scalar_ptrs = 0
self.dScalar = tuple_type(stride_mnl)
else:
class _Argument(ctypes.Structure):
_fields_ = [
("scalars", dtype2ctype[element_type]),
("scalar_ptrs", ctypes.c_void_p),
("dScalar", tuple_type)
]
def __init__(self, kwargs) -> None:
scalar_or_ptr = kwargs[name]
if isinstance(scalar_or_ptr, float):
self.scalars = to_ctype_value(scalar_or_ptr, element_type)
self.scalar_ptrs = 0
else:
self.scalar_ptrs = int(scalar_or_ptr)
self.dScalar = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
if node.name in LoadImplBase.reserved_names:
return False
strideMN = node.tensor.stride[-2:]
if strideMN == (0, 0):
return True
else:
return False
class LoadNode(NodeBase):
"""
Load Node
"""
cnt = 0
possible_impls = [
AccumulatorImpl, LoadSrcImpl, AuxLoadImpl,
RowBroadcastImpl, ColumnBroadcastImpl,
ScalarBroadcastImpl
]
def __init__(self, name: str) -> None:
if name is None:
name = f"load{LoadNode.cnt}"
LoadNode.cnt += 1
super().__init__(name)
self.op = "load"
def type_propagation(self, *args, **kwargs):
"""
Load node loads tensor under type `tensor.element` and returns an array of type `tensor.element`.
"""
if self.tensor is None:
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
self.element = self.tensor.element
self.element_output = self.tensor.element

View File

@ -0,0 +1,306 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Base & visitor classes of DAGIR Nodes
"""
import ctypes
from re import sub
from cutlass_library import LayoutType
from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _reverse_tuple
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
class TupleEmitter:
"""
Emit the cute tuple to C++ code
"""
def __init__(self, stride_dtype):
self.stride_dtype = stride_dtype
def emit(self, py_tuple):
if isinstance(py_tuple, int):
if py_tuple in [0, 1]:
return f"cute::Int<{py_tuple}>"
else:
return f"{self.stride_dtype}"
elif isinstance(py_tuple, tuple):
decl = "cute::Stride<"
for item in py_tuple:
decl += self.emit(item) + ", "
return decl[:-2] + ">"
else:
raise ValueError(f"TupleEmitter.emit only accepts tuple or int, got {type(py_tuple).__name__}")
class ImplBase:
"""
Base class for Node Implementation
"""
def __init__(self, node) -> None:
self.node = node
self.name = node.name
self.tensor = node.tensor
self._type_decl = None
self.tuple_emitter = TupleEmitter("int64_t")
@property
def stride_dtype(self):
return self.tuple_emitter.stride_dtype
@stride_dtype.setter
def stride_dtype(self, stride_dtype):
self.tuple_emitter.stride_dtype = stride_dtype
@staticmethod
def match(node, problem_size: tuple):
"""
Match function used in get_underlying_impl
"""
raise NotImplementedError(f"The `match` function is not defined.")
@property
def argument_type(self):
"""
Default class for Argument Type
"""
class _Argument(ctypes.Structure):
_fields_ = []
def __init__(self, *args, **kwargs) -> None:
pass
return _Argument
@property
def name_camel(self) -> str:
"""
Return the CamelCase name.
"""
return sub(r"(_|-)+", " ", self.name).title().replace(" ", "")
@property
def stride_mnl(self):
"""
Typename StrideMNL
"""
stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2]))))
return self.tuple_emitter.emit(stride)
def get_non_constant_stride(self, py_tuple):
if isinstance(py_tuple, int):
if py_tuple not in [0, 1]:
return py_tuple
else:
return None
non_constant_stride = []
for item in py_tuple:
item_out = self.get_non_constant_stride(item)
if item_out:
non_constant_stride.append(item_out)
return tuple(non_constant_stride)
def get_stride_mnl(self):
"""
Get the non-zero stride mnl. This is used in argument construction
"""
stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2]))))
return stride
def get_smem_size(self, *args, **kwargs):
"""
Get the shared memory size and alignment of current node
"""
return (0, 1)
class NoOpImpl(ImplBase):
"""
The NoOpImpl does nothing but forward its input to users
"""
def __init__(self, node) -> None:
super().__init__(node)
@staticmethod
def match(node, problem_size: tuple):
if node.op == "store":
# Store that is not output is a No OP
return not node.is_output
class NodeBase:
"""
Base class of DAG Node
"""
def __init__(self, name: str) -> None:
self.name = name
self.underlying_impl = None
self._tensor = None
# Whether the node is disabled for emit
self.disabled = False
@property
def name_camel(self) -> str:
"""
Return the CamelCase name.
"""
return self.underlying_impl.name_camel
@property
def tensor(self) -> Tensor:
"""
Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor)
"""
return self._tensor
@tensor.setter
def tensor(self, kwargs):
"""
Setting the tensor
"""
self._tensor = Tensor(**kwargs)
#
# Helper functions for type/shape propagation
#
def shape_propagation(self, input_node_metas):
"""
Infer shape from input nodes
General Broadcasting Rules from NumPy
When operating on two arrays, we compare their shapes element-wise.
It starts with the trailing (i.e. rightmost) dimension and works its
way left. Two dimensions are compatible when
1. they are equal
2. one of them is 1
"""
if self._tensor is not None:
return
shape = None
for src in input_node_metas:
src_shape = src.tensor.shape
if shape is None:
shape = src_shape
else:
len_difference = len(shape) - len(src_shape)
if len_difference > 0:
for _ in range(len_difference):
src_shape = [1, ] + list(src_shape)
elif len_difference < 0:
for _ in range(-len_difference):
shape = [1, ] + list(shape)
broadcasted_shape = []
# Infer broadcast shape
for shape_dim, src_dim in zip(reversed(shape), reversed(src_shape)):
if shape_dim == 1:
broadcasted_shape = [src_dim, ] + list(broadcasted_shape)
elif src_dim == 1:
broadcasted_shape = [shape_dim, ] + list(broadcasted_shape)
elif shape_dim == src_dim:
broadcasted_shape = [shape_dim, ] + list(broadcasted_shape)
else:
error_msg = "Dimension mismatch between "
for src_ in input_node_metas:
error_msg += f"{src_.name}{src_.tensor.shape}, "
error_msg = error_msg[:-2] + "."
raise RuntimeError(error_msg)
shape = tuple(broadcasted_shape)
self._tensor = Tensor(element=self.element_output, shape=shape, layout_tag=LayoutType.RowMajor)
def type_propagation(self, *args, **kwargs):
"""
Each node is associated with two data types: `element` and `element_output`.
The `element_output` is the type of return array of the node. The `element`
has specific meaning for different node types.
* Load Node: data type of tensor in gmem
* Compute Node: element compute
* Store Node: data type of tensor in gmem
This function must be overloaded in the derived classes
"""
raise NotImplementedError(f"Function `type_propagation` is not overloaded in {self.__class__.__name__}")
def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
"""
Propagate the broadcast in the reversed topological order.
For example:
C[l, m, n] = A[m, 1] + B[l, m, n]
After the broadcast propagation, it will be come
C[l, m, n] = A[l, m, n] + B[l, m, n]
and each tensor will have a proper stride accessing the underlying tensor
"""
if self.tensor is None:
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
for child in input_node_metas:
child.tensor.broadcast(self.tensor.shape)
def get_underlying_impl(self, problem_size: tuple):
"""
Get the underlying implementation of the current node.
"""
if self.tensor is None:
raise RuntimeError(f"The Layout of node {self.name} is unknown. Please call PassShapeTypePropagation first.")
for impl in self.possible_impls:
if impl.match(self, problem_size):
self.underlying_impl = impl(self)
break
if self.underlying_impl is None:
raise NotImplementedError(f"No matching op for node {self.name} with stride {self.tensor.stride}.")
#
# Visitor Nodes & Impls
#
class TopoVisitorImpl(ImplBase):
"""
Impl for topological visitor
"""
def __init__(self, node) -> None:
super().__init__(node.output_node)
self.name = node.name
self.element_output = node.output_node.element_output
class TopoVisitorNode(NodeBase):
def __init__(self, name: str, subgraph, output_node) -> None:
super().__init__(name)
self.subgraph = subgraph
self.output_node = output_node
self.op = "dag"
self.underlying_impl = TopoVisitorImpl(self)

View File

@ -0,0 +1,277 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Store node and implementations
"""
import ctypes
from cutlass_library import DataType
from cutlass_cppgen.backend.c_types import tuple_factory
from cutlass_cppgen.backend.epilogue import dtype2ctype, to_ctype_value
from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase, NoOpImpl
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
from cutlass_cppgen.backend.library import FloatRoundStyle, FunctionalOp
class StoreImplBase(ImplBase):
"""
Base class for store node implementation
"""
reserved_names = ["D"]
def __init__(self, node) -> None:
super().__init__(node)
self.element = node.element
self.element_output = node.element_output
self.stride = node.store_tensor.stride
class StoreDImpl(StoreImplBase):
"""
Store D implementation
"""
@property
def argument_type_d(self):
stride_mnl = self.get_stride_mnl()
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
class _Argument(ctypes.Structure):
_fields_ = [
("ptr_D", ctypes.c_void_p),
("stride_D", tuple_type)
]
def __init__(self, ptr: int) -> None:
self.ptr_D = ptr
self.stride_D = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
if node.name == "D" and node.store_tensor.shape == problem_size:
return True
return False
class AuxStoreImpl(StoreImplBase):
def __init__(self, node) -> None:
super().__init__(node)
self.round_style = FloatRoundStyle.ToNearest
@property
def argument_type(self):
stride_mnl = self.get_stride_mnl()
name = self.name
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
class _Argument(ctypes.Structure):
_fields_ = [
("ptr_aux", ctypes.c_void_p),
("dAux", tuple_type)
]
def __init__(self, kwargs) -> None:
ptr = kwargs[name]
self.ptr_aux = ptr
self.dAux = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
if not node.is_output:
return False
if node.name in StoreImplBase.reserved_names:
return False
strideMN = node.store_tensor.stride[-2:]
if (strideMN[0] == 1 and strideMN[1] != 0 or
strideMN[0] != 0 and strideMN[1] == 1 ):
return True
else:
return False
class ReductionImplBase(StoreImplBase):
def __init__(self, node) -> None:
super().__init__(node)
self.element = node.store_tensor.element
self.element_compute = node.element_compute
self.reg_reduce_fn = self.node.reg_reduce_fn
self.gmem_reduce_fn = self.node.gmem_reduce_fn
self.round_style = node.round_style
self.stride_dtype = "int"
def get_reduce_identity(self):
"""
Return the reduction identity of the current reduce_fn
"""
maxes = {
DataType.f32: (2 ** 31) - 1,
DataType.f16: (2 ** 15),
DataType.s32: (2 ** 31) - 1,
DataType.s8: (2 ** 7) - 1
}
mins = {
DataType.f32: -maxes[DataType.f32],
DataType.f16: -maxes[DataType.f16],
DataType.s32: -maxes[DataType.s32],
DataType.s8: -maxes[DataType.s8]
}
if self.reg_reduce_fn == FunctionalOp.Maximum:
if self.element_compute not in mins:
raise Exception(f"No min entry for data type {self.element_compute}")
return to_ctype_value(mins[self.element_compute], self.element_compute)
elif self.reg_reduce_fn == FunctionalOp.Multiplies:
return to_ctype_value(1., self.element_compute)
elif self.reg_reduce_fn == FunctionalOp.Minimum:
if self.element_compute not in maxes:
raise Exception(f"No max entry for data type {self.element_compute}")
return to_ctype_value(maxes[self.element_compute], self.element_compute)
else:
return to_ctype_value(0., self.element_compute)
@property
def argument_type(self):
self.get_reduce_identity()
stride_mnl = self.get_stride_mnl()
name = self.name
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
element_compute = self.element_compute
reduce_identity = self.get_reduce_identity()
class _Argument(ctypes.Structure):
_fields_ = [
("ptr", ctypes.c_void_p),
("reduce_identity", dtype2ctype[element_compute]),
("dMNL", tuple_type)
]
def __init__(self, kwargs) -> None:
ptr = kwargs[name]
self.ptr = ptr
self.reduce_identity = reduce_identity
self.dMNL = tuple_type(stride_mnl)
return _Argument
class ColumnReductionImpl(ReductionImplBase):
@staticmethod
def match(node, problem_size: tuple):
if not node.is_output:
return False
if node.name in StoreImplBase.reserved_names:
return False
strideMN = node.store_tensor.stride[-2:]
if strideMN == (1, 0):
return True
else:
return False
class RowReductionImpl(ReductionImplBase):
@staticmethod
def match(node, problem_size: tuple):
if not node.is_output:
return False
if node.name in StoreImplBase.reserved_names:
return False
strideMN = node.store_tensor.stride[-2:]
if strideMN == (0, 1):
return True
else:
return False
class ScalarReductionImpl(ReductionImplBase):
@staticmethod
def match(node, problem_size: tuple):
if not node.is_output:
return False
if node.name in StoreImplBase.reserved_names:
return False
strideMN = node.store_tensor.stride[-2:]
if strideMN == (0, 0):
return True
else:
return False
class StoreNode(NodeBase):
"""
Store node
"""
possible_impls = [
AuxStoreImpl, RowReductionImpl,
ColumnReductionImpl, ScalarReductionImpl,
NoOpImpl, StoreDImpl
]
def __init__(self, name: str) -> None:
super().__init__(name)
self.op = "store"
self.is_output = False
self._store_tensor = None
@property
def store_tensor(self) -> Tensor:
"""
Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor)
"""
return self._store_tensor
@store_tensor.setter
def store_tensor(self, kwargs):
"""
Setting the tensor
"""
self._store_tensor = Tensor(**kwargs)
def type_propagation(self, input_node_metas: 'list[NodeBase]'):
"""
The store nodes has element_output = element_input
"""
if self.is_output:
if self.store_tensor is None:
raise RuntimeError(f"The store tensor of node {self.name} is unknown.")
self.element = self.store_tensor.element
assert len(input_node_metas) == 1, "Store node can only have one input node"
self.element_output = input_node_metas[0].element_output
def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
super().broadcast_propagation(input_node_metas)
if self.is_output:
self._store_tensor.broadcast(self.tensor.shape)

View File

@ -0,0 +1,137 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
High-level class for tensor
"""
from cutlass_library import LayoutType
from cutlass_cppgen.backend.evt.ir.layout_algorithm import (
Layout,
broadcast,
canonicalization,
permutation,
reshape,
_reverse_tuple
)
from cutlass_cppgen.utils.datatypes import get_datatype_and_layout, get_tensor_shape, library_type
class Tensor:
"""
The tensor abstracts the data type
"""
def __init__(self, tensor=None, element=None, shape=None, stride=None,layout_tag=None, is_constant=False) -> None:
if element is not None and tensor is not None:
raise Exception(f"Must not specify both element and tensor")
elif shape is not None and tensor is not None:
raise Exception(f"Must not specify both shape and tensor")
elif layout_tag is not None and tensor is not None:
raise Exception(f"Must not specify both layout_tag and tensor")
elif (element is None or (layout_tag is None and stride is None) or shape is None) and (tensor is None) :
raise Exception(f"Must specify one of (element, shape, layout/stride) or (tensor)")
elif stride is not None and tensor is not None:
raise Exception(f"Must not specify both stride and tensor")
elif stride is not None and layout_tag is not None:
raise Exception(f"Must not specify layout_tag when stride is provided")
if isinstance(tensor, Tensor):
# Directly copy all the attributes
self.__dict__.update(vars(tensor))
else:
if tensor is None:
self.element = library_type(element)
else:
self.element, layout_tag = get_datatype_and_layout(tensor)
shape = get_tensor_shape(tensor)
if stride is not None:
self.layout = Layout(shape[::-1], stride[::-1])
else:
if layout_tag == LayoutType.RowMajor:
self.layout = Layout(shape[::-1])
elif layout_tag == LayoutType.ColumnMajor:
self.layout = permutation(Layout(shape), [idx for idx in reversed(range(len(shape)))])
self.layout = canonicalization(self.layout)
self.is_constant = is_constant
# Save the tensor value if it is constant
if is_constant and tensor is not None:
self.value = tensor
@property
def shape(self):
"""
Returns the RowMajor layout shape
"""
return _reverse_tuple(self.layout.shape)
@property
def stride(self):
"""
Returns the RowMajor layout stride
"""
return _reverse_tuple(self.layout.stride)
@property
def rank(self):
"""
Returns the rank of the tensor
"""
return len(self.shape)
#
# Layout Algorithms
#
def broadcast(self, shape):
"""
Broadcast self.layout to shape
"""
assert isinstance(shape, tuple)
self.layout = broadcast(self.layout, _reverse_tuple(shape))
def reshape(self, shape):
"""
Reshape self.layout to shape
"""
assert isinstance(shape, tuple)
reverse_shape = _reverse_tuple(shape)
self.layout = reshape(self.layout, reverse_shape)
def permute(self, indices):
"""
Permute self.layout according to indices
"""
length = len(indices)
indices = [length - idx - 1 for idx in indices]
self.layout = permutation(self.layout, indices[::-1])

View File

@ -0,0 +1,42 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_cppgen.backend.evt.passes.graph_drawer import EVTGraphDrawer
from cutlass_cppgen.backend.evt.passes.pass_argument_type import PassGetArgumentType
from cutlass_cppgen.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree
from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl
from cutlass_cppgen.backend.evt.passes.pass_fix_element_d import PassFixElementD
from cutlass_cppgen.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassManager
from cutlass_cppgen.backend.evt.passes.pass_preprocess_red import PassPreprocessRed
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
from cutlass_cppgen.backend.evt.passes.smem_size_calculator import GetSmemSize

View File

@ -0,0 +1,143 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from __future__ import annotations
import subprocess
from cutlass_library import DataTypeTag
from cutlass_cppgen.backend.evt.ir.dag_ir import DAGIR
_COLOR_MAP = {
"load": '"AliceBlue"',
"compute": "LemonChiffon1",
"accumulator": "LightGrey",
"store": "PowderBlue",
"layout": "lightseagreen",
"dag": "darkorange"
}
class EVTGraphDrawer:
"""
Visualize a EVT DAGIR with graphviz
"""
def __init__(
self,
graph: DAGIR,
name: str
):
self._name = name
self._dot_graphs = {}
self._dot_graphs[name] = self._to_dot(graph, name)
def _get_node_style(self, node):
template = {
"shape": "record",
"fillcolor": "#CAFFE3",
"style": '"filled,rounded"',
"fontcolor": "#000000",
}
if node.op in _COLOR_MAP:
template["fillcolor"] = _COLOR_MAP[node.op]
else:
raise NotImplementedError("unknown node op")
if node.disabled:
template["fontcolor"] = "grey"
template["fillcolor"] = "white"
return template
def _get_node_label(self, node):
label = "{" + f"name={node.name}|op={node.op}"
if node.op == "layout":
label += f"|fn={node.fn.__name__}"
for key in node.kwargs:
label += f"|{key}={node.kwargs[key]}"
if node.underlying_impl is not None:
label += f"|impl={type(node.underlying_impl).__name__}"
if node.op == "load":
label += f"|element_output={DataTypeTag[node.underlying_impl.element]}"
elif node.op == "compute":
label += f"|element_compute={DataTypeTag[node.underlying_impl.element_compute]}|element_output={DataTypeTag[node.underlying_impl.element_output]}"
elif node.op == "store":
label += f"|element_store={DataTypeTag[node.underlying_impl.element]}|element_output={DataTypeTag[node.underlying_impl.element_output]}"
elif node.op == "dag":
label += f"|element_output={DataTypeTag[node.underlying_impl.element_output]}"
if node.tensor is not None:
shape = node.tensor.shape
stride = node.tensor.stride
label += f"|shape={shape}|stride={stride}"
if hasattr(node, "store_tensor"):
if node.store_tensor is not None:
store_shape = node.store_tensor.shape
store_stride = node.store_tensor.stride
label += f"|store_shape={store_shape}|stride_stride={store_stride}"
label += "}"
return label
def _to_dot(
self,
graph: DAGIR,
name: str
):
import pydot
dot_graph = pydot.Dot(name, randir="TB")
for node in graph.nodes_meta:
style = self._get_node_style(node)
label = self._get_node_label(node)
dot_node = pydot.Node(
node.name, label=label, **style
)
dot_graph.add_node(dot_node)
if node.op == "dag":
dot_subgraph = self._to_dot(node.subgraph, name=node.name)
self._dot_graphs[node.name] = dot_subgraph
# Add edges
for src, dst in graph.edges:
weight = graph.get_edge_weight(src, dst)
dot_graph.add_edge(pydot.Edge(src, dst, label=weight))
return dot_graph
def get_dot_graph(self) -> pydot.Dot:
return [(key, self.get_dot_graph_by_name(key)) for key in self._dot_graphs.keys()]
def get_dot_graph_by_name(self, name) -> pydot.Dot:
return self._dot_graphs[name]
def get_main_dot_graph(self) -> pydot.Dot:
return self._dot_graphs[self._name]

View File

@ -0,0 +1,120 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Construct the epilogue visitor argument type
"""
from cutlass_cppgen.backend.c_types import visitor_factory
from cutlass_cppgen.backend.evt.ir import TopoVisitorNode
from cutlass_cppgen.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree
from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
from cutlass_cppgen.backend.evt.passes.util import cc_map
class PassGetArgumentType(EVTPassBase):
"""
Construct the epilogue visitor argument type
"""
dependencies = [
PassShapeTypePropagation, # The Layout of all nodes must be set
PassDAG2Tree, # The type of each node must be set
PassGetImpl # The DAG subgraphs must be set
]
def requires(self) -> None:
# Check "D" is in the node list
if cc_map[self.cc] in [90, 100] and (not self.dag_ir.has_node("D")):
raise SyntaxError(
"Sm90+ EVT requires the epilogue to have a returned tensor D, "
"but the variable 'D' is not found in the return values.")
def call(self):
nodes = self.dag_ir.nodes_topological_order()
self.argument_types = {}
for node in nodes:
meta = self.dag_ir.get_node_meta(node)
if not meta.disabled:
self.argument_types[node] = meta.underlying_impl.argument_type
if node == "D" and cc_map[self.cc] in [90, 100]:
continue
if isinstance(meta, TopoVisitorNode):
self.get_dag_argument_type(node)
else:
self.get_evt_argument_type(node)
self.cc_specific_method(self.set_argument_type)()
def get_evt_argument_type(self, node):
# Sort the input nodes by edge weight
input_types = [self.argument_types[child] for child in self.dag_ir.get_all_inputs(node)]
if len(input_types) > 0:
self.argument_types[node] = visitor_factory(
input_types + [self.argument_types[node],], self.dag_ir.get_all_inputs(node) + [node,])
def get_dag_argument_type(self, node):
meta = self.dag_ir.get_node_meta(node)
subgraph = meta.subgraph
subgraph_nodes = subgraph.nodes_topological_order()
# Visit the unvisited nodes in subgraph
for n in subgraph_nodes:
m = subgraph.get_node_meta(n)
if m.disabled:
continue
else:
self.argument_types[n] = m.underlying_impl.argument_type
input_types = [self.argument_types[child] for child in subgraph_nodes[:-1]]
if len(input_types) > 0:
self.argument_types[node] = visitor_factory(input_types, subgraph_nodes[:-1])
def set_argument_type(self):
pass
def sm90_set_argument_type(self):
self.dag_ir.epilogue_thread_type = self.argument_types[self.dag_ir.get_all_inputs("D")[0]]
# Get the tensorD argument type
self.dag_ir.arg_d_type = self.dag_ir.get_node_meta("D").underlying_impl.argument_type_d
# Get the tensorC argument type
if self.dag_ir.has_node("C"):
self.dag_ir.arg_c_type = self.dag_ir.get_node_meta("C").underlying_impl.argument_type_c
else:
self.dag_ir.arg_c_type = self.dag_ir.arg_d_type
def sm100_set_argument_type(self):
self.sm90_set_argument_type()
def sm80_set_argument_type(self):
nodes = self.dag_ir.nodes_topological_order()
self.dag_ir.epilogue_thread_type = self.argument_types[nodes[-1]]

View File

@ -0,0 +1,169 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Merge non-tree sub-graphs of the DAG IR into a single DAG. The fused DAG will be implemented
by the topological visitor, while the rest of the graph will be implemented with the tree visitor.
"""
from copy import deepcopy
from cutlass_cppgen.backend.evt.ir import DAGIR, TopoVisitorNode
from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
class PassDAG2Tree(EVTPassBase):
"""
Convert the DAG IR to Tree by fusing subgraphs
"""
dependencies = [
PassShapeTypePropagation,
PassGetImpl
]
def call(self):
# Step 1: find the nodes that have multiple parents
multi_parent_nodes = []
for node in self.dag_ir.nodes_topological_order():
if self.dag_ir.out_degree(node) > 1:
multi_parent_nodes.append(node)
# Step 2: find the lowest common ancestor (LCA) of all its parents
for node in multi_parent_nodes:
# A multi-parent node could be already fused by the previous node
if not self.dag_ir.has_node(node):
continue
# A node uncovered by the previous fusions can have out degree change
# Case 1: it has <= 1 edges to the previously fused subgraph, no degree change
# Case 2: it has more than one edges to the previously fused subgraph, degree drops
if self.dag_ir.out_degree(node) <= 1:
continue
# Otherwise, the node still
reachable_nodes = []
# Complexity: O(Dout*N)
for parent in self.dag_ir.get_users(node):
reachable_nodes.append(set(self.dag_ir.all_reachable_nodes(parent)))
# get the common reachable objects
common_items = set.intersection(*reachable_nodes)
node_to_fuse = set.union(*reachable_nodes).difference(common_items)
lca = None
# If common ancestor exists, find the lowest one
if len(common_items) > 0:
topo_order = self.dag_ir.nodes_topological_order()
topo_idx = -1
for item in common_items:
if lca is None:
lca = item
topo_idx = topo_order.index(item)
else:
if topo_idx > topo_order.index(item):
lca = item
topo_idx = topo_order.index(item)
else:
# there is no common ancestor for all the parents, we pack all the reachable
# nodes into a single DAG node as a fallback. The lca should be the input node of
# one of the output nodes with out_degree = 0
potential_output_nodes = []
for node in node_to_fuse:
if self.dag_ir.out_degree(node) == 0:
potential_output_nodes.append(node)
if len(potential_output_nodes) == 0:
raise RuntimeError(f"No output node with out degree = 0 found.")
output_node = None
if (self.dag_ir.cc >= 90):
# For SM90+, the lca should be the input node of D
if (not self.dag_ir.has_node("D")):
raise RuntimeError(f"D is not a node in the DAG IR.")
output_node = "D"
else:
output_node = potential_output_nodes[0]
if (output_node is None):
raise RuntimeError(f"No output node found.")
lca = self.dag_ir.get_all_inputs(output_node)[0]
node_to_fuse.remove(output_node)
# The lca is the output node of the DAG node
# Get the nodes to be fused
node_to_fuse.add(lca)
# Get all the input nodes
all_input_nodes = []
all_output_nodes = []
for node in node_to_fuse:
all_input_nodes.append(set(self.dag_ir.get_all_inputs(node)))
all_output_nodes.append(set(self.dag_ir.get_users(node)))
all_input_nodes = set.union(*all_input_nodes)
all_output_nodes = set.union(*all_output_nodes)
new_subgraph_nodes = set.union(node_to_fuse, all_input_nodes, all_output_nodes)
# Create the subgraph
subgraph_ = self.dag_ir._graph.subgraph(new_subgraph_nodes)
subgraph = DAGIR(self.dag_ir.cc)
for node in subgraph_.nodes:
meta = deepcopy(self.dag_ir.get_node_meta(node))
if node not in node_to_fuse:
meta.disabled = True
subgraph.add_node(meta)
for edge in subgraph_.edges:
subgraph.add_edge(edge[0], edge[1], self.dag_ir.get_edge_weight(edge[0], edge[1]))
# Create the fused node
dag_node = TopoVisitorNode(
name=f"dag_{lca}", subgraph=subgraph,
output_node=self.dag_ir.get_node_meta(lca))
self.dag_ir.add_node(dag_node)
# Add input edges
for idx, node in enumerate(all_input_nodes):
self.dag_ir.add_edge(node, dag_node.name, weight=idx)
# Replace all uses with DAG node (only 1 output node)
self.dag_ir.replace_all_uses_with(lca, dag_node.name)
# Remove all fused nodes
node_to_fuse.remove(lca)
for node in node_to_fuse:
self.dag_ir.remove_node(node)
def ensures(self) -> None:
# Ensure that after the pass, the resulting DAG becomes a tree
for node in self.dag_ir.nodes:
out_degree = self.dag_ir.out_degree(node)
if out_degree > 1:
raise RuntimeError(f"PassDAG2Tree failed. Node {node} still have outdegree = {out_degree}")

View File

@ -0,0 +1,64 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Fix the element_output of producer of D.
In Sm90 epilogue visitor, the node writing D to gmem does not have internal
element converter, so the compute node producing D must have element_output = type(D).
"""
from cutlass_cppgen.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
class PassFixElementD(EVTPassBase):
"""
In Sm90 epilogue visitor, the node writing D to gmem does not have internal
element converter, so the compute node producing D must have
element_output = type(D)
"""
dependencies = [
PassLayoutManipulateElimination
]
def get_producer(self, node, element_D):
node_meta = self.dag_ir.get_node_meta(node)
if node_meta.op == "compute":
node_meta.element_output = element_D
elif node_meta.op == "store":
self.get_producer(self.dag_ir.get_all_inputs(node)[0], element_D)
def call(self):
if self.dag_ir.has_node("D"):
node_d_meta = self.dag_ir.get_node_meta("D")
element_D = node_d_meta.store_tensor.element
self.get_producer("D", element_D)

View File

@ -0,0 +1,90 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Infer the underlying implement of each node.
While the frontend only distinguish between Load/Store/Compute Node,
each of these nodes can have different underlying implementation based
on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadcast, etc.
This pass infers the underlying impl of each node
"""
import cutlass_cppgen.backend.evt.backend as evt_backend
from cutlass_cppgen.backend.evt.ir import DAGIR, LoadNode
from cutlass_cppgen.backend.evt.passes.pass_fix_element_d import PassFixElementD
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
from cutlass_cppgen.backend.evt.passes.pass_no_op_elimination import PassNoOpElimination
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
from cutlass_cppgen.backend.evt.passes.util import cc_map
class PassGetImpl(EVTPassBase):
"""
While the frontend only distinguish between Load/Store/Compute Node,
each of these nodes can have different underlying implementation based
on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadcast, etc.
This pass infers the underlying impl of each node
"""
dependencies = [
PassShapeTypePropagation, # The shape and type info are required for inference
PassFixElementD
]
def __init__(self, dag_ir: DAGIR) -> None:
super().__init__(dag_ir)
self.no_op_elimination = PassNoOpElimination(dag_ir)
def requires(self) -> None:
# Verify "accum" is in the arg list
if not self.dag_ir.has_node("accum"):
raise SyntaxError("Cannot find 'accum' in the argument list.")
def call(self):
# The loop structure of the epilogue is determined by the
# accumulator shape
accumulator: LoadNode = self.dag_ir.get_node_meta("accum")
problem_size = accumulator.tensor.shape
for node_meta in self.dag_ir.node_metas_topological_order():
node_meta.get_underlying_impl(problem_size)
def ensures(self) -> None:
# Some nodes will be lowered to NoOp, eliminate them
self.no_op_elimination()
# Lower to cc-specific impl
for node_meta in self.dag_ir.nodes_meta:
node_impl_ccs = getattr(evt_backend, f"sm{cc_map[self.cc]}_nodes")
node_meta.underlying_impl = getattr(
node_impl_ccs,
f"Sm{cc_map[self.cc]}" + node_meta.underlying_impl.__class__.__name__
)(node_meta)

View File

@ -0,0 +1,217 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Eliminate layout manipulation nodes
"""
from copy import deepcopy
from cutlass_cppgen.backend.evt.ir import DAGIR, LayoutNode
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
class PassLayoutManipulateElimination(EVTPassBase):
"""
Eliminate layout manipulation nodes
"""
dependencies = [PassShapeTypePropagation]
def __init__(self, dag_ir: DAGIR) -> None:
super().__init__(dag_ir)
self.copy_cnt = 0
def call(self):
self.layout_nodes_worklist = self.get_all_layout_nodes()
# Run while loop utill all layout nodes are eliminated
while(len(self.layout_nodes_worklist) > 0):
node = self.layout_nodes_worklist.pop(0)
# for node in layout_nodes:
# Step 1: get the propagation direction
direction = self.get_propagation_direction(node)
self.visited = []
getattr(self, f"propagate_to_{direction}")(self.dag_ir.get_node_meta(node), node)
# Eliminate the current node
input_node = self.dag_ir.get_all_inputs(node)[0]
self.dag_ir.replace_all_uses_with(node, input_node)
# layout_nodes = self.get_all_layout_nodes()
def get_all_layout_nodes(self):
layout_nodes = []
for node_meta in reversed(self.dag_ir.node_metas_topological_order()):
if isinstance(node_meta, LayoutNode):
layout_nodes.append(node_meta.name)
return layout_nodes
def get_propagation_direction(self, node: str):
"""
The logic is propagating all layout nodes away from the accumulator node.
"""
self.visited = []
self.get_influenced_users(node)
nodes_influenced_dir_users = self.visited
self.visited = []
self.get_influenced_inputs(node)
nodes_influenced_dir_inputs = self.visited
if "accum" in nodes_influenced_dir_users and "accum" not in nodes_influenced_dir_inputs:
return "inputs"
elif "accum" not in nodes_influenced_dir_users and "accum" in nodes_influenced_dir_inputs:
return "users"
else:
raise RuntimeError("Unsolved propagation direction")
# Get all influenced nodes if we propagate along the user direction
def get_influenced_users(self, node: str):
if node in self.visited:
return
self.visited.append(node)
users = self.dag_ir.get_users(node)
for user in users:
self.get_influenced_users(user)
user_inputs = []
for user in users:
user_inputs.append(set(self.dag_ir.get_all_inputs(user)))
if len(user_inputs) > 0:
user_inputs = set.union(*user_inputs)
user_inputs.remove(node)
for input in user_inputs:
self.get_influenced_inputs(input)
# Get all influenced nodes if we propagate along the input direction
def get_influenced_inputs(self, node: str):
if node in self.visited:
return
self.visited.append(node)
inputs = self.dag_ir.get_all_inputs(node)
for input in inputs:
self.get_influenced_inputs(input)
input_users = []
for input in inputs:
input_users.append(set(self.dag_ir.get_users(input)))
if len(input_users) > 0:
input_users = set.union(*input_users)
input_users.remove(node)
for user in input_users:
self.get_influenced_users(user)
def add_copy_before(self, layout_node_meta: LayoutNode, target: str):
copied_node_meta = deepcopy(layout_node_meta)
copied_node = f"{copied_node_meta.name}_copy{self.copy_cnt}"
self.copy_cnt += 1
copied_node_meta.name = copied_node
self.dag_ir.add_node(copied_node_meta)
# Add edges
target_inputs = self.dag_ir.get_all_inputs(target)
for src in target_inputs:
self.dag_ir.remove_edge(src, target)
self.dag_ir.add_edge(src, copied_node)
self.dag_ir.add_edge(copied_node, target)
self.layout_nodes_worklist.append(copied_node)
def add_copy_after(self, layout_node_meta: LayoutNode, target: str):
copied_node_meta = deepcopy(layout_node_meta)
copied_node = f"{copied_node_meta.name}_copy{self.copy_cnt}"
self.copy_cnt += 1
copied_node_meta.name = copied_node
self.dag_ir.add_node(copied_node_meta)
# Add edges
users = self.dag_ir.get_users(target)
for user in users:
self.dag_ir.remove_edge(target, user)
self.dag_ir.add_edge(copied_node, user)
self.dag_ir.add_edge(target, copied_node)
self.layout_nodes_worklist.append(copied_node)
# Propagate the layout `node` along the user direction
def propagate_to_users(self, layout_node_meta: LayoutNode, node: str):
"""
Propagate layout node to users
"""
if node in self.visited:
# Avoid applying twice
return
self.visited.append(node)
node_meta = self.dag_ir.get_node_meta(node)
if layout_node_meta.name != node:
if isinstance(node_meta, LayoutNode):
# Layout node is not transparent with layout node
self.add_copy_before(layout_node_meta, node)
return
else:
layout_node_meta.apply_to_user(node_meta)
users = self.dag_ir.get_users(node)
user_inputs = []
for user in users:
user_inputs.append(set(self.dag_ir.get_all_inputs(user)))
for user in users:
self.propagate_to_users(layout_node_meta, user)
if len(user_inputs) > 0:
user_inputs = set.union(*user_inputs)
user_inputs.remove(node)
for input in user_inputs:
self.propagate_to_inputs(layout_node_meta.get_inverse_node(), input)
# Propagate the layout `node` along the input direction
def propagate_to_inputs(self, layout_node_meta: LayoutNode, node: str):
"""
Propagate layout node to inputs
"""
if node in self.visited:
# Avoid applying twice
return
self.visited.append(node)
node_meta = self.dag_ir.get_node_meta(node)
if layout_node_meta.name != node:
if isinstance(node_meta, LayoutNode):
# Layout node is not transparent with layout node
self.add_copy_after(layout_node_meta, node)
return
else:
layout_node_meta.apply_to_input(node_meta)
inputs = self.dag_ir.get_all_inputs(node)
input_users = []
for input in inputs:
input_users.append(set(self.dag_ir.get_users(input)))
for input in inputs:
self.propagate_to_inputs(layout_node_meta, input)
if len(input_users) > 0:
input_users = set.union(*input_users)
input_users.remove(node)
for user in input_users:
self.propagate_to_users(layout_node_meta.get_inverse_node(), user)

View File

@ -0,0 +1,164 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Pass manager for DAG IR.
"""
from typing import Any
import networkx as nx
from cutlass_cppgen.backend.evt.ir import DAGIR
from cutlass_cppgen.backend.evt.passes.util import cc_map
class EVTPassBase:
"""
Base class for EVT Passes
"""
dependencies = []
def __init__(self, dag_ir: DAGIR) -> None:
self.dag_ir = dag_ir
self.cc = self.dag_ir.cc
def requires(self) -> None:
"""
This function will be called before the pass is run.
"""
pass
def call(self) -> None:
"""
The pass that is run through the self.dag_ir
"""
raise NotImplementedError(
f"__call__ is not overwritten in Pass {self.__class__.__name__}")
def ensures(self) -> None:
"""
This function will be called after the pass is run.
"""
pass
def __call__(self) -> Any:
self.requires()
self.call()
self.ensures()
def cc_specific_method(self, func):
"""
This enables defining function that behaves differently under different cc
The simplest example of using this function is the following
.. highlight:: python
.. code-block:: python
class ExamplePass(EVTPassBase):
def call(sekf):
# This automatically select the smXX_func based on current cc
self.cc_specific_method(self.func)()
# Interface func, can be empty
def func(self):
pass
# Sm90 specific func
def sm90_func(self):
// sm90 specific method
return
# Sm80 specific func
def sm80_func(self):
// sm80 specific method
return
"""
func_name = f"sm{cc_map[self.cc]}_{func.__name__}"
if hasattr(self, func_name):
return getattr(self, func_name)
else:
raise NotImplementedError(f"func {func.__name__} is not overwritten for Sm{self.cc}")
class EVTPassManager(nx.DiGraph):
"""
Topological-based Pass Manager.
Each registered pass has a list of dependencies. The pass manager organizes
the passes as a DAG and launch the compiler passes under topological order.
"""
def __init__(self, dag_ir: DAGIR, pass_list):
super().__init__()
self.dag_ir = dag_ir
for pass_cls in pass_list:
self.add_pass(pass_cls)
self.sorted_passes = self.schedule()
def get_callable(self, pass_name):
"""
Return the callable of the pass
"""
return self.nodes[pass_name]["callable"]
def add_pass(self, pass_cls):
"""
Add a pass to the pass manager
:param pass_cls: the class of pass
:type pass_cls: derived class of EVTPassBase
"""
name = pass_cls.__name__
pass_callable = pass_cls(self.dag_ir)
self.add_node(name, callable=pass_callable)
def schedule(self):
"""
Schedule the added passes under topological order
"""
# Add edges
for pass_name in self.nodes:
callable = self.get_callable(pass_name)
for dependency_cls in callable.dependencies:
self.add_edge(
dependency_cls.__name__,
type(callable).__name__)
# Topological sort
return list(nx.topological_sort(self))
def __call__(self) -> Any:
"""
Launch the registered passes
"""
for pass_name in self.sorted_passes:
callable = self.get_callable(pass_name)
callable()

View File

@ -0,0 +1,53 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
No op elimination node
"""
from typing import Any
from cutlass_cppgen.backend.evt.ir import NoOpImpl
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
class PassNoOpElimination(EVTPassBase):
"""
The dead node elimination pass removes nodes with NoOpImpl in DAG IR
"""
dependencies = []
def call(self) -> Any:
for node in self.dag_ir.nodes_topological_order():
node_meta = self.dag_ir.get_node_meta(node)
if isinstance(node_meta.underlying_impl, NoOpImpl):
self.dag_ir.replace_all_uses_with(node, self.dag_ir.get_all_inputs(node)[0])

View File

@ -0,0 +1,97 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Preprocess the reduction nodes.
The parser treats reduction as Compute(op=(reg_reduce_fn, gmem_reduce_fn)) - Store()
This pass fuses these into a single store node, and then replaces all uses of the
current node with the new store node.
"""
from cutlass_cppgen.backend.evt.ir import ComputeNode, StoreNode
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
class PassPreprocessRed(EVTPassBase):
"""
Preprocess red nodes
"""
def call(self):
# Step 1: find the compute nodes with op=red
red_compute_nodes = []
for node_meta in self.dag_ir.nodes_meta:
if isinstance(node_meta, ComputeNode):
if type(node_meta.fn) == tuple:
# To keep the frontend simple, the reduction nodes
# are parsed into compute nodes by default
# The simple heuristic to distinguish between compute
# and reduction node is that compute node is a single function,
# while the reduction node is a tuple of functions for
# in-register reduction and atomic global memory reduction
red_compute_nodes.append(node_meta.name)
# Step 2: for each compute, merge it with the succeeding store
for node in red_compute_nodes:
# Verify
users = self.dag_ir.get_users(node)
inputs = self.dag_ir.get_all_inputs(node)
# Has a single user
assert len(users) == 1
assert len(inputs) == 1
user = users[0]
input = inputs[0]
user_meta = self.dag_ir.get_node_meta(user)
# Must be a store node
assert isinstance(user_meta, StoreNode)
# With output degree == 0
assert self.dag_ir.out_degree(user) == 0
# Register the reduce op
node_meta = self.dag_ir.get_node_meta(node)
user_meta.reg_reduce_fn, user_meta.gmem_reduce_fn = node_meta.fn
user_meta.element_compute = node_meta.element_compute
user_meta.round_style = node_meta.round_style
# Replace all uses
self.dag_ir.remove_edge(input, node)
input_users = self.dag_ir.get_users(input)
for iu in input_users:
weight = self.dag_ir.get_edge_weight(input, iu)
self.dag_ir.add_edge(user, iu, weight)
self.dag_ir.remove_edge(input, iu)
self.dag_ir.add_edge(input, user)
self.dag_ir.remove_node(node)
# Register the reduction name
self.dag_ir.reduction_names.append(user)

View File

@ -0,0 +1,59 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Shape and type propagation pass
"""
from cutlass_cppgen.backend.evt.ir.node import NodeBase
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
from cutlass_cppgen.backend.evt.passes.pass_preprocess_red import PassPreprocessRed
class PassShapeTypePropagation(EVTPassBase):
"""
Propagate the shape and type of all nodes
"""
dependencies = [PassPreprocessRed]
def call(self):
# Propagate the node shape and type
for node in self.dag_ir.nodes_topological_order():
node_meta: NodeBase = self.dag_ir.get_node_meta(node)
input_node_metas = self.dag_ir.get_all_inputs_meta(node)
node_meta.type_propagation(input_node_metas)
node_meta.shape_propagation(input_node_metas)
for node in reversed(self.dag_ir.nodes_topological_order()):
node_meta: NodeBase = self.dag_ir.get_node_meta(node)
input_node_metas = self.dag_ir.get_all_inputs_meta(node)
node_meta.broadcast_propagation(input_node_metas)

View File

@ -0,0 +1,319 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Compute the shared memory size in bytes
"""
from math import gcd
import cutlass_library
from pycute import flatten, shape_div, product
import cutlass_cppgen
from cutlass_cppgen.backend.evt.ir import TopoVisitorNode, DAGIR
from cutlass_cppgen.backend.library import DataType, DataTypeSize
class GetSmemSize:
"""
Get the size in byte of shared memory used by the kernel
"""
def __init__(self, dag_ir: DAGIR) -> None:
self.dag_ir = dag_ir
self.cc = self.dag_ir.cc
#
# Sm90 epilogue specific
#
def sm90_epilogue_tile(self, tile_description):
# Get the epilogue tile size
schedule = tile_description.epilogue_schedule
if schedule == cutlass_library.EpilogueScheduleType.TmaWarpSpecialized:
element_d = self.dag_ir.get_node_meta("D").element
nperf = 64 if (DataTypeSize[element_d] == 8 and tile_description.threadblock_shape[1] % 64 == 0) else 32
epi_tile_m = min(64, tile_description.threadblock_shape[0])
epi_tile_n = gcd(min(nperf, tile_description.threadblock_shape[1]), tile_description.threadblock_shape[1])
epilogue_tile_mn = (epi_tile_m, epi_tile_n)
elif schedule == cutlass_library.EpilogueScheduleType.TmaWarpSpecializedCooperative:
epi_tile_m = min(128, tile_description.threadblock_shape[0])
epi_tile_n = gcd(min(32, tile_description.threadblock_shape[1]), tile_description.threadblock_shape[1])
epilogue_tile_mn = (epi_tile_m, epi_tile_n)
else:
raise NotImplementedError(f"Unsupported schedule: {schedule}")
# Get the pipeline stages
stages_d = 2
epi_tiles = product(shape_div(tuple(tile_description.threadblock_shape)[:2], epilogue_tile_mn))
if self.dag_ir.has_node("C"):
element_c = self.dag_ir.get_node_meta("C").element
else:
element_c = None
element_d = self.dag_ir.get_node_meta("D").element
if element_c == element_d:
reuse_smem_c = True
else:
reuse_smem_c = False
stages_c = max(epi_tiles, stages_d + 1) if reuse_smem_c else epi_tiles
# Record the epilogue tile
self.cta_tile_mnk = tuple(tile_description.threadblock_shape)
self.epilogue_tile_mn = epilogue_tile_mn
self.epi_tiles = epi_tiles
self.stages_c = stages_c
self.stages_d = stages_d
self.reuse_smem_c = reuse_smem_c
self.element_c = element_c
self.element_d = element_d
self.is_source_supported = element_c is not None
def sm90_or_sm100_epilogue_smem_size(self, tile_description):
# Get the Fusion Storage
nodes = self.dag_ir.nodes_topological_order()
self.smem_types = {}
for node in nodes:
meta = self.dag_ir.get_node_meta(node)
if not meta.disabled:
self.smem_types[node] = meta.underlying_impl.get_smem_size(
self.cta_tile_mnk, self.epilogue_tile_mn,
self.stages_c, self.stages_d, self.epi_tiles)
if node == "D":
continue
if isinstance(meta, TopoVisitorNode):
self.get_dag_smem_type(node)
else:
self.get_evt_smem_type(node)
thread_smem_size = self.smem_types[self.dag_ir.get_all_inputs("D")[0]][0]
# Get the Tensor Storage
tensors = []
if self.is_source_supported:
smem_C = DataTypeSize[self.element_c] * product(self.epilogue_tile_mn) * self.stages_c // 8
tensors.append((smem_C, 128))
else:
tensors.append((0, 1))
if self.reuse_smem_c:
tensors.append((0, 128))
else:
smem_D = DataTypeSize[self.element_d] * product(self.epilogue_tile_mn) * self.stages_d // 8
tensors.append((smem_D, 128))
tensors.append((thread_smem_size, 128))
tensor_smem_size = self.get_struct_size(tensors)
# Get pipeline storage size
# sizeof(uint64_t * stages_c * 2), alignment of uint64_t
# 2 is for FullBarrier and EmptyBarrier
pipeline_smem_size = (8 * self.stages_c * 2, 8)
# get SharedStorage size
smem_size = self.get_struct_size([tensor_smem_size, pipeline_smem_size])
return smem_size[0]
def sm90_epilogue_smem_size(self, tile_description):
"""
Compute the shared memory size of sm90 collective epilogue
"""
self.sm90_epilogue_tile(tile_description)
return self.sm90_or_sm100_epilogue_smem_size(tile_description)
#
# Sm100 epilogue specific
#
def sm100_epilogue_tile(self, tile_description):
cta_tile = (tile_description.blackwell_threadblock_shape[0], tile_description.blackwell_threadblock_shape[1])
mma_tile = cta_tile
if tile_description.is_2sm:
cta_tile = (cta_tile[0] // 2, cta_tile[1])
if tile_description.is_2sm and mma_tile[0] == 128:
tmem_warps = (2, 2)
else:
tmem_warps = (4, 1)
if self.dag_ir.has_node("C"):
element_c = self.dag_ir.get_node_meta("C").element
element_c_size = DataTypeSize[element_c]
else:
element_c = None
element_c_size = 0
element_d = self.dag_ir.get_node_meta("D").element
DisableSource = element_c is None or not self.dag_ir.has_node("C") or self.dag_ir.get_node_meta("C").element == DataType.void
CtaM = cta_tile[0]
CtaN = cta_tile[1]
WarpM = tmem_warps[0]
WarpN = tmem_warps[1]
MaxBits = max(element_c_size, DataTypeSize[element_d])
DpFull = 32
M = min(CtaM, DpFull * WarpM)
if DisableSource:
# Epilogues w/o residual load are less sensitive to smem allocation
# Target a fixed amount of compute per epilogue iteration
if MaxBits == 4:
# Make epilogue tile larger to reduce the epilogue iterations.
# 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same.
ComputeElts = 8192
Nperf = ComputeElts // M
else:
ComputeElts = 4096
Nperf = ComputeElts // M
else:
# Epilogues w/ residual load are more sensitive to smem allocation
# Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize
if MaxBits == 32:
Nperf = 16 if CtaM > 64 and CtaN <= 128 else 32
elif MaxBits == 16:
Nperf = 32 if CtaN <= 128 else 64
else:
Nperf = 64
def is_m_major(layout):
return flatten(layout.stride[0]) == 1
if DisableSource or is_m_major(self.dag_ir.get_node_meta("C").tensor.layout):
N_min_C = 8 * WarpN
elif element_c_size == 6:
N_min_C = 128 * WarpN
else:
N_min_C = (128 // element_c_size) * WarpN
if is_m_major(self.dag_ir.get_node_meta("D").tensor.layout):
N_min_D = 8 * WarpN
elif DataTypeSize[element_d] == 6:
N_min_D = 128 * WarpN
else:
N_min_D = (128 // DataTypeSize[element_d]) * WarpN
N = min(CtaN, max(Nperf, N_min_C, N_min_D))
tile_m = M
tile_n_size = N // WarpN * WarpN
epilogue_tile_mn = (tile_m, tile_n_size)
epi_tiles = product(shape_div(tuple(tile_description.threadblock_shape)[:2], epilogue_tile_mn))
stages_d = min(epi_tiles, 2)
reuse_smem_c = (element_c_size > 8)
if reuse_smem_c:
stages_c = max(min(epi_tiles, 4), stages_d + 1)
else:
stages_c = min(epi_tiles, 4)
# Record the epilogue tile
self.cta_tile_mnk = tuple(tile_description.threadblock_shape)
self.epilogue_tile_mn = epilogue_tile_mn
self.epi_tiles = epi_tiles
self.stages_c = stages_c
self.stages_d = stages_d
self.reuse_smem_c = reuse_smem_c
self.element_c = element_c
self.element_d = element_d
self.is_source_supported = not DisableSource
def sm100_epilogue_smem_size(self, tile_description):
"""
Compute the shared memory size of sm100 collective epilogue
"""
self.sm100_epilogue_tile(tile_description)
return self.sm90_or_sm100_epilogue_smem_size(tile_description)
def __call__(self, tile_description):
return getattr(self, f"sm{self.cc}_epilogue_smem_size")(tile_description)
#
# Helper functions
#
@staticmethod
def get_visitor_size(members: list, ebo: bool):
"""
Get the size of struct in bytes
"""
offset = 0
max_alignment = 1
if len(members) > 0:
# Get alignment
for _, alignment in members:
max_alignment = max(max_alignment, alignment)
for type_size, _ in members:
if type_size != 0:
offset = ((offset + max_alignment - 1) // max_alignment) * max_alignment
if type_size == 0 and not ebo:
offset += 1
else:
offset += type_size
offset = ((offset + max_alignment - 1) // max_alignment) * max_alignment
return (offset, max_alignment)
else:
# Struct size is at least 1
return (1, 1)
def get_struct_size(self, members: list):
"""
Get the size of struct in bytes
"""
return self.get_visitor_size(members, False)
def get_evt_smem_type(self, node):
# Sort the input nodes by edge weight
input_types = [self.smem_types[child] for child in self.dag_ir.get_all_inputs(node)]
input_types.append(self.smem_types[node])
if len(input_types) > 1:
ebo = len(input_types) > 4
self.smem_types[node] = self.get_visitor_size(input_types, ebo)
def get_dag_smem_type(self, node):
meta = self.dag_ir.get_node_meta(node)
subgraph = meta.subgraph
subgraph_nodes = subgraph.nodes_topological_order()
# Visit the unvisited nodes in subgraph
for n in subgraph_nodes:
m = subgraph.get_node_meta(n)
if m.disabled:
continue
else:
self.smem_types[n] = m.underlying_impl.get_smem_size(
self.cta_tile_mnk, self.epilogue_tile_mn,
self.stages_c, self.stages_d, self.epi_tiles)
input_types = [self.smem_types[child] for child in subgraph_nodes[:-1]]
if len(input_types) > 0:
ebo = len(input_types) > 4
self.smem_types[node] = self.get_visitor_size(input_types, ebo)

View File

@ -0,0 +1,46 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Utilities for passes
"""
# Map from the CC of the kernel to the EVT implementation that the CC targets
cc_map = {
80: 80,
86: 80,
89: 80,
90: 90,
100: 100,
101: 100,
103: 100,
}

View File

@ -0,0 +1,109 @@
#################################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from __future__ import annotations
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
import numpy as np
from cutlass_cppgen.backend.memory_manager import device_mem_alloc, todevice
from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor
class NumpyFrontend:
"""
Frontend node for numpy
"""
@staticmethod
def argument(np_tensor: "np.ndarray", is_output: "bool") -> cuda.CUdeviceptr:
"""Convert the input numpy tensor to CUDA device pointer
:param np_tensor: input numpy nd array
:param is_output: whether the tensor is output
:return: CUDA device pointer
"""
# copy the data to device
if is_output:
return device_mem_alloc(np_tensor.size * np_tensor.itemsize)
else:
return todevice(np_tensor)
class TorchFrontend:
"""
Frontend node for torch
"""
@staticmethod
def argument(torch_tensor: "torch.Tensor") -> cuda.CUdeviceptr:
"""Convert the input torch tensor to CUDA device pointer
:param torch_tensor: input torch tensor
:param is_output: whether the tensor is output
:return: CUDA device pointer
"""
# check the device of torch_tensor
if not torch_tensor.is_cuda:
torch_tensor = torch_tensor.to("cuda")
return cuda.CUdeviceptr(torch_tensor.data_ptr())
class CupyFrontend:
"""
Frontend node for cupy
"""
@staticmethod
def argument(cupy_ndarray: "cp.ndarray"):
return cuda.CUdeviceptr(int(cupy_ndarray.data.ptr))
class TensorFrontend:
"""
Universal Frontend for client-provide tensors
"""
@staticmethod
def argument(tensor, is_output=False):
if is_numpy_tensor(tensor):
return NumpyFrontend.argument(tensor, is_output)
elif is_torch_tensor(tensor):
return TorchFrontend.argument(tensor)
elif is_cupy_tensor(tensor):
return CupyFrontend.argument(tensor)
else:
raise NotImplementedError("Unknown Tensor Type")

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,509 @@
#################################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Common data types and string names/tags for them
"""
import enum
from cutlass_library import (
ComplexTransform,
DataType,
DataTypeSize,
EpilogueScheduleType,
KernelScheduleSuffixes,
KernelScheduleType,
MathOperation,
OpcodeClass,
TileSchedulerType
)
# The following block implements enum.auto() for Python 3.5 variants that don't include it such
# as the default 3.5.2 on Ubuntu 16.04.
#
# https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility
try:
from enum import auto as enum_auto
except ImportError:
__cutlass_library_auto_enum = 0
def enum_auto() -> int:
global __cutlass_library_auto_enum
i = __cutlass_library_auto_enum
__cutlass_library_auto_enum += 1
return i
class DataTypeSizeBytes:
"""
Static class to mimic the `DataTypeSize` dictionary, but with checks for whether the
data type key is less than a full byte or a non-integer number of bytes.
"""
@staticmethod
def __class_getitem__(datatype):
"""
Returns the number of bytes in size the data type is. Raises an exception if the data type
is either less than a full byte or a non-integer number of bytes in size.
:param datatype: data type to query
:return: number of bytes the data type occupies
:rtype: int
"""
bits = DataTypeSize[datatype]
if bits < 8:
raise Exception(
f"Data type {datatype} is less than one byte in size."
)
elif bits % 8 != 0:
raise Exception(
f"Data type datatype is not an integer number of bytes."
)
return bits // 8
class SchedulerMode(enum.Enum):
Device = enum_auto()
Host = enum_auto()
SchedulerModeTag = {
SchedulerMode.Device: "cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly",
SchedulerMode.Host: "cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute",
}
ShortSchedulerModeNames = {SchedulerMode.Device: "Device", SchedulerMode.Host: "Host"}
class FunctionalOp(enum.Enum):
AtomicAdd = enum_auto()
AtomicMaximum = enum_auto()
Divides = enum_auto()
Maximum = enum_auto()
Minimum = enum_auto()
Minus = enum_auto()
Multiplies = enum_auto()
MultiplyAdd = enum_auto()
Plus = enum_auto()
Exp = enum_auto()
FunctionalOpTag = {
FunctionalOp.AtomicAdd: "cutlass::atomic_add",
FunctionalOp.AtomicMaximum: "cutlass::atomic_maximum",
FunctionalOp.Divides: "cutlass::divides",
FunctionalOp.Maximum: "cutlass::maximum",
FunctionalOp.Minimum: "cutlass::minimum",
FunctionalOp.Minus: "cutlass::minus",
FunctionalOp.Multiplies: "cutlass::multiplies",
FunctionalOp.MultiplyAdd: "cutlass::multiply_add",
FunctionalOp.Plus: "cutlass::plus",
FunctionalOp.Exp: "cutlass::fast_exp_op",
}
class ActivationOp(enum.Enum):
DGelu = enum_auto()
Gelu = enum_auto()
GeluTaylor = enum_auto()
HardSwish = enum_auto()
Identity = enum_auto()
LeakyReLU = enum_auto()
ReLU = enum_auto()
Sigmoid = enum_auto()
SiLU = enum_auto()
Tanh = enum_auto()
ActivationOpTag = {
ActivationOp.DGelu: "cutlass::epilogue::thread::dGELU",
ActivationOp.Gelu: "cutlass::epilogue::thread::GELU",
ActivationOp.GeluTaylor: "cutlass::epilogue::thread::GELU_taylor",
ActivationOp.HardSwish: "cutlass::epilogue::thread::HardSwish",
ActivationOp.Identity: "cutlass::epilogue::thread::Identity",
ActivationOp.LeakyReLU: "cutlass::epilogue::thread::LeakyReLU",
ActivationOp.ReLU: "cutlass::epilogue::thread::ReLu",
ActivationOp.Sigmoid: "cutlass::epilogue::thread::Sigmoid",
ActivationOp.SiLU: "cutlass::epilogue::thread::SiLu",
ActivationOp.Tanh: "cutlass::epilogue::thread::Tanh",
}
def op_tag(op) -> str:
"""
Dispatches `op` to the appropriate *Tag dictionary depending on whether
`op` is an ActivationOp or FunctionalOp. This is useful for cases in which
either type can be used.
:param op: operation to emit a tag for
:type op: ActivationOp | FunctionalOp
:return: tag corresponding to op
:rtype: str
"""
if isinstance(op, ActivationOp):
return ActivationOpTag[op]
elif isinstance(op, FunctionalOp):
return FunctionalOpTag[op]
else:
raise Exception(f"Unexpected op type {op}. Must be one of ActivationOp or FunctionalOp.")
class FloatRoundStyle(enum.Enum):
ToNearest = enum_auto()
ToNearestSatfinite = enum_auto()
Indeterminate = enum_auto()
TowardZero = enum_auto()
TowardInfinity = enum_auto()
TowardNegInfinity = enum_auto()
HalfUlpTruncDntz = enum_auto()
HalfUlpTruncate = enum_auto()
FloatRoundStyleTag = {
FloatRoundStyle.ToNearest: "cutlass::FloatRoundStyle::round_to_nearest",
FloatRoundStyle.ToNearestSatfinite: "cutlass::FloatRoundStyle::round_to_nearest_satfinite",
FloatRoundStyle.Indeterminate: "cutlass::FloatRoundStyle::round_indeterminate",
FloatRoundStyle.TowardZero: "cutlass::FloatRoundStyle::round_toward_zero",
FloatRoundStyle.TowardInfinity: "cutlass::FloatRoundStyle::round_toward_infinity",
FloatRoundStyle.TowardNegInfinity: "cutlass::FloatRoundStyle::round_toward_neg_infinity",
FloatRoundStyle.HalfUlpTruncDntz: "cutlass::FloatRoundStyle::round_half_ulp_trunc_dntz",
FloatRoundStyle.HalfUlpTruncate: "cutlass::FloatRoundStyle::round_half_ulp_truncate",
}
class MathInstruction:
"""
Description of a the lowest-level matrix-multiply-accumulate operation to be used in a kernel
"""
def __init__(
self,
instruction_shape,
element_a,
element_b,
element_accumulator,
opcode_class=OpcodeClass.Simt,
math_operation=MathOperation.multiply_add,
):
"""
:param instruction_shape: size of the [M, N, K] dimensions of the instruction
:type instruction_shape: list or tuple
:param element_a: data type of operand A
:param element_b: data type of operand B
:param element_accumulator: data type used in accumulation
:param opcode_class: higher-level class of the instruction (e.g., SIMT or Tensor Core)
:type opcode_class: cutlass_library.library.OpcodeClass
:param math_operation: the type of low-level operation to be performed (e.g., multiply accumulate)
:type math_operation: MathOperation
"""
self.instruction_shape = instruction_shape
self.element_a = element_a
self.element_b = element_b
self.element_accumulator = element_accumulator
self.opcode_class = opcode_class
self.math_operation = math_operation
def to_blackwell_threadblock_shape(tile_description, cluster_shape, kernel_schedule):
blackwell_threadblock_shape = tile_description.threadblock_shape
is_2sm = False if kernel_schedule is None else ("2sm" in KernelScheduleSuffixes[kernel_schedule])
if cluster_shape[0] > 0:
blackwell_threadblock_shape = [
tile_description.threadblock_shape[0] // cluster_shape[0],
tile_description.threadblock_shape[1] // cluster_shape[1],
tile_description.threadblock_shape[2] // cluster_shape[2]
]
if is_2sm:
blackwell_threadblock_shape[0] *= 2
else:
blackwell_threadblock_shape = tile_description.math_instruction.instruction_shape
return blackwell_threadblock_shape, is_2sm
class TileDescription:
"""
Description of a tile of computation to be performed in the kernel, encompassing threadblock, cluster, and warp shapes,
stage count, and math instruction specification
"""
def __init__(
self,
threadblock_shape,
stages,
warp_count,
math_instruction,
cluster_shape=[1, 1, 1],
kernel_schedule: KernelScheduleType = None,
epilogue_schedule: EpilogueScheduleType = None,
tile_scheduler: TileSchedulerType = None
):
"""
:param threadblock_shape: shape of a threadblock tyle
:type threadblock_shape: list or tuple
:param stages: number of pipline stages in the operation. For SM90 kernels, this can be set to `None` and the maximum
number of stages that can be supported for an operation on a given architecture will be computed at a later time
:type stages: int or None
:param warp_count: number of warps in each [M, N, K] dimension of a threadblock tile
:type warp_count: list, tuple, or None
:param math_instruction: specification of the instruction type and shape to be performed and the types of its operands
:type math_instruction: MathInstruction
:param cluster_shape: number of threadblocks in the [X, Y, Z] dimensions of a threadblock cluster
:param kernel_schedule: type of kernel schedule to use (only available for SM90+)
:type kernel_schedule: cutlass_library.KernelScheduleType
:param epilogue_schedule: type of epilogue schedule to use (only available for SM90+)
:type epilogue_schedule: cutlass_library.EpilogueScheduleType
:param tile_scheduler: type of tile scheduler to use (only available for SM90+)
:type tile_scheduler: cutlass_library.TileSchedulerType
"""
if ((kernel_schedule is None and epilogue_schedule is not None) or
(kernel_schedule is not None and epilogue_schedule is None)):
raise Exception("Kernel and epilogue schedule must either both be Auto or neither be Auto.")
self.threadblock_shape = threadblock_shape
self.cluster_shape = cluster_shape
self.kernel_schedule = kernel_schedule
self.epilogue_schedule = epilogue_schedule
self.tile_scheduler = tile_scheduler
self.stages = stages
self.math_instruction = math_instruction
self.instruction_shape = math_instruction.instruction_shape
# Number of warps along x, y, z directions
self.warp_count = warp_count
self.blackwell_threadblock_shape, self.is_2sm = to_blackwell_threadblock_shape(self, self.cluster_shape, self.kernel_schedule)
def clone_and_update(self, td: dict):
attrs = {
"cluster_shape": None,
"threadblock_shape": None,
"warp_count": None,
"stages": None,
"instruction_shape": None,
"kernel_schedule": None,
"epilogue_schedule": None,
"tile_scheduler": None
}
for key in attrs.keys():
if key in td.keys():
attrs[key] = td[key]
else:
attrs[key] = getattr(self, key)
attrs["math_instruction"] = MathInstruction(
attrs["instruction_shape"],
self.math_instruction.element_a,
self.math_instruction.element_b,
self.math_instruction.element_accumulator,
self.math_instruction.opcode_class,
self.math_instruction.math_operation
)
# Remove the instruction shape
del attrs["instruction_shape"]
return TileDescription(**attrs)
@property
def num_threads(self):
"""
Returns the number of threads in the threadblock
:return: number of threads in the threadblock
:rtype: int or None (if warp count is None)
"""
if self.warp_count is not None:
threads = 32
for cnt in self.warp_count:
threads *= cnt
return threads
return None
def procedural_name(self):
"""
Returns a name identifying the tile description
:return: name identifying the tile description
:rtype: int
"""
emit_stages = 0 if self.stages is None else self.stages
name = "%dx%dx%d_%dx%d_%dx%d" % (
self.cluster_shape[0],
self.cluster_shape[1],
self.cluster_shape[2],
self.threadblock_shape[0],
self.threadblock_shape[1],
self.threadblock_shape[2],
emit_stages
)
return name
def procedural_name_2x(self):
"""
Returns a name identifying the tile description
:return: name identifying the tile description
:rtype: int
"""
return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages)
def __str__(self):
"""
Returns a string with containing each of the tile description's values
:return: contents of tile description
:rtype: str
"""
if self.kernel_schedule is not None:
kschedule = self.kernel_schedule
else:
kschedule = KernelScheduleType.ScheduleAuto
if self.epilogue_schedule is not None:
eschedule = self.epilogue_schedule
else:
eschedule = EpilogueScheduleType.ScheduleAuto
if self.tile_scheduler is not None:
tschedule = self.tile_scheduler.name
else:
tschedule = "None"
return f"""
{{
ClusterShape: {self.cluster_shape}
ThreadblockShape: {self.threadblock_shape}
WarpCount: {self.warp_count}
Stages: {self.stages if self.stages is not None else 'Auto'}
InstructionShape: {self.math_instruction.instruction_shape}
Kernel schedule: {kschedule.name}
Epilogue schedule: {kschedule.name}
TileScheduler: {tschedule}
}}"""
class TensorDescription:
def __init__(self, element, layout, alignment=1, complex_transform=ComplexTransform.none):
self.element = element
self.layout = layout
if element != DataType.void:
self.alignment = min(128 // DataTypeSize[self.element], alignment)
else:
self.alignment = alignment
self.complex_transform = complex_transform
def CalculateSmemUsagePerStage(operation):
"""
Returns the amount of shared memory in bytes consumed in a single stage of a kernel.
:param op: operation for which the maximum stages should be computed. If stages are
set via the `op.tile_description.stages` parameter, this setting is ignored
in the present calculation
:type op: cutlass_cppgen.backend.Operation
:return: number of bytes of shared memory consumed by a single stage
:rtype: int
"""
m, n, k = operation.tile_description.threadblock_shape
if operation.operation_kind == OperationKind.Gemm:
stage_barrier_bytes = 32
return (
(DataTypeSize[operation.A.element] * m * k // 8)
+ (DataTypeSize[operation.B.element] * k * n // 8)
+ stage_barrier_bytes
)
else:
raise Exception("Unsupported operation kind {}.".format(operation.operation_kind))
def CalculateSmemUsage(operation):
"""
Returns the amount of shared memory in bytes consumed by a kernel.
:param op: operation for which the maximum stages should be computed. If stages are
set via the `op.tile_description.stages` parameter, this setting is ignored
in the present calculation
:type op: cutlass_cppgen.backend.Operation
:return: int
"""
return operation.tile_description.stages * CalculateSmemUsagePerStage(operation)
class ApiVersion(enum.Enum):
"""
Differentiate between CUTLASS 2.x and 3.x API versions
"""
v2x = enum_auto()
v3x = enum_auto()
def api_version(arch, opclass, dtype):
"""
Returns whether the architecture, opcode class, and datatype in question require using CUTLASS 2.x
or 3.x for code emission.
:param arch: compute capability of device on which to run
:type arch: int
:param opclass: class of the operation being performed
:type opclass: cutlass_library.OpcodeClass
:param dtype: data type to be used in operation (assumes that ElementA and ElementB are the same)
:type dtype: cutlass_library.DataType
:return: API version to be used in code emission
:rtype: ApiVersion
"""
if (arch in [90, 100, 101, 103] and
opclass == OpcodeClass.TensorOp and
(dtype != DataType.f64)):
return ApiVersion.v3x
else:
return ApiVersion.v2x
class EmissionType(enum.Enum):
"""
Tags for whether to emit a kernel- or device-level operation
"""
Kernel = enum_auto()
Device = enum_auto()

View File

@ -0,0 +1,121 @@
#################################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import numpy as np
import cutlass_cppgen
from cutlass_cppgen.utils.datatypes import is_numpy_tensor
from cutlass_cppgen.utils.lazy_import import lazy_import
if cutlass_cppgen.use_rmm:
import rmm
else:
cudart = lazy_import("cuda.cudart")
class PoolMemoryManager:
def __init__(self, init_pool_size: int, max_pool_size: int) -> None:
self.pool = rmm.mr.PoolMemoryResource(
rmm.mr.CudaMemoryResource(),
initial_pool_size=init_pool_size,
maximum_pool_size=max_pool_size
)
self.mr = rmm.mr.TrackingResourceAdaptor(self.pool)
rmm.mr.set_current_device_resource(self.mr)
def pool_size(self):
return self.pool.pool_size()
class DevicePtrWrapper:
"""
Wrapper around a pointer to device memory to provide a uniform interface with the RMM DeviceBuffer
(at least in terms of the interface used by the CUTLASS Python interface)
"""
def __init__(self, dev_ptr):
self.dev_ptr = dev_ptr
@property
def ptr(self):
return self.dev_ptr
def _todevice(host_data):
"""
Helper for transferring host data to device memory
"""
if cutlass_cppgen.use_rmm:
return rmm.DeviceBuffer.to_device(host_data.tobytes())
else:
nbytes = len(host_data.tobytes())
dev_ptr_wrapper = device_mem_alloc(nbytes)
err, = cudart.cudaMemcpy(
dev_ptr_wrapper.ptr,
host_data.__array_interface__['data'][0],
nbytes,
cudart.cudaMemcpyKind.cudaMemcpyHostToDevice
)
if err != cudart.cudaError_t.cudaSuccess:
raise Exception(f"cudaMemcpy failed with error {err}")
return dev_ptr_wrapper
def todevice(host_data, dtype=np.float32):
"""
Pass the host_data to device memory
"""
if isinstance(host_data, list):
return _todevice(np.array(host_data, dtype=dtype))
elif is_numpy_tensor(host_data):
return _todevice(host_data)
def device_mem_alloc(size):
if cutlass_cppgen.use_rmm:
return rmm.DeviceBuffer(size=size)
else:
err, ptr = cudart.cudaMalloc(size)
if err != cudart.cudaError_t.cudaSuccess:
raise Exception(f"cudaMalloc failed with error {err}")
return DevicePtrWrapper(ptr)
def align_size(size, alignment=256):
return ((size + alignment - 1) // alignment) * alignment
def create_memory_pool(init_pool_size=0, max_pool_size=2 ** 34):
if cutlass_cppgen.use_rmm:
memory_pool = PoolMemoryManager(init_pool_size=init_pool_size, max_pool_size=max_pool_size)
return memory_pool
else:
return None

View File

@ -0,0 +1,140 @@
#################################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import ctypes
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
from cutlass_cppgen.backend.utils.device import device_cc
_supports_cluster_launch = None
def supports_cluster_launch():
from cuda import __version__
_version_splits = [int(x) for x in __version__.split("rc")[0].split(".post")[0].split(".")]
global _supports_cluster_launch
if _supports_cluster_launch is None:
major, minor = _version_splits[0], _version_splits[1]
_supports_cluster_launch = device_cc() in [90, 100, 101, 103] and (major > 11 or (major == 11 and minor >= 8))
return _supports_cluster_launch
class LaunchConfiguration:
def __init__(self, grid=[1, 1, 1], block=[1, 1, 1], smem=0):
self.grid = grid
self.block = block
self.shared_memory_capacity = smem
class ExecutableOperation:
def __init__(self, operation):
self.operation = operation
self.module = None
self.kernel = None
def name(self):
return self.operation.procedural_name()
def emit(self):
return ""
def can_implement(self, configuration, arguments):
raise NotImplementedError()
def get_host_workspace_size(self, arguments):
raise NotImplementedError()
def get_device_workspace_size(self, arguments):
raise NotImplementedError()
def plan(self, arguments):
raise NotImplementedError()
def initialize(self, host_workspace, device_workspace, launch_config, arguments, stream=None):
raise NotImplementedError()
def run_with_clusters(self, launch_config, kernel_params, stream=None):
if not stream:
stream = cuda.CUstream(0)
if hasattr(self.operation, "tile_description") and hasattr(self.operation.tile_description, "cluster_shape"):
attr = cuda.CUlaunchAttribute()
attr.value.clusterDim.x, attr.value.clusterDim.y, attr.value.clusterDim.z = self.operation.tile_description.cluster_shape
attr.id = cuda.CUstreamAttrID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
attrs = [attr]
# Allow for non-portable cluster sizes
err, = cuda.cuFuncSetAttribute(
self.kernel, cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1)
if err != cuda.CUresult.CUDA_SUCCESS:
return err
else:
attrs = []
config = cuda.CUlaunchConfig()
config.gridDimX, config.gridDimY, config.gridDimZ = launch_config.grid
config.blockDimX, config.blockDimY, config.blockDimZ = launch_config.block
config.blockDimZ = launch_config.block[2]
config.sharedMemBytes = launch_config.shared_memory_capacity
config.hStream = stream
config.attrs = attrs
config.numAttrs = len(attrs)
err, = cuda.cuLaunchKernelEx(
config, f=self.kernel, kernelParams=kernel_params, extra=0)
return err
def run_without_clusters(self, launch_config, kernel_params, stream=None):
if not stream:
stream = cuda.CUstream(0)
err, = cuda.cuLaunchKernel(
self.kernel,
launch_config.grid[0], launch_config.grid[1], launch_config.grid[2],
launch_config.block[0], launch_config.block[1], launch_config.block[2],
launch_config.shared_memory_capacity,
stream,
kernel_params,
0)
return err
def run(self, host_workspace, device_workspace, launch_config, stream=None):
if not stream:
stream = cuda.CUstream(0)
cArg = (ctypes.c_char * len(host_workspace)).from_buffer(host_workspace)
packed = (ctypes.c_void_p * 1)()
packed[0] = ctypes.addressof(cArg)
if supports_cluster_launch():
return self.run_with_clusters(launch_config, packed, stream)
else:
return self.run_without_clusters(launch_config, packed, stream)

View File

@ -0,0 +1,455 @@
################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
from __future__ import annotations
import ctypes
from typing import Union
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
cudart = lazy_import("cuda.cudart")
import numpy as np
from cutlass_library import (
DataTypeNames,
DataTypeSize,
DataTypeTag,
LayoutType,
SubstituteTemplate
)
import cutlass_cppgen
from cutlass_cppgen.backend.c_types import MatrixCoord_, TensorRef2D_, get_reduction_params
from cutlass_cppgen.backend.frontend import NumpyFrontend, TorchFrontend
from cutlass_cppgen.backend.library import TensorDescription
from cutlass_cppgen.backend.memory_manager import DevicePtrWrapper
from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration
from cutlass_cppgen.shape import MatrixCoord
from cutlass_cppgen.utils.datatypes import is_numpy_tensor, is_torch_tensor
class ReductionOperation:
pass
class ReductionArguments:
"""
Arguments of reduction
"""
def __init__(
self,
operation: ReductionOperation,
problem_size: "list[int]",
partitions: int,
workspace: cuda.CUdeviceptr,
destination: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]",
source: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]",
**kwargs,
) -> None:
# tensor_C can be interpreted as the bias with bias=True in keyword args
if "bias" in kwargs.keys():
self.bias = kwargs["bias"]
else:
# by default, tensor_C is not bias
self.bias = False
if "stream" in kwargs.keys():
self.stream = kwargs["stream"]
else:
self.stream = cuda.CUstream(0)
self.operation = operation
self.ptr_workspace = workspace
# number of split-k partitions
self.partitions = partitions
if is_numpy_tensor(destination):
self.host_D = destination
self.destination_buffer = NumpyFrontend.argument(destination, True)
self.source_buffer = NumpyFrontend.argument(source, False)
self.ptr_destination = cuda.CUdeviceptr(self.destination_buffer.ptr)
self.ptr_source = cuda.CUdeviceptr(self.source_buffer.ptr)
elif is_torch_tensor(destination):
self.ptr_destination = TorchFrontend.argument(destination)
self.ptr_source = TorchFrontend.argument(source)
elif isinstance(destination, cuda.CUdeviceptr):
self.ptr_destination = destination
self.ptr_source = source
else:
raise TypeError("unknown Type")
self.problem_size = MatrixCoord_(problem_size[0], problem_size[1])
self.partition_stride = (
problem_size[0] * problem_size[1] * DataTypeSize[operation.C.element] // 8
)
if "output_op" in kwargs.keys():
self.output_op = kwargs["output_op"]
else:
self.output_op = self.operation.epilogue_type(1.0, 0.0)
self.get_arguments()
@staticmethod
def get_tensor_ref(
extent: "tuple[int]",
device_ptr: cuda.CUdeviceptr,
layout: LayoutType,
):
if layout == LayoutType.RowMajor:
return TensorRef2D_(int(device_ptr), extent[1])
else:
raise ValueError(f"Unknown layout type {layout}")
def get_arguments(self):
ref_workspace = ReductionArguments.get_tensor_ref(
extent=[
self.problem_size.row,
self.problem_size.column,
],
device_ptr=self.ptr_workspace,
layout=LayoutType.RowMajor,
)
if self.bias:
ref_source = ReductionArguments.get_tensor_ref(
extent=[0, 0],
device_ptr=self.ptr_source,
layout=LayoutType.RowMajor,
)
else:
ref_source = ReductionArguments.get_tensor_ref(
extent=[
self.problem_size.row,
self.problem_size.column,
],
device_ptr=self.ptr_source,
layout=LayoutType.RowMajor,
)
ref_destination = ReductionArguments.get_tensor_ref(
extent=[
self.problem_size.row,
self.problem_size.column,
],
device_ptr=self.ptr_destination,
layout=LayoutType.RowMajor,
)
self.c_arguments = self.operation.argument_type(
self.problem_size,
self.partitions,
self.partition_stride,
ref_workspace,
ref_destination,
ref_source,
self.output_op,
)
params_ = self.operation.rt_module.get_args(ctypes.byref(self.c_arguments))
self.host_workspace = bytearray(params_.contents)
def sync(self):
(err,) = cudart.cudaDeviceSynchronize()
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"CUDA Error {str(err)}")
if hasattr(self, "host_D"):
(err,) = cuda.cuMemcpyDtoH(
self.host_D,
self.ptr_destination,
self.host_D.size * self.host_D.itemsize,
)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
self.free()
def free(self):
"""
Frees allocated device-side memory
"""
# Free any device memory allocated manually
if not cutlass_cppgen.use_rmm:
for attr in ["destination_buffer", "source_buffer"]:
if hasattr(self, attr):
buf = getattr(self, attr)
if isinstance(buf, DevicePtrWrapper):
err, = cudart.cudaFree(buf.ptr)
if err != cudart.cudaError_t.cudaSuccess:
raise RuntimeError(f"cudaFree failed with error {err}")
del buf
class ReductionRT(ExecutableOperation):
"""
ReductionRT manages the CUTLASS runtime components for reduction
"""
KernelTemplate = r"""
extern "C"
__global__ void
${operation_name}(${operation_name}${operation_suffix}::Params params) {
// Dynamic shared memory base pointer
extern __shared__ int SharedStorageBase[];
// Declare pointer to dynamic shared memory.
${operation_name}${operation_suffix}::SharedStorage *shared_storage =
reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase);
${operation_name}${operation_suffix} op;
op(params, *shared_storage);
}
"""
HostTemplate = r"""
extern "C" {
// Get the size of params in bytes
int ${operation_name}_get_param_size(){
return sizeof(${operation_name}${operation_suffix}::Params);
}
// Get the size of dynamic shared memory in bytes
int ${operation_name}_shared_memory_size() {
return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
}
// Get the params as byte array
char* ${operation_name}_get_params(${operation_name}${operation_suffix}::Params* params){
char *bytes = ((char*)(params));
char *output = new char[sizeof(${operation_name}${operation_suffix}::Params)];
for (unsigned int i = 0; i < sizeof(${operation_name}${operation_suffix}::Params); i ++)
output[i] = bytes[i];
return output;
}
}
"""
def __init__(self, operation: ReductionOperation):
super().__init__(operation)
self.operation: ReductionOperation = operation
self.emitter = EmitReductionInstance("_type")
self.elements_per_access = self.operation.count
(
self.argument_type,
self.epilogue_type,
) = get_reduction_params(operation.epilogue_functor)
self.argtype = [ctypes.POINTER(self.argument_type)]
def emit(self):
return self.emitter.emit(self.operation)
def plan(self, arguments: ReductionArguments):
block_shape = [
self.operation.shape.column // self.elements_per_access,
self.operation.shape.row,
1,
]
grid_shape = [
(arguments.problem_size.row + self.operation.shape.row - 1)
// self.operation.shape.row,
(arguments.problem_size.column + self.operation.shape.column - 1)
// self.operation.shape.column,
1,
]
return LaunchConfiguration(
grid_shape,
block_shape,
self.shared_memory_capacity,
)
def initialize(self):
(err,) = cuda.cuFuncSetAttribute(
self.kernel,
attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
value=self.shared_memory_capacity,
)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"CUDA Error: {err}")
class ReductionOperation:
"""
CUTLASS reduction Operation
"""
def __init__(
self,
shape: MatrixCoord,
C: TensorDescription,
element_accumulator,
element_workspace=None,
element_compute=None,
epilogue_functor=None,
count: int = 1,
partitions_per_stage: int = 4,
) -> None:
self.shape = shape
self.epilogue_functor = epilogue_functor
self.element_accumulator = element_accumulator
if element_workspace is None:
self.element_workspace = element_accumulator
else:
self.element_workspace = element_workspace
if element_compute is None:
self.element_compute = element_accumulator
else:
self.element_compute = element_compute
self.element_output = C.element
self.C: TensorDescription = C
# Reduce op processing size
self.count: int = count
# Number of partitions to reduce per stage
self.partitions_per_stage: int = partitions_per_stage
self.rt_module: ReductionRT = ReductionRT(self)
self.argument_type = self.rt_module.argument_type
self.epilogue_type = self.rt_module.epilogue_type
def extended_name(self):
extend_name = "${element_workspace}_${element_accumulator}_${element_compute}_${element_output}"
return SubstituteTemplate(
extend_name,
{
"element_workspace": DataTypeNames[self.element_workspace],
"element_accumulator": DataTypeNames[self.element_accumulator],
"element_compute": DataTypeNames[self.element_compute],
"element_output": DataTypeNames[self.element_output],
},
)
def configuration_name(self):
"""The full procedural name indicates architecture, extended name, tile size"""
configuration_name = "cutlass_reduce_split_k_${extended_name}_${threadblock}"
threadblock = "%dx%d" % (
self.shape.row,
self.shape.column,
)
return SubstituteTemplate(
configuration_name,
{
"extended_name": self.extended_name(),
"threadblock": threadblock,
},
)
def procedural_name(self):
"""The full procedural name indicates architeture, extended name, tile size"""
return self.configuration_name()
def run(self, arguments: ReductionArguments) -> cuda.CUresult:
"""
Configure and launch the cuda kernel with input arguments
"""
launch_config = self.rt_module.plan(arguments)
host_workspace = arguments.host_workspace
device_workspace = None
err = self.rt_module.run(
host_workspace,
device_workspace,
launch_config,
arguments.stream
)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"CUDA Error {str(err)}")
return err
class EmitReductionInstance:
def __init__(self, operation_suffix="") -> None:
self.operation_suffix = operation_suffix
self.includes = [
"cutlass/cutlass.h",
"cutlass/numeric_types.h",
"cutlass/arch/arch.h",
"cutlass/arch/mma.h",
"cutlass/layout/matrix.h",
"cutlass/gemm/device/gemm.h",
"cutlass/gemm/device/gemm_universal_adapter.h",
"cutlass/gemm/kernel/default_gemm_universal.h",
"cutlass/reduction/kernel/reduce_split_k.h",
"cutlass/reduction/thread/reduction_operators.h",
]
self.template = """
// Reduction kernel instance
using ${operation_name}_base =
typename cutlass::reduction::kernel::ReduceSplitK<
cutlass::MatrixShape<${shape_row}, ${shape_column}>,
${epilogue_functor},
cutlass::reduction::thread::ReduceAdd<
${element_accumulator},
${element_output},
${count}>,
${partition_per_stage}>;
struct ${operation_name}${operation_suffix}:
public ${operation_name}_base { };
"""
def emit(self, operation: ReductionOperation):
vector_length_bits = min(operation.C.alignment * DataTypeSize[operation.C.element], 128)
epilogue_vector_length = vector_length_bits // DataTypeSize[operation.C.element]
values = {
"operation_name": operation.configuration_name(),
"operation_suffix": self.operation_suffix,
"shape_row": str(operation.shape.row),
"shape_column": str(operation.shape.column),
"epilogue_functor": operation.epilogue_functor.emit(),
"element_output": DataTypeTag[operation.element_output],
"epilogue_vector_length": str(epilogue_vector_length),
"element_accumulator": DataTypeTag[operation.element_accumulator],
"element_compute": DataTypeTag[operation.element_compute],
"element_workspace": DataTypeTag[operation.element_workspace],
"count": str(operation.count),
"partition_per_stage": str(operation.partitions_per_stage),
}
return SubstituteTemplate(self.template, values)

View File

@ -0,0 +1,35 @@
################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
GemmOperation = "Union[GemmOperationUniversal, GemmOperationGrouped]"
Tensor = "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]"

View File

@ -0,0 +1,33 @@
################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
from cutlass_cppgen.backend.utils.device import check_cuda_errors, device_cc

View File

@ -0,0 +1,126 @@
#################################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Utility functions for interacting with the device
"""
from __future__ import annotations
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
cudart = lazy_import("cuda.cudart")
import cutlass_cppgen
from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor
def check_cuda_errors(result: list):
"""
Checks whether `result` contains a CUDA error raises the error as an exception, if so. Otherwise,
returns the result contained in the remaining fields of `result`.
:param result: the results of the `cudart` method, consisting of an error code and any method results
:type result: list
:return: non-error-code results from the `results` parameter
"""
# `result` is of the format : (cudaError_t, result...)
err = result[0]
if err.value:
raise RuntimeError("CUDA error: {}".format(cudart.cudaGetErrorName(err)))
if len(result) == 1:
return None
elif len(result) == 2:
return result[1]
else:
return result[1:]
def device_cc(device: int = -1) -> int:
"""
Returns the compute capability of the device with ID `device`.
:param device: ID of the device to query
:type device: int
:return: compute capability of the queried device (e.g., 80 for SM80)
:rtype: int
"""
if device == -1:
device = cutlass_cppgen.device_id()
deviceProp = check_cuda_errors(cudart.cudaGetDeviceProperties(device))
major = str(deviceProp.major)
minor = str(deviceProp.minor)
return int(major + minor)
def device_sm_count(device: int = -1):
if device == -1:
device = cutlass_cppgen.device_id()
err, device_sm_count = cuda.cuDeviceGetAttribute(
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device
)
if err != cuda.CUresult.CUDA_SUCCESS:
raise Exception(
"Failed to retireve SM count. "
f"cuDeviceGetAttribute() failed with error: {cuda.cuGetErrorString(err)[1]}"
)
return device_sm_count
def to_device_ptr(tensor) -> cuda.CUdeviceptr:
"""
Converts a tensor to a CUdeviceptr
:param tensor: tensor to convert
:type tensor: np.ndarray | torch.Tensor | cp.ndarray | int
:return: device pointer
:rtype: cuda.CUdeviceptr
"""
if is_numpy_tensor(tensor):
ptr = cuda.CUdeviceptr(tensor.__array_interface__["data"][0])
elif is_torch_tensor(tensor):
ptr = cuda.CUdeviceptr(tensor.data_ptr())
elif is_cupy_tensor(tensor):
ptr = cuda.CUdeviceptr(int(tensor.data.ptr))
elif isinstance(tensor, cuda.CUdeviceptr):
ptr = tensor
elif isinstance(tensor, int):
ptr = cuda.CUdeviceptr(tensor)
else:
raise NotImplementedError(tensor)
return ptr