Rename python/cutlass to python/cutlass_cppgen (#2652)
This commit is contained in:
committed by
Haicheng Wu
parent
4260d4aef9
commit
177a82e251
213
python/cutlass_cppgen/__init__.py
Normal file
213
python/cutlass_cppgen/__init__.py
Normal file
@ -0,0 +1,213 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import cutlass_library
|
||||
|
||||
|
||||
def _cuda_install_path_from_nvcc() -> str:
|
||||
import subprocess
|
||||
# Attempt to detect CUDA_INSTALL_PATH based on location of NVCC
|
||||
result = subprocess.run(['/usr/bin/which', 'nvcc'], capture_output=True)
|
||||
if result.returncode != 0:
|
||||
raise Exception(f'Unable to find nvcc via `which` utility.')
|
||||
|
||||
cuda_install_path = result.stdout.decode('utf-8').split('/bin/nvcc')[0]
|
||||
if not os.path.isdir(cuda_install_path):
|
||||
raise Exception(f'Environment variable "CUDA_INSTALL_PATH" is not defined, '
|
||||
f'and default path of {cuda_install_path} does not exist.')
|
||||
|
||||
return cuda_install_path
|
||||
|
||||
|
||||
CUTLASS_PATH = os.getenv("CUTLASS_PATH", cutlass_library.source_path)
|
||||
|
||||
# Alias CUTLASS_PATH as source_path
|
||||
source_path = CUTLASS_PATH
|
||||
|
||||
_NVCC_VERSION = None
|
||||
def nvcc_version():
|
||||
global _NVCC_VERSION
|
||||
if _NVCC_VERSION is None:
|
||||
import subprocess
|
||||
|
||||
# Attempt to get NVCC version
|
||||
result = subprocess.run(['nvcc', '--version'], capture_output=True)
|
||||
if result.returncode != 0:
|
||||
raise Exception('Unable to run `nvcc --version')
|
||||
_NVCC_VERSION = str(result.stdout).split(" release ")[-1].split(",")[0]
|
||||
return _NVCC_VERSION
|
||||
|
||||
_CUDA_INSTALL_PATH = None
|
||||
def cuda_install_path():
|
||||
"""
|
||||
Helper method for on-demand fetching of the CUDA installation path. This allows
|
||||
the import of CUTLASS to proceed even if NVCC is not available, preferring to
|
||||
raise this error only when an operation that needs NVCC is being performed.
|
||||
"""
|
||||
global _CUDA_INSTALL_PATH
|
||||
if _CUDA_INSTALL_PATH is None:
|
||||
_CUDA_INSTALL_PATH = os.getenv("CUDA_INSTALL_PATH", _cuda_install_path_from_nvcc())
|
||||
return _CUDA_INSTALL_PATH
|
||||
|
||||
CACHE_FILE = "compiled_cache.db"
|
||||
|
||||
from cutlass_library import (
|
||||
DataType,
|
||||
EpilogueScheduleType,
|
||||
KernelScheduleType,
|
||||
MathOperation,
|
||||
LayoutType,
|
||||
OpcodeClass,
|
||||
TileDescription,
|
||||
TileSchedulerType,
|
||||
)
|
||||
|
||||
this = sys.modules[__name__]
|
||||
this.logger = logging.getLogger(__name__)
|
||||
|
||||
# RMM is only supported for Python 3.9+
|
||||
if (sys.version_info.major == 3 and sys.version_info.minor > 8) or sys.version_info.major > 3:
|
||||
try:
|
||||
import rmm
|
||||
this.use_rmm = True
|
||||
except ImportError:
|
||||
this.use_rmm = False
|
||||
else:
|
||||
this.use_rmm = False
|
||||
|
||||
|
||||
def set_log_level(level: int):
|
||||
"""
|
||||
Sets the log level
|
||||
|
||||
:param log_level: severity of logging level to use. See https://docs.python.org/3/library/logging.html#logging-levels for options
|
||||
:type log_level: int
|
||||
"""
|
||||
this.logger.setLevel(level)
|
||||
|
||||
set_log_level(logging.ERROR)
|
||||
|
||||
from cutlass_cppgen.library_defaults import OptionRegistry
|
||||
from cutlass_cppgen.backend.utils.device import device_cc
|
||||
|
||||
this._option_registry = None
|
||||
def get_option_registry():
|
||||
"""
|
||||
Helper method for on-demand initialization of the options registry. This avoids building
|
||||
the registry when CUTLASS is imported.
|
||||
"""
|
||||
if this._option_registry is None:
|
||||
this.logger.info("Initializing option registry")
|
||||
this._option_registry = OptionRegistry(device_cc())
|
||||
return this._option_registry
|
||||
|
||||
this.__version__ = '4.2.1'
|
||||
|
||||
from cutlass_cppgen.backend import create_memory_pool
|
||||
from cutlass_cppgen.emit.pytorch import pytorch
|
||||
from cutlass_cppgen.op.gemm import Gemm
|
||||
from cutlass_cppgen.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad
|
||||
from cutlass_cppgen.op.gemm_grouped import GroupedGemm
|
||||
from cutlass_cppgen.op.op import OperationBase
|
||||
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
|
||||
|
||||
this.memory_pool = None
|
||||
def get_memory_pool():
|
||||
""""
|
||||
Helper method for on-demand memory pool. This avoids allocating the memory pool unnecessarily
|
||||
whe CUTLASS is imported.
|
||||
"""
|
||||
if this.use_rmm and this.memory_pool is None:
|
||||
this.memory_pool = create_memory_pool(init_pool_size=2 ** 30, max_pool_size=2 ** 32)
|
||||
return this.memory_pool
|
||||
|
||||
|
||||
base_cuda = lazy_import("cuda")
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
cudart = lazy_import("cuda.cudart")
|
||||
|
||||
this._device_id = None
|
||||
this._nvcc_version = None
|
||||
|
||||
def check_cuda_versions():
|
||||
# Strip any additional information from the CUDA version
|
||||
_cuda_version = base_cuda.__version__.split("rc")[0]
|
||||
# Check that Python CUDA version exceeds NVCC version
|
||||
this._nvcc_version = nvcc_version()
|
||||
_cuda_list = _cuda_version.split('.')
|
||||
_nvcc_list = this._nvcc_version.split('.')
|
||||
for val_cuda, val_nvcc in zip(_cuda_list, _nvcc_list):
|
||||
if int(val_cuda) < int(val_nvcc):
|
||||
raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {this._nvcc_version}")
|
||||
|
||||
if len(_nvcc_list) > len(_cuda_list):
|
||||
if len(_nvcc_list) != len(_cuda_list) + 1:
|
||||
raise Exception(f"Malformatted NVCC version of {this._nvcc_version}")
|
||||
if _nvcc_list[:-1] == _cuda_list and int(_nvcc_list[-1]) != 0:
|
||||
raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {this._nvcc_version}")
|
||||
|
||||
def initialize_cuda_context():
|
||||
check_cuda_versions()
|
||||
|
||||
if this._device_id is not None:
|
||||
return
|
||||
|
||||
if this.use_rmm:
|
||||
# This also covers initializing the CUDA context
|
||||
get_memory_pool()
|
||||
|
||||
device_id = os.getenv("CUTLASS_CUDA_DEVICE_ID")
|
||||
if device_id is None:
|
||||
if not this.use_rmm:
|
||||
# Manually call cuInit() and create context by making a runtime API call
|
||||
err, = cudart.cudaFree(0)
|
||||
if err != cudart.cudaError_t.cudaSuccess:
|
||||
raise RuntimeError(f"cudaFree failed with error {err}")
|
||||
|
||||
err, device_count = cuda.cuDeviceGetCount()
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f"cuDeviceGetCount failed with error {err}")
|
||||
if device_count <= 0:
|
||||
raise Exception("No CUDA devices found")
|
||||
device_id = 0
|
||||
|
||||
this._device_id = int(device_id)
|
||||
|
||||
|
||||
def device_id() -> int:
|
||||
initialize_cuda_context()
|
||||
return this._device_id
|
||||
48
python/cutlass_cppgen/backend/__init__.py
Normal file
48
python/cutlass_cppgen/backend/__init__.py
Normal file
@ -0,0 +1,48 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass_cppgen.backend.arguments import *
|
||||
from cutlass_cppgen.backend.c_types import *
|
||||
from cutlass_cppgen.backend.compiler import ArtifactManager
|
||||
from cutlass_cppgen.backend.conv2d_operation import *
|
||||
from cutlass_cppgen.backend.epilogue import *
|
||||
from cutlass_cppgen.backend.frontend import *
|
||||
from cutlass_cppgen.backend.gemm_operation import *
|
||||
from cutlass_cppgen.backend.library import *
|
||||
from cutlass_cppgen.backend.memory_manager import PoolMemoryManager, create_memory_pool
|
||||
from cutlass_cppgen.backend.operation import *
|
||||
from cutlass_cppgen.backend.reduction_operation import *
|
||||
from cutlass_cppgen.backend.type_hint import *
|
||||
from cutlass_cppgen.backend.utils import *
|
||||
from cutlass_cppgen.backend.utils.device import device_cc
|
||||
|
||||
compiler = ArtifactManager()
|
||||
136
python/cutlass_cppgen/backend/arguments.py
Normal file
136
python/cutlass_cppgen/backend/arguments.py
Normal file
@ -0,0 +1,136 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from math import prod
|
||||
from typing import Union
|
||||
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
cudart = lazy_import("cuda.cudart")
|
||||
import numpy as np
|
||||
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen.backend.frontend import CupyFrontend, NumpyFrontend, TorchFrontend
|
||||
from cutlass_cppgen.backend.memory_manager import DevicePtrWrapper
|
||||
from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor
|
||||
|
||||
|
||||
class ArgumentBase:
|
||||
"""
|
||||
Base class for operation arguments
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
A: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]",
|
||||
B: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]",
|
||||
C: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]",
|
||||
D: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
# tensor_C can be interpreted as the bias with bias=True in keyword args
|
||||
self.bias = kwargs.get("bias", False)
|
||||
|
||||
self.stream = kwargs.get("stream", cuda.CUstream(0))
|
||||
|
||||
# RMM buffers used to track tensor lifetime
|
||||
self.buffers = {}
|
||||
# Host tensor to copy the computed result back
|
||||
self.host_tensors = {}
|
||||
|
||||
self.ptr_A = self.tensor_to_ptr(A, "A")
|
||||
self.ptr_B = self.tensor_to_ptr(B, "B")
|
||||
self.ptr_C = self.tensor_to_ptr(C, "C")
|
||||
self.ptr_D = self.tensor_to_ptr(D, "D", is_output=True)
|
||||
if C is not None:
|
||||
if not isinstance(C, cuda.CUdeviceptr):
|
||||
self.tensor_c_numel = prod(C.shape)
|
||||
|
||||
def tensor_to_ptr(self, tensor, name, is_output=False):
|
||||
"""
|
||||
Convert and remember the input tensor to cuda.CUdeviceptr used by cuda python
|
||||
For numpy.ndarray, it also remembers the host buffer for synchronization
|
||||
"""
|
||||
if tensor is None:
|
||||
return cuda.CUdeviceptr(0)
|
||||
if is_numpy_tensor(tensor):
|
||||
if is_output:
|
||||
assert name
|
||||
self.buffers[name] = NumpyFrontend.argument(tensor, is_output)
|
||||
if is_output:
|
||||
self.host_tensors[name] = tensor
|
||||
return self.buffers[name].ptr
|
||||
elif is_torch_tensor(tensor):
|
||||
return TorchFrontend.argument(tensor)
|
||||
elif isinstance(tensor, cuda.CUdeviceptr):
|
||||
return tensor
|
||||
elif is_cupy_tensor(tensor):
|
||||
return CupyFrontend.argument(tensor)
|
||||
else:
|
||||
raise TypeError("Unsupported Frontend. Only support numpy and torch")
|
||||
|
||||
def sync(self, stream_sync=True):
|
||||
if stream_sync:
|
||||
(err,) = cudart.cudaDeviceSynchronize()
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
|
||||
for key in self.host_tensors.keys():
|
||||
host_tensor = self.host_tensors[key]
|
||||
(err,) = cuda.cuMemcpyDtoH(
|
||||
host_tensor,
|
||||
self.buffers[key].ptr,
|
||||
host_tensor.size * host_tensor.itemsize,
|
||||
)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
|
||||
self.free()
|
||||
|
||||
def free(self):
|
||||
"""
|
||||
Frees allocated device-side memory
|
||||
"""
|
||||
# Free any device memory allocated manually
|
||||
if not cutlass_cppgen.use_rmm:
|
||||
for name, buf in self.buffers.items():
|
||||
if isinstance(buf, DevicePtrWrapper):
|
||||
err, = cudart.cudaFree(buf.ptr)
|
||||
if err != cudart.cudaError_t.cudaSuccess:
|
||||
raise RuntimeError(f"cudaFree failed with error {err}")
|
||||
|
||||
if hasattr(self, "workspace_buffer") and isinstance(self.workspace_buffer, DevicePtrWrapper):
|
||||
err, = cudart.cudaFree(self.workspace_buffer.ptr)
|
||||
if err != cudart.cudaError_t.cudaSuccess:
|
||||
raise RuntimeError(f"cudaFree failed with error {err}")
|
||||
del self.workspace_buffer
|
||||
625
python/cutlass_cppgen/backend/c_types.py
Normal file
625
python/cutlass_cppgen/backend/c_types.py
Normal file
@ -0,0 +1,625 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
import ctypes
|
||||
|
||||
from cutlass_library import (
|
||||
DataType,
|
||||
KernelScheduleType,
|
||||
TileSchedulerType
|
||||
)
|
||||
from cutlass_cppgen.backend.library import DataTypeSizeBytes
|
||||
|
||||
|
||||
class GemmCoord_(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("m", ctypes.c_int),
|
||||
("n", ctypes.c_int),
|
||||
("k", ctypes.c_int)
|
||||
]
|
||||
|
||||
def __init__(self, m, n, k) -> None:
|
||||
self.m = m
|
||||
self.n = n
|
||||
self.k = k
|
||||
|
||||
|
||||
class GemmCoordBatched_(ctypes.Structure):
|
||||
"""
|
||||
Wrapper around a GemmCoord that also contains batch count. This is used for encoding
|
||||
batched GEMM inputs to CUTLASS 3 GEMMs.
|
||||
"""
|
||||
|
||||
_fields_ = [
|
||||
("m", ctypes.c_int),
|
||||
("n", ctypes.c_int),
|
||||
("k", ctypes.c_int),
|
||||
("batch_count", ctypes.c_int)
|
||||
]
|
||||
|
||||
def __init__(self, gemm_coord, batch_count) -> None:
|
||||
self.m = gemm_coord.m
|
||||
self.n = gemm_coord.n
|
||||
self.k = gemm_coord.k
|
||||
self.batch_count = batch_count
|
||||
|
||||
|
||||
class MatrixCoord_(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("row", ctypes.c_int),
|
||||
("column", ctypes.c_int)
|
||||
]
|
||||
|
||||
|
||||
class dim3_(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("x", ctypes.c_int),
|
||||
("y", ctypes.c_int),
|
||||
("z", ctypes.c_int)
|
||||
]
|
||||
|
||||
|
||||
class StrideBatched_(ctypes.Structure):
|
||||
"""
|
||||
CUTLASS 3.0 strides for operands contain one static dimension and two variable dimensions. The
|
||||
variable dimensions represent the stride along non-unit-stride dimension of the row/column major
|
||||
layout, and the batch stride. This structure encodes the two variable dimensions.
|
||||
"""
|
||||
_fields_ = [
|
||||
("major_stride", ctypes.c_int64),
|
||||
("batch_stride", ctypes.c_int64)
|
||||
]
|
||||
|
||||
|
||||
|
||||
class GenericMainloopArguments3x_(ctypes.Structure):
|
||||
"""
|
||||
Structure representing the superset of possible mainloop arguments.
|
||||
This structure should not be passed to kernels directly, but, rather,
|
||||
be used as an input to one of the more specific schedule arguments, which
|
||||
will each select those arguments relevant to the particular schedule.
|
||||
"""
|
||||
_fields_ = [
|
||||
("ptr_A", ctypes.c_void_p),
|
||||
("stride_A", StrideBatched_),
|
||||
("ptr_B", ctypes.c_void_p),
|
||||
("stride_B", StrideBatched_),
|
||||
("mma_promotion_interval", ctypes.c_int)
|
||||
]
|
||||
|
||||
|
||||
class _PersistentTileSchedulerArguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("max_swizzle_size", ctypes.c_int),
|
||||
("raster_order_option", ctypes.c_int),
|
||||
]
|
||||
|
||||
|
||||
class _PersistentTileSchedulerStreamKArguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("splits", ctypes.c_int),
|
||||
("max_swizzle_size", ctypes.c_int),
|
||||
("raster_order_option", ctypes.c_int),
|
||||
("reduction_mode", ctypes.c_int),
|
||||
("decomposition_mode", ctypes.c_int),
|
||||
]
|
||||
|
||||
|
||||
def get_tile_scheduler_arguments_3x(
|
||||
tile_scheduler: TileSchedulerType,
|
||||
splits: int = 1):
|
||||
max_swizzle_size = 1
|
||||
raster_order_option = 0 # Heuristic
|
||||
if tile_scheduler in [TileSchedulerType.Default, TileSchedulerType.Persistent]:
|
||||
return _PersistentTileSchedulerArguments(
|
||||
max_swizzle_size,
|
||||
raster_order_option,
|
||||
)
|
||||
elif tile_scheduler == TileSchedulerType.StreamK:
|
||||
reduction_mode = 0 # Deterministic
|
||||
decomposition_mode = 0 # Heuristic
|
||||
return _PersistentTileSchedulerStreamKArguments(
|
||||
splits,
|
||||
max_swizzle_size,
|
||||
raster_order_option,
|
||||
reduction_mode,
|
||||
decomposition_mode,
|
||||
)
|
||||
|
||||
|
||||
def get_mainloop_arguments_3x(
|
||||
kernel_schedule: KernelScheduleType,
|
||||
element_A,
|
||||
element_B,
|
||||
alignment_A: int,
|
||||
alignment_B: int) -> ctypes.Structure:
|
||||
"""
|
||||
Returns the ctypes structure to be used for the 3.x kernel's mainloop parameters.
|
||||
|
||||
:param kernel_schedule: type of kernel schedule to be used in the mainloop
|
||||
:type kernel_schedule: cutlass_library.KernelScheduleType
|
||||
:param element_A: data type of operand A
|
||||
:param element_B: data type of operand B
|
||||
:param alignment_A: alignment of operand A
|
||||
:type alignment_A: int
|
||||
:param alignment_B: alignment of operand B
|
||||
:type alignment_B: int
|
||||
|
||||
:returns: ctypes structure to be used for the 3.x kernel's mainloop parameters
|
||||
:rtype: ctypes.Structure
|
||||
"""
|
||||
class _MainloopArgumentsTma(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ptr_A", ctypes.c_void_p),
|
||||
("stride_A", StrideBatched_),
|
||||
("ptr_B", ctypes.c_void_p),
|
||||
("stride_B", StrideBatched_),
|
||||
("mma_promotion_interval", ctypes.c_int)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def from_generic_mainloop_args(args: GenericMainloopArguments3x_):
|
||||
return _MainloopArgumentsTma(
|
||||
args.ptr_A, args.stride_A, args.ptr_B, args.stride_B,
|
||||
args.mma_promotion_interval
|
||||
)
|
||||
|
||||
class _MainloopArgumentsMultistage(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ptr_A", ctypes.c_void_p),
|
||||
("stride_A", StrideBatched_),
|
||||
("ptr_B", ctypes.c_void_p),
|
||||
("stride_B", StrideBatched_),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def from_generic_mainloop_args(args: GenericMainloopArguments3x_):
|
||||
return _MainloopArgumentsMultistage(
|
||||
args.ptr_A, args.stride_A, args.ptr_B, args.stride_B,
|
||||
)
|
||||
|
||||
# Currently all 3.x kernels (CpAsync and Tma) have the same argument structure.
|
||||
# Should that become not the case, this is the place to return custom ctypes
|
||||
# structures based on selected kernel schedule.
|
||||
return _MainloopArgumentsTma
|
||||
|
||||
|
||||
def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor, scheduler_args, default_epilogue):
|
||||
if not default_epilogue and hasattr(epilogue_functor, "epilogue_type_evt"):
|
||||
_EpilogueOutputOpParams = epilogue_functor.epilogue_type_evt
|
||||
else:
|
||||
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
|
||||
|
||||
if hasattr(epilogue_functor, "visitor"):
|
||||
class _EpilogueArguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("epilogue", _EpilogueOutputOpParams),
|
||||
("arg_C", epilogue_functor.arg_c_type),
|
||||
("arg_D", epilogue_functor.arg_d_type)
|
||||
]
|
||||
|
||||
def __init__(self, output_op, ptr_c, stride_c, ptr_d, stride_d) -> None:
|
||||
self.epilogue = output_op
|
||||
self.arg_C = epilogue_functor.arg_c_type(ptr_c)
|
||||
self.arg_D = epilogue_functor.arg_d_type(ptr_d)
|
||||
else:
|
||||
class _EpilogueArguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("epilogue", _EpilogueOutputOpParams),
|
||||
("ptr_C", ctypes.c_void_p),
|
||||
("stride_C", StrideBatched_),
|
||||
("ptr_D", ctypes.c_void_p),
|
||||
("stride_D", StrideBatched_),
|
||||
]
|
||||
|
||||
class _HardwareInfo(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("device_id", ctypes.c_int),
|
||||
("sm_count", ctypes.c_int),
|
||||
("max_active_clusters", ctypes.c_int),
|
||||
("cluster_shape", dim3_),
|
||||
("cluster_shape_fallback", dim3_),
|
||||
]
|
||||
|
||||
class _GemmArguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("mode", ctypes.c_int),
|
||||
("problem_size", GemmCoordBatched_),
|
||||
("mainloop", mainloop_arguments),
|
||||
("epilogue", _EpilogueArguments),
|
||||
("hw_info", _HardwareInfo),
|
||||
("scheduler", type(scheduler_args)),
|
||||
]
|
||||
|
||||
return _GemmArguments, _EpilogueArguments, _EpilogueOutputOpParams, _HardwareInfo
|
||||
|
||||
|
||||
def get_gemm_arguments(epilogue_functor):
|
||||
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
|
||||
|
||||
class _GemmArguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
# Arguments from UniversalArgumentsBase
|
||||
("mode", ctypes.c_int),
|
||||
("problem_size", GemmCoord_),
|
||||
("batch_count", ctypes.c_int),
|
||||
("batch_stride_D", ctypes.c_longlong),
|
||||
# Remaining arguments
|
||||
("epilogue", _EpilogueOutputOpParams),
|
||||
("ptr_A", ctypes.c_void_p),
|
||||
("ptr_B", ctypes.c_void_p),
|
||||
("ptr_C", ctypes.c_void_p),
|
||||
("ptr_D", ctypes.c_void_p),
|
||||
("batch_stride_A", ctypes.c_longlong),
|
||||
("batch_stride_B", ctypes.c_longlong),
|
||||
("batch_stride_C", ctypes.c_longlong),
|
||||
("stride_a", ctypes.c_longlong),
|
||||
("stride_b", ctypes.c_longlong),
|
||||
("stride_c", ctypes.c_longlong),
|
||||
("stride_d", ctypes.c_longlong),
|
||||
("lda", ctypes.c_longlong),
|
||||
("ldb", ctypes.c_longlong),
|
||||
("ldc", ctypes.c_longlong),
|
||||
("ldd", ctypes.c_longlong),
|
||||
("ptr_gather_A_indices", ctypes.c_void_p),
|
||||
("ptr_gather_B_indices", ctypes.c_void_p),
|
||||
("ptr_scatter_D_indices", ctypes.c_void_p)
|
||||
]
|
||||
|
||||
return _GemmArguments, _EpilogueOutputOpParams
|
||||
|
||||
|
||||
def get_gemm_arguments_streamk(epilogue_functor):
|
||||
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
|
||||
|
||||
class _GemmArguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("mode", ctypes.c_int),
|
||||
("problem_size", GemmCoord_),
|
||||
("batch_count", ctypes.c_int),
|
||||
("epilogue", _EpilogueOutputOpParams),
|
||||
("ptr_A", ctypes.c_void_p),
|
||||
("ptr_B", ctypes.c_void_p),
|
||||
("ptr_C", ctypes.c_void_p),
|
||||
("ptr_D", ctypes.c_void_p),
|
||||
("batch_stride_A", ctypes.c_longlong),
|
||||
("batch_stride_B", ctypes.c_longlong),
|
||||
("batch_stride_C", ctypes.c_longlong),
|
||||
("batch_stride_D", ctypes.c_longlong),
|
||||
("stride_a", ctypes.c_longlong),
|
||||
("stride_b", ctypes.c_longlong),
|
||||
("stride_c", ctypes.c_longlong),
|
||||
("stride_d", ctypes.c_longlong),
|
||||
("lda", ctypes.c_longlong),
|
||||
("ldb", ctypes.c_longlong),
|
||||
("ldc", ctypes.c_longlong),
|
||||
("ldd", ctypes.c_longlong),
|
||||
("avail_sms", ctypes.c_int)
|
||||
]
|
||||
|
||||
return _GemmArguments, _EpilogueOutputOpParams
|
||||
|
||||
|
||||
###########################################################################################
|
||||
# GEMM Grouped
|
||||
###########################################################################################
|
||||
|
||||
|
||||
def get_gemm_grouped_arguments(epilogue_functor):
|
||||
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
|
||||
|
||||
class _GEMMGroupedArguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("problem_sizes", ctypes.c_void_p),
|
||||
("problem_count", ctypes.c_int),
|
||||
("threadblock_count", ctypes.c_int),
|
||||
("output_op", _EpilogueOutputOpParams),
|
||||
("ptr_A", ctypes.c_void_p),
|
||||
("ptr_B", ctypes.c_void_p),
|
||||
("ptr_C", ctypes.c_void_p),
|
||||
("ptr_D", ctypes.c_void_p),
|
||||
("lda", ctypes.c_void_p),
|
||||
("ldb", ctypes.c_void_p),
|
||||
("ldc", ctypes.c_void_p),
|
||||
("ldd", ctypes.c_void_p),
|
||||
("host_problem_sizes", ctypes.c_void_p)
|
||||
]
|
||||
|
||||
return _GEMMGroupedArguments, _EpilogueOutputOpParams
|
||||
|
||||
|
||||
############################################################################################
|
||||
# Convolution2D
|
||||
############################################################################################
|
||||
|
||||
|
||||
class Conv2DProblemSize_(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("N", ctypes.c_int),
|
||||
("H", ctypes.c_int),
|
||||
("W", ctypes.c_int),
|
||||
("C", ctypes.c_int),
|
||||
("P", ctypes.c_int),
|
||||
("Q", ctypes.c_int),
|
||||
("K", ctypes.c_int),
|
||||
("R", ctypes.c_int),
|
||||
("S", ctypes.c_int),
|
||||
("pad_h", ctypes.c_int),
|
||||
("pad_w", ctypes.c_int),
|
||||
("stride_h", ctypes.c_int),
|
||||
("stride_w", ctypes.c_int),
|
||||
("dilation_h", ctypes.c_int),
|
||||
("dilation_w", ctypes.c_int),
|
||||
("mode", ctypes.c_int), # kCrossCorrelation: 0, kConvolution: 1
|
||||
("split_k_slices", ctypes.c_int),
|
||||
("groups", ctypes.c_int)
|
||||
]
|
||||
|
||||
def __init__(self, problem_size) -> None:
|
||||
for field_name, _ in self._fields_:
|
||||
setattr(self, field_name, getattr(problem_size, field_name))
|
||||
|
||||
|
||||
class Layout4D(ctypes.Structure):
|
||||
_fields_ = [("stride", ctypes.c_int * 3)]
|
||||
|
||||
def __init__(self, tensor_ref):
|
||||
stride = tensor_ref.stride()
|
||||
setattr(self, "stride", (stride.at(0), stride.at(1), stride.at(2)))
|
||||
|
||||
|
||||
class TensorRef_(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ptr", ctypes.c_void_p),
|
||||
("layout", Layout4D)
|
||||
]
|
||||
|
||||
def __init__(self, tensor_ref):
|
||||
setattr(self, "ptr", tensor_ref.data())
|
||||
setattr(self, "layout", Layout4D(tensor_ref.layout()))
|
||||
|
||||
|
||||
class TensorRef2D_(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ptr", ctypes.c_void_p),
|
||||
("stride", ctypes.c_int)
|
||||
]
|
||||
|
||||
|
||||
def get_conv2d_arguments(epilogue_functor):
|
||||
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
|
||||
|
||||
class _Conv2dArguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("conv_kind", ctypes.c_int),
|
||||
("problem_size", Conv2DProblemSize_),
|
||||
("ptr_A", ctypes.c_void_p),
|
||||
("ptr_B", ctypes.c_void_p),
|
||||
("ptr_C", ctypes.c_void_p),
|
||||
("ptr_D", ctypes.c_void_p),
|
||||
("tensor_C_numel", ctypes.c_int),
|
||||
("output_op", _EpilogueOutputOpParams),
|
||||
("split_k_mode", ctypes.c_int)
|
||||
]
|
||||
|
||||
return _Conv2dArguments, _EpilogueOutputOpParams
|
||||
|
||||
|
||||
############################################################################################
|
||||
# Reduction
|
||||
############################################################################################
|
||||
|
||||
|
||||
def get_reduction_params(epilogue_functor):
|
||||
_EpilogueOutputParams = epilogue_functor.epilogue_type
|
||||
|
||||
class _ReductionParams(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("problem_size", MatrixCoord_),
|
||||
("partitions", ctypes.c_int),
|
||||
("partition_stride", ctypes.c_longlong),
|
||||
("workspace", TensorRef2D_),
|
||||
("destination", TensorRef2D_),
|
||||
("source", TensorRef2D_),
|
||||
("output_op", _EpilogueOutputParams),
|
||||
]
|
||||
|
||||
return _ReductionParams, _EpilogueOutputParams
|
||||
|
||||
|
||||
###########################################################################################
|
||||
# Epilogue Visitor Type Factory
|
||||
###########################################################################################
|
||||
|
||||
class Empty(ctypes.Structure):
|
||||
_fields_ = []
|
||||
|
||||
def __init__(self, *arg) -> None:
|
||||
pass
|
||||
|
||||
class EmptyByte(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("byte", ctypes.c_byte)
|
||||
]
|
||||
|
||||
def __init__(self, *arg) -> None:
|
||||
pass
|
||||
|
||||
class EBO:
|
||||
def __init__(self, index: int, type) -> None:
|
||||
self.index = index
|
||||
self.type = type
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if isinstance(other, EBO):
|
||||
return self.index == other.index and self.type == other.type
|
||||
return False
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.index, self.type))
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"<{self.index}, {self.type}>"
|
||||
|
||||
|
||||
def tuple_factory_(input_tuple, dtype, constants=[0,1]):
|
||||
"""
|
||||
The factory function generating cute::Tuple with input tuple
|
||||
:param input_tuple: the input tuple
|
||||
:type input_tuple: tuple
|
||||
:param dtype: the data type for non-constant values
|
||||
:type dtype: str, "int32_t", "int", "int64_t"
|
||||
:param constant: the values that will be treated as constants
|
||||
:type constant: list[int]
|
||||
|
||||
:return: ctype structure representing the cute::Tuple
|
||||
:return: the empty base classes of the tuple
|
||||
"""
|
||||
|
||||
# The empty base classes of the current tuple
|
||||
empty_bases = []
|
||||
# The first non empty base class
|
||||
first_non_empty_base = None
|
||||
# The ctype fields of the current tuple
|
||||
ctype_fields = []
|
||||
|
||||
for idx, entry in enumerate(input_tuple):
|
||||
# For nested tuples
|
||||
if isinstance(entry, tuple):
|
||||
sub_tuple_ctype, sub_empty_bases = tuple_factory_(entry, dtype, constants)
|
||||
if ctypes.sizeof(sub_tuple_ctype) == 0:
|
||||
# The empty tuple base class is also an empty EBO
|
||||
empty_bases.append(EBO(idx, entry))
|
||||
else:
|
||||
if first_non_empty_base is None:
|
||||
first_non_empty_base = sub_empty_bases
|
||||
ctype_fields.append((f"entry_{idx}", sub_tuple_ctype))
|
||||
else:
|
||||
if entry in constants:
|
||||
empty_bases.append(EBO(idx, entry))
|
||||
ctype_fields.append((f"entry_{idx}", Empty))
|
||||
else:
|
||||
ctype_fields.append((f"entry_{idx}", dtype))
|
||||
if first_non_empty_base is None:
|
||||
first_non_empty_base = []
|
||||
|
||||
# Create the ctype tuple
|
||||
class TupleType(ctypes.Structure):
|
||||
_fields_ = ctype_fields
|
||||
|
||||
def __init__(self, args) -> None:
|
||||
fields = self._fields_
|
||||
|
||||
assert len(fields) == len(args)
|
||||
for field, arg in zip(fields, args):
|
||||
name = field[0]
|
||||
field_type = field[1]
|
||||
setattr(self, name, field_type(arg))
|
||||
|
||||
return TupleType, empty_bases
|
||||
|
||||
def tuple_factory(input_tuple, dtype: str, constants=[0,1]):
|
||||
"""
|
||||
The factory function generating cute::Tuple with input tuple
|
||||
:param input_tuple: the input tuple
|
||||
:type input_tuple: tuple
|
||||
:param dtype: the data type for non-constant values
|
||||
:type dtype: str, "int32_t", "int", "int64_t"
|
||||
:param constant: the values that will be treated as constants
|
||||
:type constant: list[int]
|
||||
|
||||
:return: ctype structure representing the cute::Tuple
|
||||
:return: the empty base classes of the tuple
|
||||
"""
|
||||
# Step 1: convert the dtype
|
||||
if dtype == "int64_t":
|
||||
dtype = ctypes.c_longlong
|
||||
elif dtype in ["int", "int32_t"]:
|
||||
dtype = ctypes.c_int32
|
||||
else:
|
||||
raise NotImplementedError(f"Type {dtype} is not supported")
|
||||
|
||||
tuple_type, _ = tuple_factory_(input_tuple, dtype, constants)
|
||||
|
||||
if ctypes.sizeof(tuple_type) == 0:
|
||||
return EmptyByte
|
||||
return tuple_type
|
||||
|
||||
|
||||
def visitor_factory(node_types, node_names):
|
||||
"""
|
||||
Creates the argument type of epilogue visitor type
|
||||
|
||||
:param node_types: list of argument types under ctypes
|
||||
:param node_names: list of argument names under str
|
||||
|
||||
:return: tuple type in ctypes.Structure
|
||||
"""
|
||||
ctypes_field = []
|
||||
# Struct is used when number of nodes < 4
|
||||
# Because the Sm90VisitorImplBase has specification up to 4 nodes
|
||||
# in `include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp`
|
||||
if len(node_types) <= 4:
|
||||
for idx, node_type in enumerate(node_types):
|
||||
if ctypes.sizeof(node_type) == 0:
|
||||
# Special case for empty struct
|
||||
# 1 byte placeholder is used for correct alignment
|
||||
ctypes_field.append((node_names[idx], ctypes.c_byte))
|
||||
else:
|
||||
ctypes_field.append((node_names[idx], node_type))
|
||||
|
||||
class VisitorType(ctypes.Structure):
|
||||
_fields_ = ctypes_field
|
||||
|
||||
def __init__(self, kwargs) -> None:
|
||||
for field in self._fields_:
|
||||
fname, ftype = field
|
||||
if ftype != ctypes.c_byte:
|
||||
setattr(self, fname, ftype(kwargs))
|
||||
|
||||
# For cases with more than 4 nodes, tuple is used
|
||||
else:
|
||||
for idx, node_type in enumerate(node_types):
|
||||
ctypes_field.append((node_names[idx], node_type))
|
||||
|
||||
class VisitorType(ctypes.Structure):
|
||||
_fields_ = ctypes_field
|
||||
|
||||
def __init__(self, kwargs) -> None:
|
||||
for field in self._fields_:
|
||||
fname, ftype = field
|
||||
setattr(self, fname, ftype(kwargs))
|
||||
|
||||
return VisitorType
|
||||
462
python/cutlass_cppgen/backend/compiler.py
Normal file
462
python/cutlass_cppgen/backend/compiler.py
Normal file
@ -0,0 +1,462 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
import ctypes
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
cudart = lazy_import("cuda.cudart")
|
||||
nvrtc = lazy_import("cuda.nvrtc")
|
||||
from cutlass_library import SubstituteTemplate
|
||||
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen import CACHE_FILE, CUTLASS_PATH, cuda_install_path, logger
|
||||
from cutlass_cppgen.backend.gemm_operation import GemmOperationUniversal
|
||||
from cutlass_cppgen.backend.library import ApiVersion
|
||||
from cutlass_cppgen.backend.utils.device import device_cc
|
||||
|
||||
IncludeTemplate = r"""#include "${include}"
|
||||
"""
|
||||
|
||||
|
||||
def compile_with_nvcc(cmd, source, error_file):
|
||||
succeed = True
|
||||
try:
|
||||
subprocess.check_output(cmd, stderr=subprocess.STDOUT)
|
||||
except subprocess.CalledProcessError as e:
|
||||
error_message = e.output.decode()
|
||||
with open(error_file, "w") as error_out:
|
||||
error_log = "Compilation error for the following kernel: \n"
|
||||
error_log += source
|
||||
error_log += "\nError Message:\n"
|
||||
error_log += error_message
|
||||
error_out.write(error_log)
|
||||
succeed = False
|
||||
if not succeed:
|
||||
# Print the error log to stdout if log level is set to warning or higher
|
||||
# verbosity. Otherwise, simply point to the error log file.
|
||||
logger.warning(error_log)
|
||||
raise Exception(f"Invalid Kernel. See '{error_file}' for details.")
|
||||
|
||||
|
||||
class CompilationOptions:
|
||||
"""
|
||||
Compilation options.
|
||||
"""
|
||||
|
||||
def __init__(self, flags, arch, include_paths=[]):
|
||||
self.includes = []
|
||||
self.include_paths = include_paths
|
||||
self.flags = flags
|
||||
self.arch = arch
|
||||
|
||||
def get_str(self):
|
||||
opts = []
|
||||
for flag in self.flags:
|
||||
opts.append(flag)
|
||||
|
||||
for incl in self.include_paths:
|
||||
opts.append(f"--include-path={incl}")
|
||||
|
||||
arch_flag = f"-arch=sm_{self.arch}"
|
||||
if self.arch in [90, 100, 101, 103, 120, 121] and int(cutlass_cppgen.nvcc_version().split('.')[0]) >= 12:
|
||||
arch_flag += "a"
|
||||
opts.append(arch_flag)
|
||||
|
||||
return " ".join(opts)
|
||||
|
||||
def get(self):
|
||||
options = []
|
||||
|
||||
for flag in self.flags:
|
||||
options.append(bytes(str.encode(flag)))
|
||||
|
||||
for incl in self.include_paths:
|
||||
options.append(bytes(str.encode(f" --include-path={incl}")))
|
||||
|
||||
arch_flag = f" -arch=sm_{self.arch}"
|
||||
if self.arch in [90, 100, 101, 103, 120, 121]:
|
||||
arch_flag += "a"
|
||||
|
||||
options.append(bytes(str.encode(arch_flag)))
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def convertToBinaryData(filename):
|
||||
with open(filename, "rb") as file:
|
||||
blobData = file.read()
|
||||
return blobData
|
||||
|
||||
|
||||
def CDLLBin(host_binary):
|
||||
tempfile.tempdir = "./"
|
||||
temp_so = tempfile.NamedTemporaryFile(prefix="host_func", suffix=".so", delete=True)
|
||||
with open(temp_so.name, "wb") as file:
|
||||
file.write(host_binary)
|
||||
host_lib = ctypes.CDLL(temp_so.name)
|
||||
return host_lib
|
||||
|
||||
|
||||
class ArtifactManager:
|
||||
"""
|
||||
Artifact manager
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
connection = sqlite3.connect(CACHE_FILE)
|
||||
cursor = connection.cursor()
|
||||
# Create the table if it does not already exist
|
||||
sqlite_create_table_query = """
|
||||
CREATE TABLE IF NOT EXISTS compiled_operations(op_key TEXT NOT NULL UNIQUE,
|
||||
cubin BLOB NOT NULL,
|
||||
hostbin BLOB NOT NULL,
|
||||
op_name TEXT NOT NULL,
|
||||
op_attrs TEXT NOT NULL)
|
||||
"""
|
||||
cursor.execute(sqlite_create_table_query)
|
||||
connection.commit()
|
||||
cursor.close()
|
||||
|
||||
self._nvrtc_compile_options = ["-std=c++17", "-default-device"]
|
||||
self._nvcc_compile_options = [
|
||||
"-std=c++17",
|
||||
"--expt-relaxed-constexpr",
|
||||
"-Xcudafe --diag_suppress=esa_on_defaulted_function_ignored",
|
||||
]
|
||||
self.nvcc()
|
||||
self.compiled_cache_device = {}
|
||||
self.compiled_cache_host = {}
|
||||
|
||||
def nvrtc(self):
|
||||
self.backend = "nvrtc"
|
||||
self.default_compile_options = self._nvrtc_compile_options
|
||||
|
||||
def nvcc(self):
|
||||
self.backend = "nvcc"
|
||||
self.default_compile_options = self._nvcc_compile_options
|
||||
|
||||
def insert_operation(self, op_key, cubin, hostfile, op_name, op_attrs):
|
||||
connection = sqlite3.connect(CACHE_FILE)
|
||||
cursor = connection.cursor()
|
||||
sqlite_insert_blob_query = """ INSERT OR IGNORE INTO compiled_operations (op_key, cubin, hostbin, op_name, op_attrs) VALUES (?, ?, ?, ?, ?)"""
|
||||
|
||||
hostbin = convertToBinaryData(hostfile)
|
||||
|
||||
data_tuple = (op_key, cubin, hostbin, op_name, json.dumps(op_attrs))
|
||||
|
||||
cursor.execute(sqlite_insert_blob_query, data_tuple)
|
||||
connection.commit()
|
||||
cursor.close()
|
||||
|
||||
def load_operation(self, op_key, extra_funcs):
|
||||
connection = sqlite3.connect(CACHE_FILE)
|
||||
cursor = connection.cursor()
|
||||
sqlite_fetch_blob_query = """SELECT * from compiled_operations where op_key = ?"""
|
||||
cursor.execute(sqlite_fetch_blob_query, (op_key,))
|
||||
record = cursor.fetchall()
|
||||
if len(record) == 0:
|
||||
return False
|
||||
for row in record:
|
||||
key, cubin_image, host_binary, operation_name, op_attr = row
|
||||
op_attr = json.loads(op_attr)
|
||||
err, module = cuda.cuModuleLoadData(cubin_image)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("Cuda Error: {}".format(err))
|
||||
|
||||
err, kernel = cuda.cuModuleGetFunction(module, bytes(str.encode(operation_name)))
|
||||
self.compiled_cache_device[key] = kernel
|
||||
|
||||
compiled_host_fns = {}
|
||||
host_lib = CDLLBin(host_binary)
|
||||
|
||||
func_name = operation_name + "_get_params"
|
||||
func = getattr(host_lib, func_name)
|
||||
func.restype = ctypes.POINTER(ctypes.c_char * op_attr[0])
|
||||
compiled_host_fns["get_args"] = func
|
||||
|
||||
func_name = operation_name + "_shared_memory_size"
|
||||
func = getattr(host_lib, func_name)
|
||||
compiled_host_fns["shared_memory_capacity"] = func()
|
||||
|
||||
for attr in op_attr:
|
||||
if isinstance(attr, str):
|
||||
func_name = operation_name + "_" + attr
|
||||
func = getattr(host_lib, func_name)
|
||||
|
||||
# Set the return type of the function
|
||||
if attr in extra_funcs and extra_funcs[attr] != None:
|
||||
func.restype = extra_funcs[attr]
|
||||
|
||||
compiled_host_fns[attr] = func
|
||||
|
||||
self.compiled_cache_host[key] = compiled_host_fns
|
||||
return True
|
||||
|
||||
def emit_compile_(self, operation_list, compilation_options, host_compilation_options):
|
||||
"""
|
||||
Compile a list of kernels and store them into database
|
||||
"""
|
||||
source_buffer_device = ""
|
||||
source_buffer_host = ""
|
||||
# 1. include
|
||||
includes = []
|
||||
for operation in operation_list:
|
||||
for incl in operation.emitter.includes:
|
||||
if incl not in includes:
|
||||
includes.append(incl)
|
||||
|
||||
includes_host = ["builtin_types.h", "device_launch_parameters.h", "cstddef"] + includes
|
||||
for incl in includes:
|
||||
source_buffer_device += SubstituteTemplate(
|
||||
IncludeTemplate,
|
||||
{"include": incl},
|
||||
)
|
||||
|
||||
for incl in includes_host:
|
||||
source_buffer_host += SubstituteTemplate(
|
||||
IncludeTemplate,
|
||||
{"include": incl},
|
||||
)
|
||||
|
||||
# 2. Operations
|
||||
for operation in operation_list:
|
||||
source_buffer_device += operation.emit()
|
||||
source_buffer_host += operation.emit()
|
||||
values = {
|
||||
"operation_name": operation.name(),
|
||||
"operation_suffix": operation.emitter.operation_suffix,
|
||||
}
|
||||
source_buffer_device += SubstituteTemplate(
|
||||
operation.KernelTemplate,
|
||||
values,
|
||||
)
|
||||
source_buffer_host += SubstituteTemplate(operation.HostTemplate, values)
|
||||
|
||||
if self.backend == "nvrtc":
|
||||
# 3. compile
|
||||
err, program = nvrtc.nvrtcCreateProgram(
|
||||
str.encode(source_buffer_device),
|
||||
bytes(str.encode("module.cu")),
|
||||
0, [], [])
|
||||
|
||||
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise RuntimeError("NVRTC Error: {}".format(err))
|
||||
|
||||
# Compile program
|
||||
options = compilation_options.get()
|
||||
|
||||
err, = nvrtc.nvrtcCompileProgram(program, len(options), options)
|
||||
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
error_string = "NVRTC Error: {}\n".format(err)
|
||||
|
||||
# Get log from compilation
|
||||
err, logSize = nvrtc.nvrtcGetProgramLogSize(program)
|
||||
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise RuntimeError("NVRTC Error: {}".format(err))
|
||||
|
||||
log = b" " * logSize
|
||||
err, = nvrtc.nvrtcGetProgramLog(program, log)
|
||||
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise RuntimeError("NVRTC Error: {}".format(err))
|
||||
|
||||
raise RuntimeError(error_string + log.decode() + source_buffer_device)
|
||||
|
||||
# Get data from compilation
|
||||
err, dataSize = nvrtc.nvrtcGetCUBINSize(program)
|
||||
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise RuntimeError("NVRTC Error: {}".format(err))
|
||||
|
||||
cubin_image = b" " * dataSize
|
||||
(err,) = nvrtc.nvrtcGetCUBIN(program, cubin_image)
|
||||
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise RuntimeError("NVRTC Error: {}".format(err))
|
||||
|
||||
else: # with nvcc backend
|
||||
# emit code
|
||||
tempfile.tempdir = "./"
|
||||
temp_cu = tempfile.NamedTemporaryFile(
|
||||
prefix="kernel", suffix=".cu", delete=True)
|
||||
temp_cubin = tempfile.NamedTemporaryFile(
|
||||
prefix="kernel", suffix=".cubin", delete=True)
|
||||
with open(temp_cu.name, "w") as file:
|
||||
file.write(source_buffer_device)
|
||||
|
||||
# compile with nvcc
|
||||
cmd_template = "${cuda_install_path}/bin/nvcc ${options} -cubin ${srcfile} -o ${tarfile}"
|
||||
values = {
|
||||
"cuda_install_path": cuda_install_path(),
|
||||
"options": compilation_options.get_str(),
|
||||
"srcfile": temp_cu.name,
|
||||
"tarfile": temp_cubin.name,
|
||||
}
|
||||
cmd = SubstituteTemplate(cmd_template, values)
|
||||
compile_with_nvcc(cmd.split(" "), source_buffer_device, "./cutlass_python_compilation_device_error.txt")
|
||||
|
||||
# load the cubin image
|
||||
with open(temp_cubin.name, "rb") as file:
|
||||
cubin_image = file.read()
|
||||
|
||||
tempfile.tempdir = "./"
|
||||
temp_src = tempfile.NamedTemporaryFile(
|
||||
prefix="host_src", suffix=".cu", delete=True)
|
||||
|
||||
# Write the host source
|
||||
with open(temp_src.name, "w") as outfile:
|
||||
outfile.write(source_buffer_host)
|
||||
|
||||
temp_dst = tempfile.NamedTemporaryFile(
|
||||
prefix="host_func", suffix=".so", delete=True)
|
||||
|
||||
# Set up host compilation arguments
|
||||
cmd = []
|
||||
cmd.append(f"{cuda_install_path()}/bin/nvcc")
|
||||
cmd.extend(["-x", "cu", "-Xcompiler=-fpermissive", "-Xcompiler=-w", "-Xcompiler=-fPIC"])
|
||||
cmd.extend(host_compilation_options.get_str().split(" "))
|
||||
cmd.extend(["-shared", "-o", temp_dst.name, temp_src.name, "-lcudart", "-lcuda"])
|
||||
|
||||
# Comile and load the library
|
||||
compile_with_nvcc( cmd, source_buffer_host, error_file="./cutlass_python_compilation_host_error.txt")
|
||||
host_lib = ctypes.CDLL(temp_dst.name)
|
||||
|
||||
return cubin_image, host_lib, temp_dst
|
||||
|
||||
def add_module(self, operations, compile_options=None, bypass_cache=False):
|
||||
"""
|
||||
Insert a new compiled device module
|
||||
"""
|
||||
include_paths = [
|
||||
cuda_install_path() + "/include",
|
||||
CUTLASS_PATH + "/include",
|
||||
CUTLASS_PATH + "/tools/util/include",
|
||||
CUTLASS_PATH + "/python/cutlass/cpp/include",
|
||||
]
|
||||
|
||||
cutlass_cppgen.initialize_cuda_context()
|
||||
arch = device_cc()
|
||||
|
||||
host_compile_options = CompilationOptions(
|
||||
self._nvcc_compile_options, arch, include_paths)
|
||||
if compile_options is None:
|
||||
compile_options = CompilationOptions(
|
||||
self.default_compile_options, arch, include_paths)
|
||||
# save the cubin
|
||||
operation_key = []
|
||||
operation_list = []
|
||||
for operation in operations:
|
||||
# step 1: get kernel string as key
|
||||
key = operation.rt_module.emit() + operation.procedural_name() + self.backend
|
||||
# step 1: check if the operation is in cache
|
||||
compiled_kernel = self.compiled_cache_device.get(key)
|
||||
|
||||
if compiled_kernel is None and not bypass_cache:
|
||||
hit = self.load_operation(key, getattr( operation.rt_module, "extra_funcs", {}))
|
||||
if hit:
|
||||
compiled_kernel = self.compiled_cache_device.get(key)
|
||||
assert compiled_kernel is not None
|
||||
if compiled_kernel is not None:
|
||||
operation.rt_module.kernel = compiled_kernel
|
||||
compiled_host_fns = self.compiled_cache_host.get(key)
|
||||
assert compiled_host_fns is not None
|
||||
for key in compiled_host_fns.keys():
|
||||
setattr(operation.rt_module, key, compiled_host_fns[key])
|
||||
operation.rt_module.initialize()
|
||||
else:
|
||||
operation_list.append(operation.rt_module)
|
||||
operation_key.append(key)
|
||||
|
||||
if len(operation_list) > 0:
|
||||
cubin_image, host_lib, host_file = self.emit_compile_(
|
||||
operation_list, compile_options, host_compile_options)
|
||||
|
||||
err, module = cuda.cuModuleLoadData(cubin_image)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("Cuda Error: {}".format(err))
|
||||
|
||||
operation_name = []
|
||||
operation_attr = []
|
||||
for operation, key in zip(operation_list, operation_key):
|
||||
# get device kernels
|
||||
err, operation.kernel = cuda.cuModuleGetFunction(
|
||||
module,
|
||||
bytes(str.encode(operation.name()))
|
||||
)
|
||||
operation_name.append(operation.name())
|
||||
self.compiled_cache_device[key] = operation.kernel
|
||||
# get host functions
|
||||
compiled_host_fns = {}
|
||||
op_attr = []
|
||||
|
||||
# get param size
|
||||
func_name = operation.name() + "_get_param_size"
|
||||
func = getattr(host_lib, func_name)
|
||||
param_size = func()
|
||||
|
||||
func_name = operation.name() + "_get_params"
|
||||
func = getattr(host_lib, func_name)
|
||||
func.argtype = operation.argtype
|
||||
func.restype = ctypes.POINTER(ctypes.c_char * param_size)
|
||||
setattr(operation, "get_args", func)
|
||||
compiled_host_fns["get_args"] = func
|
||||
|
||||
# set shared memory size
|
||||
func_name = operation.name() + "_shared_memory_size"
|
||||
func = getattr(host_lib, func_name)
|
||||
setattr(operation, "shared_memory_capacity", func())
|
||||
compiled_host_fns["shared_memory_capacity"] = func()
|
||||
# set the maximum dynamic shared size
|
||||
operation.initialize()
|
||||
|
||||
# get extra functions
|
||||
op_attr.append(param_size)
|
||||
|
||||
if hasattr(operation, "extra_funcs"):
|
||||
for suffix, ret_type in operation.extra_funcs.items():
|
||||
func_name = operation.name() + "_" + suffix
|
||||
func = getattr(host_lib, func_name)
|
||||
if ret_type is not None:
|
||||
func.restype = ret_type
|
||||
setattr(operation, suffix, func)
|
||||
compiled_host_fns[suffix] = func
|
||||
op_attr.append(suffix)
|
||||
|
||||
operation_attr.append(op_attr)
|
||||
self.compiled_cache_host[key] = compiled_host_fns
|
||||
|
||||
for (key, operation_name, operation_attr,) in zip(operation_key, operation_name, operation_attr):
|
||||
self.insert_operation(
|
||||
key, cubin_image, host_file.name, operation_name, operation_attr)
|
||||
700
python/cutlass_cppgen/backend/conv2d_operation.py
Normal file
700
python/cutlass_cppgen/backend/conv2d_operation.py
Normal file
@ -0,0 +1,700 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
from __future__ import annotations
|
||||
|
||||
import ctypes
|
||||
from typing import Union
|
||||
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
from cutlass_library import SubstituteTemplate
|
||||
import numpy as np
|
||||
|
||||
from cutlass_library import (
|
||||
ConvKindNames,
|
||||
ConvKindTag,
|
||||
DataTypeNames,
|
||||
DataTypeSize,
|
||||
DataTypeTag,
|
||||
IteratorAlgorithmNames,
|
||||
IteratorAlgorithmTag,
|
||||
LayoutTag,
|
||||
LayoutType,
|
||||
MathOperation,
|
||||
MathOperationTag,
|
||||
OpcodeClass,
|
||||
OpcodeClassNames,
|
||||
OpcodeClassTag,
|
||||
OperationKind,
|
||||
ShortDataTypeNames,
|
||||
ShortLayoutTypeNames,
|
||||
SplitKMode,
|
||||
StrideSupport,
|
||||
StrideSupportTag,
|
||||
SwizzlingFunctor,
|
||||
SwizzlingFunctorTag,
|
||||
get_complex_from_real,
|
||||
)
|
||||
|
||||
from cutlass_cppgen.backend.arguments import ArgumentBase
|
||||
from cutlass_cppgen.backend.c_types import dim3_, get_conv2d_arguments
|
||||
from cutlass_cppgen.backend.library import (
|
||||
EmissionType,
|
||||
TensorDescription,
|
||||
TileDescription,
|
||||
)
|
||||
from cutlass_cppgen.backend.memory_manager import device_mem_alloc
|
||||
from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration
|
||||
from cutlass_cppgen.backend.utils.device import to_device_ptr
|
||||
from cutlass_cppgen.shape import GemmCoord
|
||||
|
||||
|
||||
class Conv2dArguments(ArgumentBase):
|
||||
"""
|
||||
Argument wrapper for Conv2d. It encodes problem information and
|
||||
user-provide tensors into the kernel's argument.
|
||||
|
||||
:param operation: the Conv2d operation to take the argument
|
||||
:type operation: :class:`cutlass_cppgen.backend.Conv2dOperation`
|
||||
:param problem_size: the Conv2d problem size
|
||||
:type problem_size: :class:`cutlass_cppgen.shape.Conv2dProblemSize`
|
||||
:param A: tensor A
|
||||
:type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
||||
:param B: tensor B
|
||||
:type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
||||
:param C: tensor C
|
||||
:type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
||||
:param D: tensor D
|
||||
:type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
||||
:param split_k_mode: conv2d split K mode, defaults to cutlass_library.library.SplitKMode.Serial
|
||||
:type split_k_mode: cutlass_library.library.SplitKMode, optional
|
||||
:param output_op: output operator, optional
|
||||
:type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
|
||||
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
||||
:type stream: :class:`cuda.cuda.CUstream`
|
||||
"""
|
||||
|
||||
def __init__(self, operation, problem_size, A, B, C, D,
|
||||
split_k_mode=SplitKMode.Serial, **kwargs, ) -> None:
|
||||
self.operation = operation
|
||||
self.conv_kind = operation.conv_kind
|
||||
self.layout_A = operation.A.layout
|
||||
self.layout_B = operation.B.layout
|
||||
self.layout_C = operation.C.layout
|
||||
|
||||
self.element_A = operation.A.element
|
||||
self.element_B = operation.B.element
|
||||
self.element_C = operation.C.element
|
||||
|
||||
if self.layout_C == LayoutType.TensorNC32HW32:
|
||||
raise Exception("Layout type TensorNC32HW32 is not currently supported")
|
||||
|
||||
super().__init__(A, B, C, D, **kwargs)
|
||||
|
||||
if "split_k_slices" in kwargs.keys() and kwargs["split_k_slices"] > 1:
|
||||
self.split_k_mode = split_k_mode
|
||||
self.split_k_slices = kwargs["split_k_slices"]
|
||||
else:
|
||||
self.split_k_mode = SplitKMode.Serial
|
||||
self.split_k_slices = 1
|
||||
|
||||
if "output_op" in kwargs.keys() and self.split_k_mode != SplitKMode.Parallel:
|
||||
self.output_op = kwargs["output_op"]
|
||||
else:
|
||||
self.output_op = self.operation.epilogue_type(1.0, 0.0)
|
||||
|
||||
self.problem_size = problem_size
|
||||
self.problem_size.split_k_slices = self.split_k_slices
|
||||
|
||||
self.initialize()
|
||||
|
||||
def get_arguments(self):
|
||||
tc_numel = -1
|
||||
if hasattr(self, "tensor_c_numel"):
|
||||
tc_numel = self.tensor_c_numel
|
||||
|
||||
self.c_arguments = self.operation.argument_type(
|
||||
int(self.conv_kind),
|
||||
self.problem_size.ctype,
|
||||
int(to_device_ptr(self.ptr_A)),
|
||||
int(to_device_ptr(self.ptr_B)),
|
||||
int(to_device_ptr(self.ptr_C)),
|
||||
int(to_device_ptr(self.ptr_D)),
|
||||
tc_numel,
|
||||
self.output_op,
|
||||
int(self.split_k_mode)
|
||||
)
|
||||
|
||||
def initialize(self):
|
||||
self.launch_config = self.operation.rt_module.plan(self)
|
||||
|
||||
self.get_arguments()
|
||||
|
||||
# Allocate and initialize device workspace
|
||||
device_workspace_size = self.operation.rt_module.get_workspace_size(self.c_arguments)
|
||||
if device_workspace_size > 0:
|
||||
self.workspace_buffer = device_mem_alloc(device_workspace_size)
|
||||
workspace_ptr = self.workspace_buffer.ptr
|
||||
err, = cuda.cuMemsetD32(
|
||||
workspace_ptr, 0, device_workspace_size // 4)
|
||||
else:
|
||||
workspace_ptr = None
|
||||
|
||||
self.semaphore = 0
|
||||
if workspace_ptr is not None and self.split_k_mode == SplitKMode.Parallel:
|
||||
self.ptr_D = workspace_ptr
|
||||
# Reset arguments now that ptr_D has been updated
|
||||
self.get_arguments()
|
||||
elif workspace_ptr is not None and self.split_k_mode == SplitKMode.Serial:
|
||||
self.semaphore = workspace_ptr
|
||||
|
||||
params_ = self.operation.rt_module.get_args(
|
||||
self.c_arguments, ctypes.c_void_p(int(self.semaphore)))
|
||||
self.host_workspace = bytearray(params_.contents)
|
||||
self.device_workspace = None
|
||||
|
||||
def sync(self):
|
||||
"""
|
||||
Synchronize the arguments. If the input tensor is in host,
|
||||
copy it from device to host.
|
||||
"""
|
||||
return super().sync()
|
||||
|
||||
|
||||
class Conv2dRT(ExecutableOperation):
|
||||
"""
|
||||
Conv2dRT manages the CUTLASS runtime components
|
||||
"""
|
||||
|
||||
KernelTemplate = r"""
|
||||
extern "C"
|
||||
__global__ void
|
||||
${operation_name}(${operation_name}${operation_suffix}::Params params) {
|
||||
|
||||
// Dynamic shared memory base pointer
|
||||
extern __shared__ int SharedStorageBase[];
|
||||
|
||||
// Declare pointer to dynamic shared memory.
|
||||
${operation_name}${operation_suffix}::SharedStorage *shared_storage =
|
||||
reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase);
|
||||
|
||||
${operation_name}${operation_suffix} op;
|
||||
|
||||
op(params, *shared_storage);
|
||||
}
|
||||
"""
|
||||
|
||||
HostTemplate = r"""
|
||||
extern "C" {
|
||||
// Get the size of params in bytes
|
||||
int ${operation_name}_get_param_size(){
|
||||
return sizeof(${operation_name}${operation_suffix}::Params);
|
||||
}
|
||||
|
||||
// Get the size of dynamic shared memory in bytes
|
||||
int ${operation_name}_shared_memory_size() {
|
||||
return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
|
||||
}
|
||||
|
||||
using ElementA = typename ${operation_name}_base::ElementA;
|
||||
using ElementB = typename ${operation_name}_base::ElementB;
|
||||
using ElementC = typename ${operation_name}_base::ElementC;
|
||||
using LayoutA = typename ${operation_name}_base::LayoutA;
|
||||
using LayoutB = typename ${operation_name}_base::LayoutB;
|
||||
using LayoutC = typename ${operation_name}_base::LayoutC;
|
||||
using EpilogueOutputOp = typename ${operation_name}_base::EpilogueOutputOp;
|
||||
|
||||
struct ${operation_name}_TemporaryArgs {
|
||||
int conv_kind;
|
||||
cutlass::conv::Conv2dProblemSize problem_size;
|
||||
ElementA* ptr_A;
|
||||
ElementB* ptr_B;
|
||||
ElementC* ptr_C;
|
||||
ElementC* ptr_D;
|
||||
int tensor_c_numel;
|
||||
typename EpilogueOutputOp::Params epilogue_params;
|
||||
int split_k_mode;
|
||||
};
|
||||
|
||||
typename ${operation_name}${operation_suffix}::Arguments
|
||||
construct_arguments(${operation_name}_TemporaryArgs args) {
|
||||
cutlass::conv::Operator conv_operator = static_cast<cutlass::conv::Operator>(args.conv_kind);
|
||||
auto tc_A = cutlass::conv::implicit_gemm_tensor_a_extent(conv_operator, args.problem_size);
|
||||
auto tc_B = cutlass::conv::implicit_gemm_tensor_b_extent(conv_operator, args.problem_size);
|
||||
auto tc_C = cutlass::conv::implicit_gemm_tensor_c_extent(conv_operator, args.problem_size);
|
||||
auto tc_D = cutlass::conv::implicit_gemm_tensor_c_extent(conv_operator, args.problem_size);
|
||||
|
||||
auto size_C = tc_C.at(0) * tc_C.at(1) * tc_C.at(2) * tc_C.at(3);
|
||||
if (args.tensor_c_numel >= 0 && args.tensor_c_numel == tc_C.at(3) && args.tensor_c_numel < size_C) {
|
||||
// C is interpreted as bias
|
||||
tc_C = {0, 0, 0, 0};
|
||||
}
|
||||
|
||||
cutlass::TensorRef<ElementA, LayoutA> tref_A(args.ptr_A, LayoutA::packed(tc_A));
|
||||
cutlass::TensorRef<ElementB, LayoutA> tref_B(args.ptr_B, LayoutB::packed(tc_B));
|
||||
cutlass::TensorRef<ElementC, LayoutA> tref_C(args.ptr_C, LayoutC::packed(tc_C));
|
||||
cutlass::TensorRef<ElementC, LayoutA> tref_D(args.ptr_D, LayoutC::packed(tc_D));
|
||||
|
||||
return {
|
||||
args.problem_size,
|
||||
tref_A,
|
||||
tref_B,
|
||||
tref_C,
|
||||
tref_D,
|
||||
args.epilogue_params,
|
||||
static_cast<cutlass::conv::SplitKMode>(args.split_k_mode)
|
||||
};
|
||||
}
|
||||
|
||||
// Get the params as byte array
|
||||
char* ${operation_name}_get_params(${operation_name}_TemporaryArgs args, int *semaphore=nullptr) {
|
||||
auto arguments = construct_arguments(args);
|
||||
typename ${operation_name}${operation_suffix}::Params* params;
|
||||
params = new ${operation_name}${operation_suffix}::Params(arguments, semaphore);
|
||||
|
||||
char *bytes = ((char*)(params));
|
||||
char *output = new char[sizeof(${operation_name}${operation_suffix}::Params)];
|
||||
for (unsigned int i = 0; i < sizeof(${operation_name}${operation_suffix}::Params); i ++)
|
||||
output[i] = bytes[i];
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
dim3 ${operation_name}_get_grid_shape(
|
||||
int conv_kind,
|
||||
cutlass::conv::Conv2dProblemSize problem_size,
|
||||
cutlass::gemm::GemmCoord tile_size,
|
||||
int split_k_slices
|
||||
) {
|
||||
|
||||
using Swizzle = typename ${operation_name}_base::ThreadblockSwizzle;
|
||||
auto tiled_shape = Swizzle::get_tiled_shape(
|
||||
static_cast<cutlass::conv::Operator>(conv_kind),
|
||||
problem_size,
|
||||
tile_size,
|
||||
split_k_slices);
|
||||
|
||||
return Swizzle::get_grid_shape(tiled_shape);
|
||||
}
|
||||
|
||||
size_t ${operation_name}_get_workspace_size(${operation_name}_TemporaryArgs args) {
|
||||
auto arguments = construct_arguments(args);
|
||||
|
||||
// Temporarily define device::-level Conv2d so that we can call get_workspace_size
|
||||
using DeviceConv = cutlass::conv::device::ImplicitGemmConvolution<${operation_name}_base>;
|
||||
return DeviceConv::get_workspace_size(arguments);
|
||||
}
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, operation: "Conv2dOperation"):
|
||||
super().__init__(operation)
|
||||
self.extra_funcs = {
|
||||
"get_grid_shape": dim3_,
|
||||
"get_workspace_size": ctypes.c_uint64
|
||||
}
|
||||
self.argument_type, self.epilogue_type = get_conv2d_arguments(operation.epilogue_functor)
|
||||
self.argtype = [ctypes.POINTER(self.argument_type), ctypes.c_void_p]
|
||||
self.conv_kind = operation.conv_kind
|
||||
|
||||
self.operation: Conv2dOperation = operation
|
||||
|
||||
self.emitter = EmitConv2dInstance("_type")
|
||||
|
||||
self.threads = operation.tile_description.num_threads
|
||||
|
||||
self.swizzle_functor = operation.swizzling_functor
|
||||
|
||||
def emit(self):
|
||||
return self.emitter.emit(self.operation)
|
||||
|
||||
def plan(self, arguments: Conv2dArguments):
|
||||
tile_size = GemmCoord(
|
||||
self.operation.tile_description.threadblock_shape[0],
|
||||
self.operation.tile_description.threadblock_shape[1],
|
||||
self.operation.tile_description.threadblock_shape[2],
|
||||
)
|
||||
|
||||
grid = self.get_grid_shape(
|
||||
int(self.conv_kind),
|
||||
arguments.problem_size.ctype,
|
||||
tile_size.ctype,
|
||||
arguments.split_k_slices
|
||||
)
|
||||
|
||||
return LaunchConfiguration(
|
||||
[grid.x, grid.y, grid.z], [self.threads, 1, 1],
|
||||
self.shared_memory_capacity)
|
||||
|
||||
def initialize(self):
|
||||
err, = cuda.cuFuncSetAttribute(
|
||||
self.kernel,
|
||||
attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
value=self.shared_memory_capacity)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError(f"CUDA Error: {err}")
|
||||
|
||||
|
||||
class Conv2dOperation:
|
||||
"""
|
||||
CUTLASS Conv2d operation description.
|
||||
|
||||
:param conv_kind: convolution operator
|
||||
:type conv_kind: :class:`cutlass_library.library.ConvKind`
|
||||
|
||||
:param iterator_algorithm: Selects among several implementation
|
||||
variants trading off performance with simplicity
|
||||
:type iterator_algorithm: :class:`cutlass_library.library.IteratorAlgorithm`
|
||||
|
||||
:param arch: GPU compute capability (sm_xx)
|
||||
:type arch: int
|
||||
|
||||
:param tile_description: tile description
|
||||
:type tile_description: :class:`cutlass_cppgen.backend.TileDescription`
|
||||
|
||||
:param A: tensor A description
|
||||
:type A: :class:`cutlass_cppgen.backend.TensorDescription`
|
||||
|
||||
:param B: tensor B description
|
||||
:type B: :class:`cutlass_cppgen.backend.TensorDescription`
|
||||
|
||||
:param C: tensor C description
|
||||
:type C: :class:`cutlass_cppgen.backend.TensorDescription`
|
||||
|
||||
:param D: tensor D description
|
||||
:type D: :class:`cutlass_cppgen.backend.TensorDescription`
|
||||
|
||||
:param element_epilogue: element type for computation in epilogue \
|
||||
:type element_epilogue: cutlass_library.library.DataType
|
||||
|
||||
:param stride_support: distinguish among partial specializations that \
|
||||
accelerate certain problems where convolution stride is unit \
|
||||
:type stride_support: :class:`cutlass_library.library.StrideSupport`
|
||||
|
||||
:param epilogue_functor: convolution epilogue functor
|
||||
:type epilogue_functor: :class:`EpilogueFunctor`
|
||||
|
||||
:param swizzling_functor: threadblock swizzling functor
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
conv_kind,
|
||||
iterator_algorithm,
|
||||
arch: int,
|
||||
tile_description: TileDescription,
|
||||
A: TensorDescription,
|
||||
B: TensorDescription,
|
||||
C: TensorDescription,
|
||||
stride_support,
|
||||
epilogue_functor,
|
||||
swizzling_functor=SwizzlingFunctor.Identity1,
|
||||
emission_type=EmissionType.Kernel,
|
||||
**kwargs
|
||||
):
|
||||
self.operation_kind: OperationKind = OperationKind.Conv2d
|
||||
self.arch: int = arch
|
||||
self.tile_description: TileDescription = tile_description
|
||||
self.conv_kind = conv_kind
|
||||
self.A: TensorDescription = A
|
||||
self.B: TensorDescription = B
|
||||
self.C: TensorDescription = C
|
||||
self.epilogue_functor = epilogue_functor
|
||||
self.iterator_algorithm = iterator_algorithm
|
||||
self.stride_support = stride_support
|
||||
self.swizzling_functor = swizzling_functor
|
||||
|
||||
self.emission_type = emission_type
|
||||
|
||||
self.rt_module: Conv2dRT = Conv2dRT(self)
|
||||
self.argument_type = self.rt_module.argument_type
|
||||
self.epilogue_type = self.rt_module.epilogue_type
|
||||
|
||||
def run(self, arguments: Conv2dArguments) -> cuda.CUresult:
|
||||
"""
|
||||
Launch the cuda kernel with input arguments
|
||||
|
||||
:param arguments: conv2d arguments
|
||||
:type arguments: :class:`cutlass_cppgen.backend.Conv2dArguments`
|
||||
"""
|
||||
|
||||
# launch the kernel
|
||||
err = self.rt_module.run(
|
||||
arguments.host_workspace,
|
||||
arguments.device_workspace,
|
||||
arguments.launch_config,
|
||||
arguments.stream
|
||||
)
|
||||
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError(f"CUDA Error {err}")
|
||||
|
||||
return err
|
||||
|
||||
#
|
||||
# Get function name
|
||||
#
|
||||
|
||||
def procedural_name(self):
|
||||
"""The full procedural name indicates architecture, extended name, tile size, and layout."""
|
||||
return self.configuration_name()
|
||||
|
||||
def configuration_name(self):
|
||||
"""The full procedural name indicates architecture, extended name, tile size, and layout."""
|
||||
|
||||
opcode_class_name = OpcodeClassNames[
|
||||
self.tile_description.math_instruction.opcode_class
|
||||
]
|
||||
|
||||
threadblock = "%dx%d_%dx%d" % (
|
||||
self.tile_description.threadblock_shape[0],
|
||||
self.tile_description.threadblock_shape[1],
|
||||
self.tile_description.threadblock_shape[2],
|
||||
self.tile_description.stages,
|
||||
)
|
||||
|
||||
if self.stride_support == StrideSupport.Unity:
|
||||
configuration_name = "cutlass_sm${arch}_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_align${alignment}"
|
||||
else:
|
||||
configuration_name = "cutlass_sm${arch}_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}"
|
||||
|
||||
return SubstituteTemplate(
|
||||
configuration_name,
|
||||
{
|
||||
"arch": str(self.arch),
|
||||
"opcode_class": opcode_class_name,
|
||||
"extended_name": self.extended_name(),
|
||||
"threadblock": threadblock,
|
||||
"layout": self.layout_name(),
|
||||
"alignment": "%d" % self.A.alignment
|
||||
},
|
||||
)
|
||||
|
||||
def extended_name(self):
|
||||
"""Append data types if they differ from compute type."""
|
||||
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
|
||||
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
||||
extended_name = "${element_c}_${core_name}_${element_a}"
|
||||
elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
|
||||
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
||||
extended_name = "${core_name}_${element_a}"
|
||||
else:
|
||||
extended_name = "${core_name}"
|
||||
|
||||
extended_name = SubstituteTemplate(extended_name, {
|
||||
"element_a": DataTypeNames[self.A.element],
|
||||
"element_c": DataTypeNames[self.C.element],
|
||||
"core_name": self.core_name(),
|
||||
})
|
||||
|
||||
return extended_name
|
||||
|
||||
def layout_name(self):
|
||||
return "%s" % (ShortLayoutTypeNames[self.A.layout])
|
||||
|
||||
def core_name(self):
|
||||
"""The basic operation kind is prefixed with a letter indicating the accumulation type."""
|
||||
|
||||
intermediate_type = ""
|
||||
|
||||
if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp:
|
||||
inst_shape = "%dx%dx%d" % tuple(
|
||||
self.tile_description.math_instruction.instruction_shape)
|
||||
if self.tile_description.math_instruction.element_a != self.A.element and \
|
||||
self.tile_description.math_instruction.element_a != self.accumulator_type():
|
||||
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
|
||||
else:
|
||||
inst_shape = ""
|
||||
|
||||
return "%s%s%s%s_%s" % (
|
||||
ShortDataTypeNames[self.accumulator_type()],
|
||||
inst_shape,
|
||||
intermediate_type,
|
||||
ConvKindNames[self.conv_kind],
|
||||
IteratorAlgorithmNames[self.iterator_algorithm]
|
||||
)
|
||||
|
||||
def is_complex(self):
|
||||
complex_operators = [
|
||||
MathOperation.multiply_add_complex,
|
||||
MathOperation.multiply_add_complex_gaussian,
|
||||
]
|
||||
return self.tile_description.math_instruction.math_operation in complex_operators
|
||||
|
||||
def accumulator_type(self):
|
||||
accum = self.tile_description.math_instruction.element_accumulator
|
||||
|
||||
if self.is_complex():
|
||||
return get_complex_from_real(accum)
|
||||
|
||||
return accum
|
||||
|
||||
def device_op(self):
|
||||
"""
|
||||
Returns a new Conv2dOperation object that is constructed with emission type
|
||||
``EmissionType.Device``.
|
||||
|
||||
:return: operation ready for device-level code emission
|
||||
:rtype: Conv2dOperation
|
||||
"""
|
||||
return Conv2dOperation(
|
||||
self.conv_kind, self.iterator_algorithm, self.arch, self.tile_description,
|
||||
self.A, self.B, self.C, self.stride_support, self.epilogue_functor, self.swizzling_functor,
|
||||
emission_type=EmissionType.Device)
|
||||
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
# Emits single instances of a CUTLASS device-wide operator
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
|
||||
class EmitConv2dInstance:
|
||||
def __init__(self, operation_suffix=""):
|
||||
self.operation_suffix = operation_suffix
|
||||
self.includes = [
|
||||
"cutlass/cutlass.h",
|
||||
"cutlass/conv/kernel/default_conv2d_fprop.h",
|
||||
"cutlass/conv/kernel/default_conv2d_dgrad.h",
|
||||
"cutlass/conv/kernel/default_conv2d_wgrad.h",
|
||||
"cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
]
|
||||
self.template = """
|
||||
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
|
||||
using ${operation_name}_base =
|
||||
typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
|
||||
${element_a},
|
||||
${layout_a},
|
||||
${element_b},
|
||||
${layout_b},
|
||||
${element_c},
|
||||
${layout_c},
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor},
|
||||
${swizzling_functor},
|
||||
${stages},
|
||||
${math_operator},
|
||||
${iterator_algorithm},
|
||||
${stride_support},
|
||||
${align_a},
|
||||
${align_b}
|
||||
>::Kernel;
|
||||
|
||||
struct ${operation_name}${operation_suffix}:
|
||||
public ${operation_name}_base { };
|
||||
|
||||
"""
|
||||
|
||||
self.template_device = """
|
||||
// Conv2d operation ${operation_name}
|
||||
|
||||
using Conv2d${conv_kind_name}Kernel = typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
|
||||
${element_a},
|
||||
${layout_a},
|
||||
${element_b},
|
||||
${layout_b},
|
||||
${element_c},
|
||||
${layout_c},
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor},
|
||||
${swizzling_functor},
|
||||
${stages},
|
||||
${math_operator},
|
||||
${iterator_algorithm},
|
||||
${stride_support},
|
||||
${align_a},
|
||||
${align_b}
|
||||
>::Kernel;
|
||||
|
||||
using DeviceKernel =
|
||||
typename cutlass::conv::device::ImplicitGemmConvolution<Conv2d${conv_kind_name}Kernel>;
|
||||
"""
|
||||
|
||||
def emit(self, operation):
|
||||
warp_shape = [int(operation.tile_description.threadblock_shape[idx] /
|
||||
operation.tile_description.warp_count[idx]) for idx in range(3)]
|
||||
|
||||
epilogue_vector_length = int(min(
|
||||
operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
|
||||
|
||||
values = {
|
||||
"operation_name": operation.procedural_name(),
|
||||
"operation_suffix": self.operation_suffix,
|
||||
"conv_kind": ConvKindTag[operation.conv_kind],
|
||||
"conv_kind_name": ConvKindNames[operation.conv_kind].capitalize(),
|
||||
"element_a": DataTypeTag[operation.A.element],
|
||||
"layout_a": LayoutTag[operation.A.layout],
|
||||
"element_b": DataTypeTag[operation.B.element],
|
||||
"layout_b": LayoutTag[operation.B.layout],
|
||||
"element_c": DataTypeTag[operation.C.element],
|
||||
"layout_c": LayoutTag[operation.C.layout],
|
||||
"element_accumulator": DataTypeTag[operation.accumulator_type()],
|
||||
"opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
||||
"arch": "cutlass::arch::Sm%d" % operation.arch,
|
||||
"threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
|
||||
"threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
|
||||
"threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
|
||||
"warp_shape_m": str(warp_shape[0]),
|
||||
"warp_shape_n": str(warp_shape[1]),
|
||||
"warp_shape_k": str(warp_shape[2]),
|
||||
"instruction_shape_m": str(operation.tile_description.math_instruction.instruction_shape[0]),
|
||||
"instruction_shape_n": str(operation.tile_description.math_instruction.instruction_shape[1]),
|
||||
"instruction_shape_k": str(operation.tile_description.math_instruction.instruction_shape[2]),
|
||||
"epilogue_vector_length": str(epilogue_vector_length),
|
||||
"epilogue_functor": operation.epilogue_functor.emit(),
|
||||
"swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
|
||||
"stages": str(operation.tile_description.stages),
|
||||
"iterator_algorithm": IteratorAlgorithmTag[operation.iterator_algorithm],
|
||||
"iterator_algorithm_name": IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(),
|
||||
"stride_support": StrideSupportTag[operation.stride_support],
|
||||
"math_operator": "cutlass::arch::OpMultiplyAddComplex" if operation.is_complex() else MathOperationTag[operation.tile_description.math_instruction.math_operation],
|
||||
"align_a": str(operation.A.alignment),
|
||||
"align_b": str(operation.B.alignment),
|
||||
}
|
||||
|
||||
if operation.emission_type == EmissionType.Kernel:
|
||||
conv2d_template = self.template
|
||||
else:
|
||||
conv2d_template = self.template_device
|
||||
|
||||
return SubstituteTemplate(conv2d_template, values)
|
||||
541
python/cutlass_cppgen/backend/epilogue.py
Normal file
541
python/cutlass_cppgen/backend/epilogue.py
Normal file
@ -0,0 +1,541 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
import ctypes
|
||||
|
||||
from cutlass_library import SubstituteTemplate
|
||||
import numpy as np
|
||||
|
||||
from cutlass_library import DataType, DataTypeTag
|
||||
from cutlass_cppgen.backend.c_types import MatrixCoord_, tuple_factory
|
||||
from cutlass_cppgen.backend.frontend import NumpyFrontend
|
||||
from cutlass_cppgen.backend.library import ActivationOp, ActivationOpTag
|
||||
from cutlass_cppgen.utils.datatypes import is_numpy_tensor, is_torch_available, is_torch_tensor
|
||||
|
||||
dtype2ctype = {
|
||||
DataType.f16: ctypes.c_uint16,
|
||||
DataType.bf16: ctypes.c_uint16,
|
||||
DataType.f32: ctypes.c_float,
|
||||
DataType.f64: ctypes.c_double,
|
||||
DataType.s8: ctypes.c_int8,
|
||||
DataType.s32: ctypes.c_int32
|
||||
}
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def get_scalar(value):
|
||||
"""
|
||||
Returns a scalar value from a container (e.g., np.ndarray)
|
||||
"""
|
||||
if is_numpy_tensor(value):
|
||||
if value.size != 1:
|
||||
raise Exception("Scalars used in epilogue must be of size 1")
|
||||
return value.reshape(-1)[0]
|
||||
elif is_torch_tensor(value):
|
||||
if value.size != 1:
|
||||
raise Exception("Scalars used in epilogue must be of size 1")
|
||||
return value.reshape(-1)[0]
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
def to_ctype_value(value, dtype):
|
||||
"""
|
||||
Converts ``value`` to the corresponding storage needed for the ctype that
|
||||
will store ``value``.
|
||||
"""
|
||||
scalar = get_scalar(value)
|
||||
if dtype == DataType.f16:
|
||||
# Convert f16 value into an integer
|
||||
return int.from_bytes(np.float16(scalar).tobytes(), "little")
|
||||
else:
|
||||
return scalar
|
||||
|
||||
|
||||
#################################################################################################
|
||||
#
|
||||
# Epilogue Functors
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
class EpilogueFunctorBase:
|
||||
"""
|
||||
Base class for thread-level epilogue functors
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def emit(self, tag, template_argument):
|
||||
template = """${tag}<${arguments}>"""
|
||||
arguments = ""
|
||||
for idx, arg in enumerate(template_argument):
|
||||
arguments += arg
|
||||
if idx < len(template_argument) - 1:
|
||||
arguments += ", "
|
||||
values = {
|
||||
"tag": tag,
|
||||
"arguments": arguments,
|
||||
}
|
||||
|
||||
return SubstituteTemplate(template, values)
|
||||
|
||||
|
||||
class LinearCombination(EpilogueFunctorBase):
|
||||
"""
|
||||
Apply a linear combination operator to an array of elements
|
||||
D = alpha * accumulator + beta * source
|
||||
|
||||
:param element_output: data type used to load and store tensors
|
||||
|
||||
:param epilogue_vector_length: number of elements computed per operation.
|
||||
Usually it is 128/sizeof_bits_v<ElementOutput_>, but we use 64 and 32 sometimes
|
||||
when there are not enough data to store
|
||||
|
||||
:param element_accumulator: Accumulator data type
|
||||
|
||||
:param element_epilogue: data type used to compute linear combination
|
||||
"""
|
||||
|
||||
tag = "cutlass::epilogue::thread::LinearCombination"
|
||||
|
||||
def __init__(
|
||||
self, element_output, epilogue_vector_length,
|
||||
element_accumulator=None, element_epilogue=None) -> None:
|
||||
super().__init__()
|
||||
|
||||
if element_accumulator is None:
|
||||
element_accumulator = element_output
|
||||
if element_epilogue is None:
|
||||
element_epilogue = element_output
|
||||
|
||||
self.element_output = element_output
|
||||
self.element_accumulator = element_accumulator
|
||||
self.element_epilogue = element_epilogue
|
||||
self.epilogue_vector_length = epilogue_vector_length
|
||||
|
||||
self.template_arguments = [
|
||||
DataTypeTag[element_output],
|
||||
str(epilogue_vector_length),
|
||||
DataTypeTag[element_accumulator],
|
||||
DataTypeTag[element_epilogue],
|
||||
]
|
||||
|
||||
c_element_epilogue = dtype2ctype[self.element_epilogue]
|
||||
element_epilogue = self.element_epilogue
|
||||
|
||||
class _EpilogueOutputOpParamsEVT(ctypes.Structure):
|
||||
"""
|
||||
Epilogue params when using the default linear combination of EVT, which
|
||||
does not currently use {alpha,beta}_ptr_array
|
||||
"""
|
||||
|
||||
stride_type = tuple_factory((0,0,1), "int64_t", [0])
|
||||
_fields_ = [
|
||||
("alpha", c_element_epilogue),
|
||||
("beta", c_element_epilogue),
|
||||
("alpha_ptr", ctypes.c_void_p),
|
||||
("beta_ptr", ctypes.c_void_p),
|
||||
("dalpha", stride_type),
|
||||
("dbeta", stride_type),
|
||||
]
|
||||
|
||||
def __init__(self, alpha, beta, *args) -> None:
|
||||
self.alpha = to_ctype_value(alpha, element_epilogue)
|
||||
self.beta = to_ctype_value(beta, element_epilogue)
|
||||
|
||||
class _EpilogueOutputOpParams(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("alpha", c_element_epilogue),
|
||||
("beta", c_element_epilogue),
|
||||
("alpha_ptr", ctypes.c_void_p),
|
||||
("beta_ptr", ctypes.c_void_p),
|
||||
("alpha_ptr_array", ctypes.c_void_p),
|
||||
("beta_ptr_array", ctypes.c_void_p),
|
||||
]
|
||||
|
||||
def __init__(self, alpha, beta, *args) -> None:
|
||||
self.alpha = to_ctype_value(alpha, element_epilogue)
|
||||
self.beta = to_ctype_value(beta, element_epilogue)
|
||||
|
||||
def to_evt_params(self) -> _EpilogueOutputOpParamsEVT:
|
||||
return _EpilogueOutputOpParamsEVT(self.alpha, self.beta)
|
||||
|
||||
self.epilogue_type = _EpilogueOutputOpParams
|
||||
self.epilogue_type_evt = _EpilogueOutputOpParamsEVT
|
||||
|
||||
def emit(self):
|
||||
return super().emit(self.tag, self.template_arguments)
|
||||
|
||||
|
||||
class LinearCombinationClamp(LinearCombination):
|
||||
"""
|
||||
Applies a linear combination operator to an array of elements then clamps
|
||||
the output before converting to the output element type.
|
||||
|
||||
D = alpha * accumulator + beta * source + uniform
|
||||
|
||||
:param element_output: data type used to load and store tensors
|
||||
|
||||
:param epilogue_vector_length: number of elements computed per operation.
|
||||
Usually it is 128/sizeof_bits_v<ElementOutput_>, but we use 64 and 32 sometimes
|
||||
when there are not enough data to store
|
||||
|
||||
:param element_accumulator: Accumulator data type
|
||||
|
||||
:param element_epilogue: data type used to compute linear combination
|
||||
"""
|
||||
|
||||
tag = "cutlass::epilogue::thread::LinearCombinationClamp"
|
||||
|
||||
def __init__(
|
||||
self, element_output, epilogue_vector_length,
|
||||
element_accumulator=None, element_epilogue=None) -> None:
|
||||
# Base constructor
|
||||
super().__init__(
|
||||
element_output,
|
||||
epilogue_vector_length,
|
||||
element_accumulator,
|
||||
element_epilogue,
|
||||
)
|
||||
|
||||
c_element_epilogue = dtype2ctype[self.element_epilogue]
|
||||
element_epilogue = self.element_epilogue
|
||||
|
||||
class _EpilogueOutputOpParams(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("alpha", c_element_epilogue),
|
||||
("beta", c_element_epilogue),
|
||||
("alpha_ptr", ctypes.c_void_p),
|
||||
("beta_ptr", ctypes.c_void_p),
|
||||
]
|
||||
|
||||
def __init__(self, alpha, beta, *args) -> None:
|
||||
self.alpha = to_ctype_value(alpha, element_epilogue)
|
||||
self.beta = to_ctype_value(beta, element_epilogue)
|
||||
|
||||
self.epilogue_type = _EpilogueOutputOpParams
|
||||
|
||||
|
||||
class FastLinearCombinationClamp(EpilogueFunctorBase):
|
||||
"""
|
||||
Applies a linear combination operator to an array of elements then clamps
|
||||
the output before converting to the output element type.
|
||||
|
||||
D = alpha * accumulator + beta * source
|
||||
|
||||
Note: The below method only when problem_size_K <= 256 for signed int8 gemm
|
||||
or problem_size_K <= 128 for unsigned int8 gemm. The default approach is
|
||||
above.
|
||||
|
||||
:param element_output: data type used to load and store tensors
|
||||
|
||||
:param epilogue_vector_length: number of elements computed per operation.
|
||||
Usually it is 128/sizeof_bits_v<ElementOutput_>, but we use 64 and 32 sometimes
|
||||
when there are not enough data to store
|
||||
"""
|
||||
|
||||
tag = "cutlass::epilogue::thread::FastLinearCombinationClamp"
|
||||
|
||||
def __init__(self, element_output, epilogue_vector_length, *args) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.template_arguments = [
|
||||
DataTypeTag[element_output], str(epilogue_vector_length)
|
||||
]
|
||||
|
||||
self.element_accumulator = DataType.s32
|
||||
self.element_epilogue = DataType.f32
|
||||
|
||||
# get epilogue output op
|
||||
c_element_epilogue = dtype2ctype[self.element_epilogue]
|
||||
element_epilogue = self.element_epilogue
|
||||
|
||||
class _EpilogueOutputOpParams(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("alpha", c_element_epilogue),
|
||||
("beta", c_element_epilogue),
|
||||
("alpha_ptr", ctypes.c_void_p),
|
||||
("beta_ptr", ctypes.c_void_p),
|
||||
]
|
||||
|
||||
def __init__(self, alpha, beta, *args) -> None:
|
||||
self.alpha = to_ctype_value(alpha, element_epilogue)
|
||||
self.beta = to_ctype_value(beta, element_epilogue)
|
||||
|
||||
self.epilogue_type = _EpilogueOutputOpParams
|
||||
|
||||
def emit(self):
|
||||
return super().emit(self.tag, self.template_arguments)
|
||||
|
||||
|
||||
class LinearCombinationGeneric(LinearCombination):
|
||||
"""
|
||||
Applies a linear combination operator followed by an activation function
|
||||
to an array of elements.
|
||||
|
||||
D = activation(alpha * accumulator + beta * source)
|
||||
|
||||
:param activation_functor: input activation functor
|
||||
|
||||
:param element_output: data type used to load and store tensors
|
||||
|
||||
:param epilogue_vector_length: number of elements computed per operation.
|
||||
Usually it is 128/sizeof_bits_v<ElementOutput_>, but we use 64 and 32 sometimes
|
||||
when there are not enough data to store
|
||||
|
||||
:param element_accumulator: Accumulator data type
|
||||
|
||||
:param element_epilogue: data type used to compute linear combination
|
||||
"""
|
||||
|
||||
tag = "cutlass::epilogue::thread::LinearCombinationGeneric"
|
||||
|
||||
def __init__(
|
||||
self, activation_functor,
|
||||
element_output, epilogue_vector_length,
|
||||
element_accumulator=None, element_epilogue=None) -> None:
|
||||
super().__init__(
|
||||
element_output,
|
||||
epilogue_vector_length,
|
||||
element_accumulator,
|
||||
element_epilogue,
|
||||
)
|
||||
|
||||
self.template_arguments = [
|
||||
activation_functor.emit()] + self.template_arguments
|
||||
|
||||
self.activation_functor = activation_functor
|
||||
self.element_epilogue = element_epilogue
|
||||
|
||||
# get epilogue output op
|
||||
self.epilogue_type = self.activation_functor.epilogue_output_op(self.element_epilogue)
|
||||
|
||||
|
||||
class ActivationFunctor:
|
||||
"""
|
||||
Base class for frequently used activation functions
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def numpy(x: np.ndarray):
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def emit(cls):
|
||||
return ActivationOpTag[cls.binding_type]
|
||||
|
||||
@staticmethod
|
||||
def epilogue_output_op(element_epilogue):
|
||||
c_element_epilogue = dtype2ctype[element_epilogue]
|
||||
|
||||
class _EpilogueOutputOpParams(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("alpha", c_element_epilogue),
|
||||
("beta", c_element_epilogue),
|
||||
("alpha_ptr", ctypes.c_void_p),
|
||||
("beta_ptr", ctypes.c_void_p),
|
||||
]
|
||||
|
||||
def __init__(self, alpha, beta, *args) -> None:
|
||||
self.alpha = to_ctype_value(alpha, element_epilogue)
|
||||
self.beta = to_ctype_value(beta, element_epilogue)
|
||||
|
||||
return _EpilogueOutputOpParams
|
||||
|
||||
class ActivationMeta(type):
|
||||
@classmethod
|
||||
def __call__(cls, x, *args):
|
||||
if is_numpy_tensor(x):
|
||||
return cls.numpy(x, *args)
|
||||
elif is_torch_tensor(x):
|
||||
return cls.torch(x, *args)
|
||||
else:
|
||||
raise NotImplementedError("Unsupported tensor type")
|
||||
|
||||
@classmethod
|
||||
def numpy(cls, *args):
|
||||
raise NotImplementedError(f"Numpy reference for {cls.__name__[:-4]} is not implemented.")
|
||||
|
||||
@classmethod
|
||||
def torch(cls, *args):
|
||||
raise NotImplementedError(f"PyTorch reference for {cls.__name__[:-4]} is not implemented.")
|
||||
|
||||
##############################################################################
|
||||
# identity operator
|
||||
class identityMeta(ActivationMeta):
|
||||
@classmethod
|
||||
def numpy(cls, x):
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def torch(cls, x):
|
||||
return x
|
||||
|
||||
class identity(ActivationFunctor, metaclass=identityMeta):
|
||||
binding_type = ActivationOp.Identity
|
||||
|
||||
|
||||
##############################################################################
|
||||
# ReLu operator
|
||||
class reluMeta(ActivationMeta):
|
||||
@classmethod
|
||||
def numpy(cls, x):
|
||||
return np.where(x > 0, x, 0)
|
||||
|
||||
@classmethod
|
||||
def torch(cls, x):
|
||||
return F.relu(x)
|
||||
|
||||
class relu(ActivationFunctor, metaclass=reluMeta):
|
||||
binding_type = ActivationOp.ReLU
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Leaky ReLu operator
|
||||
class leakyReLUMeta(ActivationMeta):
|
||||
@classmethod
|
||||
def numpy(cls, x, leaky_alpha):
|
||||
return np.maximum(x, 0) + np.minimum(x, 0) * leaky_alpha
|
||||
|
||||
@classmethod
|
||||
def torch(cls, x, leaky_alpha):
|
||||
return F.leaky_relu(x, leaky_alpha)
|
||||
|
||||
class leaky_relu(ActivationFunctor, metaclass=leakyReLUMeta):
|
||||
binding_type = ActivationOp.LeakyReLU
|
||||
|
||||
@staticmethod
|
||||
def epilogue_output_op(element_epilogue):
|
||||
c_element_epilogue = dtype2ctype[element_epilogue]
|
||||
|
||||
class _EpilogueOutputOpParams(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("alpha", c_element_epilogue),
|
||||
("beta", c_element_epilogue),
|
||||
("alpha_ptr", ctypes.c_void_p),
|
||||
("beta_ptr", ctypes.c_void_p),
|
||||
("leaky_alpha", c_element_epilogue)
|
||||
]
|
||||
|
||||
def __init__(self, alpha, beta, leaky_alpha=0.2, *args) -> None:
|
||||
self.alpha = to_ctype_value(alpha, element_epilogue)
|
||||
self.beta = to_ctype_value(beta, element_epilogue)
|
||||
self.alpha_ptr = 0
|
||||
self.beta_ptr = 0
|
||||
self.leaky_alpha = to_ctype_value(leaky_alpha, element_epilogue)
|
||||
|
||||
return _EpilogueOutputOpParams
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Tanh operator
|
||||
class tanhMeta(ActivationMeta):
|
||||
@classmethod
|
||||
def numpy(cls, x):
|
||||
return np.tanh(x)
|
||||
|
||||
@classmethod
|
||||
def torch(cls, x):
|
||||
return torch.tanh(x)
|
||||
|
||||
class tanh(ActivationFunctor, metaclass=tanhMeta):
|
||||
binding_type = ActivationOp.Tanh
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Sigmoid operator
|
||||
class sigmoidMeta(ActivationMeta):
|
||||
@classmethod
|
||||
def numpy(cls, x):
|
||||
return 1.0 / (1.0 + np.exp(-x))
|
||||
|
||||
@classmethod
|
||||
def torch(cls, x):
|
||||
return F.sigmoid(x)
|
||||
|
||||
class sigmoid(ActivationFunctor, metaclass=sigmoidMeta):
|
||||
binding_type = ActivationOp.Sigmoid
|
||||
|
||||
|
||||
##############################################################################
|
||||
# SiLu operator
|
||||
class siluMeta(ActivationMeta):
|
||||
@classmethod
|
||||
def numpy(cls, x):
|
||||
return x * sigmoidMeta.numpy()
|
||||
|
||||
@classmethod
|
||||
def silu(cls, x):
|
||||
return F.silu(x)
|
||||
|
||||
|
||||
class silu(ActivationFunctor, metaclass=siluMeta):
|
||||
binding_type = ActivationOp.SiLU
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Hardswish operator
|
||||
class hardswishMeta(ActivationMeta):
|
||||
@classmethod
|
||||
def numpy(cls, x):
|
||||
relu6 = np.minimum(np.maximum(x + 3.0, 0), 6.0)
|
||||
return x * relu6 / 6.0
|
||||
|
||||
@classmethod
|
||||
def torch(cls, x):
|
||||
return F.hardswish(x)
|
||||
|
||||
|
||||
class hardswish(ActivationFunctor, metaclass=hardswishMeta):
|
||||
binding_type = ActivationOp.HardSwish
|
||||
|
||||
|
||||
##############################################################################
|
||||
# GELU operator
|
||||
class geluMeta(ActivationMeta):
|
||||
@classmethod
|
||||
def numpy(cls, x):
|
||||
from scipy.special import erf
|
||||
return 0.5 * x * (1 + erf(x / np.sqrt(2.0)))
|
||||
|
||||
@classmethod
|
||||
def torch(cls, x):
|
||||
return F.gelu(x)
|
||||
|
||||
|
||||
class gelu(ActivationFunctor, metaclass=geluMeta):
|
||||
binding_type = ActivationOp.Gelu
|
||||
34
python/cutlass_cppgen/backend/evt/__init__.py
Normal file
34
python/cutlass_cppgen/backend/evt/__init__.py
Normal file
@ -0,0 +1,34 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass_cppgen.backend.evt.epilogue import EpilogueFunctorVisitor
|
||||
from cutlass_cppgen.backend.evt.frontend import PythonASTFrontend
|
||||
38
python/cutlass_cppgen/backend/evt/backend/__init__.py
Normal file
38
python/cutlass_cppgen/backend/evt/backend/__init__.py
Normal file
@ -0,0 +1,38 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass_cppgen.backend.evt.backend.sm80_emitter import Sm80Emitter
|
||||
import cutlass_cppgen.backend.evt.backend.sm80_nodes as sm80_nodes
|
||||
from cutlass_cppgen.backend.evt.backend.sm90_emitter import Sm90Emitter
|
||||
import cutlass_cppgen.backend.evt.backend.sm90_nodes as sm90_nodes
|
||||
from cutlass_cppgen.backend.evt.backend.sm100_emitter import Sm100Emitter
|
||||
import cutlass_cppgen.backend.evt.backend.sm100_nodes as sm100_nodes
|
||||
159
python/cutlass_cppgen/backend/evt/backend/emitter_base.py
Normal file
159
python/cutlass_cppgen/backend/evt/backend/emitter_base.py
Normal file
@ -0,0 +1,159 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Base class for Epilogue Visitor Emitter
|
||||
"""
|
||||
|
||||
from cutlass_library import DataTypeTag
|
||||
from cutlass_cppgen.backend.evt.ir import TopoVisitorNode, DAGIR
|
||||
|
||||
|
||||
class FusionCallbacks:
|
||||
def __init__(self, dag_ir: DAGIR, cc: int, emit_CD=True) -> None:
|
||||
"""
|
||||
Emit the EVT fusion callbacks
|
||||
:param dag_ir: the DAG IR holding the epilogue visitor
|
||||
:param cc: compute capability
|
||||
:param emit_CD: whether to emit nodes C & D as a part of the fusion callbacks
|
||||
For Sm90, set emit_CD=False, as Tensor C & D are hardcoded in the collective API
|
||||
so that their shared memory can be explicitly reused
|
||||
For Sm89, set emit_CD=True as they are treated as normal AuxLoad & AuxStore nodes.
|
||||
"""
|
||||
self.dag_ir = dag_ir
|
||||
self.emit_CD = emit_CD
|
||||
self.cc = cc
|
||||
self.evt_cc = 90 if cc >= 90 else cc
|
||||
if self.cc < 90:
|
||||
self.namespace = "threadblock"
|
||||
else:
|
||||
self.namespace = "fusion"
|
||||
|
||||
#
|
||||
# Helper functions
|
||||
#
|
||||
|
||||
def get_visitor_name(self, node: str):
|
||||
"""
|
||||
Get the visitor name
|
||||
"""
|
||||
meta = self.dag_ir.get_node_meta(node)
|
||||
if not isinstance(meta, TopoVisitorNode) and self.dag_ir.in_degree(node) > 0:
|
||||
return f"EVT{meta.name_camel}"
|
||||
else:
|
||||
return meta.name_camel
|
||||
|
||||
def emit(self):
|
||||
node_metas = self.dag_ir.node_metas_topological_order()
|
||||
epilogue_str = ""
|
||||
# Step 1: emit individual node type decl
|
||||
# emit the EVT & DAG connector
|
||||
for meta in node_metas:
|
||||
if not meta.disabled:
|
||||
epilogue_str += self.emit_node(meta)
|
||||
if not self.emit_CD and meta.name == "D":
|
||||
continue
|
||||
if isinstance(meta, TopoVisitorNode):
|
||||
epilogue_str += self.emit_dag(meta)
|
||||
else:
|
||||
epilogue_str += self.emit_evt(meta)
|
||||
|
||||
# Step 2: post-processing & get callback name
|
||||
if not self.emit_CD:
|
||||
if not self.dag_ir.has_node("C"):
|
||||
epilogue_str += "using ElementC = void;\nusing StrideC = StrideD;\n"
|
||||
output_node = self.dag_ir.get_all_inputs("D")[0]
|
||||
# The callback is the src of node D
|
||||
callback_name = self.get_visitor_name(output_node)
|
||||
else:
|
||||
# The callback is the last node in the topological order
|
||||
callback_name = self.get_visitor_name(node_metas[-1].name)
|
||||
return epilogue_str, callback_name
|
||||
|
||||
def emit_evt(self, node):
|
||||
if self.dag_ir.in_degree(node.name) == 0:
|
||||
return ""
|
||||
|
||||
evt_tmp = f"""
|
||||
using EVT{node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.evt_cc}EVT<
|
||||
{node.name_camel},
|
||||
"""
|
||||
sorted_children = self.dag_ir.get_all_inputs(node.name)
|
||||
evt_node_strs = [f" {self.get_visitor_name(child_name)}" for child_name in sorted_children]
|
||||
evt_tmp += ",\n".join(evt_node_strs) + ">;\n"
|
||||
|
||||
return evt_tmp
|
||||
|
||||
def emit_dag(self, node):
|
||||
subgraph = node.subgraph
|
||||
subgraph_nodes = subgraph.nodes_topological_order()
|
||||
# Emit the Edge Tuple
|
||||
edge_tuples = "cute::tuple<\n"
|
||||
for n in subgraph_nodes[:-1]:
|
||||
in_edges = subgraph.in_edges(n)
|
||||
edge_weights = [subgraph.get_edge_weight(edge[0], edge[1]) for edge in in_edges]
|
||||
sorted_children = [edge[0] for _, edge in sorted(zip(edge_weights, in_edges))]
|
||||
edge_tuple = " cute::seq<"
|
||||
edge_str = [str(subgraph_nodes.index(child)) for child in sorted_children]
|
||||
edge_tuple += ", ".join(edge_str) + ">,\n"
|
||||
|
||||
edge_tuples += edge_tuple
|
||||
edge_tuples += " >"
|
||||
|
||||
# Emit the node list
|
||||
dag_nodes = ""
|
||||
dag_node_strs = []
|
||||
for n in subgraph_nodes[:-1]:
|
||||
n_meta = subgraph.get_node_meta(n)
|
||||
if n_meta.disabled:
|
||||
dag_node_strs.append(f" {self.get_visitor_name(n)}")
|
||||
else:
|
||||
dag_node_strs.append(f" {n_meta.name_camel}")
|
||||
dag_nodes = ",\n".join(dag_node_strs)
|
||||
|
||||
return f"""
|
||||
using {node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.evt_cc}TopologicalVisitor<
|
||||
{DataTypeTag[node.subgraph.element_compute]},
|
||||
{edge_tuples},
|
||||
{dag_nodes}
|
||||
>;
|
||||
"""
|
||||
|
||||
def emit_node(self, node):
|
||||
if isinstance(node, TopoVisitorNode):
|
||||
emission = ""
|
||||
for node in node.subgraph.node_metas_topological_order():
|
||||
if not node.disabled:
|
||||
emission += self.emit_node(node)
|
||||
return emission
|
||||
else:
|
||||
return node.underlying_impl.type_decl
|
||||
116
python/cutlass_cppgen/backend/evt/backend/sm100_emitter.py
Normal file
116
python/cutlass_cppgen/backend/evt/backend/sm100_emitter.py
Normal file
@ -0,0 +1,116 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Emitter for Sm100 Epilogue Visitor
|
||||
"""
|
||||
|
||||
from cutlass_library import DataType, DataTypeTag, EpilogueScheduleTag, OpcodeClassTag
|
||||
from cutlass_cppgen.backend.library import to_blackwell_threadblock_shape
|
||||
from cutlass_cppgen.backend import GemmOperationUniversal
|
||||
from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks
|
||||
from cutlass_cppgen.backend.evt.ir.node import TupleEmitter
|
||||
|
||||
|
||||
class Sm100CollectiveEpilogue:
|
||||
def __init__(self, tile_description,
|
||||
kernel_schedule,
|
||||
epilogue_schedule,
|
||||
element_accumulator,
|
||||
element_d,
|
||||
fusion_callbacks) -> None:
|
||||
|
||||
self.cta_tile_mnk, _ = to_blackwell_threadblock_shape(tile_description, tile_description.cluster_shape, kernel_schedule)
|
||||
self.element_accumulator = element_accumulator
|
||||
if fusion_callbacks.dag_ir.has_node("C"):
|
||||
self.element_c = fusion_callbacks.dag_ir.get_node_meta("C").element
|
||||
else:
|
||||
self.element_c = DataType.void
|
||||
self.element_d = element_d
|
||||
self.schedule = epilogue_schedule
|
||||
self.fusion_callbacks = fusion_callbacks
|
||||
self.opclass = tile_description.math_instruction.opcode_class
|
||||
|
||||
@property
|
||||
def CtaTileMNK(self) -> str:
|
||||
"""
|
||||
The threadblock shape
|
||||
"""
|
||||
return f"cute::Shape<_{self.cta_tile_mnk[0]}, _{self.cta_tile_mnk[1]}, _{self.cta_tile_mnk[2]}>"
|
||||
|
||||
@property
|
||||
def EpilogueTileType(self) -> str:
|
||||
"""
|
||||
The epilogue tile type
|
||||
"""
|
||||
return "cutlass::epilogue::collective::EpilogueTileAuto"
|
||||
|
||||
@property
|
||||
def Schedule(self) -> str:
|
||||
return EpilogueScheduleTag[self.schedule]
|
||||
|
||||
def emit(self):
|
||||
tuple_emitter = TupleEmitter("int64_t")
|
||||
stride_D_str = self.fusion_callbacks.dag_ir.get_node_meta("D").underlying_impl.stride_mnl
|
||||
stride_C_str = stride_D_str
|
||||
if self.fusion_callbacks.dag_ir.has_node("C"):
|
||||
stride_C_str = self.fusion_callbacks.dag_ir.get_node_meta("C").underlying_impl.stride_mnl
|
||||
|
||||
callback_decl, callback_name = self.fusion_callbacks.emit()
|
||||
return callback_name, f"""
|
||||
using EpilogueDescriptor = cutlass::epilogue::collective::detail::Sm100EpilogueDescriptor<
|
||||
{OpcodeClassTag[self.opclass]},
|
||||
{self.CtaTileMNK}, {self.EpilogueTileType},
|
||||
{DataTypeTag[self.element_accumulator]}, {DataTypeTag[self.element_c]}, {DataTypeTag[self.element_d]},
|
||||
{self.Schedule}, {stride_C_str}, {stride_D_str},
|
||||
false /* IsPerColScaleSupported */,
|
||||
false /* IsBlockScaleSupported */
|
||||
>;
|
||||
{callback_decl}
|
||||
"""
|
||||
|
||||
|
||||
class Sm100Emitter:
|
||||
def __init__(self, operation: GemmOperationUniversal, graph) -> None:
|
||||
fusion_callbacks = FusionCallbacks(graph, cc=100, emit_CD=False)
|
||||
|
||||
self.collective_epilogue = Sm100CollectiveEpilogue(
|
||||
tile_description=operation.tile_description,
|
||||
kernel_schedule=operation.tile_description.kernel_schedule,
|
||||
epilogue_schedule=operation.tile_description.epilogue_schedule,
|
||||
element_accumulator=operation.tile_description.math_instruction.element_accumulator,
|
||||
element_d=fusion_callbacks.dag_ir.get_node_meta("D").element,
|
||||
fusion_callbacks=fusion_callbacks
|
||||
)
|
||||
|
||||
def emit(self):
|
||||
return self.collective_epilogue.emit()
|
||||
134
python/cutlass_cppgen/backend/evt/backend/sm100_nodes.py
Normal file
134
python/cutlass_cppgen/backend/evt/backend/sm100_nodes.py
Normal file
@ -0,0 +1,134 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from pycute import product
|
||||
|
||||
from cutlass_library import DataTypeSize, DataTypeTag
|
||||
|
||||
from cutlass_cppgen.backend.evt.ir import AuxLoadImpl, AuxStoreImpl
|
||||
import cutlass_cppgen.backend.evt.backend.sm90_nodes as sm90_nodes
|
||||
|
||||
from cutlass_cppgen.backend.library import FloatRoundStyleTag
|
||||
|
||||
|
||||
Sm100AccumulatorImpl = sm90_nodes.Sm90AccumulatorImpl
|
||||
Sm100LoadSrcImpl = sm90_nodes.Sm90LoadSrcImpl
|
||||
Sm100ScalarBroadcastImpl = sm90_nodes.Sm90ScalarBroadcastImpl
|
||||
Sm100RowBroadcastImpl = sm90_nodes.Sm90RowBroadcastImpl
|
||||
Sm100ColumnBroadcastImpl = sm90_nodes.Sm90ColumnBroadcastImpl
|
||||
Sm100ComputeImpl = sm90_nodes.Sm90ComputeImpl
|
||||
Sm100StoreDImpl = sm90_nodes.Sm90StoreDImpl
|
||||
Sm100ColumnReductionImpl = sm90_nodes.Sm90ColumnReductionImpl
|
||||
Sm100RowReductionImpl = sm90_nodes.Sm90RowReductionImpl
|
||||
Sm100ScalarReductionImpl = sm90_nodes.Sm90ScalarReductionImpl
|
||||
|
||||
|
||||
class Sm100AuxLoadImpl(AuxLoadImpl):
|
||||
|
||||
@property
|
||||
def descriptor(self) -> str:
|
||||
"""
|
||||
Descriptor for Aux Load
|
||||
"""
|
||||
return f"{self.name_camel}Descriptor"
|
||||
|
||||
def decl_descriptor(self) -> str:
|
||||
"""
|
||||
Declare the descriptor type
|
||||
"""
|
||||
return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::Sm100AuxLoadDescriptor<EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}>;\n"
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = self.decl_descriptor()
|
||||
self._type_decl += f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxLoad<
|
||||
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
|
||||
{self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, typename {self.descriptor}::CopyOpS2R
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
|
||||
"""
|
||||
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
|
||||
"""
|
||||
return (DataTypeSize[self.element] * stages_c * product(epilogue_tile_mn) // 8, 128)
|
||||
|
||||
|
||||
class Sm100AuxStoreImpl(AuxStoreImpl):
|
||||
|
||||
@property
|
||||
def descriptor(self) -> str:
|
||||
"""
|
||||
Descriptor for Aux Load
|
||||
"""
|
||||
return f"{self.name_camel}Descriptor"
|
||||
|
||||
def decl_descriptor(self) -> str:
|
||||
"""
|
||||
Declare the descriptor type
|
||||
"""
|
||||
return f"""
|
||||
using {self.descriptor} = cutlass::epilogue::collective::detail::Sm100AuxStoreDescriptor<
|
||||
EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}
|
||||
>;
|
||||
"""
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = self.decl_descriptor()
|
||||
self._type_decl += f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxStore<
|
||||
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
|
||||
{FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom,
|
||||
typename {self.descriptor}::CopyOpR2S
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
|
||||
"""
|
||||
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
|
||||
"""
|
||||
return (DataTypeSize[self.element] * stages_d * product(epilogue_tile_mn) // 8, 128)
|
||||
47
python/cutlass_cppgen/backend/evt/backend/sm80_emitter.py
Normal file
47
python/cutlass_cppgen/backend/evt/backend/sm80_emitter.py
Normal file
@ -0,0 +1,47 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Emitter for Sm80 Epilogue Visitor
|
||||
"""
|
||||
|
||||
from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks
|
||||
from cutlass_cppgen.backend import GemmOperationUniversal
|
||||
|
||||
|
||||
class Sm80Emitter:
|
||||
def __init__(self, operation: GemmOperationUniversal, graph) -> None:
|
||||
self.fusion_callbacks = FusionCallbacks(graph, cc=80)
|
||||
|
||||
def emit(self):
|
||||
callback_decl, callback_name = self.fusion_callbacks.emit()
|
||||
return callback_name, callback_decl
|
||||
258
python/cutlass_cppgen/backend/evt/backend/sm80_nodes.py
Normal file
258
python/cutlass_cppgen/backend/evt/backend/sm80_nodes.py
Normal file
@ -0,0 +1,258 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass_library import DataTypeSize, DataTypeTag
|
||||
|
||||
from cutlass_cppgen.backend.evt.ir import (
|
||||
# Load Node
|
||||
AccumulatorImpl,
|
||||
AuxLoadImpl,
|
||||
ColumnBroadcastImpl,
|
||||
LoadNode,
|
||||
LoadSrcImpl,
|
||||
RowBroadcastImpl,
|
||||
ScalarBroadcastImpl,
|
||||
# Compute Node
|
||||
ComputeImpl,
|
||||
# Store Node
|
||||
AuxStoreImpl,
|
||||
ColumnReductionImpl,
|
||||
RowReductionImpl,
|
||||
ScalarReductionImpl
|
||||
)
|
||||
|
||||
from cutlass_cppgen.backend.library import (
|
||||
FloatRoundStyleTag,
|
||||
FunctionalOp,
|
||||
op_tag,
|
||||
)
|
||||
|
||||
|
||||
class Sm80AccumulatorImpl(AccumulatorImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""\nusing {self.name_camel} = cutlass::epilogue::threadblock::VisitorAccFetch;\n"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80AuxLoadImpl(AuxLoadImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorAuxLoad<
|
||||
OutputTileThreadMap, {DataTypeTag[self.element]}, {self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80LoadSrcImpl(Sm80AuxLoadImpl):
|
||||
pass
|
||||
|
||||
|
||||
class Sm80ScalarBroadcastImpl(ScalarBroadcastImpl):
|
||||
def __init__(self, node: LoadNode) -> None:
|
||||
super().__init__(node)
|
||||
self.broadcast_count = 1
|
||||
self.reduction_fn = FunctionalOp.Multiplies
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorScalarBroadcast<
|
||||
{DataTypeTag[self.element]}, {self.stride_mnl}, {self.broadcast_count}, {op_tag(self.reduction_fn)}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80RowBroadcastImpl(RowBroadcastImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorRowBroadcast<
|
||||
OutputTileThreadMap, {DataTypeTag[self.element]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80ColumnBroadcastImpl(ColumnBroadcastImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorColBroadcast<
|
||||
OutputTileThreadMap, {DataTypeTag[self.element]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80ComputeImpl(ComputeImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
{op_tag(self.fn)}, {DataTypeTag[self.element_output]}, {DataTypeTag[self.element_compute]},
|
||||
{FloatRoundStyleTag[self.round_style]}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80AuxStoreImpl(AuxStoreImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorAuxStore<
|
||||
OutputTileThreadMap, {DataTypeTag[self.element]}, {FloatRoundStyleTag[self.round_style]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80StoreDImpl(Sm80AuxStoreImpl):
|
||||
pass
|
||||
|
||||
|
||||
class Sm80ColumnReductionImpl(ColumnReductionImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorColReduction<
|
||||
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
|
||||
OutputTileThreadMap, {DataTypeTag[self.element]},
|
||||
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80RowReductionImpl(RowReductionImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorRowReduction<
|
||||
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
|
||||
OutputTileThreadMap, {DataTypeTag[self.element]},
|
||||
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80ScalarReductionImpl(ScalarReductionImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorScalarReduction<
|
||||
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
|
||||
OutputTileThreadMap, {DataTypeTag[self.element]},
|
||||
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
98
python/cutlass_cppgen/backend/evt/backend/sm90_emitter.py
Normal file
98
python/cutlass_cppgen/backend/evt/backend/sm90_emitter.py
Normal file
@ -0,0 +1,98 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Emitter for Sm90 Epilogue Visitor
|
||||
"""
|
||||
|
||||
from cutlass_library import DataTypeTag, EpilogueScheduleTag
|
||||
from cutlass_cppgen.backend import GemmOperationUniversal
|
||||
from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks
|
||||
|
||||
|
||||
class CollectiveEpilogue:
|
||||
def __init__(self, tile_description,
|
||||
schedule,
|
||||
element_c,
|
||||
element_d,
|
||||
fusion_callbacks) -> None:
|
||||
|
||||
self.cta_tile_mnk = tile_description.threadblock_shape
|
||||
self.element_c = element_c
|
||||
self.element_d = element_d
|
||||
self.schedule = schedule
|
||||
self.fusion_callbacks = fusion_callbacks
|
||||
|
||||
@property
|
||||
def CtaTileMNK(self) -> str:
|
||||
"""
|
||||
The threadblock shape
|
||||
"""
|
||||
return f"cute::Shape<_{self.cta_tile_mnk[0]}, _{self.cta_tile_mnk[1]}, _{self.cta_tile_mnk[2]}>"
|
||||
|
||||
@property
|
||||
def EpilogueTileType(self) -> str:
|
||||
"""
|
||||
The epilogue tile type
|
||||
"""
|
||||
return "cutlass::epilogue::collective::EpilogueTileAuto"
|
||||
|
||||
@property
|
||||
def Schedule(self) -> str:
|
||||
return EpilogueScheduleTag[self.schedule]
|
||||
|
||||
def emit(self):
|
||||
callback_decl, callback_name = self.fusion_callbacks.emit()
|
||||
return callback_name, f"""
|
||||
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
||||
{self.CtaTileMNK}, {self.EpilogueTileType},
|
||||
{DataTypeTag[self.element_c]}, {DataTypeTag[self.element_d]},
|
||||
{self.Schedule}
|
||||
>;
|
||||
{callback_decl}
|
||||
"""
|
||||
|
||||
|
||||
class Sm90Emitter:
|
||||
def __init__(self, operation: GemmOperationUniversal, graph) -> None:
|
||||
fusion_callbacks = FusionCallbacks(graph, cc=90, emit_CD=False)
|
||||
|
||||
self.collective_epilogue = CollectiveEpilogue(
|
||||
tile_description=operation.tile_description,
|
||||
schedule=operation.tile_description.epilogue_schedule,
|
||||
element_c=operation.C.element,
|
||||
element_d=operation.C.element,
|
||||
fusion_callbacks=fusion_callbacks
|
||||
)
|
||||
|
||||
def emit(self):
|
||||
return self.collective_epilogue.emit()
|
||||
329
python/cutlass_cppgen/backend/evt/backend/sm90_nodes.py
Normal file
329
python/cutlass_cppgen/backend/evt/backend/sm90_nodes.py
Normal file
@ -0,0 +1,329 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from pycute import product
|
||||
|
||||
from cutlass_library import DataTypeSize, DataTypeTag
|
||||
from cutlass_cppgen.backend.evt.ir import (
|
||||
# Load Node
|
||||
AccumulatorImpl,
|
||||
AuxLoadImpl,
|
||||
ColumnBroadcastImpl,
|
||||
LoadNode,
|
||||
LoadSrcImpl,
|
||||
RowBroadcastImpl,
|
||||
ScalarBroadcastImpl,
|
||||
# Compute Node
|
||||
ComputeImpl,
|
||||
ComputeNode,
|
||||
# Store Node
|
||||
AuxStoreImpl,
|
||||
ColumnReductionImpl,
|
||||
RowReductionImpl,
|
||||
ScalarReductionImpl,
|
||||
StoreNode,
|
||||
StoreDImpl,
|
||||
)
|
||||
from cutlass_cppgen.backend.library import (
|
||||
FloatRoundStyleTag,
|
||||
FunctionalOp,
|
||||
op_tag,
|
||||
)
|
||||
|
||||
|
||||
class Sm90AccumulatorImpl(AccumulatorImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""\nusing {self.name_camel} = cutlass::epilogue::fusion::Sm90AccFetch;\n"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90LoadSrcImpl(LoadSrcImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using ElementC = {DataTypeTag[self.element]};
|
||||
using StrideC = {self.stride_mnl};
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90SrcFetch<{DataTypeTag[self.element]}>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90AuxLoadImpl(AuxLoadImpl):
|
||||
|
||||
@property
|
||||
def descriptor(self) -> str:
|
||||
"""
|
||||
Descriptor for Aux Load
|
||||
"""
|
||||
return f"{self.name_camel}Descriptor"
|
||||
|
||||
def decl_descriptor(self) -> str:
|
||||
"""
|
||||
Declare the descriptor type
|
||||
"""
|
||||
return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::AuxLoadDescriptor<EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}>;\n"
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = self.decl_descriptor()
|
||||
self._type_decl += f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxLoad<
|
||||
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
|
||||
{self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, typename {self.descriptor}::CopyOpS2R
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
|
||||
"""
|
||||
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
|
||||
"""
|
||||
return (DataTypeSize[self.element] * stages_c * product(epilogue_tile_mn) // 8, 128)
|
||||
|
||||
|
||||
class Sm90ScalarBroadcastImpl(ScalarBroadcastImpl):
|
||||
def __init__(self, node: LoadNode) -> None:
|
||||
super().__init__(node)
|
||||
self.broadcast_count = 1
|
||||
self.reduction_fn = FunctionalOp.Multiplies
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ScalarBroadcast<
|
||||
{DataTypeTag[self.element]}, {self.stride_mnl}, {self.broadcast_count}, {op_tag(self.reduction_fn)}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90RowBroadcastImpl(RowBroadcastImpl):
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, {DataTypeTag[self.element_output]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90ColumnBroadcastImpl(ColumnBroadcastImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, {DataTypeTag[self.element_output]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90ComputeImpl(ComputeImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90Compute<
|
||||
{op_tag(self.fn)}, {DataTypeTag[self.element_output]}, {DataTypeTag[self.element_compute]},
|
||||
{FloatRoundStyleTag[self.round_style]}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90AuxStoreImpl(AuxStoreImpl):
|
||||
|
||||
@property
|
||||
def descriptor(self) -> str:
|
||||
"""
|
||||
Descriptor for Aux Load
|
||||
"""
|
||||
return f"{self.name_camel}Descriptor"
|
||||
|
||||
def decl_descriptor(self) -> str:
|
||||
"""
|
||||
Declare the descriptor type
|
||||
"""
|
||||
return f"""
|
||||
using {self.descriptor} = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
|
||||
EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}
|
||||
>;
|
||||
"""
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = self.decl_descriptor()
|
||||
self._type_decl += f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxStore<
|
||||
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
|
||||
{FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom,
|
||||
typename {self.descriptor}::CopyOpR2S
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
|
||||
"""
|
||||
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
|
||||
"""
|
||||
return (DataTypeSize[self.element] * stages_d * product(epilogue_tile_mn) // 8, 128)
|
||||
|
||||
|
||||
class Sm90StoreDImpl(StoreDImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
return f"""
|
||||
using ElementD = {DataTypeTag[self.element]};
|
||||
using StrideD = {self.stride_mnl};
|
||||
"""
|
||||
|
||||
|
||||
class Sm90ColumnReductionImpl(ColumnReductionImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ColReduction<
|
||||
{op_tag(self.reg_reduce_fn)}, {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, 0,
|
||||
typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]},
|
||||
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90RowReductionImpl(RowReductionImpl):
|
||||
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowReduction<
|
||||
{op_tag(self.reg_reduce_fn)}, {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, 0 /* Stages */,
|
||||
typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]},
|
||||
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90ScalarReductionImpl(ScalarReductionImpl):
|
||||
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ScalarReduction<
|
||||
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
|
||||
{DataTypeTag[self.element]}, {DataTypeTag[self.element_compute]},
|
||||
{FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
168
python/cutlass_cppgen/backend/evt/epilogue.py
Normal file
168
python/cutlass_cppgen/backend/evt/epilogue.py
Normal file
@ -0,0 +1,168 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Epilogue Visitor interface for compiling, and running visitor-based epilogue.
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
from cutlass_library import DataType
|
||||
import numpy as np
|
||||
|
||||
from cutlass_cppgen.backend.epilogue import EpilogueFunctorBase
|
||||
import cutlass_cppgen.backend.evt.backend
|
||||
from cutlass_cppgen.backend.frontend import TensorFrontend
|
||||
from cutlass_cppgen.utils.datatypes import is_numpy_tensor
|
||||
from cutlass_cppgen.backend.evt.passes.util import cc_map
|
||||
|
||||
|
||||
class EpilogueFunctorVisitor(EpilogueFunctorBase):
|
||||
"""
|
||||
Apply an epilogue functor described by the epilogue EVT
|
||||
|
||||
:param cc: compute capability
|
||||
:param visitor_frontend: user-provide visitor frontend
|
||||
|
||||
"""
|
||||
def __init__(self, cc: int, visitor, element_compute=DataType.f32) -> None:
|
||||
# Type of Emitter based on CC
|
||||
self.emit_cls = getattr(cutlass_cppgen.backend.evt.backend, f"Sm{cc_map[cc]}Emitter")
|
||||
|
||||
# Visitor Types
|
||||
self.visitor = visitor
|
||||
self.graph = visitor.dag_ir
|
||||
|
||||
# Data types
|
||||
self.element_epilogue = element_compute # element compute
|
||||
self.element_output = self.graph.get_node_meta('D').underlying_impl.element
|
||||
|
||||
# Epilogue Thread Type
|
||||
epilogue_thread_type = self.visitor.epilogue_thread_type
|
||||
if cc_map[cc] in [90, 100]:
|
||||
self.arg_c_type = self.visitor.arg_c_type
|
||||
self.arg_d_type = self.visitor.arg_d_type
|
||||
output_names = self.visitor.return_names
|
||||
reduction_names = self.visitor.reduction_names
|
||||
|
||||
# Epilogue stages specialized for sm80 kernel
|
||||
if cc == 80:
|
||||
if hasattr(self.visitor, "epilogue_stages"):
|
||||
self.epilogue_stages = self.visitor.epilogue_stages
|
||||
assert self.epilogue_stages <= 2, "Only supports Stages <=2 in SM80 Epilogue"
|
||||
|
||||
# Epilogue Argument Type
|
||||
class _Arguments(ctypes.Structure):
|
||||
"""
|
||||
Concepts:
|
||||
class _EpilogueArguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("epilogue", _Arguments), <- this class
|
||||
("ptr_C", ctypes.c_void_p),
|
||||
("stride_C", StrideBatched_),
|
||||
("ptr_D", ctypes.c_void_p),
|
||||
("stride_D", StrideBatched_)
|
||||
]
|
||||
"""
|
||||
_fields_ = [
|
||||
("output_op", epilogue_thread_type)
|
||||
]
|
||||
|
||||
def __init__(self, kwargs: dict) -> None:
|
||||
# The user-input kwargs is a dict of (name: tensors)
|
||||
# We first convert all of them to device pointers
|
||||
ptr_kwargs = {}
|
||||
for key in kwargs.keys():
|
||||
is_output = key in output_names and key not in reduction_names
|
||||
ptr_kwargs[key] = self.get_tensor_ptr(key, kwargs, is_output)
|
||||
# Initialize the thread arguments
|
||||
self.output_op = epilogue_thread_type(ptr_kwargs)
|
||||
|
||||
def get_tensor_ptr(self, tensor_name, kwargs, is_output=False):
|
||||
"""
|
||||
Helper function for extracting device pointer
|
||||
"""
|
||||
# Skip the special tensors
|
||||
if cc in [90, 100]:
|
||||
if tensor_name in ["C", "D"]:
|
||||
return 0
|
||||
if tensor_name not in kwargs.keys():
|
||||
raise ValueError(f"Tensor {tensor_name} is not provided.")
|
||||
tensor = kwargs[tensor_name]
|
||||
|
||||
# For float scalar constant, directly return the value
|
||||
if isinstance(tensor, float):
|
||||
return tensor
|
||||
|
||||
# The tensor frontend returns a device buffer for np.ndarray
|
||||
# and device ptr for other frontends
|
||||
buffer_or_ptr = TensorFrontend.argument(tensor, is_output)
|
||||
if is_numpy_tensor(tensor):
|
||||
# Remember the host tensor for later synchronization
|
||||
setattr(self, f"{tensor_name}_buffer", buffer_or_ptr)
|
||||
setattr(self, f"{tensor_name}_host", tensor)
|
||||
return int(buffer_or_ptr.ptr)
|
||||
else:
|
||||
return int(buffer_or_ptr)
|
||||
|
||||
def sync(self):
|
||||
"""
|
||||
Synchronize the results from device to host
|
||||
"""
|
||||
for name in output_names:
|
||||
if hasattr(self, f"{name}_host"):
|
||||
host_tensor = getattr(self, f"{name}_host")
|
||||
tensor_ptr = getattr(self, f"{name}_buffer").ptr
|
||||
(err,) = cuda.cuMemcpyDtoH(
|
||||
host_tensor,
|
||||
tensor_ptr,
|
||||
host_tensor.size * host_tensor.itemsize,
|
||||
)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
|
||||
self.epilogue_type = _Arguments
|
||||
|
||||
def emit(self, operation):
|
||||
"""
|
||||
Emit the C++ code
|
||||
"""
|
||||
emitter = self.emit_cls(operation, self.graph)
|
||||
return emitter.emit()
|
||||
|
||||
def get_smem_size(self, tile_description):
|
||||
"""
|
||||
Get the shared memory size in bytes
|
||||
"""
|
||||
return self.visitor.get_smem_size(tile_description)
|
||||
33
python/cutlass_cppgen/backend/evt/frontend/__init__.py
Normal file
33
python/cutlass_cppgen/backend/evt/frontend/__init__.py
Normal file
@ -0,0 +1,33 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass_cppgen.backend.evt.frontend.python_ast import PythonASTFrontend
|
||||
272
python/cutlass_cppgen/backend/evt/frontend/frontend_base.py
Normal file
272
python/cutlass_cppgen/backend/evt/frontend/frontend_base.py
Normal file
@ -0,0 +1,272 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Base class for Python EVT Frontend
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
|
||||
from cutlass_library import DataType
|
||||
from cutlass_cppgen.backend.evt.ir import (
|
||||
ComputeNode,
|
||||
DAGIR,
|
||||
LayoutNode,
|
||||
LoadNode,
|
||||
StoreNode,
|
||||
)
|
||||
from cutlass_cppgen.backend.evt.passes import (
|
||||
EVTGraphDrawer,
|
||||
EVTPassManager,
|
||||
GetSmemSize,
|
||||
PassDAG2Tree,
|
||||
PassGetArgumentType,
|
||||
PassGetImpl,
|
||||
PassFixElementD,
|
||||
PassLayoutManipulateElimination,
|
||||
PassPreprocessRed,
|
||||
PassShapeTypePropagation,
|
||||
)
|
||||
from cutlass_cppgen.backend.evt.passes.util import cc_map
|
||||
from cutlass_cppgen.backend.utils import device_cc
|
||||
from cutlass_cppgen.epilogue.evt_ops import permute, reshape
|
||||
from cutlass_cppgen.utils.datatypes import library_type
|
||||
|
||||
|
||||
class EVTFrontendBase:
|
||||
layout_fns = {
|
||||
"permute": permute,
|
||||
"reshape": reshape
|
||||
}
|
||||
|
||||
def __init__(self, cc, element_compute=DataType.f32, additional_passes=[], **kwargs) -> None:
|
||||
self.cc = cc
|
||||
self.element_compute = library_type(element_compute)
|
||||
self.dag_ir = DAGIR(self.cc, self.element_compute)
|
||||
self.compute_cnt = 0
|
||||
self.layout_cnt = 0
|
||||
self.imm_cnt = 0
|
||||
|
||||
self.pass_manager = EVTPassManager(
|
||||
self.dag_ir,
|
||||
[
|
||||
PassPreprocessRed,
|
||||
PassGetArgumentType,
|
||||
PassShapeTypePropagation,
|
||||
PassLayoutManipulateElimination,
|
||||
PassGetImpl,
|
||||
PassDAG2Tree,
|
||||
PassFixElementD
|
||||
] + additional_passes)
|
||||
|
||||
if self.cc == 80:
|
||||
self._epilogue_stages = 1
|
||||
else:
|
||||
self._epilogue_stages = None
|
||||
|
||||
@property
|
||||
def epilogue_stages(self):
|
||||
return self._epilogue_stages
|
||||
|
||||
@epilogue_stages.setter
|
||||
def epilogue_stages(self, stages):
|
||||
self._epilogue_stages = stages
|
||||
|
||||
|
||||
def parse(self, *args, **kwargs):
|
||||
raise NotImplementedError(f"The 'parse' function must be overloaded in frontend class")
|
||||
|
||||
def trace(self, *args, **kwargs):
|
||||
# Parse the input
|
||||
self.parse(*args, **kwargs)
|
||||
|
||||
# Verify the DAG IR to ensure that "D" is the output node with out_degree = 0
|
||||
if (self.cc >= 90):
|
||||
if (self.dag_ir.out_degree("D") != 0):
|
||||
raise RuntimeError(
|
||||
f"On SM90 or higher, D is expected to be a output node with 0 users to "
|
||||
f"enable smem reuse between C and D, but got {self.dag_ir.out_degree('D')}")
|
||||
|
||||
# Run the passes
|
||||
self.pass_manager()
|
||||
# Set the epilogue type
|
||||
self.epilogue_thread_type = self.dag_ir.epilogue_thread_type
|
||||
if cc_map[self.cc] in [90, 100]:
|
||||
self.arg_c_type = self.dag_ir.arg_c_type
|
||||
self.arg_d_type = self.dag_ir.arg_d_type
|
||||
self.reduction_names = self.dag_ir.reduction_names
|
||||
|
||||
#
|
||||
# Helper functions for DAG IR manipulation
|
||||
#
|
||||
|
||||
def add_node(self, node):
|
||||
self.dag_ir.add_node(node)
|
||||
|
||||
def add_edge(self, src, tgt, weight=0):
|
||||
self.dag_ir.add_edge(src, tgt, weight=weight)
|
||||
|
||||
def set_tensor(self, node_name, example):
|
||||
"""
|
||||
Add an example tensor to node {node_name} in the DAG IR
|
||||
"""
|
||||
meta = self.dag_ir.get_node_meta(node_name)
|
||||
meta.tensor = {"tensor": example}
|
||||
|
||||
def set_store_tensor(self, node_name, example):
|
||||
"""
|
||||
Add an example tensor to node {node_name} in the DAG IR
|
||||
"""
|
||||
meta = self.dag_ir.get_node_meta(node_name)
|
||||
meta.store_tensor = {"tensor": example}
|
||||
|
||||
def mark_output(self, node_name):
|
||||
"""
|
||||
Mark a store node as output
|
||||
"""
|
||||
meta = self.dag_ir.get_node_meta(node_name)
|
||||
if not isinstance(meta, StoreNode):
|
||||
raise ValueError(
|
||||
f"Only StoreNodes can be marked as output. "
|
||||
f"Got {type(meta).__name__}: {node_name}")
|
||||
meta.is_output = True
|
||||
|
||||
# Add node with specific type
|
||||
|
||||
def add_load_node(self, name, example):
|
||||
"""
|
||||
Add a Load node to DAG IR
|
||||
:param name: name of the loaded variable
|
||||
:type name: str
|
||||
:param example: example input
|
||||
:type example: np.ndarray|torch.Tensor|cupy.ndarray|float
|
||||
"""
|
||||
if name is None:
|
||||
raise ValueError(f"Name is not provided.")
|
||||
if example is None:
|
||||
raise ValueError(f"Example input for {name} is not provided.")
|
||||
load_node = LoadNode(name)
|
||||
load_node.tensor = {"tensor": example}
|
||||
# Special logics for accumulator
|
||||
if name == "accum":
|
||||
if load_node.tensor.rank == 2:
|
||||
new_shape = tuple([1, ] + list(load_node.tensor.shape))
|
||||
load_node.tensor.broadcast(new_shape)
|
||||
elif load_node.tensor.rank < 2 or load_node.tensor.rank > 3:
|
||||
raise ValueError(f"Expect example inputs for 'accum' be a rank-2 or rank-3 tensor. Got {load_node.tensor.shape}.")
|
||||
self.add_node(load_node)
|
||||
|
||||
def add_imm(self, value: Union[float,int]):
|
||||
"""
|
||||
Add an immediate scalar value to DAG IR
|
||||
:param value: the value of the immediate scalar
|
||||
:type value: float
|
||||
"""
|
||||
try:
|
||||
value = float(value)
|
||||
except:
|
||||
raise ValueError(f"{type(value).__name__} cannot be converted to float.")
|
||||
|
||||
name = f"imm_{value}_k{self.imm_cnt}".replace('.', '_')
|
||||
self.imm_cnt += 1
|
||||
load_node = LoadNode(name)
|
||||
load_node.tensor = {"tensor": value, "is_constant": True}
|
||||
self.add_node(load_node)
|
||||
return name
|
||||
|
||||
def add_compute_node(self, op, name=None):
|
||||
"""
|
||||
Add a compute node.
|
||||
:param op: the computation op
|
||||
:param name: the node name (optional)
|
||||
:type name: str
|
||||
:return: the name of the compute node
|
||||
"""
|
||||
if name is None:
|
||||
name = f"compute_{self.compute_cnt}"
|
||||
self.compute_cnt += 1
|
||||
compute_node = ComputeNode(
|
||||
name=name, fn=op,
|
||||
element_output=self.element_compute,
|
||||
element_compute=self.element_compute)
|
||||
self.add_node(compute_node)
|
||||
return compute_node.name
|
||||
|
||||
def add_layout_node(self, op, kwargs, name=None):
|
||||
"""
|
||||
Add a layout node.
|
||||
:param op: the layout op
|
||||
:type op: evt_ops
|
||||
:param name: the node name (optional)
|
||||
:type name: str
|
||||
:return: the name of the layout node
|
||||
"""
|
||||
if name is None:
|
||||
name = f"layout_{self.layout_cnt}"
|
||||
self.layout_cnt += 1
|
||||
layout_node = LayoutNode(name=name, fn=op, kwargs=kwargs)
|
||||
self.add_node(layout_node)
|
||||
return layout_node.name
|
||||
|
||||
def add_store_node(self, name):
|
||||
store_node = StoreNode(name)
|
||||
self.add_node(store_node)
|
||||
|
||||
#
|
||||
# Visualization The DAG IR
|
||||
#
|
||||
|
||||
def visualize(self, name="dag_ir"):
|
||||
"""
|
||||
Visualize the dag ir with svg file
|
||||
:param name: the name of the graph
|
||||
"""
|
||||
drawer = EVTGraphDrawer(self.dag_ir, name)
|
||||
try:
|
||||
for name, graph in drawer.get_dot_graph():
|
||||
graph.write_svg(f"./{name}.svg")
|
||||
except:
|
||||
raise RuntimeError(
|
||||
"'dot' is not found in path. GraphDrawer is disabled. "
|
||||
"Please install it with 'sudo apt-get install graphviz'."
|
||||
)
|
||||
|
||||
#
|
||||
# Get shared memory size
|
||||
#
|
||||
|
||||
def get_smem_size(self, tile_description):
|
||||
"""
|
||||
Get the shared memory size of the epilogue
|
||||
"""
|
||||
smem_size = GetSmemSize(self.dag_ir)(tile_description)
|
||||
return smem_size
|
||||
194
python/cutlass_cppgen/backend/evt/frontend/python_ast.py
Normal file
194
python/cutlass_cppgen/backend/evt/frontend/python_ast.py
Normal file
@ -0,0 +1,194 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Python AST frontend that parses input into DAG IR
|
||||
"""
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
|
||||
from cutlass_library import DataType
|
||||
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen.backend.evt.frontend.frontend_base import EVTFrontendBase
|
||||
from cutlass_cppgen.backend.epilogue import identity, relu, tanh, sigmoid, silu, hardswish, gelu
|
||||
from cutlass_cppgen.backend.library import FunctionalOp
|
||||
|
||||
|
||||
class PythonASTFrontend(EVTFrontendBase, ast.NodeVisitor):
|
||||
def __init__(self, cc, element_compute=DataType.f32, **kwargs):
|
||||
super().__init__(cc, element_compute, **kwargs)
|
||||
# Flags
|
||||
# If this state is True, visit_Constant returns values without creating imm node
|
||||
self.no_imm = False
|
||||
self.visiting_return = False
|
||||
|
||||
def parse(self, example_inputs):
|
||||
self.example_inputs = example_inputs
|
||||
self.source = textwrap.dedent(inspect.getsource(self.__call__))
|
||||
self.ast = ast.parse(self.source)
|
||||
self.visit(self.ast)
|
||||
|
||||
#
|
||||
# Helper functions
|
||||
#
|
||||
@staticmethod
|
||||
def ast_op_to_bindings(op):
|
||||
mapping = {
|
||||
ast.Add: FunctionalOp.Plus,
|
||||
ast.Sub: FunctionalOp.Minus,
|
||||
ast.Mult: FunctionalOp.Multiplies,
|
||||
ast.Div: FunctionalOp.Divides,
|
||||
"maximum": FunctionalOp.Maximum,
|
||||
"minimum": FunctionalOp.Minimum,
|
||||
"identity": identity.binding_type,
|
||||
"relu": relu.binding_type,
|
||||
"tanh": tanh.binding_type,
|
||||
"sigmoid": sigmoid.binding_type,
|
||||
"silu": silu.binding_type,
|
||||
"hardswish": hardswish.binding_type,
|
||||
"gelu": gelu.binding_type,
|
||||
"multiply_add": FunctionalOp.MultiplyAdd,
|
||||
"sum": (FunctionalOp.Plus, FunctionalOp.AtomicAdd),
|
||||
"max": (FunctionalOp.Maximum, FunctionalOp.AtomicMaximum),
|
||||
"exp": FunctionalOp.Exp
|
||||
}
|
||||
return mapping[op]
|
||||
|
||||
#
|
||||
# Visiting different node types
|
||||
#
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef):
|
||||
# Visit args and register load nodes
|
||||
for arg in node.args.args:
|
||||
self.visit(arg)
|
||||
for expr in node.body:
|
||||
self.visit(expr)
|
||||
|
||||
def visit_arg(self, node: ast.arg):
|
||||
# Name of the argument
|
||||
name = node.arg
|
||||
try:
|
||||
example_tensor = self.example_inputs[name]
|
||||
except:
|
||||
raise RuntimeError(f"Example input for {name} is not provided.")
|
||||
|
||||
self.add_load_node(name, example_tensor)
|
||||
|
||||
def visit_Name(self, node: ast.Name):
|
||||
return node.id
|
||||
|
||||
def visit_Constant(self, node: ast.Constant):
|
||||
if self.no_imm:
|
||||
return node.value
|
||||
else:
|
||||
name = self.add_imm(node.value)
|
||||
return name
|
||||
|
||||
def visit_Tuple(self, node: ast.Tuple):
|
||||
results = []
|
||||
for elt in node.elts:
|
||||
results.append(self.visit(elt))
|
||||
return tuple(results)
|
||||
|
||||
def visit_keyword(self, node: ast.keyword):
|
||||
return {node.arg: self.visit(node.value)}
|
||||
|
||||
def visit_BinOp(self, node: ast.BinOp):
|
||||
if self.visiting_return:
|
||||
raise SyntaxError("Return value cannot be an expression")
|
||||
lhs = self.visit(node.left)
|
||||
rhs = self.visit(node.right)
|
||||
op = self.ast_op_to_bindings(type(node.op))
|
||||
name = self.add_compute_node(op)
|
||||
|
||||
# Add edges
|
||||
# The edge weights are used to sort the input args
|
||||
self.add_edge(lhs, name, weight=0)
|
||||
self.add_edge(rhs, name, weight=1)
|
||||
return name
|
||||
|
||||
def visit_Assign(self, node: ast.BinOp):
|
||||
target = self.visit(node.targets[0])
|
||||
value = self.visit(node.value)
|
||||
# Create the assign node
|
||||
self.add_store_node(target)
|
||||
|
||||
# Add edges
|
||||
self.add_edge(value, target)
|
||||
return target
|
||||
|
||||
def visit_Call(self, node: ast.Call):
|
||||
if self.visiting_return:
|
||||
raise SyntaxError("Return value cannot be an expression")
|
||||
func = self.visit(node.func)
|
||||
args = [self.visit(arg) for arg in node.args]
|
||||
|
||||
if func in self.layout_fns.keys():
|
||||
# Parse kwargs
|
||||
# By default, visiting imm automatically creates a load node
|
||||
# However, in function call, keyword args are used to set
|
||||
# specific function attributes such as indices for permute
|
||||
# So no_imm is set to True temporarily
|
||||
self.no_imm = True
|
||||
kwargs = {}
|
||||
for kw in node.keywords:
|
||||
kwargs.update(self.visit(kw))
|
||||
self.no_imm = False
|
||||
op = self.layout_fns[func]
|
||||
name = self.add_layout_node(op, kwargs)
|
||||
else:
|
||||
op = self.ast_op_to_bindings(func)
|
||||
name = self.add_compute_node(op)
|
||||
|
||||
# Add edges
|
||||
for idx, arg in enumerate(args):
|
||||
self.add_edge(arg, name, weight=idx)
|
||||
return name
|
||||
|
||||
def visit_Return(self, node: ast.Return):
|
||||
self.visiting_return = True
|
||||
results = self.visit(node.value)
|
||||
self.visiting_return = False
|
||||
self.return_names = results
|
||||
if not isinstance(results, tuple):
|
||||
results = (results,)
|
||||
for rst in results:
|
||||
try:
|
||||
example_tensor = self.example_inputs[rst]
|
||||
except:
|
||||
raise RuntimeError(f"Example input for {rst} is not provided.")
|
||||
self.set_store_tensor(rst, example_tensor)
|
||||
self.mark_output(rst)
|
||||
53
python/cutlass_cppgen/backend/evt/ir/__init__.py
Normal file
53
python/cutlass_cppgen/backend/evt/ir/__init__.py
Normal file
@ -0,0 +1,53 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass_cppgen.backend.evt.ir.compute_nodes import ComputeNode, ComputeImpl
|
||||
from cutlass_cppgen.backend.evt.ir.dag_ir import DAGIR
|
||||
from cutlass_cppgen.backend.evt.ir.layout_nodes import LayoutNode
|
||||
from cutlass_cppgen.backend.evt.ir.load_nodes import (
|
||||
LoadNode,
|
||||
AccumulatorImpl,
|
||||
LoadSrcImpl,
|
||||
AuxLoadImpl,
|
||||
RowBroadcastImpl,
|
||||
ColumnBroadcastImpl,
|
||||
ScalarBroadcastImpl
|
||||
)
|
||||
from cutlass_cppgen.backend.evt.ir.node import TopoVisitorNode, NoOpImpl
|
||||
from cutlass_cppgen.backend.evt.ir.store_nodes import (
|
||||
StoreNode,
|
||||
StoreDImpl,
|
||||
AuxStoreImpl,
|
||||
ColumnReductionImpl,
|
||||
RowReductionImpl,
|
||||
ScalarReductionImpl
|
||||
)
|
||||
91
python/cutlass_cppgen/backend/evt/ir/compute_nodes.py
Normal file
91
python/cutlass_cppgen/backend/evt/ir/compute_nodes.py
Normal file
@ -0,0 +1,91 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Python registration for compute nodes in EVT
|
||||
"""
|
||||
|
||||
from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase
|
||||
from cutlass_cppgen.backend.library import FloatRoundStyle
|
||||
|
||||
|
||||
class ComputeImplBase(ImplBase):
|
||||
"""
|
||||
Base class for compute implementation
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
|
||||
|
||||
class ComputeImpl(ComputeImplBase):
|
||||
"""
|
||||
Implementation for Compute Node
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
|
||||
self.fn = node.fn
|
||||
self.element_output = node.element_output
|
||||
self.element_compute = node.element_compute
|
||||
self.round_style = node.round_style
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
return True
|
||||
|
||||
|
||||
class ComputeNode(NodeBase):
|
||||
"""
|
||||
Compute Node in DAG IR
|
||||
"""
|
||||
possible_impls = [
|
||||
ComputeImpl
|
||||
]
|
||||
def __init__(
|
||||
self, name: str, fn, element_output,
|
||||
element_compute,
|
||||
round_style=FloatRoundStyle.ToNearest) -> None:
|
||||
super().__init__(name)
|
||||
self.op = "compute"
|
||||
self.fn = fn
|
||||
self.element_compute = element_compute
|
||||
self.round_style = round_style
|
||||
|
||||
def type_propagation(self, *args, **kwargs):
|
||||
"""
|
||||
Load node loads tensor under type `tensor.element` and returns an array of type `tensor.element`.
|
||||
"""
|
||||
self.element = self.element_compute
|
||||
# In general, the compute nodes have element_output = element_compute
|
||||
# In certain cases like producer of D it is overwritten by other passes
|
||||
if not hasattr(self, "element_output"):
|
||||
self.element_output = self.element
|
||||
254
python/cutlass_cppgen/backend/evt/ir/dag_ir.py
Normal file
254
python/cutlass_cppgen/backend/evt/ir/dag_ir.py
Normal file
@ -0,0 +1,254 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
DAG IR used by Python EVT
|
||||
"""
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from cutlass_library import DataType
|
||||
|
||||
from cutlass_cppgen.backend.evt.ir.compute_nodes import ComputeNode
|
||||
from cutlass_cppgen.backend.evt.ir.node import NodeBase
|
||||
from cutlass_cppgen.backend.library import ActivationOp
|
||||
from cutlass_cppgen.backend.utils import device_cc
|
||||
|
||||
|
||||
class DAGIR:
|
||||
"""
|
||||
``DAGIR`` is the main data structure used in the EVT Intermediate Representation.
|
||||
It consists of a series of ``Node`` s, each representing epilogue visitor nodes.
|
||||
|
||||
In the DAGIR, ``node`` is an string of its name. ``node_meta`` is the underlying class of the node
|
||||
"""
|
||||
def __init__(self, cc, element_compute=DataType.f32) -> None:
|
||||
# The EVT DAGIR is managed through the nextworkX Digraph class
|
||||
self._graph = nx.DiGraph()
|
||||
|
||||
self.element_compute = element_compute
|
||||
|
||||
self.reduction_names = []
|
||||
|
||||
self.cc = cc
|
||||
|
||||
self.identity_counter = 0
|
||||
|
||||
#
|
||||
# IR manipulator
|
||||
#
|
||||
|
||||
def add_node(self, meta: NodeBase):
|
||||
"""
|
||||
Add a node to dag ir
|
||||
"""
|
||||
if self.has_node(meta.name):
|
||||
raise SyntaxError(f"Variable '{meta.name}' cannot be defined twice.")
|
||||
self._graph.add_node(meta.name, meta=meta)
|
||||
|
||||
def add_edge(self, src: str, dst: str, weight: int=0):
|
||||
"""
|
||||
Add an edge src -> dst to dag ir with weight
|
||||
"""
|
||||
if not self.has_node(src):
|
||||
raise SyntaxError(f"Variable '{src}' is undefined.")
|
||||
if not self.has_node(dst):
|
||||
raise SyntaxError(f"Variable '{dst}' is undefined.")
|
||||
|
||||
if self._graph.has_edge(src, dst):
|
||||
# The DiGraph doesn't support multiple edges between two nodes
|
||||
# We insert an identity node in such case as a workaround
|
||||
identity_name = f"autogen_identity_{self.identity_counter}"
|
||||
self.identity_counter += 1
|
||||
compute_node = ComputeNode(
|
||||
name=identity_name, fn=ActivationOp.Identity,
|
||||
element_output=self.element_compute,
|
||||
element_compute=self.element_compute)
|
||||
self.add_node(compute_node)
|
||||
self.add_edge(src, identity_name, 0)
|
||||
self.add_edge(identity_name, dst, weight)
|
||||
else:
|
||||
self._graph.add_edge(src, dst, weight=weight)
|
||||
|
||||
def remove_node(self, node: str):
|
||||
"""
|
||||
Remove node from dag ir
|
||||
"""
|
||||
self._graph.remove_node(node)
|
||||
|
||||
def remove_edge(self, src: str, dst: str):
|
||||
"""
|
||||
Remove edge src -> dst
|
||||
"""
|
||||
self._graph.remove_edge(src, dst)
|
||||
|
||||
#
|
||||
# Helper functions for getting attrs
|
||||
#
|
||||
|
||||
def has_node(self, node: str) -> bool:
|
||||
"""
|
||||
Check if the node is in the graph
|
||||
"""
|
||||
return self._graph.has_node(node)
|
||||
|
||||
def in_degree(self, node: str):
|
||||
"""
|
||||
Get the input degree of node
|
||||
"""
|
||||
return self._graph.in_degree(node)
|
||||
|
||||
def in_edges(self, node: str):
|
||||
"""
|
||||
Get the input edges of node
|
||||
"""
|
||||
return [edge for edge in self._graph.in_edges(node)]
|
||||
|
||||
def out_degree(self, node: str):
|
||||
"""
|
||||
Get the output degree of node
|
||||
"""
|
||||
return self._graph.out_degree(node)
|
||||
|
||||
def out_edges(self, node: str):
|
||||
"""
|
||||
Get the output edges of node
|
||||
"""
|
||||
return [edge for edge in self._graph.out_edges(node)]
|
||||
|
||||
def get_node_meta(self, node: str):
|
||||
"""
|
||||
Get the meta data of the node
|
||||
"""
|
||||
return self._graph.nodes[node]["meta"]
|
||||
|
||||
def get_edge_weight(self, src, dst):
|
||||
"""
|
||||
Get the edge weight of edge src->dst
|
||||
"""
|
||||
return self._graph.get_edge_data(src, dst)["weight"]
|
||||
|
||||
#
|
||||
# High-level helper functions
|
||||
#
|
||||
|
||||
def all_reachable_nodes(self, node: str):
|
||||
"""
|
||||
Get all the nodes reachable from the current node (exclude)
|
||||
"""
|
||||
return list(nx.dfs_preorder_nodes(self._graph, source=node))
|
||||
|
||||
def get_users(self, node: str):
|
||||
"""
|
||||
Get all users of the current node
|
||||
"""
|
||||
return [edge[1] for edge in self.out_edges(node)]
|
||||
|
||||
def get_all_inputs(self, node: str):
|
||||
"""
|
||||
Get all the input nodes sorted by edge weight
|
||||
"""
|
||||
in_edges = self.in_edges(node)
|
||||
edge_weights = [self.get_edge_weight(*edge) for edge in in_edges]
|
||||
return [edge[0] for _, edge in sorted(zip(edge_weights, in_edges))]
|
||||
|
||||
def get_all_inputs_meta(self, node: str):
|
||||
"""
|
||||
Get all the input node metas sorted by edge weight
|
||||
"""
|
||||
return [self.get_node_meta(input_node) for input_node in self.get_all_inputs(node)]
|
||||
|
||||
def replace_all_uses_with(self, node1, node2):
|
||||
"""
|
||||
Replace all uses of node1 with node2
|
||||
"""
|
||||
for edge in self.out_edges(node1):
|
||||
weight = self.get_edge_weight(*edge)
|
||||
user = edge[1]
|
||||
self.add_edge(node2, user, weight)
|
||||
self.remove_edge(node1, user)
|
||||
self.remove_node(node1)
|
||||
|
||||
#
|
||||
# Node accessor
|
||||
#
|
||||
def nodes_topological_order(self):
|
||||
"""
|
||||
Get the nodes in the unique lexicographical topological order
|
||||
It generates a unique ordering of nodes by first sorting topologically
|
||||
and then additionally by sorting lexicographically.
|
||||
|
||||
Although topological_sort alone also works, this generates a unique key
|
||||
for each epilogue visitor pattern and ensures the compilation cache can be reused.
|
||||
:return: list[str]
|
||||
"""
|
||||
return list(nx.lexicographical_topological_sort(self._graph))
|
||||
|
||||
def node_metas_topological_order(self):
|
||||
"""
|
||||
Get the node metas in topological order
|
||||
:return: list[NodeBase]
|
||||
"""
|
||||
return [self.get_node_meta(node) for node in self.nodes_topological_order()]
|
||||
|
||||
@property
|
||||
def nodes(self):
|
||||
"""
|
||||
Get all nodes
|
||||
:return: list[str]
|
||||
"""
|
||||
return list(self._graph.nodes)
|
||||
|
||||
@property
|
||||
def nodes_meta(self):
|
||||
"""
|
||||
Get all node metas
|
||||
:return: list[NodeBase]
|
||||
"""
|
||||
return [data[1]['meta'] for data in self._graph.nodes.data()]
|
||||
|
||||
@property
|
||||
def edges(self):
|
||||
"""
|
||||
Get all edges
|
||||
:return: list[(str, str)]
|
||||
"""
|
||||
return list(self._graph.edges)
|
||||
|
||||
#
|
||||
# Path
|
||||
#
|
||||
def has_path(self, src: str, target: str) -> bool:
|
||||
"""
|
||||
Return True is a path exists from src to target
|
||||
"""
|
||||
return nx.has_path(self._graph, src, target)
|
||||
324
python/cutlass_cppgen/backend/evt/ir/layout_algorithm.py
Normal file
324
python/cutlass_cppgen/backend/evt/ir/layout_algorithm.py
Normal file
@ -0,0 +1,324 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Layout algebras
|
||||
"""
|
||||
|
||||
from pycute import Layout, composition, make_layout, flatten, product
|
||||
|
||||
|
||||
def _infer_split(old_shape, new_shape):
|
||||
old_shape = _tuple_to_list(old_shape)
|
||||
new_shape = _tuple_to_list(new_shape)
|
||||
if len(old_shape) == 0 and len(new_shape) == 0:
|
||||
return []
|
||||
if len(old_shape) == 0:
|
||||
if product(tuple(new_shape)) != 1:
|
||||
raise ValueError("Invalid reshape size")
|
||||
else:
|
||||
return new_shape
|
||||
if len(new_shape) == 0:
|
||||
if product(tuple(old_shape)) != 1:
|
||||
raise ValueError("Invalid reshape size")
|
||||
else:
|
||||
return old_shape
|
||||
# This is done recursively by only process the last dimension at each time
|
||||
old_dim = old_shape[-1]
|
||||
new_dim = new_shape[-1]
|
||||
# Exact match
|
||||
if old_dim == new_dim:
|
||||
return _infer_split(old_shape[:-1], new_shape[:-1]) + [new_dim,]
|
||||
# Needs split
|
||||
if old_dim > new_dim and old_dim % new_dim == 0:
|
||||
residual = old_dim // new_dim
|
||||
return _infer_split(old_shape[:-1] + [residual,], new_shape[:-1]) + [new_dim,]
|
||||
# Needs merge
|
||||
if old_dim < new_dim and new_dim % old_dim == 0:
|
||||
residual = new_dim // old_dim
|
||||
return _infer_split(old_shape[:-1], new_shape[:-1] + [residual,]) + [old_dim,]
|
||||
|
||||
raise NotImplementedError(f"Unsupported split: {old_shape} -> {new_shape}")
|
||||
|
||||
def _infer_merge(flatten_shape, shape):
|
||||
flatten_shape = _tuple_to_list(flatten_shape)
|
||||
shape = _tuple_to_list(shape)
|
||||
idx_flat = 0
|
||||
merged_shape = []
|
||||
for dim in shape:
|
||||
# Exact match
|
||||
if dim == flatten_shape[idx_flat]:
|
||||
merged_shape.append(dim)
|
||||
idx_flat += 1
|
||||
# Need group
|
||||
elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0:
|
||||
residual = dim
|
||||
group = []
|
||||
while(residual > 1):
|
||||
group.append(flatten_shape[idx_flat])
|
||||
residual = residual // flatten_shape[idx_flat]
|
||||
idx_flat += 1
|
||||
merged_shape.append(group)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}")
|
||||
|
||||
return merged_shape
|
||||
|
||||
def _list_to_tuple(nested_list):
|
||||
if isinstance(nested_list, list) or isinstance(nested_list, tuple):
|
||||
return tuple(_list_to_tuple(item) for item in nested_list)
|
||||
return nested_list
|
||||
|
||||
def _tuple_to_list(nested_tuple):
|
||||
if isinstance(nested_tuple, list) or isinstance(nested_tuple, tuple):
|
||||
return list(_tuple_to_list(item) for item in nested_tuple)
|
||||
return nested_tuple
|
||||
|
||||
def _reverse_tuple(nested_tuple: tuple):
|
||||
if isinstance(nested_tuple, tuple):
|
||||
return tuple([_reverse_tuple(item) for item in nested_tuple][::-1])
|
||||
return nested_tuple
|
||||
|
||||
def _get_first_lhs_nonzero_stride(stride_list, idx):
|
||||
for i in reversed(range(idx)):
|
||||
if stride_list[i] != 0:
|
||||
return i
|
||||
else:
|
||||
return None
|
||||
|
||||
def _get_first_rhs_nonzero_stride(stride_list, idx):
|
||||
for i in range(idx+1, len(stride_list)):
|
||||
if stride_list[i] != 0:
|
||||
return i
|
||||
else:
|
||||
return None
|
||||
|
||||
def reshape(layout, new_shape):
|
||||
"""
|
||||
General reshape of input layout.
|
||||
It takes two steps:
|
||||
1. split the dimensions of the old layout
|
||||
2. merge the splitted dimensions according to the new shape
|
||||
"""
|
||||
#
|
||||
# Step 1: Split the dimensions of the old layout
|
||||
#
|
||||
# 1.1 Flat old and new shape
|
||||
old_flatten_shape = list(flatten(layout.shape))
|
||||
new_flatten_shape = list(flatten(new_shape))
|
||||
|
||||
# 1.2 Infer the flatten splitted shape
|
||||
splitted_flatten_shape = _infer_split(old_flatten_shape, new_flatten_shape)
|
||||
|
||||
# 1.3 Unflat the splitted shape based on the old shape
|
||||
splited_shape = _infer_merge(splitted_flatten_shape, old_flatten_shape)
|
||||
|
||||
# 1.4 Infer the type of each split
|
||||
# If the split type is in row-major (R), the dimension list is reversed because
|
||||
# the cute::composition only support column-major split
|
||||
split_type = [] # the type of each split (ColumnMajor or RowMajor)
|
||||
permuted_splitted_shape = []
|
||||
old_flatten_stride = list(flatten(layout.stride))
|
||||
for idx, dim in enumerate(splited_shape):
|
||||
if not isinstance(dim, list):
|
||||
permuted_splitted_shape.append(dim)
|
||||
split_type.append("C")
|
||||
else:
|
||||
lhs_stride = _get_first_lhs_nonzero_stride(old_flatten_stride, idx)
|
||||
rhs_stride = _get_first_rhs_nonzero_stride(old_flatten_stride, idx)
|
||||
# Special case for single tuple
|
||||
# Use column-major by default
|
||||
if lhs_stride is None and rhs_stride is None:
|
||||
permuted_splitted_shape.append(dim)
|
||||
split_type.append("C")
|
||||
else:
|
||||
if lhs_stride is not None and rhs_stride is not None:
|
||||
# We consider shape[idx]:stride[idx]
|
||||
# Case 1: stride[idx - 1] <= stride[idx] <= stride[idx + 1]: column major
|
||||
if lhs_stride <= old_flatten_stride[idx] and old_flatten_stride[idx] <= rhs_stride:
|
||||
permuted_splitted_shape.append(dim)
|
||||
split_type.append("C")
|
||||
# Case 2: stride[idx - 1] > stride[idx] > stride[idx + 1]: row major
|
||||
elif lhs_stride > old_flatten_stride[idx] and old_flatten_stride[idx] > rhs_stride:
|
||||
permuted_splitted_shape.append([d for d in reversed(dim)])
|
||||
split_type.append("R")
|
||||
# Case 3: stride[idx - 1] <= stride[idx] > stride[idx + 1]: concave
|
||||
elif lhs_stride <= old_flatten_stride[idx] and old_flatten_stride[idx] > rhs_stride:
|
||||
if lhs_stride >= rhs_stride:
|
||||
permuted_splitted_shape.append(dim)
|
||||
split_type.append("C")
|
||||
else:
|
||||
permuted_splitted_shape.append([d for d in reversed(dim)])
|
||||
split_type.append("R")
|
||||
# Case 4: stride[idx - 1] > stride[idx] <= stride[idx + 1]: concave
|
||||
elif lhs_stride > old_flatten_stride[idx] and old_flatten_stride[idx] <= rhs_stride:
|
||||
if lhs_stride >= rhs_stride:
|
||||
permuted_splitted_shape.append(dim)
|
||||
split_type.append("C")
|
||||
else:
|
||||
permuted_splitted_shape.append([d for d in reversed(dim)])
|
||||
split_type.append("R")
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
elif lhs_stride is None:
|
||||
# Case 1: dim's stride < dim+1's stride, expand in column major
|
||||
if old_flatten_stride[idx] > rhs_stride:
|
||||
permuted_splitted_shape.append([d for d in reversed(dim)])
|
||||
split_type.append("R")
|
||||
else:
|
||||
permuted_splitted_shape.append(dim)
|
||||
split_type.append("C")
|
||||
else:
|
||||
# Case 1: dim's stride > dim-1's stride
|
||||
if old_flatten_stride[idx] < lhs_stride:
|
||||
permuted_splitted_shape.append([d for d in reversed(dim)])
|
||||
split_type.append("R")
|
||||
else:
|
||||
permuted_splitted_shape.append(dim)
|
||||
split_type.append("C")
|
||||
|
||||
# 1.4 Generate the splitted layout
|
||||
permuted_splitted_layout = composition(layout, Layout(_list_to_tuple(permuted_splitted_shape)))
|
||||
|
||||
# 1.5 Reverse the permutation in 1.4 before merge
|
||||
splitted_shape = []
|
||||
splitted_stride = []
|
||||
for shape_dim, stride_dim, type in zip(
|
||||
permuted_splitted_layout.shape,
|
||||
permuted_splitted_layout.stride,
|
||||
split_type):
|
||||
if type == "C":
|
||||
splitted_shape.append(shape_dim)
|
||||
splitted_stride.append(stride_dim)
|
||||
else:
|
||||
splitted_shape.append(tuple([d for d in reversed(shape_dim)]))
|
||||
splitted_stride.append(tuple([d for d in reversed(stride_dim)]))
|
||||
splitted_layout = Layout(tuple(splitted_shape), tuple(splitted_stride))
|
||||
|
||||
|
||||
#
|
||||
# Step 2: Merge the splitted dimensions according to the new shape
|
||||
#
|
||||
# 2.1 Merge layout
|
||||
merged_layout = composition(splitted_layout, Layout(new_shape))
|
||||
|
||||
# 2.2 Cleaning up
|
||||
output_layout = composition(merged_layout, Layout(new_shape))
|
||||
return output_layout
|
||||
|
||||
|
||||
def permutation(layout, permutation):
|
||||
"""
|
||||
Permute the layout
|
||||
"""
|
||||
new_shape = tuple([layout.shape[idx] for idx in permutation])
|
||||
new_stride = tuple([layout.stride[idx] for idx in permutation])
|
||||
return Layout(new_shape, new_stride)
|
||||
|
||||
|
||||
def _broadcast(layout, new_shape):
|
||||
if len(layout) == 1 and isinstance(new_shape, int):
|
||||
old_dim = layout.shape
|
||||
old_stride = layout.stride
|
||||
new_dim = new_shape
|
||||
if old_dim == new_dim:
|
||||
return Layout(old_dim, old_stride)
|
||||
elif old_dim == 1:
|
||||
return Layout(new_dim, 0)
|
||||
else:
|
||||
raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {new_dim}")
|
||||
|
||||
# Align the dimensions
|
||||
old_shape = layout.shape
|
||||
if isinstance(old_shape, int):
|
||||
old_shape = (old_shape,)
|
||||
sub_layouts = [layout,]
|
||||
else:
|
||||
sub_layouts = [sub_layout for sub_layout in layout]
|
||||
rhs_broadcast_layouts = [Layout(1, 0)] * (len(new_shape) - len(old_shape))
|
||||
# Get the broadcasted layout
|
||||
broadcast_layouts = []
|
||||
try:
|
||||
layout = make_layout(*sub_layouts, *rhs_broadcast_layouts)
|
||||
broadcast_layouts = []
|
||||
for idx, sub_layout in enumerate(layout):
|
||||
broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx]))
|
||||
except NotImplementedError:
|
||||
layout = make_layout(*rhs_broadcast_layouts, *sub_layouts)
|
||||
for idx, sub_layout in enumerate(layout):
|
||||
broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx]))
|
||||
return make_layout(*broadcast_layouts)
|
||||
|
||||
|
||||
def broadcast(layout, new_shape):
|
||||
"""
|
||||
Broadcast the new layout based on the input shape
|
||||
The broadcasted shape equals to the new shape
|
||||
The stride of broadcasted dimensions are 0
|
||||
"""
|
||||
return _broadcast(layout, new_shape)
|
||||
|
||||
|
||||
def debroadcast(layout, dims):
|
||||
"""
|
||||
Squeeze the 0-stride
|
||||
"""
|
||||
for dim in dims:
|
||||
if layout.stride[dim] != 0:
|
||||
raise ValueError(f"Dim{dim} cannot be debroadcasted as it has stride {layout.stride[dim]}")
|
||||
new_shape = tuple([s for idx, s in enumerate(layout.shape) if idx not in dims])
|
||||
new_stride = tuple([s for idx, s in enumerate(layout.stride) if idx not in dims])
|
||||
return Layout(new_shape, new_stride)
|
||||
|
||||
|
||||
def canonicalization_(shapes, strides):
|
||||
if isinstance(shapes, tuple):
|
||||
c_shapes = []
|
||||
c_strides = []
|
||||
for shape, stride in zip(shapes, strides):
|
||||
c_shape, c_stride = canonicalization_(shape, stride)
|
||||
c_shapes.append(c_shape)
|
||||
c_strides.append(c_stride)
|
||||
return tuple(c_shapes), tuple(c_strides)
|
||||
else:
|
||||
if shapes == 1:
|
||||
return 1, 0
|
||||
else:
|
||||
return shapes, strides
|
||||
|
||||
def canonicalization(layout):
|
||||
"""
|
||||
Canonicalize the input layout
|
||||
1. set the stride of shape "1" to 0
|
||||
"""
|
||||
new_shape, new_stride = canonicalization_(layout.shape, layout.stride)
|
||||
return Layout(new_shape, new_stride)
|
||||
336
python/cutlass_cppgen/backend/evt/ir/layout_nodes.py
Normal file
336
python/cutlass_cppgen/backend/evt/ir/layout_nodes.py
Normal file
@ -0,0 +1,336 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Layout manipulation nodes and implementations
|
||||
|
||||
The layout Nodes change the layout of intermediate nodes in epilogue visitor graph
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
from cutlass_library import LayoutType
|
||||
from pycute import product, flatten
|
||||
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _tuple_to_list
|
||||
from cutlass_cppgen.backend.evt.ir.node import NodeBase
|
||||
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
|
||||
|
||||
|
||||
class PermutationImpl:
|
||||
"""
|
||||
Detailed implementation and helper functions for permutation
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
assert "indices" in node.kwargs.keys()
|
||||
self.indices = list(node.kwargs["indices"])
|
||||
self.inverse_indices = self.get_inverse_indices(self.indices)
|
||||
|
||||
def get_inverse_impl(self):
|
||||
inverse_impl = deepcopy(self)
|
||||
inverse_impl.indices = self.inverse_indices
|
||||
inverse_impl.inverse_indices = self.indices
|
||||
return inverse_impl
|
||||
|
||||
def update(self, shape):
|
||||
num_dim = len(shape)
|
||||
indices = self.indices
|
||||
num_old_dim = len(indices)
|
||||
# Add offset
|
||||
for i, idx in enumerate(indices):
|
||||
indices[i] = idx + num_dim - num_old_dim
|
||||
# Add broadcast dims
|
||||
for i in range(num_dim - num_old_dim):
|
||||
indices = [i,] + indices
|
||||
|
||||
self.indices = indices
|
||||
self.inverse_indices = self.get_inverse_indices(self.indices)
|
||||
|
||||
def get_inverse_indices(self, indices):
|
||||
"""
|
||||
Get the indices for inverse permutation
|
||||
"""
|
||||
num_dim = len(indices)
|
||||
inverse_indices = [0] * num_dim
|
||||
for i in range(num_dim):
|
||||
inverse_indices[indices[i]] = i
|
||||
return inverse_indices
|
||||
|
||||
def shape_propagation(self, input_node_meta):
|
||||
input_shape = input_node_meta.tensor.shape
|
||||
output_shape = tuple([input_shape[idx] for idx in self.indices])
|
||||
return output_shape
|
||||
|
||||
def broadcast(self, shape, node_meta: NodeBase):
|
||||
"""
|
||||
Broadcast the inputs based on current shape
|
||||
"""
|
||||
self.update(shape)
|
||||
inverse_shape = tuple([shape[idx] for idx in self.inverse_indices])
|
||||
node_meta.tensor.broadcast(inverse_shape)
|
||||
|
||||
def apply_to_user(self, usr_meta: NodeBase):
|
||||
"""
|
||||
Propagate the permutation to the users of the current nodes
|
||||
"""
|
||||
usr_meta.tensor.permute(self.inverse_indices)
|
||||
if hasattr(usr_meta, "store_tensor"):
|
||||
if usr_meta.store_tensor is not None:
|
||||
usr_meta.store_tensor.permute(self.inverse_indices)
|
||||
|
||||
def apply_to_input(self, input_meta: NodeBase):
|
||||
"""
|
||||
Propagate the permutation to inputs of the current nodes
|
||||
"""
|
||||
input_meta.tensor.permute(self.indices)
|
||||
if hasattr(input_meta, "store_tensor"):
|
||||
if input_meta.store_tensor is not None:
|
||||
input_meta.store_tensor.permute(self.indices)
|
||||
|
||||
|
||||
class ReshapeImpl:
|
||||
"""
|
||||
Detailed implementation and helper functions for reshape
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
self.node = node
|
||||
assert "new_shape" in node.kwargs.keys()
|
||||
self.output_shape = _list_to_tuple(node.kwargs["new_shape"])
|
||||
|
||||
def get_inverse_impl(self):
|
||||
inverse_impl = deepcopy(self)
|
||||
inverse_impl.output_shape = self.input_shape
|
||||
inverse_impl.input_shape = self.output_shape
|
||||
return inverse_impl
|
||||
|
||||
def shape_propagation(self, input_node_meta):
|
||||
self.input_shape = input_node_meta.tensor.shape
|
||||
return _list_to_tuple(self.output_shape)
|
||||
|
||||
def broadcast(self, shape, node_meta: NodeBase):
|
||||
"""
|
||||
Broadcast the inputs based on current shape.
|
||||
"""
|
||||
# Step 1: infer split
|
||||
flatten_split_shape = self.infer_split(flatten(self.input_shape), flatten(self.output_shape))
|
||||
split_input_shape = self.infer_merge(flatten_split_shape, self.input_shape)
|
||||
split_output_shape = self.infer_merge(flatten_split_shape, self.output_shape)
|
||||
|
||||
# broadcast shape -> split_output_shape -> flatten_split_shape
|
||||
if len(shape) - len(split_output_shape) > 0:
|
||||
for _ in range(len(shape) - len(split_output_shape)):
|
||||
split_output_shape = [1,] + split_output_shape
|
||||
flatten_split_shape = [1,] + flatten_split_shape
|
||||
split_input_shape = [1,] + split_input_shape
|
||||
broadcast_factor = []
|
||||
for dim, old_dim in zip(shape, split_output_shape):
|
||||
if not isinstance(dim, list):
|
||||
dim = [dim,]
|
||||
if not isinstance(old_dim, list):
|
||||
old_dim = [old_dim,]
|
||||
if product(tuple(dim)) == product(tuple(old_dim)):
|
||||
broadcast_factor += [1] * len(old_dim)
|
||||
elif product(tuple(old_dim)) == 1:
|
||||
assert len(dim) == 1
|
||||
broadcast_factor.append(dim[0])
|
||||
else:
|
||||
raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {dim}")
|
||||
|
||||
# flatten_split_shape -> split_input_shape
|
||||
factor_idx = 0
|
||||
broadcast_split_input_shape = []
|
||||
for dim in split_input_shape:
|
||||
if isinstance(dim, list):
|
||||
new_dim = []
|
||||
for d in dim:
|
||||
new_dim.append(d * broadcast_factor[factor_idx])
|
||||
factor_idx += 1
|
||||
broadcast_split_input_shape.append(new_dim)
|
||||
else:
|
||||
broadcast_split_input_shape.append(dim * broadcast_factor[factor_idx])
|
||||
factor_idx += 1
|
||||
broadcast_split_input_shape = _list_to_tuple(broadcast_split_input_shape)
|
||||
node_meta.tensor.reshape(_list_to_tuple(split_input_shape))
|
||||
node_meta.tensor.broadcast(broadcast_split_input_shape)
|
||||
# Last reshape op to clean up
|
||||
broadcast_input_shape = tuple([product(dim) for dim in broadcast_split_input_shape])
|
||||
node_meta.tensor.reshape(broadcast_input_shape)
|
||||
# Update the input shape and output shape
|
||||
self.input_shape = _list_to_tuple(node_meta.tensor.shape)
|
||||
self.output_shape = _list_to_tuple(shape)
|
||||
|
||||
def apply_to_user(self, user_meta: NodeBase):
|
||||
"""
|
||||
Propagate the reshape to user nodes
|
||||
"""
|
||||
user_meta.tensor.reshape(tuple(self.input_shape))
|
||||
if hasattr(user_meta, "store_tensor"):
|
||||
if user_meta.store_tensor is not None:
|
||||
user_meta.store_tensor.reshape(tuple(self.input_shape))
|
||||
|
||||
def apply_to_input(self, input_meta: NodeBase):
|
||||
"""
|
||||
Propagate the reshape to input nodes
|
||||
"""
|
||||
input_meta.tensor.reshape(tuple(self.output_shape))
|
||||
if hasattr(input_meta, "store_tensor"):
|
||||
if input_meta.store_tensor is not None:
|
||||
input_meta.store_tensor.reshape(tuple(self.output_shape))
|
||||
|
||||
#
|
||||
# Helper functions
|
||||
#
|
||||
|
||||
def infer_split(self, input_shape, output_shape):
|
||||
"""
|
||||
Infer the flatten splitted shape that can be merged to both input_shape and output_shape
|
||||
"""
|
||||
input_shape = _tuple_to_list(input_shape)
|
||||
output_shape = _tuple_to_list(output_shape)
|
||||
if len(input_shape) == 0 and len(output_shape) == 0:
|
||||
return []
|
||||
if len(input_shape) == 0:
|
||||
if product(tuple(output_shape)) != 1:
|
||||
raise ValueError("Invalid reshape size")
|
||||
else:
|
||||
return output_shape
|
||||
if len(output_shape) == 0:
|
||||
if product(tuple(input_shape)) != 1:
|
||||
raise ValueError("Invalid reshape size")
|
||||
else:
|
||||
return input_shape
|
||||
# This is done recursively by only process the last dimension at each time
|
||||
old_dim = input_shape[-1]
|
||||
new_dim = output_shape[-1]
|
||||
# Exact match
|
||||
if old_dim == new_dim:
|
||||
return self.infer_split(input_shape[:-1], output_shape[:-1]) + [new_dim,]
|
||||
# Needs split
|
||||
if old_dim > new_dim and old_dim % new_dim == 0:
|
||||
residual = old_dim // new_dim
|
||||
return self.infer_split(input_shape[:-1] + [residual,], output_shape[:-1]) + [new_dim,]
|
||||
# Needs merge
|
||||
if old_dim < new_dim and new_dim % old_dim == 0:
|
||||
residual = new_dim // old_dim
|
||||
return self.infer_split(input_shape[:-1], output_shape[:-1] + [residual,]) + [old_dim,]
|
||||
|
||||
raise NotImplementedError(f"Unsupported split: {input_shape} -> {output_shape}")
|
||||
|
||||
def infer_merge(self, flatten_shape, shape):
|
||||
flatten_shape = _tuple_to_list(flatten_shape)
|
||||
shape = _tuple_to_list(shape)
|
||||
idx_flat = len(flatten_shape) - 1
|
||||
merged_shape = []
|
||||
for dim in reversed(shape):
|
||||
# Exact match
|
||||
if dim == flatten_shape[idx_flat]:
|
||||
merged_shape.append(dim)
|
||||
idx_flat -= 1
|
||||
# need group
|
||||
elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0:
|
||||
residual = dim
|
||||
group = []
|
||||
while(residual > 1):
|
||||
group.append(flatten_shape[idx_flat])
|
||||
residual = residual // flatten_shape[idx_flat]
|
||||
idx_flat -= 1
|
||||
merged_shape.append(group[::-1])
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}")
|
||||
|
||||
return merged_shape[::-1]
|
||||
|
||||
|
||||
class LayoutNode(NodeBase):
|
||||
"""
|
||||
Layout manipulation nodes
|
||||
"""
|
||||
fn_to_impl = {
|
||||
"permute": PermutationImpl,
|
||||
"reshape": ReshapeImpl
|
||||
}
|
||||
def __init__(self, name: str, fn, kwargs: dict) -> None:
|
||||
super().__init__(name)
|
||||
self.op = "layout"
|
||||
self.fn = fn
|
||||
self.kwargs = kwargs
|
||||
self.underlying_impl = self.fn_to_impl[self.fn.__name__](self)
|
||||
|
||||
def get_inverse_node(self):
|
||||
inverse_node = deepcopy(self)
|
||||
inverse_node.underlying_impl = self.underlying_impl.get_inverse_impl()
|
||||
return inverse_node
|
||||
|
||||
def shape_propagation(self, input_node_metas):
|
||||
if self._tensor is not None:
|
||||
return
|
||||
assert len(input_node_metas) == 1, "Layout node can only have one input node"
|
||||
|
||||
output_shape = self.underlying_impl.shape_propagation(input_node_metas[0])
|
||||
|
||||
self._tensor = Tensor(
|
||||
element=self.element_output,
|
||||
shape=output_shape, layout_tag=LayoutType.RowMajor
|
||||
)
|
||||
|
||||
return super().shape_propagation(input_node_metas)
|
||||
|
||||
def type_propagation(self, input_node_metas: 'list[NodeBase]'):
|
||||
"""
|
||||
The store nodes has element_output = element_input
|
||||
"""
|
||||
assert len(input_node_metas) == 1, "Layout node can only have one input node"
|
||||
self.element_output = input_node_metas[0].element_output
|
||||
|
||||
def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
|
||||
"""
|
||||
Propagate the broadcast in the reversed topological order
|
||||
"""
|
||||
if self.tensor is None:
|
||||
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
|
||||
shape = self.tensor.shape
|
||||
|
||||
for child in input_node_metas:
|
||||
self.underlying_impl.broadcast(shape, child)
|
||||
|
||||
def apply_to_user(self, usr_meta: NodeBase):
|
||||
"""
|
||||
Propagate the permutation to user nodes
|
||||
"""
|
||||
self.underlying_impl.apply_to_user(usr_meta)
|
||||
|
||||
def apply_to_input(self, input_meta: NodeBase):
|
||||
"""
|
||||
Propagate the permutation to input nodes
|
||||
"""
|
||||
self.underlying_impl.apply_to_input(input_meta)
|
||||
294
python/cutlass_cppgen/backend/evt/ir/load_nodes.py
Normal file
294
python/cutlass_cppgen/backend/evt/ir/load_nodes.py
Normal file
@ -0,0 +1,294 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Load nodes and implementations
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
|
||||
from cutlass_cppgen.backend.c_types import tuple_factory
|
||||
from cutlass_cppgen.backend.epilogue import dtype2ctype, to_ctype_value
|
||||
from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase
|
||||
|
||||
|
||||
class LoadImplBase(ImplBase):
|
||||
"""
|
||||
Base class for load node implementations
|
||||
"""
|
||||
reserved_names = ["accum", "C"]
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.element = node.element
|
||||
self.element_output = node.element_output
|
||||
self.stride = node.tensor.stride
|
||||
|
||||
|
||||
class AccumulatorImpl(LoadImplBase):
|
||||
"""
|
||||
Accumulator node implementation
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
return node.name == "accum" and node.tensor.shape == problem_size
|
||||
|
||||
|
||||
class LoadSrcImpl(LoadImplBase):
|
||||
"""
|
||||
Load C implementation
|
||||
"""
|
||||
@property
|
||||
def name_camel(self) -> str:
|
||||
return "TensorC"
|
||||
|
||||
@property
|
||||
def argument_type_c(self):
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ptr_C", ctypes.c_void_p),
|
||||
("stride_C", tuple_type)
|
||||
]
|
||||
def __init__(self, ptr) -> None:
|
||||
self.ptr_C = ptr
|
||||
self.stride_C = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
return node.name == "C" and node.tensor.shape == problem_size
|
||||
|
||||
|
||||
class AuxLoadImpl(LoadImplBase):
|
||||
"""
|
||||
Load arbitrary tensor
|
||||
"""
|
||||
@property
|
||||
def argument_type(self):
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
name = self.name
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
element_type = self.element
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ptr_aux", ctypes.c_void_p),
|
||||
("null_default", dtype2ctype[element_type]),
|
||||
("dAux", tuple_type)
|
||||
]
|
||||
def __init__(self, kwargs) -> None:
|
||||
ptr = kwargs[name]
|
||||
self.ptr_aux = ptr
|
||||
self.null_default = to_ctype_value(0, element_type)
|
||||
self.dAux = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if node.name in LoadImplBase.reserved_names:
|
||||
return False
|
||||
strideMN = node.tensor.stride[-2:]
|
||||
if (strideMN[0] == 1 and strideMN[1] != 0 or
|
||||
strideMN[0] != 0 and strideMN[1] == 1 ):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class RowBroadcastImpl(LoadImplBase):
|
||||
"""
|
||||
Broadcast a row vector
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.stride_dtype = "int"
|
||||
|
||||
@property
|
||||
def argument_type(self):
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
name = self.name
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
element_type = self.element
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ptr_row", ctypes.c_void_p),
|
||||
("null_default", dtype2ctype[element_type]),
|
||||
("dRow", tuple_type)
|
||||
]
|
||||
def __init__(self, kwargs) -> None:
|
||||
ptr = kwargs[name]
|
||||
self.ptr_row = ptr
|
||||
self.null_default = to_ctype_value(0, element_type)
|
||||
self.dRow = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if node.name in LoadImplBase.reserved_names:
|
||||
return False
|
||||
|
||||
strideMN = node.tensor.stride[-2:]
|
||||
if strideMN == (0, 1):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class ColumnBroadcastImpl(LoadImplBase):
|
||||
"""
|
||||
Broadcast a column vector
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.stride_dtype = "int"
|
||||
|
||||
@property
|
||||
def argument_type(self):
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
name = self.name
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
element_type = self.element
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ptr_col", ctypes.c_void_p),
|
||||
("null_default", dtype2ctype[element_type]),
|
||||
("dCol", tuple_type)
|
||||
]
|
||||
def __init__(self, kwargs) -> None:
|
||||
ptr = kwargs[name]
|
||||
self.ptr_col = int(ptr)
|
||||
self.null_default = to_ctype_value(0, element_type)
|
||||
self.dCol = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if node.name in LoadImplBase.reserved_names:
|
||||
return False
|
||||
|
||||
strideMN = node.tensor.stride[-2:]
|
||||
if strideMN == (1, 0):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class ScalarBroadcastImpl(LoadImplBase):
|
||||
"""
|
||||
Broadcast a scalar
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.stride_dtype = "int"
|
||||
|
||||
@property
|
||||
def argument_type(self):
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
name = self.name
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
element_type = self.element
|
||||
|
||||
if self.tensor.is_constant:
|
||||
value = self.tensor.value
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("scalars", dtype2ctype[element_type]),
|
||||
("scalar_ptrs", ctypes.c_void_p),
|
||||
("dScalar", tuple_type)
|
||||
]
|
||||
def __init__(self, kwargs) -> None:
|
||||
self.scalars = to_ctype_value(value, element_type)
|
||||
self.scalar_ptrs = 0
|
||||
self.dScalar = tuple_type(stride_mnl)
|
||||
|
||||
else:
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("scalars", dtype2ctype[element_type]),
|
||||
("scalar_ptrs", ctypes.c_void_p),
|
||||
("dScalar", tuple_type)
|
||||
]
|
||||
def __init__(self, kwargs) -> None:
|
||||
scalar_or_ptr = kwargs[name]
|
||||
if isinstance(scalar_or_ptr, float):
|
||||
self.scalars = to_ctype_value(scalar_or_ptr, element_type)
|
||||
self.scalar_ptrs = 0
|
||||
else:
|
||||
self.scalar_ptrs = int(scalar_or_ptr)
|
||||
|
||||
self.dScalar = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if node.name in LoadImplBase.reserved_names:
|
||||
return False
|
||||
|
||||
strideMN = node.tensor.stride[-2:]
|
||||
if strideMN == (0, 0):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class LoadNode(NodeBase):
|
||||
"""
|
||||
Load Node
|
||||
"""
|
||||
cnt = 0
|
||||
possible_impls = [
|
||||
AccumulatorImpl, LoadSrcImpl, AuxLoadImpl,
|
||||
RowBroadcastImpl, ColumnBroadcastImpl,
|
||||
ScalarBroadcastImpl
|
||||
]
|
||||
def __init__(self, name: str) -> None:
|
||||
if name is None:
|
||||
name = f"load{LoadNode.cnt}"
|
||||
LoadNode.cnt += 1
|
||||
super().__init__(name)
|
||||
self.op = "load"
|
||||
|
||||
def type_propagation(self, *args, **kwargs):
|
||||
"""
|
||||
Load node loads tensor under type `tensor.element` and returns an array of type `tensor.element`.
|
||||
"""
|
||||
if self.tensor is None:
|
||||
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
|
||||
|
||||
self.element = self.tensor.element
|
||||
self.element_output = self.tensor.element
|
||||
306
python/cutlass_cppgen/backend/evt/ir/node.py
Normal file
306
python/cutlass_cppgen/backend/evt/ir/node.py
Normal file
@ -0,0 +1,306 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Base & visitor classes of DAGIR Nodes
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
from re import sub
|
||||
|
||||
from cutlass_library import LayoutType
|
||||
|
||||
from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _reverse_tuple
|
||||
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
|
||||
|
||||
|
||||
class TupleEmitter:
|
||||
"""
|
||||
Emit the cute tuple to C++ code
|
||||
"""
|
||||
def __init__(self, stride_dtype):
|
||||
self.stride_dtype = stride_dtype
|
||||
|
||||
def emit(self, py_tuple):
|
||||
if isinstance(py_tuple, int):
|
||||
if py_tuple in [0, 1]:
|
||||
return f"cute::Int<{py_tuple}>"
|
||||
else:
|
||||
return f"{self.stride_dtype}"
|
||||
elif isinstance(py_tuple, tuple):
|
||||
decl = "cute::Stride<"
|
||||
for item in py_tuple:
|
||||
decl += self.emit(item) + ", "
|
||||
return decl[:-2] + ">"
|
||||
else:
|
||||
raise ValueError(f"TupleEmitter.emit only accepts tuple or int, got {type(py_tuple).__name__}")
|
||||
|
||||
|
||||
class ImplBase:
|
||||
"""
|
||||
Base class for Node Implementation
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
self.node = node
|
||||
self.name = node.name
|
||||
self.tensor = node.tensor
|
||||
self._type_decl = None
|
||||
self.tuple_emitter = TupleEmitter("int64_t")
|
||||
|
||||
@property
|
||||
def stride_dtype(self):
|
||||
return self.tuple_emitter.stride_dtype
|
||||
|
||||
@stride_dtype.setter
|
||||
def stride_dtype(self, stride_dtype):
|
||||
self.tuple_emitter.stride_dtype = stride_dtype
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
"""
|
||||
Match function used in get_underlying_impl
|
||||
"""
|
||||
raise NotImplementedError(f"The `match` function is not defined.")
|
||||
|
||||
@property
|
||||
def argument_type(self):
|
||||
"""
|
||||
Default class for Argument Type
|
||||
"""
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = []
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
return _Argument
|
||||
|
||||
@property
|
||||
def name_camel(self) -> str:
|
||||
"""
|
||||
Return the CamelCase name.
|
||||
"""
|
||||
return sub(r"(_|-)+", " ", self.name).title().replace(" ", "")
|
||||
|
||||
@property
|
||||
def stride_mnl(self):
|
||||
"""
|
||||
Typename StrideMNL
|
||||
"""
|
||||
stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2]))))
|
||||
return self.tuple_emitter.emit(stride)
|
||||
|
||||
def get_non_constant_stride(self, py_tuple):
|
||||
if isinstance(py_tuple, int):
|
||||
if py_tuple not in [0, 1]:
|
||||
return py_tuple
|
||||
else:
|
||||
return None
|
||||
non_constant_stride = []
|
||||
for item in py_tuple:
|
||||
item_out = self.get_non_constant_stride(item)
|
||||
if item_out:
|
||||
non_constant_stride.append(item_out)
|
||||
return tuple(non_constant_stride)
|
||||
|
||||
def get_stride_mnl(self):
|
||||
"""
|
||||
Get the non-zero stride mnl. This is used in argument construction
|
||||
"""
|
||||
stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2]))))
|
||||
return stride
|
||||
|
||||
def get_smem_size(self, *args, **kwargs):
|
||||
"""
|
||||
Get the shared memory size and alignment of current node
|
||||
"""
|
||||
return (0, 1)
|
||||
|
||||
|
||||
class NoOpImpl(ImplBase):
|
||||
"""
|
||||
The NoOpImpl does nothing but forward its input to users
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if node.op == "store":
|
||||
# Store that is not output is a No OP
|
||||
return not node.is_output
|
||||
|
||||
|
||||
class NodeBase:
|
||||
"""
|
||||
Base class of DAG Node
|
||||
"""
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
self.underlying_impl = None
|
||||
|
||||
self._tensor = None
|
||||
|
||||
# Whether the node is disabled for emit
|
||||
self.disabled = False
|
||||
|
||||
@property
|
||||
def name_camel(self) -> str:
|
||||
"""
|
||||
Return the CamelCase name.
|
||||
"""
|
||||
return self.underlying_impl.name_camel
|
||||
|
||||
@property
|
||||
def tensor(self) -> Tensor:
|
||||
"""
|
||||
Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor)
|
||||
"""
|
||||
return self._tensor
|
||||
|
||||
@tensor.setter
|
||||
def tensor(self, kwargs):
|
||||
"""
|
||||
Setting the tensor
|
||||
"""
|
||||
self._tensor = Tensor(**kwargs)
|
||||
|
||||
#
|
||||
# Helper functions for type/shape propagation
|
||||
#
|
||||
|
||||
def shape_propagation(self, input_node_metas):
|
||||
"""
|
||||
Infer shape from input nodes
|
||||
General Broadcasting Rules from NumPy
|
||||
When operating on two arrays, we compare their shapes element-wise.
|
||||
It starts with the trailing (i.e. rightmost) dimension and works its
|
||||
way left. Two dimensions are compatible when
|
||||
1. they are equal
|
||||
2. one of them is 1
|
||||
"""
|
||||
if self._tensor is not None:
|
||||
return
|
||||
|
||||
shape = None
|
||||
for src in input_node_metas:
|
||||
src_shape = src.tensor.shape
|
||||
if shape is None:
|
||||
shape = src_shape
|
||||
else:
|
||||
len_difference = len(shape) - len(src_shape)
|
||||
if len_difference > 0:
|
||||
for _ in range(len_difference):
|
||||
src_shape = [1, ] + list(src_shape)
|
||||
elif len_difference < 0:
|
||||
for _ in range(-len_difference):
|
||||
shape = [1, ] + list(shape)
|
||||
broadcasted_shape = []
|
||||
# Infer broadcast shape
|
||||
for shape_dim, src_dim in zip(reversed(shape), reversed(src_shape)):
|
||||
if shape_dim == 1:
|
||||
broadcasted_shape = [src_dim, ] + list(broadcasted_shape)
|
||||
elif src_dim == 1:
|
||||
broadcasted_shape = [shape_dim, ] + list(broadcasted_shape)
|
||||
elif shape_dim == src_dim:
|
||||
broadcasted_shape = [shape_dim, ] + list(broadcasted_shape)
|
||||
else:
|
||||
error_msg = "Dimension mismatch between "
|
||||
for src_ in input_node_metas:
|
||||
error_msg += f"{src_.name}{src_.tensor.shape}, "
|
||||
error_msg = error_msg[:-2] + "."
|
||||
raise RuntimeError(error_msg)
|
||||
shape = tuple(broadcasted_shape)
|
||||
|
||||
self._tensor = Tensor(element=self.element_output, shape=shape, layout_tag=LayoutType.RowMajor)
|
||||
|
||||
def type_propagation(self, *args, **kwargs):
|
||||
"""
|
||||
Each node is associated with two data types: `element` and `element_output`.
|
||||
The `element_output` is the type of return array of the node. The `element`
|
||||
has specific meaning for different node types.
|
||||
* Load Node: data type of tensor in gmem
|
||||
* Compute Node: element compute
|
||||
* Store Node: data type of tensor in gmem
|
||||
This function must be overloaded in the derived classes
|
||||
"""
|
||||
raise NotImplementedError(f"Function `type_propagation` is not overloaded in {self.__class__.__name__}")
|
||||
|
||||
def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
|
||||
"""
|
||||
Propagate the broadcast in the reversed topological order.
|
||||
For example:
|
||||
C[l, m, n] = A[m, 1] + B[l, m, n]
|
||||
After the broadcast propagation, it will be come
|
||||
C[l, m, n] = A[l, m, n] + B[l, m, n]
|
||||
and each tensor will have a proper stride accessing the underlying tensor
|
||||
"""
|
||||
if self.tensor is None:
|
||||
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
|
||||
for child in input_node_metas:
|
||||
child.tensor.broadcast(self.tensor.shape)
|
||||
|
||||
def get_underlying_impl(self, problem_size: tuple):
|
||||
"""
|
||||
Get the underlying implementation of the current node.
|
||||
"""
|
||||
if self.tensor is None:
|
||||
raise RuntimeError(f"The Layout of node {self.name} is unknown. Please call PassShapeTypePropagation first.")
|
||||
|
||||
for impl in self.possible_impls:
|
||||
if impl.match(self, problem_size):
|
||||
self.underlying_impl = impl(self)
|
||||
break
|
||||
|
||||
if self.underlying_impl is None:
|
||||
raise NotImplementedError(f"No matching op for node {self.name} with stride {self.tensor.stride}.")
|
||||
|
||||
#
|
||||
# Visitor Nodes & Impls
|
||||
#
|
||||
|
||||
class TopoVisitorImpl(ImplBase):
|
||||
"""
|
||||
Impl for topological visitor
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node.output_node)
|
||||
self.name = node.name
|
||||
self.element_output = node.output_node.element_output
|
||||
|
||||
class TopoVisitorNode(NodeBase):
|
||||
def __init__(self, name: str, subgraph, output_node) -> None:
|
||||
super().__init__(name)
|
||||
self.subgraph = subgraph
|
||||
self.output_node = output_node
|
||||
self.op = "dag"
|
||||
self.underlying_impl = TopoVisitorImpl(self)
|
||||
277
python/cutlass_cppgen/backend/evt/ir/store_nodes.py
Normal file
277
python/cutlass_cppgen/backend/evt/ir/store_nodes.py
Normal file
@ -0,0 +1,277 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Store node and implementations
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
|
||||
from cutlass_library import DataType
|
||||
|
||||
from cutlass_cppgen.backend.c_types import tuple_factory
|
||||
from cutlass_cppgen.backend.epilogue import dtype2ctype, to_ctype_value
|
||||
from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase, NoOpImpl
|
||||
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
|
||||
from cutlass_cppgen.backend.library import FloatRoundStyle, FunctionalOp
|
||||
|
||||
|
||||
class StoreImplBase(ImplBase):
|
||||
"""
|
||||
Base class for store node implementation
|
||||
"""
|
||||
reserved_names = ["D"]
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.element = node.element
|
||||
self.element_output = node.element_output
|
||||
self.stride = node.store_tensor.stride
|
||||
|
||||
|
||||
class StoreDImpl(StoreImplBase):
|
||||
"""
|
||||
Store D implementation
|
||||
"""
|
||||
|
||||
@property
|
||||
def argument_type_d(self):
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ptr_D", ctypes.c_void_p),
|
||||
("stride_D", tuple_type)
|
||||
]
|
||||
def __init__(self, ptr: int) -> None:
|
||||
self.ptr_D = ptr
|
||||
self.stride_D = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if node.name == "D" and node.store_tensor.shape == problem_size:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class AuxStoreImpl(StoreImplBase):
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.round_style = FloatRoundStyle.ToNearest
|
||||
|
||||
@property
|
||||
def argument_type(self):
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
name = self.name
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ptr_aux", ctypes.c_void_p),
|
||||
("dAux", tuple_type)
|
||||
]
|
||||
def __init__(self, kwargs) -> None:
|
||||
ptr = kwargs[name]
|
||||
self.ptr_aux = ptr
|
||||
self.dAux = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if not node.is_output:
|
||||
return False
|
||||
if node.name in StoreImplBase.reserved_names:
|
||||
return False
|
||||
|
||||
strideMN = node.store_tensor.stride[-2:]
|
||||
if (strideMN[0] == 1 and strideMN[1] != 0 or
|
||||
strideMN[0] != 0 and strideMN[1] == 1 ):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class ReductionImplBase(StoreImplBase):
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.element = node.store_tensor.element
|
||||
self.element_compute = node.element_compute
|
||||
self.reg_reduce_fn = self.node.reg_reduce_fn
|
||||
self.gmem_reduce_fn = self.node.gmem_reduce_fn
|
||||
self.round_style = node.round_style
|
||||
self.stride_dtype = "int"
|
||||
|
||||
def get_reduce_identity(self):
|
||||
"""
|
||||
Return the reduction identity of the current reduce_fn
|
||||
"""
|
||||
maxes = {
|
||||
DataType.f32: (2 ** 31) - 1,
|
||||
DataType.f16: (2 ** 15),
|
||||
DataType.s32: (2 ** 31) - 1,
|
||||
DataType.s8: (2 ** 7) - 1
|
||||
}
|
||||
mins = {
|
||||
DataType.f32: -maxes[DataType.f32],
|
||||
DataType.f16: -maxes[DataType.f16],
|
||||
DataType.s32: -maxes[DataType.s32],
|
||||
DataType.s8: -maxes[DataType.s8]
|
||||
}
|
||||
if self.reg_reduce_fn == FunctionalOp.Maximum:
|
||||
if self.element_compute not in mins:
|
||||
raise Exception(f"No min entry for data type {self.element_compute}")
|
||||
return to_ctype_value(mins[self.element_compute], self.element_compute)
|
||||
elif self.reg_reduce_fn == FunctionalOp.Multiplies:
|
||||
return to_ctype_value(1., self.element_compute)
|
||||
elif self.reg_reduce_fn == FunctionalOp.Minimum:
|
||||
if self.element_compute not in maxes:
|
||||
raise Exception(f"No max entry for data type {self.element_compute}")
|
||||
return to_ctype_value(maxes[self.element_compute], self.element_compute)
|
||||
else:
|
||||
return to_ctype_value(0., self.element_compute)
|
||||
|
||||
@property
|
||||
def argument_type(self):
|
||||
self.get_reduce_identity()
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
name = self.name
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
element_compute = self.element_compute
|
||||
reduce_identity = self.get_reduce_identity()
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ptr", ctypes.c_void_p),
|
||||
("reduce_identity", dtype2ctype[element_compute]),
|
||||
("dMNL", tuple_type)
|
||||
]
|
||||
def __init__(self, kwargs) -> None:
|
||||
ptr = kwargs[name]
|
||||
self.ptr = ptr
|
||||
self.reduce_identity = reduce_identity
|
||||
self.dMNL = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
|
||||
class ColumnReductionImpl(ReductionImplBase):
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if not node.is_output:
|
||||
return False
|
||||
if node.name in StoreImplBase.reserved_names:
|
||||
return False
|
||||
|
||||
strideMN = node.store_tensor.stride[-2:]
|
||||
if strideMN == (1, 0):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class RowReductionImpl(ReductionImplBase):
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if not node.is_output:
|
||||
return False
|
||||
if node.name in StoreImplBase.reserved_names:
|
||||
return False
|
||||
|
||||
strideMN = node.store_tensor.stride[-2:]
|
||||
if strideMN == (0, 1):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class ScalarReductionImpl(ReductionImplBase):
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if not node.is_output:
|
||||
return False
|
||||
if node.name in StoreImplBase.reserved_names:
|
||||
return False
|
||||
|
||||
strideMN = node.store_tensor.stride[-2:]
|
||||
if strideMN == (0, 0):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class StoreNode(NodeBase):
|
||||
"""
|
||||
Store node
|
||||
"""
|
||||
possible_impls = [
|
||||
AuxStoreImpl, RowReductionImpl,
|
||||
ColumnReductionImpl, ScalarReductionImpl,
|
||||
NoOpImpl, StoreDImpl
|
||||
]
|
||||
def __init__(self, name: str) -> None:
|
||||
super().__init__(name)
|
||||
self.op = "store"
|
||||
self.is_output = False
|
||||
self._store_tensor = None
|
||||
|
||||
@property
|
||||
def store_tensor(self) -> Tensor:
|
||||
"""
|
||||
Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor)
|
||||
"""
|
||||
return self._store_tensor
|
||||
|
||||
@store_tensor.setter
|
||||
def store_tensor(self, kwargs):
|
||||
"""
|
||||
Setting the tensor
|
||||
"""
|
||||
self._store_tensor = Tensor(**kwargs)
|
||||
|
||||
def type_propagation(self, input_node_metas: 'list[NodeBase]'):
|
||||
"""
|
||||
The store nodes has element_output = element_input
|
||||
"""
|
||||
if self.is_output:
|
||||
if self.store_tensor is None:
|
||||
raise RuntimeError(f"The store tensor of node {self.name} is unknown.")
|
||||
self.element = self.store_tensor.element
|
||||
assert len(input_node_metas) == 1, "Store node can only have one input node"
|
||||
self.element_output = input_node_metas[0].element_output
|
||||
|
||||
def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
|
||||
super().broadcast_propagation(input_node_metas)
|
||||
if self.is_output:
|
||||
self._store_tensor.broadcast(self.tensor.shape)
|
||||
137
python/cutlass_cppgen/backend/evt/ir/tensor.py
Normal file
137
python/cutlass_cppgen/backend/evt/ir/tensor.py
Normal file
@ -0,0 +1,137 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
High-level class for tensor
|
||||
"""
|
||||
|
||||
from cutlass_library import LayoutType
|
||||
|
||||
from cutlass_cppgen.backend.evt.ir.layout_algorithm import (
|
||||
Layout,
|
||||
broadcast,
|
||||
canonicalization,
|
||||
permutation,
|
||||
reshape,
|
||||
_reverse_tuple
|
||||
)
|
||||
from cutlass_cppgen.utils.datatypes import get_datatype_and_layout, get_tensor_shape, library_type
|
||||
|
||||
|
||||
class Tensor:
|
||||
"""
|
||||
The tensor abstracts the data type
|
||||
"""
|
||||
def __init__(self, tensor=None, element=None, shape=None, stride=None,layout_tag=None, is_constant=False) -> None:
|
||||
if element is not None and tensor is not None:
|
||||
raise Exception(f"Must not specify both element and tensor")
|
||||
elif shape is not None and tensor is not None:
|
||||
raise Exception(f"Must not specify both shape and tensor")
|
||||
elif layout_tag is not None and tensor is not None:
|
||||
raise Exception(f"Must not specify both layout_tag and tensor")
|
||||
elif (element is None or (layout_tag is None and stride is None) or shape is None) and (tensor is None) :
|
||||
raise Exception(f"Must specify one of (element, shape, layout/stride) or (tensor)")
|
||||
elif stride is not None and tensor is not None:
|
||||
raise Exception(f"Must not specify both stride and tensor")
|
||||
elif stride is not None and layout_tag is not None:
|
||||
raise Exception(f"Must not specify layout_tag when stride is provided")
|
||||
|
||||
if isinstance(tensor, Tensor):
|
||||
# Directly copy all the attributes
|
||||
self.__dict__.update(vars(tensor))
|
||||
else:
|
||||
if tensor is None:
|
||||
self.element = library_type(element)
|
||||
else:
|
||||
self.element, layout_tag = get_datatype_and_layout(tensor)
|
||||
shape = get_tensor_shape(tensor)
|
||||
if stride is not None:
|
||||
self.layout = Layout(shape[::-1], stride[::-1])
|
||||
else:
|
||||
if layout_tag == LayoutType.RowMajor:
|
||||
self.layout = Layout(shape[::-1])
|
||||
elif layout_tag == LayoutType.ColumnMajor:
|
||||
self.layout = permutation(Layout(shape), [idx for idx in reversed(range(len(shape)))])
|
||||
self.layout = canonicalization(self.layout)
|
||||
|
||||
self.is_constant = is_constant
|
||||
# Save the tensor value if it is constant
|
||||
if is_constant and tensor is not None:
|
||||
self.value = tensor
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
"""
|
||||
Returns the RowMajor layout shape
|
||||
"""
|
||||
return _reverse_tuple(self.layout.shape)
|
||||
|
||||
@property
|
||||
def stride(self):
|
||||
"""
|
||||
Returns the RowMajor layout stride
|
||||
"""
|
||||
return _reverse_tuple(self.layout.stride)
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
"""
|
||||
Returns the rank of the tensor
|
||||
"""
|
||||
return len(self.shape)
|
||||
|
||||
#
|
||||
# Layout Algorithms
|
||||
#
|
||||
|
||||
def broadcast(self, shape):
|
||||
"""
|
||||
Broadcast self.layout to shape
|
||||
"""
|
||||
assert isinstance(shape, tuple)
|
||||
self.layout = broadcast(self.layout, _reverse_tuple(shape))
|
||||
|
||||
def reshape(self, shape):
|
||||
"""
|
||||
Reshape self.layout to shape
|
||||
"""
|
||||
assert isinstance(shape, tuple)
|
||||
reverse_shape = _reverse_tuple(shape)
|
||||
self.layout = reshape(self.layout, reverse_shape)
|
||||
|
||||
def permute(self, indices):
|
||||
"""
|
||||
Permute self.layout according to indices
|
||||
"""
|
||||
length = len(indices)
|
||||
indices = [length - idx - 1 for idx in indices]
|
||||
self.layout = permutation(self.layout, indices[::-1])
|
||||
42
python/cutlass_cppgen/backend/evt/passes/__init__.py
Normal file
42
python/cutlass_cppgen/backend/evt/passes/__init__.py
Normal file
@ -0,0 +1,42 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass_cppgen.backend.evt.passes.graph_drawer import EVTGraphDrawer
|
||||
from cutlass_cppgen.backend.evt.passes.pass_argument_type import PassGetArgumentType
|
||||
from cutlass_cppgen.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree
|
||||
from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl
|
||||
from cutlass_cppgen.backend.evt.passes.pass_fix_element_d import PassFixElementD
|
||||
from cutlass_cppgen.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination
|
||||
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassManager
|
||||
from cutlass_cppgen.backend.evt.passes.pass_preprocess_red import PassPreprocessRed
|
||||
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
||||
from cutlass_cppgen.backend.evt.passes.smem_size_calculator import GetSmemSize
|
||||
143
python/cutlass_cppgen/backend/evt/passes/graph_drawer.py
Normal file
143
python/cutlass_cppgen/backend/evt/passes/graph_drawer.py
Normal file
@ -0,0 +1,143 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
|
||||
from cutlass_library import DataTypeTag
|
||||
|
||||
from cutlass_cppgen.backend.evt.ir.dag_ir import DAGIR
|
||||
|
||||
|
||||
_COLOR_MAP = {
|
||||
"load": '"AliceBlue"',
|
||||
"compute": "LemonChiffon1",
|
||||
"accumulator": "LightGrey",
|
||||
"store": "PowderBlue",
|
||||
"layout": "lightseagreen",
|
||||
"dag": "darkorange"
|
||||
}
|
||||
|
||||
|
||||
class EVTGraphDrawer:
|
||||
"""
|
||||
Visualize a EVT DAGIR with graphviz
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
graph: DAGIR,
|
||||
name: str
|
||||
):
|
||||
self._name = name
|
||||
self._dot_graphs = {}
|
||||
|
||||
self._dot_graphs[name] = self._to_dot(graph, name)
|
||||
|
||||
def _get_node_style(self, node):
|
||||
template = {
|
||||
"shape": "record",
|
||||
"fillcolor": "#CAFFE3",
|
||||
"style": '"filled,rounded"',
|
||||
"fontcolor": "#000000",
|
||||
}
|
||||
if node.op in _COLOR_MAP:
|
||||
template["fillcolor"] = _COLOR_MAP[node.op]
|
||||
else:
|
||||
raise NotImplementedError("unknown node op")
|
||||
if node.disabled:
|
||||
template["fontcolor"] = "grey"
|
||||
template["fillcolor"] = "white"
|
||||
return template
|
||||
|
||||
def _get_node_label(self, node):
|
||||
label = "{" + f"name={node.name}|op={node.op}"
|
||||
if node.op == "layout":
|
||||
label += f"|fn={node.fn.__name__}"
|
||||
for key in node.kwargs:
|
||||
label += f"|{key}={node.kwargs[key]}"
|
||||
if node.underlying_impl is not None:
|
||||
label += f"|impl={type(node.underlying_impl).__name__}"
|
||||
if node.op == "load":
|
||||
label += f"|element_output={DataTypeTag[node.underlying_impl.element]}"
|
||||
elif node.op == "compute":
|
||||
label += f"|element_compute={DataTypeTag[node.underlying_impl.element_compute]}|element_output={DataTypeTag[node.underlying_impl.element_output]}"
|
||||
elif node.op == "store":
|
||||
label += f"|element_store={DataTypeTag[node.underlying_impl.element]}|element_output={DataTypeTag[node.underlying_impl.element_output]}"
|
||||
elif node.op == "dag":
|
||||
label += f"|element_output={DataTypeTag[node.underlying_impl.element_output]}"
|
||||
if node.tensor is not None:
|
||||
shape = node.tensor.shape
|
||||
stride = node.tensor.stride
|
||||
label += f"|shape={shape}|stride={stride}"
|
||||
|
||||
if hasattr(node, "store_tensor"):
|
||||
if node.store_tensor is not None:
|
||||
store_shape = node.store_tensor.shape
|
||||
store_stride = node.store_tensor.stride
|
||||
label += f"|store_shape={store_shape}|stride_stride={store_stride}"
|
||||
|
||||
label += "}"
|
||||
return label
|
||||
|
||||
def _to_dot(
|
||||
self,
|
||||
graph: DAGIR,
|
||||
name: str
|
||||
):
|
||||
import pydot
|
||||
dot_graph = pydot.Dot(name, randir="TB")
|
||||
for node in graph.nodes_meta:
|
||||
style = self._get_node_style(node)
|
||||
label = self._get_node_label(node)
|
||||
dot_node = pydot.Node(
|
||||
node.name, label=label, **style
|
||||
)
|
||||
dot_graph.add_node(dot_node)
|
||||
if node.op == "dag":
|
||||
dot_subgraph = self._to_dot(node.subgraph, name=node.name)
|
||||
self._dot_graphs[node.name] = dot_subgraph
|
||||
|
||||
# Add edges
|
||||
for src, dst in graph.edges:
|
||||
weight = graph.get_edge_weight(src, dst)
|
||||
dot_graph.add_edge(pydot.Edge(src, dst, label=weight))
|
||||
|
||||
return dot_graph
|
||||
|
||||
def get_dot_graph(self) -> pydot.Dot:
|
||||
return [(key, self.get_dot_graph_by_name(key)) for key in self._dot_graphs.keys()]
|
||||
|
||||
def get_dot_graph_by_name(self, name) -> pydot.Dot:
|
||||
return self._dot_graphs[name]
|
||||
|
||||
def get_main_dot_graph(self) -> pydot.Dot:
|
||||
return self._dot_graphs[self._name]
|
||||
120
python/cutlass_cppgen/backend/evt/passes/pass_argument_type.py
Normal file
120
python/cutlass_cppgen/backend/evt/passes/pass_argument_type.py
Normal file
@ -0,0 +1,120 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Construct the epilogue visitor argument type
|
||||
"""
|
||||
|
||||
from cutlass_cppgen.backend.c_types import visitor_factory
|
||||
from cutlass_cppgen.backend.evt.ir import TopoVisitorNode
|
||||
from cutlass_cppgen.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree
|
||||
from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl
|
||||
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
||||
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
||||
from cutlass_cppgen.backend.evt.passes.util import cc_map
|
||||
|
||||
|
||||
class PassGetArgumentType(EVTPassBase):
|
||||
"""
|
||||
Construct the epilogue visitor argument type
|
||||
"""
|
||||
dependencies = [
|
||||
PassShapeTypePropagation, # The Layout of all nodes must be set
|
||||
PassDAG2Tree, # The type of each node must be set
|
||||
PassGetImpl # The DAG subgraphs must be set
|
||||
]
|
||||
|
||||
def requires(self) -> None:
|
||||
# Check "D" is in the node list
|
||||
if cc_map[self.cc] in [90, 100] and (not self.dag_ir.has_node("D")):
|
||||
raise SyntaxError(
|
||||
"Sm90+ EVT requires the epilogue to have a returned tensor D, "
|
||||
"but the variable 'D' is not found in the return values.")
|
||||
|
||||
def call(self):
|
||||
nodes = self.dag_ir.nodes_topological_order()
|
||||
self.argument_types = {}
|
||||
for node in nodes:
|
||||
meta = self.dag_ir.get_node_meta(node)
|
||||
if not meta.disabled:
|
||||
self.argument_types[node] = meta.underlying_impl.argument_type
|
||||
if node == "D" and cc_map[self.cc] in [90, 100]:
|
||||
continue
|
||||
if isinstance(meta, TopoVisitorNode):
|
||||
self.get_dag_argument_type(node)
|
||||
else:
|
||||
self.get_evt_argument_type(node)
|
||||
|
||||
self.cc_specific_method(self.set_argument_type)()
|
||||
|
||||
def get_evt_argument_type(self, node):
|
||||
# Sort the input nodes by edge weight
|
||||
input_types = [self.argument_types[child] for child in self.dag_ir.get_all_inputs(node)]
|
||||
if len(input_types) > 0:
|
||||
self.argument_types[node] = visitor_factory(
|
||||
input_types + [self.argument_types[node],], self.dag_ir.get_all_inputs(node) + [node,])
|
||||
|
||||
def get_dag_argument_type(self, node):
|
||||
meta = self.dag_ir.get_node_meta(node)
|
||||
subgraph = meta.subgraph
|
||||
subgraph_nodes = subgraph.nodes_topological_order()
|
||||
# Visit the unvisited nodes in subgraph
|
||||
for n in subgraph_nodes:
|
||||
m = subgraph.get_node_meta(n)
|
||||
if m.disabled:
|
||||
continue
|
||||
else:
|
||||
self.argument_types[n] = m.underlying_impl.argument_type
|
||||
input_types = [self.argument_types[child] for child in subgraph_nodes[:-1]]
|
||||
if len(input_types) > 0:
|
||||
self.argument_types[node] = visitor_factory(input_types, subgraph_nodes[:-1])
|
||||
|
||||
def set_argument_type(self):
|
||||
pass
|
||||
|
||||
def sm90_set_argument_type(self):
|
||||
self.dag_ir.epilogue_thread_type = self.argument_types[self.dag_ir.get_all_inputs("D")[0]]
|
||||
# Get the tensorD argument type
|
||||
self.dag_ir.arg_d_type = self.dag_ir.get_node_meta("D").underlying_impl.argument_type_d
|
||||
|
||||
# Get the tensorC argument type
|
||||
if self.dag_ir.has_node("C"):
|
||||
self.dag_ir.arg_c_type = self.dag_ir.get_node_meta("C").underlying_impl.argument_type_c
|
||||
else:
|
||||
self.dag_ir.arg_c_type = self.dag_ir.arg_d_type
|
||||
|
||||
def sm100_set_argument_type(self):
|
||||
self.sm90_set_argument_type()
|
||||
|
||||
def sm80_set_argument_type(self):
|
||||
nodes = self.dag_ir.nodes_topological_order()
|
||||
self.dag_ir.epilogue_thread_type = self.argument_types[nodes[-1]]
|
||||
169
python/cutlass_cppgen/backend/evt/passes/pass_dag_2_tree.py
Normal file
169
python/cutlass_cppgen/backend/evt/passes/pass_dag_2_tree.py
Normal file
@ -0,0 +1,169 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Merge non-tree sub-graphs of the DAG IR into a single DAG. The fused DAG will be implemented
|
||||
by the topological visitor, while the rest of the graph will be implemented with the tree visitor.
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
from cutlass_cppgen.backend.evt.ir import DAGIR, TopoVisitorNode
|
||||
from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl
|
||||
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
||||
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
||||
|
||||
|
||||
class PassDAG2Tree(EVTPassBase):
|
||||
"""
|
||||
Convert the DAG IR to Tree by fusing subgraphs
|
||||
"""
|
||||
dependencies = [
|
||||
PassShapeTypePropagation,
|
||||
PassGetImpl
|
||||
]
|
||||
|
||||
def call(self):
|
||||
# Step 1: find the nodes that have multiple parents
|
||||
multi_parent_nodes = []
|
||||
|
||||
for node in self.dag_ir.nodes_topological_order():
|
||||
if self.dag_ir.out_degree(node) > 1:
|
||||
multi_parent_nodes.append(node)
|
||||
# Step 2: find the lowest common ancestor (LCA) of all its parents
|
||||
for node in multi_parent_nodes:
|
||||
# A multi-parent node could be already fused by the previous node
|
||||
if not self.dag_ir.has_node(node):
|
||||
continue
|
||||
# A node uncovered by the previous fusions can have out degree change
|
||||
# Case 1: it has <= 1 edges to the previously fused subgraph, no degree change
|
||||
# Case 2: it has more than one edges to the previously fused subgraph, degree drops
|
||||
if self.dag_ir.out_degree(node) <= 1:
|
||||
continue
|
||||
|
||||
# Otherwise, the node still
|
||||
reachable_nodes = []
|
||||
# Complexity: O(Dout*N)
|
||||
for parent in self.dag_ir.get_users(node):
|
||||
reachable_nodes.append(set(self.dag_ir.all_reachable_nodes(parent)))
|
||||
# get the common reachable objects
|
||||
common_items = set.intersection(*reachable_nodes)
|
||||
node_to_fuse = set.union(*reachable_nodes).difference(common_items)
|
||||
|
||||
lca = None
|
||||
# If common ancestor exists, find the lowest one
|
||||
if len(common_items) > 0:
|
||||
topo_order = self.dag_ir.nodes_topological_order()
|
||||
topo_idx = -1
|
||||
for item in common_items:
|
||||
if lca is None:
|
||||
lca = item
|
||||
topo_idx = topo_order.index(item)
|
||||
else:
|
||||
if topo_idx > topo_order.index(item):
|
||||
lca = item
|
||||
topo_idx = topo_order.index(item)
|
||||
else:
|
||||
# there is no common ancestor for all the parents, we pack all the reachable
|
||||
# nodes into a single DAG node as a fallback. The lca should be the input node of
|
||||
# one of the output nodes with out_degree = 0
|
||||
potential_output_nodes = []
|
||||
for node in node_to_fuse:
|
||||
if self.dag_ir.out_degree(node) == 0:
|
||||
potential_output_nodes.append(node)
|
||||
if len(potential_output_nodes) == 0:
|
||||
raise RuntimeError(f"No output node with out degree = 0 found.")
|
||||
|
||||
output_node = None
|
||||
if (self.dag_ir.cc >= 90):
|
||||
# For SM90+, the lca should be the input node of D
|
||||
if (not self.dag_ir.has_node("D")):
|
||||
raise RuntimeError(f"D is not a node in the DAG IR.")
|
||||
output_node = "D"
|
||||
else:
|
||||
output_node = potential_output_nodes[0]
|
||||
|
||||
if (output_node is None):
|
||||
raise RuntimeError(f"No output node found.")
|
||||
lca = self.dag_ir.get_all_inputs(output_node)[0]
|
||||
node_to_fuse.remove(output_node)
|
||||
|
||||
# The lca is the output node of the DAG node
|
||||
# Get the nodes to be fused
|
||||
node_to_fuse.add(lca)
|
||||
# Get all the input nodes
|
||||
all_input_nodes = []
|
||||
all_output_nodes = []
|
||||
for node in node_to_fuse:
|
||||
all_input_nodes.append(set(self.dag_ir.get_all_inputs(node)))
|
||||
all_output_nodes.append(set(self.dag_ir.get_users(node)))
|
||||
all_input_nodes = set.union(*all_input_nodes)
|
||||
all_output_nodes = set.union(*all_output_nodes)
|
||||
|
||||
new_subgraph_nodes = set.union(node_to_fuse, all_input_nodes, all_output_nodes)
|
||||
|
||||
# Create the subgraph
|
||||
subgraph_ = self.dag_ir._graph.subgraph(new_subgraph_nodes)
|
||||
subgraph = DAGIR(self.dag_ir.cc)
|
||||
for node in subgraph_.nodes:
|
||||
meta = deepcopy(self.dag_ir.get_node_meta(node))
|
||||
if node not in node_to_fuse:
|
||||
meta.disabled = True
|
||||
subgraph.add_node(meta)
|
||||
for edge in subgraph_.edges:
|
||||
subgraph.add_edge(edge[0], edge[1], self.dag_ir.get_edge_weight(edge[0], edge[1]))
|
||||
|
||||
|
||||
# Create the fused node
|
||||
dag_node = TopoVisitorNode(
|
||||
name=f"dag_{lca}", subgraph=subgraph,
|
||||
output_node=self.dag_ir.get_node_meta(lca))
|
||||
self.dag_ir.add_node(dag_node)
|
||||
|
||||
# Add input edges
|
||||
for idx, node in enumerate(all_input_nodes):
|
||||
self.dag_ir.add_edge(node, dag_node.name, weight=idx)
|
||||
|
||||
# Replace all uses with DAG node (only 1 output node)
|
||||
self.dag_ir.replace_all_uses_with(lca, dag_node.name)
|
||||
|
||||
# Remove all fused nodes
|
||||
node_to_fuse.remove(lca)
|
||||
for node in node_to_fuse:
|
||||
self.dag_ir.remove_node(node)
|
||||
|
||||
def ensures(self) -> None:
|
||||
# Ensure that after the pass, the resulting DAG becomes a tree
|
||||
for node in self.dag_ir.nodes:
|
||||
out_degree = self.dag_ir.out_degree(node)
|
||||
if out_degree > 1:
|
||||
raise RuntimeError(f"PassDAG2Tree failed. Node {node} still have outdegree = {out_degree}")
|
||||
@ -0,0 +1,64 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Fix the element_output of producer of D.
|
||||
|
||||
In Sm90 epilogue visitor, the node writing D to gmem does not have internal
|
||||
element converter, so the compute node producing D must have element_output = type(D).
|
||||
"""
|
||||
|
||||
from cutlass_cppgen.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination
|
||||
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
||||
|
||||
|
||||
class PassFixElementD(EVTPassBase):
|
||||
"""
|
||||
In Sm90 epilogue visitor, the node writing D to gmem does not have internal
|
||||
element converter, so the compute node producing D must have
|
||||
element_output = type(D)
|
||||
"""
|
||||
dependencies = [
|
||||
PassLayoutManipulateElimination
|
||||
]
|
||||
def get_producer(self, node, element_D):
|
||||
node_meta = self.dag_ir.get_node_meta(node)
|
||||
if node_meta.op == "compute":
|
||||
node_meta.element_output = element_D
|
||||
elif node_meta.op == "store":
|
||||
self.get_producer(self.dag_ir.get_all_inputs(node)[0], element_D)
|
||||
|
||||
def call(self):
|
||||
if self.dag_ir.has_node("D"):
|
||||
node_d_meta = self.dag_ir.get_node_meta("D")
|
||||
element_D = node_d_meta.store_tensor.element
|
||||
self.get_producer("D", element_D)
|
||||
90
python/cutlass_cppgen/backend/evt/passes/pass_get_impl.py
Normal file
90
python/cutlass_cppgen/backend/evt/passes/pass_get_impl.py
Normal file
@ -0,0 +1,90 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Infer the underlying implement of each node.
|
||||
|
||||
While the frontend only distinguish between Load/Store/Compute Node,
|
||||
each of these nodes can have different underlying implementation based
|
||||
on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadcast, etc.
|
||||
This pass infers the underlying impl of each node
|
||||
"""
|
||||
|
||||
import cutlass_cppgen.backend.evt.backend as evt_backend
|
||||
from cutlass_cppgen.backend.evt.ir import DAGIR, LoadNode
|
||||
from cutlass_cppgen.backend.evt.passes.pass_fix_element_d import PassFixElementD
|
||||
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
||||
from cutlass_cppgen.backend.evt.passes.pass_no_op_elimination import PassNoOpElimination
|
||||
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
||||
from cutlass_cppgen.backend.evt.passes.util import cc_map
|
||||
|
||||
|
||||
class PassGetImpl(EVTPassBase):
|
||||
"""
|
||||
While the frontend only distinguish between Load/Store/Compute Node,
|
||||
each of these nodes can have different underlying implementation based
|
||||
on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadcast, etc.
|
||||
This pass infers the underlying impl of each node
|
||||
"""
|
||||
dependencies = [
|
||||
PassShapeTypePropagation, # The shape and type info are required for inference
|
||||
PassFixElementD
|
||||
]
|
||||
|
||||
def __init__(self, dag_ir: DAGIR) -> None:
|
||||
super().__init__(dag_ir)
|
||||
self.no_op_elimination = PassNoOpElimination(dag_ir)
|
||||
|
||||
def requires(self) -> None:
|
||||
# Verify "accum" is in the arg list
|
||||
if not self.dag_ir.has_node("accum"):
|
||||
raise SyntaxError("Cannot find 'accum' in the argument list.")
|
||||
|
||||
def call(self):
|
||||
# The loop structure of the epilogue is determined by the
|
||||
# accumulator shape
|
||||
accumulator: LoadNode = self.dag_ir.get_node_meta("accum")
|
||||
problem_size = accumulator.tensor.shape
|
||||
|
||||
for node_meta in self.dag_ir.node_metas_topological_order():
|
||||
node_meta.get_underlying_impl(problem_size)
|
||||
|
||||
def ensures(self) -> None:
|
||||
# Some nodes will be lowered to NoOp, eliminate them
|
||||
self.no_op_elimination()
|
||||
# Lower to cc-specific impl
|
||||
for node_meta in self.dag_ir.nodes_meta:
|
||||
node_impl_ccs = getattr(evt_backend, f"sm{cc_map[self.cc]}_nodes")
|
||||
node_meta.underlying_impl = getattr(
|
||||
node_impl_ccs,
|
||||
f"Sm{cc_map[self.cc]}" + node_meta.underlying_impl.__class__.__name__
|
||||
)(node_meta)
|
||||
@ -0,0 +1,217 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Eliminate layout manipulation nodes
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
from cutlass_cppgen.backend.evt.ir import DAGIR, LayoutNode
|
||||
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
||||
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
||||
|
||||
|
||||
class PassLayoutManipulateElimination(EVTPassBase):
|
||||
"""
|
||||
Eliminate layout manipulation nodes
|
||||
"""
|
||||
dependencies = [PassShapeTypePropagation]
|
||||
|
||||
def __init__(self, dag_ir: DAGIR) -> None:
|
||||
super().__init__(dag_ir)
|
||||
self.copy_cnt = 0
|
||||
|
||||
def call(self):
|
||||
self.layout_nodes_worklist = self.get_all_layout_nodes()
|
||||
# Run while loop utill all layout nodes are eliminated
|
||||
while(len(self.layout_nodes_worklist) > 0):
|
||||
node = self.layout_nodes_worklist.pop(0)
|
||||
# for node in layout_nodes:
|
||||
# Step 1: get the propagation direction
|
||||
direction = self.get_propagation_direction(node)
|
||||
self.visited = []
|
||||
getattr(self, f"propagate_to_{direction}")(self.dag_ir.get_node_meta(node), node)
|
||||
# Eliminate the current node
|
||||
input_node = self.dag_ir.get_all_inputs(node)[0]
|
||||
self.dag_ir.replace_all_uses_with(node, input_node)
|
||||
# layout_nodes = self.get_all_layout_nodes()
|
||||
|
||||
def get_all_layout_nodes(self):
|
||||
layout_nodes = []
|
||||
for node_meta in reversed(self.dag_ir.node_metas_topological_order()):
|
||||
if isinstance(node_meta, LayoutNode):
|
||||
layout_nodes.append(node_meta.name)
|
||||
return layout_nodes
|
||||
|
||||
def get_propagation_direction(self, node: str):
|
||||
"""
|
||||
The logic is propagating all layout nodes away from the accumulator node.
|
||||
"""
|
||||
self.visited = []
|
||||
self.get_influenced_users(node)
|
||||
nodes_influenced_dir_users = self.visited
|
||||
self.visited = []
|
||||
self.get_influenced_inputs(node)
|
||||
nodes_influenced_dir_inputs = self.visited
|
||||
|
||||
if "accum" in nodes_influenced_dir_users and "accum" not in nodes_influenced_dir_inputs:
|
||||
return "inputs"
|
||||
elif "accum" not in nodes_influenced_dir_users and "accum" in nodes_influenced_dir_inputs:
|
||||
return "users"
|
||||
else:
|
||||
raise RuntimeError("Unsolved propagation direction")
|
||||
|
||||
# Get all influenced nodes if we propagate along the user direction
|
||||
def get_influenced_users(self, node: str):
|
||||
if node in self.visited:
|
||||
return
|
||||
self.visited.append(node)
|
||||
|
||||
users = self.dag_ir.get_users(node)
|
||||
for user in users:
|
||||
self.get_influenced_users(user)
|
||||
user_inputs = []
|
||||
for user in users:
|
||||
user_inputs.append(set(self.dag_ir.get_all_inputs(user)))
|
||||
if len(user_inputs) > 0:
|
||||
user_inputs = set.union(*user_inputs)
|
||||
user_inputs.remove(node)
|
||||
for input in user_inputs:
|
||||
self.get_influenced_inputs(input)
|
||||
|
||||
# Get all influenced nodes if we propagate along the input direction
|
||||
def get_influenced_inputs(self, node: str):
|
||||
if node in self.visited:
|
||||
return
|
||||
self.visited.append(node)
|
||||
|
||||
inputs = self.dag_ir.get_all_inputs(node)
|
||||
for input in inputs:
|
||||
self.get_influenced_inputs(input)
|
||||
input_users = []
|
||||
for input in inputs:
|
||||
input_users.append(set(self.dag_ir.get_users(input)))
|
||||
if len(input_users) > 0:
|
||||
input_users = set.union(*input_users)
|
||||
input_users.remove(node)
|
||||
for user in input_users:
|
||||
self.get_influenced_users(user)
|
||||
|
||||
def add_copy_before(self, layout_node_meta: LayoutNode, target: str):
|
||||
copied_node_meta = deepcopy(layout_node_meta)
|
||||
copied_node = f"{copied_node_meta.name}_copy{self.copy_cnt}"
|
||||
self.copy_cnt += 1
|
||||
copied_node_meta.name = copied_node
|
||||
self.dag_ir.add_node(copied_node_meta)
|
||||
# Add edges
|
||||
target_inputs = self.dag_ir.get_all_inputs(target)
|
||||
for src in target_inputs:
|
||||
self.dag_ir.remove_edge(src, target)
|
||||
self.dag_ir.add_edge(src, copied_node)
|
||||
self.dag_ir.add_edge(copied_node, target)
|
||||
self.layout_nodes_worklist.append(copied_node)
|
||||
|
||||
def add_copy_after(self, layout_node_meta: LayoutNode, target: str):
|
||||
copied_node_meta = deepcopy(layout_node_meta)
|
||||
copied_node = f"{copied_node_meta.name}_copy{self.copy_cnt}"
|
||||
self.copy_cnt += 1
|
||||
copied_node_meta.name = copied_node
|
||||
self.dag_ir.add_node(copied_node_meta)
|
||||
# Add edges
|
||||
users = self.dag_ir.get_users(target)
|
||||
for user in users:
|
||||
self.dag_ir.remove_edge(target, user)
|
||||
self.dag_ir.add_edge(copied_node, user)
|
||||
self.dag_ir.add_edge(target, copied_node)
|
||||
self.layout_nodes_worklist.append(copied_node)
|
||||
|
||||
# Propagate the layout `node` along the user direction
|
||||
def propagate_to_users(self, layout_node_meta: LayoutNode, node: str):
|
||||
"""
|
||||
Propagate layout node to users
|
||||
"""
|
||||
if node in self.visited:
|
||||
# Avoid applying twice
|
||||
return
|
||||
self.visited.append(node)
|
||||
|
||||
node_meta = self.dag_ir.get_node_meta(node)
|
||||
if layout_node_meta.name != node:
|
||||
if isinstance(node_meta, LayoutNode):
|
||||
# Layout node is not transparent with layout node
|
||||
self.add_copy_before(layout_node_meta, node)
|
||||
return
|
||||
else:
|
||||
layout_node_meta.apply_to_user(node_meta)
|
||||
|
||||
users = self.dag_ir.get_users(node)
|
||||
user_inputs = []
|
||||
for user in users:
|
||||
user_inputs.append(set(self.dag_ir.get_all_inputs(user)))
|
||||
for user in users:
|
||||
self.propagate_to_users(layout_node_meta, user)
|
||||
if len(user_inputs) > 0:
|
||||
user_inputs = set.union(*user_inputs)
|
||||
user_inputs.remove(node)
|
||||
for input in user_inputs:
|
||||
self.propagate_to_inputs(layout_node_meta.get_inverse_node(), input)
|
||||
|
||||
# Propagate the layout `node` along the input direction
|
||||
def propagate_to_inputs(self, layout_node_meta: LayoutNode, node: str):
|
||||
"""
|
||||
Propagate layout node to inputs
|
||||
"""
|
||||
if node in self.visited:
|
||||
# Avoid applying twice
|
||||
return
|
||||
self.visited.append(node)
|
||||
|
||||
node_meta = self.dag_ir.get_node_meta(node)
|
||||
if layout_node_meta.name != node:
|
||||
if isinstance(node_meta, LayoutNode):
|
||||
# Layout node is not transparent with layout node
|
||||
self.add_copy_after(layout_node_meta, node)
|
||||
return
|
||||
else:
|
||||
layout_node_meta.apply_to_input(node_meta)
|
||||
inputs = self.dag_ir.get_all_inputs(node)
|
||||
input_users = []
|
||||
for input in inputs:
|
||||
input_users.append(set(self.dag_ir.get_users(input)))
|
||||
for input in inputs:
|
||||
self.propagate_to_inputs(layout_node_meta, input)
|
||||
if len(input_users) > 0:
|
||||
input_users = set.union(*input_users)
|
||||
input_users.remove(node)
|
||||
for user in input_users:
|
||||
self.propagate_to_users(layout_node_meta.get_inverse_node(), user)
|
||||
164
python/cutlass_cppgen/backend/evt/passes/pass_manager.py
Normal file
164
python/cutlass_cppgen/backend/evt/passes/pass_manager.py
Normal file
@ -0,0 +1,164 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Pass manager for DAG IR.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from cutlass_cppgen.backend.evt.ir import DAGIR
|
||||
from cutlass_cppgen.backend.evt.passes.util import cc_map
|
||||
|
||||
|
||||
class EVTPassBase:
|
||||
"""
|
||||
Base class for EVT Passes
|
||||
"""
|
||||
dependencies = []
|
||||
def __init__(self, dag_ir: DAGIR) -> None:
|
||||
self.dag_ir = dag_ir
|
||||
self.cc = self.dag_ir.cc
|
||||
|
||||
def requires(self) -> None:
|
||||
"""
|
||||
This function will be called before the pass is run.
|
||||
"""
|
||||
pass
|
||||
|
||||
def call(self) -> None:
|
||||
"""
|
||||
The pass that is run through the self.dag_ir
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"__call__ is not overwritten in Pass {self.__class__.__name__}")
|
||||
|
||||
def ensures(self) -> None:
|
||||
"""
|
||||
This function will be called after the pass is run.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __call__(self) -> Any:
|
||||
self.requires()
|
||||
self.call()
|
||||
self.ensures()
|
||||
|
||||
def cc_specific_method(self, func):
|
||||
"""
|
||||
This enables defining function that behaves differently under different cc
|
||||
The simplest example of using this function is the following
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
class ExamplePass(EVTPassBase):
|
||||
|
||||
def call(sekf):
|
||||
# This automatically select the smXX_func based on current cc
|
||||
self.cc_specific_method(self.func)()
|
||||
|
||||
# Interface func, can be empty
|
||||
def func(self):
|
||||
pass
|
||||
|
||||
# Sm90 specific func
|
||||
def sm90_func(self):
|
||||
// sm90 specific method
|
||||
return
|
||||
|
||||
# Sm80 specific func
|
||||
def sm80_func(self):
|
||||
// sm80 specific method
|
||||
return
|
||||
"""
|
||||
func_name = f"sm{cc_map[self.cc]}_{func.__name__}"
|
||||
if hasattr(self, func_name):
|
||||
return getattr(self, func_name)
|
||||
else:
|
||||
raise NotImplementedError(f"func {func.__name__} is not overwritten for Sm{self.cc}")
|
||||
|
||||
|
||||
class EVTPassManager(nx.DiGraph):
|
||||
"""
|
||||
Topological-based Pass Manager.
|
||||
Each registered pass has a list of dependencies. The pass manager organizes
|
||||
the passes as a DAG and launch the compiler passes under topological order.
|
||||
"""
|
||||
def __init__(self, dag_ir: DAGIR, pass_list):
|
||||
super().__init__()
|
||||
self.dag_ir = dag_ir
|
||||
for pass_cls in pass_list:
|
||||
self.add_pass(pass_cls)
|
||||
|
||||
self.sorted_passes = self.schedule()
|
||||
|
||||
def get_callable(self, pass_name):
|
||||
"""
|
||||
Return the callable of the pass
|
||||
"""
|
||||
return self.nodes[pass_name]["callable"]
|
||||
|
||||
def add_pass(self, pass_cls):
|
||||
"""
|
||||
Add a pass to the pass manager
|
||||
:param pass_cls: the class of pass
|
||||
:type pass_cls: derived class of EVTPassBase
|
||||
"""
|
||||
name = pass_cls.__name__
|
||||
pass_callable = pass_cls(self.dag_ir)
|
||||
self.add_node(name, callable=pass_callable)
|
||||
|
||||
def schedule(self):
|
||||
"""
|
||||
Schedule the added passes under topological order
|
||||
"""
|
||||
# Add edges
|
||||
for pass_name in self.nodes:
|
||||
callable = self.get_callable(pass_name)
|
||||
for dependency_cls in callable.dependencies:
|
||||
self.add_edge(
|
||||
dependency_cls.__name__,
|
||||
type(callable).__name__)
|
||||
|
||||
# Topological sort
|
||||
return list(nx.topological_sort(self))
|
||||
|
||||
def __call__(self) -> Any:
|
||||
"""
|
||||
Launch the registered passes
|
||||
"""
|
||||
for pass_name in self.sorted_passes:
|
||||
callable = self.get_callable(pass_name)
|
||||
callable()
|
||||
@ -0,0 +1,53 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
No op elimination node
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from cutlass_cppgen.backend.evt.ir import NoOpImpl
|
||||
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
||||
|
||||
|
||||
class PassNoOpElimination(EVTPassBase):
|
||||
"""
|
||||
The dead node elimination pass removes nodes with NoOpImpl in DAG IR
|
||||
"""
|
||||
dependencies = []
|
||||
|
||||
def call(self) -> Any:
|
||||
for node in self.dag_ir.nodes_topological_order():
|
||||
node_meta = self.dag_ir.get_node_meta(node)
|
||||
if isinstance(node_meta.underlying_impl, NoOpImpl):
|
||||
self.dag_ir.replace_all_uses_with(node, self.dag_ir.get_all_inputs(node)[0])
|
||||
@ -0,0 +1,97 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Preprocess the reduction nodes.
|
||||
|
||||
The parser treats reduction as Compute(op=(reg_reduce_fn, gmem_reduce_fn)) - Store()
|
||||
This pass fuses these into a single store node, and then replaces all uses of the
|
||||
current node with the new store node.
|
||||
"""
|
||||
|
||||
from cutlass_cppgen.backend.evt.ir import ComputeNode, StoreNode
|
||||
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
||||
|
||||
|
||||
class PassPreprocessRed(EVTPassBase):
|
||||
"""
|
||||
Preprocess red nodes
|
||||
"""
|
||||
|
||||
def call(self):
|
||||
# Step 1: find the compute nodes with op=red
|
||||
red_compute_nodes = []
|
||||
for node_meta in self.dag_ir.nodes_meta:
|
||||
if isinstance(node_meta, ComputeNode):
|
||||
if type(node_meta.fn) == tuple:
|
||||
# To keep the frontend simple, the reduction nodes
|
||||
# are parsed into compute nodes by default
|
||||
# The simple heuristic to distinguish between compute
|
||||
# and reduction node is that compute node is a single function,
|
||||
# while the reduction node is a tuple of functions for
|
||||
# in-register reduction and atomic global memory reduction
|
||||
red_compute_nodes.append(node_meta.name)
|
||||
|
||||
# Step 2: for each compute, merge it with the succeeding store
|
||||
for node in red_compute_nodes:
|
||||
# Verify
|
||||
users = self.dag_ir.get_users(node)
|
||||
inputs = self.dag_ir.get_all_inputs(node)
|
||||
# Has a single user
|
||||
assert len(users) == 1
|
||||
assert len(inputs) == 1
|
||||
user = users[0]
|
||||
input = inputs[0]
|
||||
|
||||
user_meta = self.dag_ir.get_node_meta(user)
|
||||
# Must be a store node
|
||||
assert isinstance(user_meta, StoreNode)
|
||||
# With output degree == 0
|
||||
assert self.dag_ir.out_degree(user) == 0
|
||||
# Register the reduce op
|
||||
node_meta = self.dag_ir.get_node_meta(node)
|
||||
user_meta.reg_reduce_fn, user_meta.gmem_reduce_fn = node_meta.fn
|
||||
user_meta.element_compute = node_meta.element_compute
|
||||
user_meta.round_style = node_meta.round_style
|
||||
|
||||
# Replace all uses
|
||||
self.dag_ir.remove_edge(input, node)
|
||||
input_users = self.dag_ir.get_users(input)
|
||||
for iu in input_users:
|
||||
weight = self.dag_ir.get_edge_weight(input, iu)
|
||||
self.dag_ir.add_edge(user, iu, weight)
|
||||
self.dag_ir.remove_edge(input, iu)
|
||||
self.dag_ir.add_edge(input, user)
|
||||
self.dag_ir.remove_node(node)
|
||||
|
||||
# Register the reduction name
|
||||
self.dag_ir.reduction_names.append(user)
|
||||
@ -0,0 +1,59 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Shape and type propagation pass
|
||||
"""
|
||||
|
||||
from cutlass_cppgen.backend.evt.ir.node import NodeBase
|
||||
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
||||
from cutlass_cppgen.backend.evt.passes.pass_preprocess_red import PassPreprocessRed
|
||||
|
||||
|
||||
class PassShapeTypePropagation(EVTPassBase):
|
||||
"""
|
||||
Propagate the shape and type of all nodes
|
||||
"""
|
||||
dependencies = [PassPreprocessRed]
|
||||
|
||||
def call(self):
|
||||
# Propagate the node shape and type
|
||||
for node in self.dag_ir.nodes_topological_order():
|
||||
node_meta: NodeBase = self.dag_ir.get_node_meta(node)
|
||||
input_node_metas = self.dag_ir.get_all_inputs_meta(node)
|
||||
node_meta.type_propagation(input_node_metas)
|
||||
node_meta.shape_propagation(input_node_metas)
|
||||
|
||||
for node in reversed(self.dag_ir.nodes_topological_order()):
|
||||
node_meta: NodeBase = self.dag_ir.get_node_meta(node)
|
||||
input_node_metas = self.dag_ir.get_all_inputs_meta(node)
|
||||
node_meta.broadcast_propagation(input_node_metas)
|
||||
319
python/cutlass_cppgen/backend/evt/passes/smem_size_calculator.py
Normal file
319
python/cutlass_cppgen/backend/evt/passes/smem_size_calculator.py
Normal file
@ -0,0 +1,319 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Compute the shared memory size in bytes
|
||||
"""
|
||||
|
||||
from math import gcd
|
||||
|
||||
import cutlass_library
|
||||
from pycute import flatten, shape_div, product
|
||||
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen.backend.evt.ir import TopoVisitorNode, DAGIR
|
||||
from cutlass_cppgen.backend.library import DataType, DataTypeSize
|
||||
|
||||
|
||||
class GetSmemSize:
|
||||
"""
|
||||
Get the size in byte of shared memory used by the kernel
|
||||
"""
|
||||
def __init__(self, dag_ir: DAGIR) -> None:
|
||||
self.dag_ir = dag_ir
|
||||
self.cc = self.dag_ir.cc
|
||||
|
||||
#
|
||||
# Sm90 epilogue specific
|
||||
#
|
||||
|
||||
def sm90_epilogue_tile(self, tile_description):
|
||||
# Get the epilogue tile size
|
||||
schedule = tile_description.epilogue_schedule
|
||||
if schedule == cutlass_library.EpilogueScheduleType.TmaWarpSpecialized:
|
||||
element_d = self.dag_ir.get_node_meta("D").element
|
||||
nperf = 64 if (DataTypeSize[element_d] == 8 and tile_description.threadblock_shape[1] % 64 == 0) else 32
|
||||
epi_tile_m = min(64, tile_description.threadblock_shape[0])
|
||||
epi_tile_n = gcd(min(nperf, tile_description.threadblock_shape[1]), tile_description.threadblock_shape[1])
|
||||
epilogue_tile_mn = (epi_tile_m, epi_tile_n)
|
||||
elif schedule == cutlass_library.EpilogueScheduleType.TmaWarpSpecializedCooperative:
|
||||
epi_tile_m = min(128, tile_description.threadblock_shape[0])
|
||||
epi_tile_n = gcd(min(32, tile_description.threadblock_shape[1]), tile_description.threadblock_shape[1])
|
||||
epilogue_tile_mn = (epi_tile_m, epi_tile_n)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported schedule: {schedule}")
|
||||
|
||||
# Get the pipeline stages
|
||||
stages_d = 2
|
||||
epi_tiles = product(shape_div(tuple(tile_description.threadblock_shape)[:2], epilogue_tile_mn))
|
||||
if self.dag_ir.has_node("C"):
|
||||
element_c = self.dag_ir.get_node_meta("C").element
|
||||
else:
|
||||
element_c = None
|
||||
|
||||
element_d = self.dag_ir.get_node_meta("D").element
|
||||
if element_c == element_d:
|
||||
reuse_smem_c = True
|
||||
else:
|
||||
reuse_smem_c = False
|
||||
stages_c = max(epi_tiles, stages_d + 1) if reuse_smem_c else epi_tiles
|
||||
|
||||
# Record the epilogue tile
|
||||
self.cta_tile_mnk = tuple(tile_description.threadblock_shape)
|
||||
self.epilogue_tile_mn = epilogue_tile_mn
|
||||
self.epi_tiles = epi_tiles
|
||||
self.stages_c = stages_c
|
||||
self.stages_d = stages_d
|
||||
self.reuse_smem_c = reuse_smem_c
|
||||
self.element_c = element_c
|
||||
self.element_d = element_d
|
||||
self.is_source_supported = element_c is not None
|
||||
|
||||
def sm90_or_sm100_epilogue_smem_size(self, tile_description):
|
||||
# Get the Fusion Storage
|
||||
nodes = self.dag_ir.nodes_topological_order()
|
||||
self.smem_types = {}
|
||||
for node in nodes:
|
||||
meta = self.dag_ir.get_node_meta(node)
|
||||
if not meta.disabled:
|
||||
self.smem_types[node] = meta.underlying_impl.get_smem_size(
|
||||
self.cta_tile_mnk, self.epilogue_tile_mn,
|
||||
self.stages_c, self.stages_d, self.epi_tiles)
|
||||
if node == "D":
|
||||
continue
|
||||
if isinstance(meta, TopoVisitorNode):
|
||||
self.get_dag_smem_type(node)
|
||||
else:
|
||||
self.get_evt_smem_type(node)
|
||||
|
||||
thread_smem_size = self.smem_types[self.dag_ir.get_all_inputs("D")[0]][0]
|
||||
# Get the Tensor Storage
|
||||
tensors = []
|
||||
if self.is_source_supported:
|
||||
smem_C = DataTypeSize[self.element_c] * product(self.epilogue_tile_mn) * self.stages_c // 8
|
||||
tensors.append((smem_C, 128))
|
||||
else:
|
||||
tensors.append((0, 1))
|
||||
if self.reuse_smem_c:
|
||||
tensors.append((0, 128))
|
||||
else:
|
||||
smem_D = DataTypeSize[self.element_d] * product(self.epilogue_tile_mn) * self.stages_d // 8
|
||||
tensors.append((smem_D, 128))
|
||||
tensors.append((thread_smem_size, 128))
|
||||
|
||||
tensor_smem_size = self.get_struct_size(tensors)
|
||||
# Get pipeline storage size
|
||||
# sizeof(uint64_t * stages_c * 2), alignment of uint64_t
|
||||
# 2 is for FullBarrier and EmptyBarrier
|
||||
pipeline_smem_size = (8 * self.stages_c * 2, 8)
|
||||
|
||||
# get SharedStorage size
|
||||
smem_size = self.get_struct_size([tensor_smem_size, pipeline_smem_size])
|
||||
return smem_size[0]
|
||||
|
||||
def sm90_epilogue_smem_size(self, tile_description):
|
||||
"""
|
||||
Compute the shared memory size of sm90 collective epilogue
|
||||
"""
|
||||
self.sm90_epilogue_tile(tile_description)
|
||||
return self.sm90_or_sm100_epilogue_smem_size(tile_description)
|
||||
|
||||
#
|
||||
# Sm100 epilogue specific
|
||||
#
|
||||
|
||||
def sm100_epilogue_tile(self, tile_description):
|
||||
cta_tile = (tile_description.blackwell_threadblock_shape[0], tile_description.blackwell_threadblock_shape[1])
|
||||
mma_tile = cta_tile
|
||||
|
||||
if tile_description.is_2sm:
|
||||
cta_tile = (cta_tile[0] // 2, cta_tile[1])
|
||||
|
||||
if tile_description.is_2sm and mma_tile[0] == 128:
|
||||
tmem_warps = (2, 2)
|
||||
else:
|
||||
tmem_warps = (4, 1)
|
||||
|
||||
if self.dag_ir.has_node("C"):
|
||||
element_c = self.dag_ir.get_node_meta("C").element
|
||||
element_c_size = DataTypeSize[element_c]
|
||||
else:
|
||||
element_c = None
|
||||
element_c_size = 0
|
||||
|
||||
element_d = self.dag_ir.get_node_meta("D").element
|
||||
|
||||
DisableSource = element_c is None or not self.dag_ir.has_node("C") or self.dag_ir.get_node_meta("C").element == DataType.void
|
||||
|
||||
CtaM = cta_tile[0]
|
||||
CtaN = cta_tile[1]
|
||||
WarpM = tmem_warps[0]
|
||||
WarpN = tmem_warps[1]
|
||||
MaxBits = max(element_c_size, DataTypeSize[element_d])
|
||||
DpFull = 32
|
||||
M = min(CtaM, DpFull * WarpM)
|
||||
|
||||
if DisableSource:
|
||||
# Epilogues w/o residual load are less sensitive to smem allocation
|
||||
# Target a fixed amount of compute per epilogue iteration
|
||||
if MaxBits == 4:
|
||||
# Make epilogue tile larger to reduce the epilogue iterations.
|
||||
# 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same.
|
||||
ComputeElts = 8192
|
||||
Nperf = ComputeElts // M
|
||||
else:
|
||||
ComputeElts = 4096
|
||||
Nperf = ComputeElts // M
|
||||
else:
|
||||
# Epilogues w/ residual load are more sensitive to smem allocation
|
||||
# Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize
|
||||
if MaxBits == 32:
|
||||
Nperf = 16 if CtaM > 64 and CtaN <= 128 else 32
|
||||
elif MaxBits == 16:
|
||||
Nperf = 32 if CtaN <= 128 else 64
|
||||
else:
|
||||
Nperf = 64
|
||||
|
||||
def is_m_major(layout):
|
||||
return flatten(layout.stride[0]) == 1
|
||||
|
||||
if DisableSource or is_m_major(self.dag_ir.get_node_meta("C").tensor.layout):
|
||||
N_min_C = 8 * WarpN
|
||||
elif element_c_size == 6:
|
||||
N_min_C = 128 * WarpN
|
||||
else:
|
||||
N_min_C = (128 // element_c_size) * WarpN
|
||||
|
||||
if is_m_major(self.dag_ir.get_node_meta("D").tensor.layout):
|
||||
N_min_D = 8 * WarpN
|
||||
elif DataTypeSize[element_d] == 6:
|
||||
N_min_D = 128 * WarpN
|
||||
else:
|
||||
N_min_D = (128 // DataTypeSize[element_d]) * WarpN
|
||||
|
||||
N = min(CtaN, max(Nperf, N_min_C, N_min_D))
|
||||
|
||||
tile_m = M
|
||||
tile_n_size = N // WarpN * WarpN
|
||||
|
||||
epilogue_tile_mn = (tile_m, tile_n_size)
|
||||
epi_tiles = product(shape_div(tuple(tile_description.threadblock_shape)[:2], epilogue_tile_mn))
|
||||
|
||||
stages_d = min(epi_tiles, 2)
|
||||
reuse_smem_c = (element_c_size > 8)
|
||||
|
||||
if reuse_smem_c:
|
||||
stages_c = max(min(epi_tiles, 4), stages_d + 1)
|
||||
else:
|
||||
stages_c = min(epi_tiles, 4)
|
||||
|
||||
# Record the epilogue tile
|
||||
self.cta_tile_mnk = tuple(tile_description.threadblock_shape)
|
||||
self.epilogue_tile_mn = epilogue_tile_mn
|
||||
self.epi_tiles = epi_tiles
|
||||
self.stages_c = stages_c
|
||||
self.stages_d = stages_d
|
||||
self.reuse_smem_c = reuse_smem_c
|
||||
self.element_c = element_c
|
||||
self.element_d = element_d
|
||||
self.is_source_supported = not DisableSource
|
||||
|
||||
def sm100_epilogue_smem_size(self, tile_description):
|
||||
"""
|
||||
Compute the shared memory size of sm100 collective epilogue
|
||||
"""
|
||||
self.sm100_epilogue_tile(tile_description)
|
||||
return self.sm90_or_sm100_epilogue_smem_size(tile_description)
|
||||
|
||||
def __call__(self, tile_description):
|
||||
return getattr(self, f"sm{self.cc}_epilogue_smem_size")(tile_description)
|
||||
|
||||
#
|
||||
# Helper functions
|
||||
#
|
||||
|
||||
@staticmethod
|
||||
def get_visitor_size(members: list, ebo: bool):
|
||||
"""
|
||||
Get the size of struct in bytes
|
||||
"""
|
||||
offset = 0
|
||||
max_alignment = 1
|
||||
if len(members) > 0:
|
||||
# Get alignment
|
||||
for _, alignment in members:
|
||||
max_alignment = max(max_alignment, alignment)
|
||||
|
||||
for type_size, _ in members:
|
||||
if type_size != 0:
|
||||
offset = ((offset + max_alignment - 1) // max_alignment) * max_alignment
|
||||
if type_size == 0 and not ebo:
|
||||
offset += 1
|
||||
else:
|
||||
offset += type_size
|
||||
offset = ((offset + max_alignment - 1) // max_alignment) * max_alignment
|
||||
return (offset, max_alignment)
|
||||
else:
|
||||
# Struct size is at least 1
|
||||
return (1, 1)
|
||||
|
||||
def get_struct_size(self, members: list):
|
||||
"""
|
||||
Get the size of struct in bytes
|
||||
"""
|
||||
return self.get_visitor_size(members, False)
|
||||
|
||||
def get_evt_smem_type(self, node):
|
||||
# Sort the input nodes by edge weight
|
||||
input_types = [self.smem_types[child] for child in self.dag_ir.get_all_inputs(node)]
|
||||
input_types.append(self.smem_types[node])
|
||||
if len(input_types) > 1:
|
||||
ebo = len(input_types) > 4
|
||||
self.smem_types[node] = self.get_visitor_size(input_types, ebo)
|
||||
|
||||
def get_dag_smem_type(self, node):
|
||||
meta = self.dag_ir.get_node_meta(node)
|
||||
subgraph = meta.subgraph
|
||||
subgraph_nodes = subgraph.nodes_topological_order()
|
||||
# Visit the unvisited nodes in subgraph
|
||||
for n in subgraph_nodes:
|
||||
m = subgraph.get_node_meta(n)
|
||||
if m.disabled:
|
||||
continue
|
||||
else:
|
||||
self.smem_types[n] = m.underlying_impl.get_smem_size(
|
||||
self.cta_tile_mnk, self.epilogue_tile_mn,
|
||||
self.stages_c, self.stages_d, self.epi_tiles)
|
||||
input_types = [self.smem_types[child] for child in subgraph_nodes[:-1]]
|
||||
if len(input_types) > 0:
|
||||
ebo = len(input_types) > 4
|
||||
self.smem_types[node] = self.get_visitor_size(input_types, ebo)
|
||||
46
python/cutlass_cppgen/backend/evt/passes/util.py
Normal file
46
python/cutlass_cppgen/backend/evt/passes/util.py
Normal file
@ -0,0 +1,46 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Utilities for passes
|
||||
"""
|
||||
|
||||
# Map from the CC of the kernel to the EVT implementation that the CC targets
|
||||
cc_map = {
|
||||
80: 80,
|
||||
86: 80,
|
||||
89: 80,
|
||||
90: 90,
|
||||
100: 100,
|
||||
101: 100,
|
||||
103: 100,
|
||||
}
|
||||
109
python/cutlass_cppgen/backend/frontend.py
Normal file
109
python/cutlass_cppgen/backend/frontend.py
Normal file
@ -0,0 +1,109 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
from __future__ import annotations
|
||||
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
import numpy as np
|
||||
|
||||
from cutlass_cppgen.backend.memory_manager import device_mem_alloc, todevice
|
||||
from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor
|
||||
|
||||
|
||||
class NumpyFrontend:
|
||||
"""
|
||||
Frontend node for numpy
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def argument(np_tensor: "np.ndarray", is_output: "bool") -> cuda.CUdeviceptr:
|
||||
"""Convert the input numpy tensor to CUDA device pointer
|
||||
|
||||
:param np_tensor: input numpy nd array
|
||||
:param is_output: whether the tensor is output
|
||||
|
||||
:return: CUDA device pointer
|
||||
"""
|
||||
# copy the data to device
|
||||
if is_output:
|
||||
return device_mem_alloc(np_tensor.size * np_tensor.itemsize)
|
||||
else:
|
||||
return todevice(np_tensor)
|
||||
|
||||
|
||||
class TorchFrontend:
|
||||
"""
|
||||
Frontend node for torch
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def argument(torch_tensor: "torch.Tensor") -> cuda.CUdeviceptr:
|
||||
"""Convert the input torch tensor to CUDA device pointer
|
||||
|
||||
:param torch_tensor: input torch tensor
|
||||
:param is_output: whether the tensor is output
|
||||
|
||||
:return: CUDA device pointer
|
||||
"""
|
||||
|
||||
# check the device of torch_tensor
|
||||
if not torch_tensor.is_cuda:
|
||||
torch_tensor = torch_tensor.to("cuda")
|
||||
|
||||
return cuda.CUdeviceptr(torch_tensor.data_ptr())
|
||||
|
||||
|
||||
class CupyFrontend:
|
||||
"""
|
||||
Frontend node for cupy
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def argument(cupy_ndarray: "cp.ndarray"):
|
||||
return cuda.CUdeviceptr(int(cupy_ndarray.data.ptr))
|
||||
|
||||
|
||||
class TensorFrontend:
|
||||
"""
|
||||
Universal Frontend for client-provide tensors
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def argument(tensor, is_output=False):
|
||||
if is_numpy_tensor(tensor):
|
||||
return NumpyFrontend.argument(tensor, is_output)
|
||||
elif is_torch_tensor(tensor):
|
||||
return TorchFrontend.argument(tensor)
|
||||
elif is_cupy_tensor(tensor):
|
||||
return CupyFrontend.argument(tensor)
|
||||
else:
|
||||
raise NotImplementedError("Unknown Tensor Type")
|
||||
2145
python/cutlass_cppgen/backend/gemm_operation.py
Normal file
2145
python/cutlass_cppgen/backend/gemm_operation.py
Normal file
File diff suppressed because it is too large
Load Diff
509
python/cutlass_cppgen/backend/library.py
Normal file
509
python/cutlass_cppgen/backend/library.py
Normal file
@ -0,0 +1,509 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Common data types and string names/tags for them
|
||||
"""
|
||||
|
||||
import enum
|
||||
|
||||
from cutlass_library import (
|
||||
ComplexTransform,
|
||||
DataType,
|
||||
DataTypeSize,
|
||||
EpilogueScheduleType,
|
||||
KernelScheduleSuffixes,
|
||||
KernelScheduleType,
|
||||
MathOperation,
|
||||
OpcodeClass,
|
||||
TileSchedulerType
|
||||
)
|
||||
|
||||
|
||||
# The following block implements enum.auto() for Python 3.5 variants that don't include it such
|
||||
# as the default 3.5.2 on Ubuntu 16.04.
|
||||
#
|
||||
# https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility
|
||||
|
||||
try:
|
||||
from enum import auto as enum_auto
|
||||
except ImportError:
|
||||
__cutlass_library_auto_enum = 0
|
||||
|
||||
def enum_auto() -> int:
|
||||
global __cutlass_library_auto_enum
|
||||
i = __cutlass_library_auto_enum
|
||||
__cutlass_library_auto_enum += 1
|
||||
return i
|
||||
|
||||
|
||||
class DataTypeSizeBytes:
|
||||
"""
|
||||
Static class to mimic the `DataTypeSize` dictionary, but with checks for whether the
|
||||
data type key is less than a full byte or a non-integer number of bytes.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def __class_getitem__(datatype):
|
||||
"""
|
||||
Returns the number of bytes in size the data type is. Raises an exception if the data type
|
||||
is either less than a full byte or a non-integer number of bytes in size.
|
||||
|
||||
:param datatype: data type to query
|
||||
|
||||
:return: number of bytes the data type occupies
|
||||
:rtype: int
|
||||
"""
|
||||
bits = DataTypeSize[datatype]
|
||||
if bits < 8:
|
||||
raise Exception(
|
||||
f"Data type {datatype} is less than one byte in size."
|
||||
)
|
||||
elif bits % 8 != 0:
|
||||
raise Exception(
|
||||
f"Data type datatype is not an integer number of bytes."
|
||||
)
|
||||
return bits // 8
|
||||
|
||||
|
||||
class SchedulerMode(enum.Enum):
|
||||
Device = enum_auto()
|
||||
Host = enum_auto()
|
||||
|
||||
|
||||
SchedulerModeTag = {
|
||||
SchedulerMode.Device: "cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly",
|
||||
SchedulerMode.Host: "cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute",
|
||||
}
|
||||
|
||||
|
||||
ShortSchedulerModeNames = {SchedulerMode.Device: "Device", SchedulerMode.Host: "Host"}
|
||||
|
||||
|
||||
class FunctionalOp(enum.Enum):
|
||||
AtomicAdd = enum_auto()
|
||||
AtomicMaximum = enum_auto()
|
||||
Divides = enum_auto()
|
||||
Maximum = enum_auto()
|
||||
Minimum = enum_auto()
|
||||
Minus = enum_auto()
|
||||
Multiplies = enum_auto()
|
||||
MultiplyAdd = enum_auto()
|
||||
Plus = enum_auto()
|
||||
Exp = enum_auto()
|
||||
|
||||
|
||||
FunctionalOpTag = {
|
||||
FunctionalOp.AtomicAdd: "cutlass::atomic_add",
|
||||
FunctionalOp.AtomicMaximum: "cutlass::atomic_maximum",
|
||||
FunctionalOp.Divides: "cutlass::divides",
|
||||
FunctionalOp.Maximum: "cutlass::maximum",
|
||||
FunctionalOp.Minimum: "cutlass::minimum",
|
||||
FunctionalOp.Minus: "cutlass::minus",
|
||||
FunctionalOp.Multiplies: "cutlass::multiplies",
|
||||
FunctionalOp.MultiplyAdd: "cutlass::multiply_add",
|
||||
FunctionalOp.Plus: "cutlass::plus",
|
||||
FunctionalOp.Exp: "cutlass::fast_exp_op",
|
||||
}
|
||||
|
||||
|
||||
class ActivationOp(enum.Enum):
|
||||
DGelu = enum_auto()
|
||||
Gelu = enum_auto()
|
||||
GeluTaylor = enum_auto()
|
||||
HardSwish = enum_auto()
|
||||
Identity = enum_auto()
|
||||
LeakyReLU = enum_auto()
|
||||
ReLU = enum_auto()
|
||||
Sigmoid = enum_auto()
|
||||
SiLU = enum_auto()
|
||||
Tanh = enum_auto()
|
||||
|
||||
|
||||
ActivationOpTag = {
|
||||
ActivationOp.DGelu: "cutlass::epilogue::thread::dGELU",
|
||||
ActivationOp.Gelu: "cutlass::epilogue::thread::GELU",
|
||||
ActivationOp.GeluTaylor: "cutlass::epilogue::thread::GELU_taylor",
|
||||
ActivationOp.HardSwish: "cutlass::epilogue::thread::HardSwish",
|
||||
ActivationOp.Identity: "cutlass::epilogue::thread::Identity",
|
||||
ActivationOp.LeakyReLU: "cutlass::epilogue::thread::LeakyReLU",
|
||||
ActivationOp.ReLU: "cutlass::epilogue::thread::ReLu",
|
||||
ActivationOp.Sigmoid: "cutlass::epilogue::thread::Sigmoid",
|
||||
ActivationOp.SiLU: "cutlass::epilogue::thread::SiLu",
|
||||
ActivationOp.Tanh: "cutlass::epilogue::thread::Tanh",
|
||||
}
|
||||
|
||||
|
||||
def op_tag(op) -> str:
|
||||
"""
|
||||
Dispatches `op` to the appropriate *Tag dictionary depending on whether
|
||||
`op` is an ActivationOp or FunctionalOp. This is useful for cases in which
|
||||
either type can be used.
|
||||
|
||||
:param op: operation to emit a tag for
|
||||
:type op: ActivationOp | FunctionalOp
|
||||
|
||||
:return: tag corresponding to op
|
||||
:rtype: str
|
||||
"""
|
||||
if isinstance(op, ActivationOp):
|
||||
return ActivationOpTag[op]
|
||||
elif isinstance(op, FunctionalOp):
|
||||
return FunctionalOpTag[op]
|
||||
else:
|
||||
raise Exception(f"Unexpected op type {op}. Must be one of ActivationOp or FunctionalOp.")
|
||||
|
||||
|
||||
class FloatRoundStyle(enum.Enum):
|
||||
ToNearest = enum_auto()
|
||||
ToNearestSatfinite = enum_auto()
|
||||
Indeterminate = enum_auto()
|
||||
TowardZero = enum_auto()
|
||||
TowardInfinity = enum_auto()
|
||||
TowardNegInfinity = enum_auto()
|
||||
HalfUlpTruncDntz = enum_auto()
|
||||
HalfUlpTruncate = enum_auto()
|
||||
|
||||
|
||||
FloatRoundStyleTag = {
|
||||
FloatRoundStyle.ToNearest: "cutlass::FloatRoundStyle::round_to_nearest",
|
||||
FloatRoundStyle.ToNearestSatfinite: "cutlass::FloatRoundStyle::round_to_nearest_satfinite",
|
||||
FloatRoundStyle.Indeterminate: "cutlass::FloatRoundStyle::round_indeterminate",
|
||||
FloatRoundStyle.TowardZero: "cutlass::FloatRoundStyle::round_toward_zero",
|
||||
FloatRoundStyle.TowardInfinity: "cutlass::FloatRoundStyle::round_toward_infinity",
|
||||
FloatRoundStyle.TowardNegInfinity: "cutlass::FloatRoundStyle::round_toward_neg_infinity",
|
||||
FloatRoundStyle.HalfUlpTruncDntz: "cutlass::FloatRoundStyle::round_half_ulp_trunc_dntz",
|
||||
FloatRoundStyle.HalfUlpTruncate: "cutlass::FloatRoundStyle::round_half_ulp_truncate",
|
||||
}
|
||||
|
||||
|
||||
class MathInstruction:
|
||||
"""
|
||||
Description of a the lowest-level matrix-multiply-accumulate operation to be used in a kernel
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
instruction_shape,
|
||||
element_a,
|
||||
element_b,
|
||||
element_accumulator,
|
||||
opcode_class=OpcodeClass.Simt,
|
||||
math_operation=MathOperation.multiply_add,
|
||||
):
|
||||
"""
|
||||
:param instruction_shape: size of the [M, N, K] dimensions of the instruction
|
||||
:type instruction_shape: list or tuple
|
||||
:param element_a: data type of operand A
|
||||
:param element_b: data type of operand B
|
||||
:param element_accumulator: data type used in accumulation
|
||||
:param opcode_class: higher-level class of the instruction (e.g., SIMT or Tensor Core)
|
||||
:type opcode_class: cutlass_library.library.OpcodeClass
|
||||
:param math_operation: the type of low-level operation to be performed (e.g., multiply accumulate)
|
||||
:type math_operation: MathOperation
|
||||
"""
|
||||
self.instruction_shape = instruction_shape
|
||||
self.element_a = element_a
|
||||
self.element_b = element_b
|
||||
self.element_accumulator = element_accumulator
|
||||
self.opcode_class = opcode_class
|
||||
self.math_operation = math_operation
|
||||
|
||||
|
||||
def to_blackwell_threadblock_shape(tile_description, cluster_shape, kernel_schedule):
|
||||
blackwell_threadblock_shape = tile_description.threadblock_shape
|
||||
is_2sm = False if kernel_schedule is None else ("2sm" in KernelScheduleSuffixes[kernel_schedule])
|
||||
if cluster_shape[0] > 0:
|
||||
blackwell_threadblock_shape = [
|
||||
tile_description.threadblock_shape[0] // cluster_shape[0],
|
||||
tile_description.threadblock_shape[1] // cluster_shape[1],
|
||||
tile_description.threadblock_shape[2] // cluster_shape[2]
|
||||
]
|
||||
if is_2sm:
|
||||
blackwell_threadblock_shape[0] *= 2
|
||||
else:
|
||||
blackwell_threadblock_shape = tile_description.math_instruction.instruction_shape
|
||||
return blackwell_threadblock_shape, is_2sm
|
||||
|
||||
|
||||
class TileDescription:
|
||||
"""
|
||||
Description of a tile of computation to be performed in the kernel, encompassing threadblock, cluster, and warp shapes,
|
||||
stage count, and math instruction specification
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
threadblock_shape,
|
||||
stages,
|
||||
warp_count,
|
||||
math_instruction,
|
||||
cluster_shape=[1, 1, 1],
|
||||
kernel_schedule: KernelScheduleType = None,
|
||||
epilogue_schedule: EpilogueScheduleType = None,
|
||||
tile_scheduler: TileSchedulerType = None
|
||||
):
|
||||
"""
|
||||
:param threadblock_shape: shape of a threadblock tyle
|
||||
:type threadblock_shape: list or tuple
|
||||
:param stages: number of pipline stages in the operation. For SM90 kernels, this can be set to `None` and the maximum
|
||||
number of stages that can be supported for an operation on a given architecture will be computed at a later time
|
||||
:type stages: int or None
|
||||
:param warp_count: number of warps in each [M, N, K] dimension of a threadblock tile
|
||||
:type warp_count: list, tuple, or None
|
||||
:param math_instruction: specification of the instruction type and shape to be performed and the types of its operands
|
||||
:type math_instruction: MathInstruction
|
||||
:param cluster_shape: number of threadblocks in the [X, Y, Z] dimensions of a threadblock cluster
|
||||
:param kernel_schedule: type of kernel schedule to use (only available for SM90+)
|
||||
:type kernel_schedule: cutlass_library.KernelScheduleType
|
||||
:param epilogue_schedule: type of epilogue schedule to use (only available for SM90+)
|
||||
:type epilogue_schedule: cutlass_library.EpilogueScheduleType
|
||||
:param tile_scheduler: type of tile scheduler to use (only available for SM90+)
|
||||
:type tile_scheduler: cutlass_library.TileSchedulerType
|
||||
"""
|
||||
if ((kernel_schedule is None and epilogue_schedule is not None) or
|
||||
(kernel_schedule is not None and epilogue_schedule is None)):
|
||||
raise Exception("Kernel and epilogue schedule must either both be Auto or neither be Auto.")
|
||||
|
||||
self.threadblock_shape = threadblock_shape
|
||||
self.cluster_shape = cluster_shape
|
||||
self.kernel_schedule = kernel_schedule
|
||||
self.epilogue_schedule = epilogue_schedule
|
||||
self.tile_scheduler = tile_scheduler
|
||||
self.stages = stages
|
||||
|
||||
self.math_instruction = math_instruction
|
||||
self.instruction_shape = math_instruction.instruction_shape
|
||||
|
||||
# Number of warps along x, y, z directions
|
||||
self.warp_count = warp_count
|
||||
|
||||
self.blackwell_threadblock_shape, self.is_2sm = to_blackwell_threadblock_shape(self, self.cluster_shape, self.kernel_schedule)
|
||||
|
||||
def clone_and_update(self, td: dict):
|
||||
attrs = {
|
||||
"cluster_shape": None,
|
||||
"threadblock_shape": None,
|
||||
"warp_count": None,
|
||||
"stages": None,
|
||||
"instruction_shape": None,
|
||||
"kernel_schedule": None,
|
||||
"epilogue_schedule": None,
|
||||
"tile_scheduler": None
|
||||
}
|
||||
for key in attrs.keys():
|
||||
if key in td.keys():
|
||||
attrs[key] = td[key]
|
||||
else:
|
||||
attrs[key] = getattr(self, key)
|
||||
|
||||
attrs["math_instruction"] = MathInstruction(
|
||||
attrs["instruction_shape"],
|
||||
self.math_instruction.element_a,
|
||||
self.math_instruction.element_b,
|
||||
self.math_instruction.element_accumulator,
|
||||
self.math_instruction.opcode_class,
|
||||
self.math_instruction.math_operation
|
||||
)
|
||||
|
||||
# Remove the instruction shape
|
||||
del attrs["instruction_shape"]
|
||||
|
||||
return TileDescription(**attrs)
|
||||
|
||||
@property
|
||||
def num_threads(self):
|
||||
"""
|
||||
Returns the number of threads in the threadblock
|
||||
|
||||
:return: number of threads in the threadblock
|
||||
:rtype: int or None (if warp count is None)
|
||||
"""
|
||||
if self.warp_count is not None:
|
||||
threads = 32
|
||||
for cnt in self.warp_count:
|
||||
threads *= cnt
|
||||
return threads
|
||||
return None
|
||||
|
||||
def procedural_name(self):
|
||||
"""
|
||||
Returns a name identifying the tile description
|
||||
|
||||
:return: name identifying the tile description
|
||||
:rtype: int
|
||||
"""
|
||||
emit_stages = 0 if self.stages is None else self.stages
|
||||
name = "%dx%dx%d_%dx%d_%dx%d" % (
|
||||
self.cluster_shape[0],
|
||||
self.cluster_shape[1],
|
||||
self.cluster_shape[2],
|
||||
self.threadblock_shape[0],
|
||||
self.threadblock_shape[1],
|
||||
self.threadblock_shape[2],
|
||||
emit_stages
|
||||
)
|
||||
|
||||
return name
|
||||
|
||||
def procedural_name_2x(self):
|
||||
"""
|
||||
Returns a name identifying the tile description
|
||||
|
||||
:return: name identifying the tile description
|
||||
:rtype: int
|
||||
"""
|
||||
return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages)
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
Returns a string with containing each of the tile description's values
|
||||
|
||||
:return: contents of tile description
|
||||
:rtype: str
|
||||
"""
|
||||
if self.kernel_schedule is not None:
|
||||
kschedule = self.kernel_schedule
|
||||
else:
|
||||
kschedule = KernelScheduleType.ScheduleAuto
|
||||
|
||||
if self.epilogue_schedule is not None:
|
||||
eschedule = self.epilogue_schedule
|
||||
else:
|
||||
eschedule = EpilogueScheduleType.ScheduleAuto
|
||||
|
||||
if self.tile_scheduler is not None:
|
||||
tschedule = self.tile_scheduler.name
|
||||
else:
|
||||
tschedule = "None"
|
||||
return f"""
|
||||
{{
|
||||
ClusterShape: {self.cluster_shape}
|
||||
ThreadblockShape: {self.threadblock_shape}
|
||||
WarpCount: {self.warp_count}
|
||||
Stages: {self.stages if self.stages is not None else 'Auto'}
|
||||
InstructionShape: {self.math_instruction.instruction_shape}
|
||||
Kernel schedule: {kschedule.name}
|
||||
Epilogue schedule: {kschedule.name}
|
||||
TileScheduler: {tschedule}
|
||||
}}"""
|
||||
|
||||
|
||||
class TensorDescription:
|
||||
def __init__(self, element, layout, alignment=1, complex_transform=ComplexTransform.none):
|
||||
self.element = element
|
||||
self.layout = layout
|
||||
if element != DataType.void:
|
||||
self.alignment = min(128 // DataTypeSize[self.element], alignment)
|
||||
else:
|
||||
self.alignment = alignment
|
||||
self.complex_transform = complex_transform
|
||||
|
||||
|
||||
def CalculateSmemUsagePerStage(operation):
|
||||
"""
|
||||
Returns the amount of shared memory in bytes consumed in a single stage of a kernel.
|
||||
|
||||
:param op: operation for which the maximum stages should be computed. If stages are
|
||||
set via the `op.tile_description.stages` parameter, this setting is ignored
|
||||
in the present calculation
|
||||
:type op: cutlass_cppgen.backend.Operation
|
||||
|
||||
:return: number of bytes of shared memory consumed by a single stage
|
||||
:rtype: int
|
||||
"""
|
||||
m, n, k = operation.tile_description.threadblock_shape
|
||||
|
||||
if operation.operation_kind == OperationKind.Gemm:
|
||||
stage_barrier_bytes = 32
|
||||
return (
|
||||
(DataTypeSize[operation.A.element] * m * k // 8)
|
||||
+ (DataTypeSize[operation.B.element] * k * n // 8)
|
||||
+ stage_barrier_bytes
|
||||
)
|
||||
else:
|
||||
raise Exception("Unsupported operation kind {}.".format(operation.operation_kind))
|
||||
|
||||
|
||||
def CalculateSmemUsage(operation):
|
||||
"""
|
||||
Returns the amount of shared memory in bytes consumed by a kernel.
|
||||
|
||||
:param op: operation for which the maximum stages should be computed. If stages are
|
||||
set via the `op.tile_description.stages` parameter, this setting is ignored
|
||||
in the present calculation
|
||||
:type op: cutlass_cppgen.backend.Operation
|
||||
|
||||
:return: int
|
||||
"""
|
||||
return operation.tile_description.stages * CalculateSmemUsagePerStage(operation)
|
||||
|
||||
|
||||
class ApiVersion(enum.Enum):
|
||||
"""
|
||||
Differentiate between CUTLASS 2.x and 3.x API versions
|
||||
"""
|
||||
|
||||
v2x = enum_auto()
|
||||
v3x = enum_auto()
|
||||
|
||||
|
||||
def api_version(arch, opclass, dtype):
|
||||
"""
|
||||
Returns whether the architecture, opcode class, and datatype in question require using CUTLASS 2.x
|
||||
or 3.x for code emission.
|
||||
|
||||
:param arch: compute capability of device on which to run
|
||||
:type arch: int
|
||||
:param opclass: class of the operation being performed
|
||||
:type opclass: cutlass_library.OpcodeClass
|
||||
:param dtype: data type to be used in operation (assumes that ElementA and ElementB are the same)
|
||||
:type dtype: cutlass_library.DataType
|
||||
|
||||
:return: API version to be used in code emission
|
||||
:rtype: ApiVersion
|
||||
"""
|
||||
if (arch in [90, 100, 101, 103] and
|
||||
opclass == OpcodeClass.TensorOp and
|
||||
(dtype != DataType.f64)):
|
||||
return ApiVersion.v3x
|
||||
else:
|
||||
return ApiVersion.v2x
|
||||
|
||||
|
||||
class EmissionType(enum.Enum):
|
||||
"""
|
||||
Tags for whether to emit a kernel- or device-level operation
|
||||
"""
|
||||
|
||||
Kernel = enum_auto()
|
||||
Device = enum_auto()
|
||||
121
python/cutlass_cppgen/backend/memory_manager.py
Normal file
121
python/cutlass_cppgen/backend/memory_manager.py
Normal file
@ -0,0 +1,121 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
import numpy as np
|
||||
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen.utils.datatypes import is_numpy_tensor
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
|
||||
if cutlass_cppgen.use_rmm:
|
||||
import rmm
|
||||
else:
|
||||
cudart = lazy_import("cuda.cudart")
|
||||
|
||||
|
||||
class PoolMemoryManager:
|
||||
def __init__(self, init_pool_size: int, max_pool_size: int) -> None:
|
||||
self.pool = rmm.mr.PoolMemoryResource(
|
||||
rmm.mr.CudaMemoryResource(),
|
||||
initial_pool_size=init_pool_size,
|
||||
maximum_pool_size=max_pool_size
|
||||
)
|
||||
self.mr = rmm.mr.TrackingResourceAdaptor(self.pool)
|
||||
rmm.mr.set_current_device_resource(self.mr)
|
||||
|
||||
def pool_size(self):
|
||||
return self.pool.pool_size()
|
||||
|
||||
|
||||
class DevicePtrWrapper:
|
||||
"""
|
||||
Wrapper around a pointer to device memory to provide a uniform interface with the RMM DeviceBuffer
|
||||
(at least in terms of the interface used by the CUTLASS Python interface)
|
||||
"""
|
||||
def __init__(self, dev_ptr):
|
||||
self.dev_ptr = dev_ptr
|
||||
|
||||
@property
|
||||
def ptr(self):
|
||||
return self.dev_ptr
|
||||
|
||||
|
||||
def _todevice(host_data):
|
||||
"""
|
||||
Helper for transferring host data to device memory
|
||||
"""
|
||||
if cutlass_cppgen.use_rmm:
|
||||
return rmm.DeviceBuffer.to_device(host_data.tobytes())
|
||||
else:
|
||||
nbytes = len(host_data.tobytes())
|
||||
dev_ptr_wrapper = device_mem_alloc(nbytes)
|
||||
err, = cudart.cudaMemcpy(
|
||||
dev_ptr_wrapper.ptr,
|
||||
host_data.__array_interface__['data'][0],
|
||||
nbytes,
|
||||
cudart.cudaMemcpyKind.cudaMemcpyHostToDevice
|
||||
)
|
||||
if err != cudart.cudaError_t.cudaSuccess:
|
||||
raise Exception(f"cudaMemcpy failed with error {err}")
|
||||
return dev_ptr_wrapper
|
||||
|
||||
|
||||
def todevice(host_data, dtype=np.float32):
|
||||
"""
|
||||
Pass the host_data to device memory
|
||||
"""
|
||||
if isinstance(host_data, list):
|
||||
return _todevice(np.array(host_data, dtype=dtype))
|
||||
elif is_numpy_tensor(host_data):
|
||||
return _todevice(host_data)
|
||||
|
||||
|
||||
def device_mem_alloc(size):
|
||||
if cutlass_cppgen.use_rmm:
|
||||
return rmm.DeviceBuffer(size=size)
|
||||
else:
|
||||
err, ptr = cudart.cudaMalloc(size)
|
||||
if err != cudart.cudaError_t.cudaSuccess:
|
||||
raise Exception(f"cudaMalloc failed with error {err}")
|
||||
return DevicePtrWrapper(ptr)
|
||||
|
||||
|
||||
def align_size(size, alignment=256):
|
||||
return ((size + alignment - 1) // alignment) * alignment
|
||||
|
||||
|
||||
def create_memory_pool(init_pool_size=0, max_pool_size=2 ** 34):
|
||||
if cutlass_cppgen.use_rmm:
|
||||
memory_pool = PoolMemoryManager(init_pool_size=init_pool_size, max_pool_size=max_pool_size)
|
||||
return memory_pool
|
||||
else:
|
||||
return None
|
||||
140
python/cutlass_cppgen/backend/operation.py
Normal file
140
python/cutlass_cppgen/backend/operation.py
Normal file
@ -0,0 +1,140 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
import ctypes
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
|
||||
from cutlass_cppgen.backend.utils.device import device_cc
|
||||
|
||||
_supports_cluster_launch = None
|
||||
|
||||
|
||||
def supports_cluster_launch():
|
||||
from cuda import __version__
|
||||
_version_splits = [int(x) for x in __version__.split("rc")[0].split(".post")[0].split(".")]
|
||||
global _supports_cluster_launch
|
||||
if _supports_cluster_launch is None:
|
||||
major, minor = _version_splits[0], _version_splits[1]
|
||||
_supports_cluster_launch = device_cc() in [90, 100, 101, 103] and (major > 11 or (major == 11 and minor >= 8))
|
||||
return _supports_cluster_launch
|
||||
|
||||
|
||||
class LaunchConfiguration:
|
||||
def __init__(self, grid=[1, 1, 1], block=[1, 1, 1], smem=0):
|
||||
self.grid = grid
|
||||
self.block = block
|
||||
self.shared_memory_capacity = smem
|
||||
|
||||
|
||||
class ExecutableOperation:
|
||||
def __init__(self, operation):
|
||||
self.operation = operation
|
||||
self.module = None
|
||||
self.kernel = None
|
||||
|
||||
def name(self):
|
||||
return self.operation.procedural_name()
|
||||
|
||||
def emit(self):
|
||||
return ""
|
||||
|
||||
def can_implement(self, configuration, arguments):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_host_workspace_size(self, arguments):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_device_workspace_size(self, arguments):
|
||||
raise NotImplementedError()
|
||||
|
||||
def plan(self, arguments):
|
||||
raise NotImplementedError()
|
||||
|
||||
def initialize(self, host_workspace, device_workspace, launch_config, arguments, stream=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def run_with_clusters(self, launch_config, kernel_params, stream=None):
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
if hasattr(self.operation, "tile_description") and hasattr(self.operation.tile_description, "cluster_shape"):
|
||||
attr = cuda.CUlaunchAttribute()
|
||||
attr.value.clusterDim.x, attr.value.clusterDim.y, attr.value.clusterDim.z = self.operation.tile_description.cluster_shape
|
||||
attr.id = cuda.CUstreamAttrID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
|
||||
attrs = [attr]
|
||||
|
||||
# Allow for non-portable cluster sizes
|
||||
err, = cuda.cuFuncSetAttribute(
|
||||
self.kernel, cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
return err
|
||||
else:
|
||||
attrs = []
|
||||
|
||||
config = cuda.CUlaunchConfig()
|
||||
config.gridDimX, config.gridDimY, config.gridDimZ = launch_config.grid
|
||||
config.blockDimX, config.blockDimY, config.blockDimZ = launch_config.block
|
||||
config.blockDimZ = launch_config.block[2]
|
||||
config.sharedMemBytes = launch_config.shared_memory_capacity
|
||||
config.hStream = stream
|
||||
config.attrs = attrs
|
||||
config.numAttrs = len(attrs)
|
||||
|
||||
err, = cuda.cuLaunchKernelEx(
|
||||
config, f=self.kernel, kernelParams=kernel_params, extra=0)
|
||||
return err
|
||||
|
||||
def run_without_clusters(self, launch_config, kernel_params, stream=None):
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
err, = cuda.cuLaunchKernel(
|
||||
self.kernel,
|
||||
launch_config.grid[0], launch_config.grid[1], launch_config.grid[2],
|
||||
launch_config.block[0], launch_config.block[1], launch_config.block[2],
|
||||
launch_config.shared_memory_capacity,
|
||||
stream,
|
||||
kernel_params,
|
||||
0)
|
||||
|
||||
return err
|
||||
|
||||
def run(self, host_workspace, device_workspace, launch_config, stream=None):
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
cArg = (ctypes.c_char * len(host_workspace)).from_buffer(host_workspace)
|
||||
packed = (ctypes.c_void_p * 1)()
|
||||
packed[0] = ctypes.addressof(cArg)
|
||||
|
||||
if supports_cluster_launch():
|
||||
return self.run_with_clusters(launch_config, packed, stream)
|
||||
else:
|
||||
return self.run_without_clusters(launch_config, packed, stream)
|
||||
455
python/cutlass_cppgen/backend/reduction_operation.py
Normal file
455
python/cutlass_cppgen/backend/reduction_operation.py
Normal file
@ -0,0 +1,455 @@
|
||||
################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
from __future__ import annotations
|
||||
|
||||
import ctypes
|
||||
from typing import Union
|
||||
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
cudart = lazy_import("cuda.cudart")
|
||||
import numpy as np
|
||||
|
||||
from cutlass_library import (
|
||||
DataTypeNames,
|
||||
DataTypeSize,
|
||||
DataTypeTag,
|
||||
LayoutType,
|
||||
SubstituteTemplate
|
||||
)
|
||||
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen.backend.c_types import MatrixCoord_, TensorRef2D_, get_reduction_params
|
||||
from cutlass_cppgen.backend.frontend import NumpyFrontend, TorchFrontend
|
||||
from cutlass_cppgen.backend.library import TensorDescription
|
||||
from cutlass_cppgen.backend.memory_manager import DevicePtrWrapper
|
||||
from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration
|
||||
from cutlass_cppgen.shape import MatrixCoord
|
||||
from cutlass_cppgen.utils.datatypes import is_numpy_tensor, is_torch_tensor
|
||||
|
||||
|
||||
class ReductionOperation:
|
||||
pass
|
||||
|
||||
|
||||
class ReductionArguments:
|
||||
"""
|
||||
Arguments of reduction
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
operation: ReductionOperation,
|
||||
problem_size: "list[int]",
|
||||
partitions: int,
|
||||
workspace: cuda.CUdeviceptr,
|
||||
destination: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]",
|
||||
source: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
# tensor_C can be interpreted as the bias with bias=True in keyword args
|
||||
if "bias" in kwargs.keys():
|
||||
self.bias = kwargs["bias"]
|
||||
else:
|
||||
# by default, tensor_C is not bias
|
||||
self.bias = False
|
||||
if "stream" in kwargs.keys():
|
||||
self.stream = kwargs["stream"]
|
||||
else:
|
||||
self.stream = cuda.CUstream(0)
|
||||
|
||||
self.operation = operation
|
||||
self.ptr_workspace = workspace
|
||||
|
||||
# number of split-k partitions
|
||||
self.partitions = partitions
|
||||
|
||||
if is_numpy_tensor(destination):
|
||||
self.host_D = destination
|
||||
self.destination_buffer = NumpyFrontend.argument(destination, True)
|
||||
self.source_buffer = NumpyFrontend.argument(source, False)
|
||||
self.ptr_destination = cuda.CUdeviceptr(self.destination_buffer.ptr)
|
||||
self.ptr_source = cuda.CUdeviceptr(self.source_buffer.ptr)
|
||||
elif is_torch_tensor(destination):
|
||||
self.ptr_destination = TorchFrontend.argument(destination)
|
||||
self.ptr_source = TorchFrontend.argument(source)
|
||||
elif isinstance(destination, cuda.CUdeviceptr):
|
||||
self.ptr_destination = destination
|
||||
self.ptr_source = source
|
||||
else:
|
||||
raise TypeError("unknown Type")
|
||||
|
||||
self.problem_size = MatrixCoord_(problem_size[0], problem_size[1])
|
||||
|
||||
self.partition_stride = (
|
||||
problem_size[0] * problem_size[1] * DataTypeSize[operation.C.element] // 8
|
||||
)
|
||||
|
||||
if "output_op" in kwargs.keys():
|
||||
self.output_op = kwargs["output_op"]
|
||||
else:
|
||||
self.output_op = self.operation.epilogue_type(1.0, 0.0)
|
||||
|
||||
self.get_arguments()
|
||||
|
||||
@staticmethod
|
||||
def get_tensor_ref(
|
||||
extent: "tuple[int]",
|
||||
device_ptr: cuda.CUdeviceptr,
|
||||
layout: LayoutType,
|
||||
):
|
||||
if layout == LayoutType.RowMajor:
|
||||
return TensorRef2D_(int(device_ptr), extent[1])
|
||||
else:
|
||||
raise ValueError(f"Unknown layout type {layout}")
|
||||
|
||||
def get_arguments(self):
|
||||
ref_workspace = ReductionArguments.get_tensor_ref(
|
||||
extent=[
|
||||
self.problem_size.row,
|
||||
self.problem_size.column,
|
||||
],
|
||||
device_ptr=self.ptr_workspace,
|
||||
layout=LayoutType.RowMajor,
|
||||
)
|
||||
if self.bias:
|
||||
ref_source = ReductionArguments.get_tensor_ref(
|
||||
extent=[0, 0],
|
||||
device_ptr=self.ptr_source,
|
||||
layout=LayoutType.RowMajor,
|
||||
)
|
||||
else:
|
||||
ref_source = ReductionArguments.get_tensor_ref(
|
||||
extent=[
|
||||
self.problem_size.row,
|
||||
self.problem_size.column,
|
||||
],
|
||||
device_ptr=self.ptr_source,
|
||||
layout=LayoutType.RowMajor,
|
||||
)
|
||||
|
||||
ref_destination = ReductionArguments.get_tensor_ref(
|
||||
extent=[
|
||||
self.problem_size.row,
|
||||
self.problem_size.column,
|
||||
],
|
||||
device_ptr=self.ptr_destination,
|
||||
layout=LayoutType.RowMajor,
|
||||
)
|
||||
|
||||
self.c_arguments = self.operation.argument_type(
|
||||
self.problem_size,
|
||||
self.partitions,
|
||||
self.partition_stride,
|
||||
ref_workspace,
|
||||
ref_destination,
|
||||
ref_source,
|
||||
self.output_op,
|
||||
)
|
||||
|
||||
params_ = self.operation.rt_module.get_args(ctypes.byref(self.c_arguments))
|
||||
self.host_workspace = bytearray(params_.contents)
|
||||
|
||||
def sync(self):
|
||||
(err,) = cudart.cudaDeviceSynchronize()
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError(f"CUDA Error {str(err)}")
|
||||
|
||||
if hasattr(self, "host_D"):
|
||||
(err,) = cuda.cuMemcpyDtoH(
|
||||
self.host_D,
|
||||
self.ptr_destination,
|
||||
self.host_D.size * self.host_D.itemsize,
|
||||
)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
|
||||
self.free()
|
||||
|
||||
def free(self):
|
||||
"""
|
||||
Frees allocated device-side memory
|
||||
"""
|
||||
# Free any device memory allocated manually
|
||||
if not cutlass_cppgen.use_rmm:
|
||||
for attr in ["destination_buffer", "source_buffer"]:
|
||||
if hasattr(self, attr):
|
||||
buf = getattr(self, attr)
|
||||
if isinstance(buf, DevicePtrWrapper):
|
||||
err, = cudart.cudaFree(buf.ptr)
|
||||
if err != cudart.cudaError_t.cudaSuccess:
|
||||
raise RuntimeError(f"cudaFree failed with error {err}")
|
||||
del buf
|
||||
|
||||
|
||||
class ReductionRT(ExecutableOperation):
|
||||
"""
|
||||
ReductionRT manages the CUTLASS runtime components for reduction
|
||||
"""
|
||||
|
||||
KernelTemplate = r"""
|
||||
extern "C"
|
||||
__global__ void
|
||||
${operation_name}(${operation_name}${operation_suffix}::Params params) {
|
||||
|
||||
// Dynamic shared memory base pointer
|
||||
extern __shared__ int SharedStorageBase[];
|
||||
|
||||
// Declare pointer to dynamic shared memory.
|
||||
${operation_name}${operation_suffix}::SharedStorage *shared_storage =
|
||||
reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase);
|
||||
|
||||
${operation_name}${operation_suffix} op;
|
||||
|
||||
op(params, *shared_storage);
|
||||
}
|
||||
"""
|
||||
HostTemplate = r"""
|
||||
extern "C" {
|
||||
// Get the size of params in bytes
|
||||
int ${operation_name}_get_param_size(){
|
||||
return sizeof(${operation_name}${operation_suffix}::Params);
|
||||
}
|
||||
|
||||
// Get the size of dynamic shared memory in bytes
|
||||
int ${operation_name}_shared_memory_size() {
|
||||
return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
|
||||
}
|
||||
|
||||
// Get the params as byte array
|
||||
char* ${operation_name}_get_params(${operation_name}${operation_suffix}::Params* params){
|
||||
char *bytes = ((char*)(params));
|
||||
char *output = new char[sizeof(${operation_name}${operation_suffix}::Params)];
|
||||
for (unsigned int i = 0; i < sizeof(${operation_name}${operation_suffix}::Params); i ++)
|
||||
output[i] = bytes[i];
|
||||
|
||||
return output;
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, operation: ReductionOperation):
|
||||
super().__init__(operation)
|
||||
|
||||
self.operation: ReductionOperation = operation
|
||||
self.emitter = EmitReductionInstance("_type")
|
||||
|
||||
self.elements_per_access = self.operation.count
|
||||
(
|
||||
self.argument_type,
|
||||
self.epilogue_type,
|
||||
) = get_reduction_params(operation.epilogue_functor)
|
||||
self.argtype = [ctypes.POINTER(self.argument_type)]
|
||||
|
||||
def emit(self):
|
||||
return self.emitter.emit(self.operation)
|
||||
|
||||
def plan(self, arguments: ReductionArguments):
|
||||
block_shape = [
|
||||
self.operation.shape.column // self.elements_per_access,
|
||||
self.operation.shape.row,
|
||||
1,
|
||||
]
|
||||
grid_shape = [
|
||||
(arguments.problem_size.row + self.operation.shape.row - 1)
|
||||
// self.operation.shape.row,
|
||||
(arguments.problem_size.column + self.operation.shape.column - 1)
|
||||
// self.operation.shape.column,
|
||||
1,
|
||||
]
|
||||
return LaunchConfiguration(
|
||||
grid_shape,
|
||||
block_shape,
|
||||
self.shared_memory_capacity,
|
||||
)
|
||||
|
||||
def initialize(self):
|
||||
(err,) = cuda.cuFuncSetAttribute(
|
||||
self.kernel,
|
||||
attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
value=self.shared_memory_capacity,
|
||||
)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError(f"CUDA Error: {err}")
|
||||
|
||||
|
||||
class ReductionOperation:
|
||||
"""
|
||||
CUTLASS reduction Operation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shape: MatrixCoord,
|
||||
C: TensorDescription,
|
||||
element_accumulator,
|
||||
element_workspace=None,
|
||||
element_compute=None,
|
||||
epilogue_functor=None,
|
||||
count: int = 1,
|
||||
partitions_per_stage: int = 4,
|
||||
) -> None:
|
||||
self.shape = shape
|
||||
self.epilogue_functor = epilogue_functor
|
||||
self.element_accumulator = element_accumulator
|
||||
|
||||
if element_workspace is None:
|
||||
self.element_workspace = element_accumulator
|
||||
else:
|
||||
self.element_workspace = element_workspace
|
||||
|
||||
if element_compute is None:
|
||||
self.element_compute = element_accumulator
|
||||
else:
|
||||
self.element_compute = element_compute
|
||||
|
||||
self.element_output = C.element
|
||||
self.C: TensorDescription = C
|
||||
|
||||
# Reduce op processing size
|
||||
self.count: int = count
|
||||
|
||||
# Number of partitions to reduce per stage
|
||||
self.partitions_per_stage: int = partitions_per_stage
|
||||
|
||||
self.rt_module: ReductionRT = ReductionRT(self)
|
||||
self.argument_type = self.rt_module.argument_type
|
||||
self.epilogue_type = self.rt_module.epilogue_type
|
||||
|
||||
def extended_name(self):
|
||||
extend_name = "${element_workspace}_${element_accumulator}_${element_compute}_${element_output}"
|
||||
|
||||
return SubstituteTemplate(
|
||||
extend_name,
|
||||
{
|
||||
"element_workspace": DataTypeNames[self.element_workspace],
|
||||
"element_accumulator": DataTypeNames[self.element_accumulator],
|
||||
"element_compute": DataTypeNames[self.element_compute],
|
||||
"element_output": DataTypeNames[self.element_output],
|
||||
},
|
||||
)
|
||||
|
||||
def configuration_name(self):
|
||||
"""The full procedural name indicates architecture, extended name, tile size"""
|
||||
|
||||
configuration_name = "cutlass_reduce_split_k_${extended_name}_${threadblock}"
|
||||
|
||||
threadblock = "%dx%d" % (
|
||||
self.shape.row,
|
||||
self.shape.column,
|
||||
)
|
||||
|
||||
return SubstituteTemplate(
|
||||
configuration_name,
|
||||
{
|
||||
"extended_name": self.extended_name(),
|
||||
"threadblock": threadblock,
|
||||
},
|
||||
)
|
||||
|
||||
def procedural_name(self):
|
||||
"""The full procedural name indicates architeture, extended name, tile size"""
|
||||
return self.configuration_name()
|
||||
|
||||
def run(self, arguments: ReductionArguments) -> cuda.CUresult:
|
||||
"""
|
||||
Configure and launch the cuda kernel with input arguments
|
||||
"""
|
||||
launch_config = self.rt_module.plan(arguments)
|
||||
|
||||
host_workspace = arguments.host_workspace
|
||||
device_workspace = None
|
||||
|
||||
err = self.rt_module.run(
|
||||
host_workspace,
|
||||
device_workspace,
|
||||
launch_config,
|
||||
arguments.stream
|
||||
)
|
||||
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError(f"CUDA Error {str(err)}")
|
||||
|
||||
return err
|
||||
|
||||
|
||||
class EmitReductionInstance:
|
||||
def __init__(self, operation_suffix="") -> None:
|
||||
self.operation_suffix = operation_suffix
|
||||
self.includes = [
|
||||
"cutlass/cutlass.h",
|
||||
"cutlass/numeric_types.h",
|
||||
"cutlass/arch/arch.h",
|
||||
"cutlass/arch/mma.h",
|
||||
"cutlass/layout/matrix.h",
|
||||
"cutlass/gemm/device/gemm.h",
|
||||
"cutlass/gemm/device/gemm_universal_adapter.h",
|
||||
"cutlass/gemm/kernel/default_gemm_universal.h",
|
||||
"cutlass/reduction/kernel/reduce_split_k.h",
|
||||
"cutlass/reduction/thread/reduction_operators.h",
|
||||
]
|
||||
self.template = """
|
||||
// Reduction kernel instance
|
||||
using ${operation_name}_base =
|
||||
typename cutlass::reduction::kernel::ReduceSplitK<
|
||||
cutlass::MatrixShape<${shape_row}, ${shape_column}>,
|
||||
${epilogue_functor},
|
||||
cutlass::reduction::thread::ReduceAdd<
|
||||
${element_accumulator},
|
||||
${element_output},
|
||||
${count}>,
|
||||
${partition_per_stage}>;
|
||||
|
||||
struct ${operation_name}${operation_suffix}:
|
||||
public ${operation_name}_base { };
|
||||
"""
|
||||
|
||||
def emit(self, operation: ReductionOperation):
|
||||
vector_length_bits = min(operation.C.alignment * DataTypeSize[operation.C.element], 128)
|
||||
epilogue_vector_length = vector_length_bits // DataTypeSize[operation.C.element]
|
||||
|
||||
values = {
|
||||
"operation_name": operation.configuration_name(),
|
||||
"operation_suffix": self.operation_suffix,
|
||||
"shape_row": str(operation.shape.row),
|
||||
"shape_column": str(operation.shape.column),
|
||||
"epilogue_functor": operation.epilogue_functor.emit(),
|
||||
"element_output": DataTypeTag[operation.element_output],
|
||||
"epilogue_vector_length": str(epilogue_vector_length),
|
||||
"element_accumulator": DataTypeTag[operation.element_accumulator],
|
||||
"element_compute": DataTypeTag[operation.element_compute],
|
||||
"element_workspace": DataTypeTag[operation.element_workspace],
|
||||
"count": str(operation.count),
|
||||
"partition_per_stage": str(operation.partitions_per_stage),
|
||||
}
|
||||
|
||||
return SubstituteTemplate(self.template, values)
|
||||
35
python/cutlass_cppgen/backend/type_hint.py
Normal file
35
python/cutlass_cppgen/backend/type_hint.py
Normal file
@ -0,0 +1,35 @@
|
||||
################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
GemmOperation = "Union[GemmOperationUniversal, GemmOperationGrouped]"
|
||||
|
||||
Tensor = "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]"
|
||||
33
python/cutlass_cppgen/backend/utils/__init__.py
Normal file
33
python/cutlass_cppgen/backend/utils/__init__.py
Normal file
@ -0,0 +1,33 @@
|
||||
################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from cutlass_cppgen.backend.utils.device import check_cuda_errors, device_cc
|
||||
126
python/cutlass_cppgen/backend/utils/device.py
Normal file
126
python/cutlass_cppgen/backend/utils/device.py
Normal file
@ -0,0 +1,126 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Utility functions for interacting with the device
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
cudart = lazy_import("cuda.cudart")
|
||||
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor
|
||||
|
||||
|
||||
def check_cuda_errors(result: list):
|
||||
"""
|
||||
Checks whether `result` contains a CUDA error raises the error as an exception, if so. Otherwise,
|
||||
returns the result contained in the remaining fields of `result`.
|
||||
|
||||
:param result: the results of the `cudart` method, consisting of an error code and any method results
|
||||
:type result: list
|
||||
|
||||
:return: non-error-code results from the `results` parameter
|
||||
"""
|
||||
# `result` is of the format : (cudaError_t, result...)
|
||||
err = result[0]
|
||||
if err.value:
|
||||
raise RuntimeError("CUDA error: {}".format(cudart.cudaGetErrorName(err)))
|
||||
|
||||
if len(result) == 1:
|
||||
return None
|
||||
elif len(result) == 2:
|
||||
return result[1]
|
||||
else:
|
||||
return result[1:]
|
||||
|
||||
|
||||
def device_cc(device: int = -1) -> int:
|
||||
"""
|
||||
Returns the compute capability of the device with ID `device`.
|
||||
|
||||
:param device: ID of the device to query
|
||||
:type device: int
|
||||
|
||||
:return: compute capability of the queried device (e.g., 80 for SM80)
|
||||
:rtype: int
|
||||
"""
|
||||
if device == -1:
|
||||
device = cutlass_cppgen.device_id()
|
||||
|
||||
deviceProp = check_cuda_errors(cudart.cudaGetDeviceProperties(device))
|
||||
major = str(deviceProp.major)
|
||||
minor = str(deviceProp.minor)
|
||||
return int(major + minor)
|
||||
|
||||
|
||||
def device_sm_count(device: int = -1):
|
||||
if device == -1:
|
||||
device = cutlass_cppgen.device_id()
|
||||
err, device_sm_count = cuda.cuDeviceGetAttribute(
|
||||
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device
|
||||
)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(
|
||||
"Failed to retireve SM count. "
|
||||
f"cuDeviceGetAttribute() failed with error: {cuda.cuGetErrorString(err)[1]}"
|
||||
)
|
||||
|
||||
return device_sm_count
|
||||
|
||||
|
||||
def to_device_ptr(tensor) -> cuda.CUdeviceptr:
|
||||
"""
|
||||
Converts a tensor to a CUdeviceptr
|
||||
|
||||
:param tensor: tensor to convert
|
||||
:type tensor: np.ndarray | torch.Tensor | cp.ndarray | int
|
||||
|
||||
:return: device pointer
|
||||
:rtype: cuda.CUdeviceptr
|
||||
"""
|
||||
if is_numpy_tensor(tensor):
|
||||
ptr = cuda.CUdeviceptr(tensor.__array_interface__["data"][0])
|
||||
elif is_torch_tensor(tensor):
|
||||
ptr = cuda.CUdeviceptr(tensor.data_ptr())
|
||||
elif is_cupy_tensor(tensor):
|
||||
ptr = cuda.CUdeviceptr(int(tensor.data.ptr))
|
||||
elif isinstance(tensor, cuda.CUdeviceptr):
|
||||
ptr = tensor
|
||||
elif isinstance(tensor, int):
|
||||
ptr = cuda.CUdeviceptr(tensor)
|
||||
else:
|
||||
raise NotImplementedError(tensor)
|
||||
|
||||
return ptr
|
||||
33
python/cutlass_cppgen/emit/__init__.py
Normal file
33
python/cutlass_cppgen/emit/__init__.py
Normal file
@ -0,0 +1,33 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass_cppgen.emit.pytorch import pytorch
|
||||
267
python/cutlass_cppgen/emit/common.py
Normal file
267
python/cutlass_cppgen/emit/common.py
Normal file
@ -0,0 +1,267 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Common utilities for emitting CUTLASS kernels
|
||||
"""
|
||||
|
||||
import cutlass_cppgen
|
||||
|
||||
# Strings used for printing information about the generation of emitted scripts
|
||||
_AUTOGEN_STR = f"This file was automatically generated by the CUTLASS {cutlass_cppgen.__version__} Python interface (https://github.com/nvidia/cutlass/python)"
|
||||
|
||||
|
||||
_CSTYLE_AUTOGEN_COMMENT = f"""// {_AUTOGEN_STR}
|
||||
"""
|
||||
|
||||
|
||||
_PYSTYLE_AUTOGEN_COMMENT = f"""# {_AUTOGEN_STR}
|
||||
"""
|
||||
|
||||
_CUTLASS_KERNEL_ARGS_2x = """
|
||||
typename DeviceKernel::Arguments arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{M, N, K}, // problem size
|
||||
1,
|
||||
{alpha, beta},
|
||||
A, B, C, D,
|
||||
0, 0, 0, 0, // batch strides
|
||||
DeviceKernel::LayoutA::packed({M, K}).stride(0), // lda
|
||||
DeviceKernel::LayoutB::packed({K, N}).stride(0), // ldb
|
||||
DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldc
|
||||
DeviceKernel::LayoutC::packed({M, N}).stride(0) // ldd
|
||||
};
|
||||
"""
|
||||
|
||||
_CUTLASS_KERNEL_ARGS_2x_STREAM_K = """
|
||||
typename DeviceKernel::Arguments arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{M, N, K}, // problem size
|
||||
1,
|
||||
{alpha, beta},
|
||||
A, B, C, D,
|
||||
0, 0, 0, 0, // batch strides
|
||||
DeviceKernel::LayoutA::packed({M, K}).stride(0), // lda
|
||||
DeviceKernel::LayoutB::packed({K, N}).stride(0), // ldb
|
||||
DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldc
|
||||
DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldd
|
||||
-1 // avail_sms
|
||||
};
|
||||
"""
|
||||
|
||||
_CUTLASS_KERNEL_RUN_GEMM_2x = """
|
||||
using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;
|
||||
|
||||
cutlass::Status ${name}_kernel_run(int M, int N, int K,
|
||||
const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C, DeviceKernel::ElementC* D,
|
||||
ElementCompute alpha, ElementCompute beta) {
|
||||
${args}
|
||||
size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
DeviceKernel gemm_op;
|
||||
cutlass::Status status = gemm_op.initialize(arguments,
|
||||
workspace.get(),
|
||||
nullptr); // CUDA stream
|
||||
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = gemm_op();
|
||||
return status;
|
||||
}
|
||||
"""
|
||||
|
||||
_CUTLASS_KERNEL_RUN_GEMM_3x = """
|
||||
using StrideA = typename DeviceKernel::GemmKernel::StrideA;
|
||||
using StrideB = typename DeviceKernel::GemmKernel::StrideB;
|
||||
using StrideC = typename DeviceKernel::GemmKernel::StrideC;
|
||||
using StrideD = typename DeviceKernel::GemmKernel::StrideD;
|
||||
|
||||
using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;
|
||||
|
||||
cutlass::Status ${name}_kernel_run(
|
||||
int M, int N, int K, int L,
|
||||
const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C, DeviceKernel::ElementC* D,
|
||||
ElementCompute alpha, ElementCompute beta, const cutlass::KernelHardwareInfo& hw_info) {
|
||||
|
||||
typename DeviceKernel::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{M, N, K, L}, // problem size
|
||||
{
|
||||
A, // ptrA
|
||||
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)), // stride A
|
||||
B, // ptrB
|
||||
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)), // stride B
|
||||
},
|
||||
{
|
||||
{alpha, beta},
|
||||
C, // ptrC
|
||||
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)), // stride C
|
||||
D, // ptrD
|
||||
cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)), // stride D
|
||||
},
|
||||
hw_info
|
||||
};
|
||||
|
||||
size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
DeviceKernel gemm_op;
|
||||
cutlass::Status status = gemm_op.run(arguments,
|
||||
workspace.get(),
|
||||
nullptr); // CUDA stream
|
||||
|
||||
return status;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
_CUTLASS_KERNEL_RUN_GROUPED_GEMM_2x = """
|
||||
using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;
|
||||
|
||||
int threadblock_count = DeviceKernel::sufficient();
|
||||
|
||||
cutlass::Status ${name}_kernel_run(int problem_count, cutlass::gemm::GemmCoord* problem_sizes,
|
||||
DeviceKernel::ElementA** A, DeviceKernel::ElementB** B, DeviceKernel::ElementC** C, DeviceKernel::ElementC** D,
|
||||
int64_t* lda, int64_t* ldb, int64_t* ldc, int64_t* ldd,
|
||||
ElementCompute alpha, ElementCompute beta) {
|
||||
|
||||
typename DeviceKernel::Arguments arguments {
|
||||
problem_sizes,
|
||||
problem_count,
|
||||
threadblock_count,
|
||||
{alpha, beta},
|
||||
A, B, C, D,
|
||||
lda, ldb, ldc, ldd
|
||||
};
|
||||
|
||||
size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
DeviceKernel gemm_op;
|
||||
cutlass::Status status = gemm_op.initialize(arguments,
|
||||
workspace.get(),
|
||||
nullptr); // CUDA stream
|
||||
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = gemm_op();
|
||||
return status;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
_CUTLASS_KERNEL_RUN_CONV2D_2x = """
|
||||
|
||||
using UnderlyingKernel = typename DeviceKernel::UnderlyingKernel;
|
||||
namespace {
|
||||
using TensorRefA = typename UnderlyingKernel::TensorRefA;
|
||||
using TensorRefB = typename UnderlyingKernel::TensorRefB;
|
||||
using TensorRefC = typename UnderlyingKernel::TensorRefC;
|
||||
using ElementCompute = typename UnderlyingKernel::EpilogueOutputOp::ElementCompute;
|
||||
}
|
||||
|
||||
template<typename TensorRef, typename Element>
|
||||
TensorRef get_tensor_ref(cutlass::Tensor4DCoord tensor_coord, Element* ptr){
|
||||
cutlass::layout::TensorNHWC layout = cutlass::layout::TensorNHWC::packed(tensor_coord);
|
||||
TensorRef tensor_ref(ptr, layout);
|
||||
return tensor_ref;
|
||||
}
|
||||
|
||||
cutlass::Status ${name}_kernel_run(cutlass::conv::Conv2dProblemSize* problem_size,
|
||||
UnderlyingKernel::ElementA* A, UnderlyingKernel::ElementB* B,
|
||||
UnderlyingKernel::ElementC* C, UnderlyingKernel::ElementC* D,
|
||||
ElementCompute alpha, ElementCompute beta, std::string split_k_mode,
|
||||
cudaStream_t stream, int device_id=0) {
|
||||
// create the tensor references
|
||||
cutlass::Tensor4DCoord tensor_coord_A = cutlass::conv::implicit_gemm_tensor_a_extent(
|
||||
cutlass::conv::Operator::k${conv_kind_name}, *problem_size
|
||||
);
|
||||
cutlass::Tensor4DCoord tensor_coord_B = cutlass::conv::implicit_gemm_tensor_b_extent(
|
||||
cutlass::conv::Operator::k${conv_kind_name}, *problem_size
|
||||
);
|
||||
cutlass::Tensor4DCoord tensor_coord_C = cutlass::conv::implicit_gemm_tensor_c_extent(
|
||||
cutlass::conv::Operator::k${conv_kind_name}, *problem_size
|
||||
);
|
||||
|
||||
TensorRefA tensor_ref_A = get_tensor_ref<TensorRefA, UnderlyingKernel::ElementA>(tensor_coord_A, A);
|
||||
TensorRefB tensor_ref_B = get_tensor_ref<TensorRefB, UnderlyingKernel::ElementB>(tensor_coord_B, B);
|
||||
TensorRefC tensor_ref_C = get_tensor_ref<TensorRefC, UnderlyingKernel::ElementC>(tensor_coord_C, C);
|
||||
TensorRefC tensor_ref_D = get_tensor_ref<TensorRefC, UnderlyingKernel::ElementC>(tensor_coord_C, D);
|
||||
|
||||
cutlass::conv::SplitKMode mode;
|
||||
if (split_k_mode == "serial") {
|
||||
mode = cutlass::conv::SplitKMode::kSerial;
|
||||
} else if (split_k_mode == "parallel") {
|
||||
mode = cutlass::conv::SplitKMode::kParallel;
|
||||
} else {
|
||||
throw std::runtime_error("Invalid split_k_mode: " + split_k_mode);
|
||||
}
|
||||
|
||||
typename DeviceKernel::Arguments arguments{
|
||||
*problem_size,
|
||||
tensor_ref_A,
|
||||
tensor_ref_B,
|
||||
tensor_ref_C,
|
||||
tensor_ref_D,
|
||||
{alpha, beta},
|
||||
mode
|
||||
};
|
||||
|
||||
DeviceKernel implicit_gemm_op;
|
||||
|
||||
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
|
||||
|
||||
void* workspace_ptr = device_memory_allocation(workspace_size, device_id);
|
||||
|
||||
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = implicit_gemm_op.initialize(arguments, workspace_ptr, stream);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
//
|
||||
// Launch initialized CUTLASS kernel
|
||||
//
|
||||
status = implicit_gemm_op(stream);
|
||||
|
||||
return status;
|
||||
}
|
||||
"""
|
||||
936
python/cutlass_cppgen/emit/pytorch.py
Normal file
936
python/cutlass_cppgen/emit/pytorch.py
Normal file
@ -0,0 +1,936 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Utilities for generating source for building a PyTorch CUDA extension that using a CUTLASS kernel.
|
||||
If specified, the extension can be JIT compiled via PyTorch's ``cpp_extension.load`` method.
|
||||
|
||||
Example usage with JIT compilation:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_library.LayoutType.RowMajor)
|
||||
op = plan.construct()
|
||||
mod = cutlass_cppgen.emit.pytorch(op, 'cutlass_gemm', 80, jit=True)
|
||||
|
||||
# Generate inputs for the GEMM
|
||||
A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)]
|
||||
|
||||
# Run the module
|
||||
D = mod.run(A, B, C)
|
||||
|
||||
|
||||
Example usage without JIT compilation:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
op = plan.construct()
|
||||
cutlass_cppgen.emit.pytorch(op, 'cutlass_gemm', 80, jit=False, sourcedir='output')
|
||||
|
||||
After this call, the directory ``output`` contains ``setup.py``,
|
||||
``cutlass_gemm.cpp``, and ``cutlass_gemm_kernel.cu``. The module can be built from
|
||||
within ``output`` by running: ``TORCH_CUDA_ARCH_LIST="8.0" python setup.py develop --user``.
|
||||
|
||||
The module can later be used in Python via:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
import cutlass_gemm
|
||||
|
||||
# Generate inputs for the GEMM
|
||||
A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)]
|
||||
|
||||
# Run the module
|
||||
D = cutlass_gemm.run(A, B, C)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from cutlass_library import ConvKind, ConvKindNames, DataType, SubstituteTemplate
|
||||
|
||||
from cutlass_cppgen import CUTLASS_PATH, logger, swizzle
|
||||
from cutlass_cppgen.backend.gemm_operation import GemmOperationGrouped, GemmOperationUniversal
|
||||
from cutlass_cppgen.backend.conv2d_operation import Conv2dOperation
|
||||
from cutlass_cppgen.backend.library import ApiVersion
|
||||
from cutlass_cppgen.emit import common
|
||||
from cutlass_cppgen.utils.datatypes import is_torch_available
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
_PYTORCH_CUDA_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/util/device_memory.h"
|
||||
|
||||
// helper function allocating the memory
|
||||
void* device_memory_allocation(size_t size, int device_id=0) {
|
||||
if (size > 0) {
|
||||
torch::Device device(torch::kCUDA, device_id);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
torch::TensorOptions options = torch::TensorOptions().dtype(torch::kI8).device(device);
|
||||
at::Tensor device_tensor = torch::empty({(long)size,}, options);
|
||||
return reinterpret_cast<void*>(device_tensor.data_ptr());
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
${includes}
|
||||
${declaration}
|
||||
${impl}
|
||||
"""
|
||||
|
||||
_PYTORCH_GEMM_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
// CUDA forward declarations
|
||||
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt, float alpha=1.f, float beta=0.f);
|
||||
|
||||
// C++ interface
|
||||
at::Tensor ${name}(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt, float alpha=1.f, float beta=0.f) {
|
||||
return ${name}_kernel(A, B, C, alpha, beta);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("run", py::overload_cast<const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>, float, float>(&${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f);
|
||||
}
|
||||
"""
|
||||
|
||||
_PYTORCH_GROUPED_GEMM_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
// CUDA forward declarations
|
||||
std::vector<at::Tensor> ${name}_kernel(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C=at::nullopt, float alpha=1.f, float beta=0.f);
|
||||
|
||||
// C++ interface
|
||||
std::vector<at::Tensor> ${name}(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C=at::nullopt, float alpha=1.f, float beta=0.f) {
|
||||
return ${name}_kernel(A, B, C, alpha, beta);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("run", py::overload_cast<const std::vector<at::Tensor>&, const std::vector<at::Tensor>&, at::optional<const std::vector<at::Tensor>>, float, float>(&${name}),
|
||||
py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f);
|
||||
}
|
||||
"""
|
||||
|
||||
_PYTORCH_CONV2D_FPROP_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
// CUDA forward declarations
|
||||
at::Tensor ${name}_kernel(
|
||||
const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
||||
float alpha=1.f, float beta=0.f,
|
||||
std::string split_k_mode="serial", int split_k_slices=1);
|
||||
|
||||
// C++ interface
|
||||
at::Tensor ${name}(
|
||||
const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
||||
float alpha=1.f, float beta=0.f,
|
||||
std::string split_k_mode="serial", int split_k_slices=1) {
|
||||
return ${name}_kernel(A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("run",
|
||||
py::overload_cast<
|
||||
const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>,
|
||||
std::tuple<int, int>, std::tuple<int, int>, std::tuple<int, int>, float, float, std::string, int>(
|
||||
&${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr,
|
||||
py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1),
|
||||
py::arg("alpha") = 1.f, py::arg("beta") = 0.f,
|
||||
py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1);
|
||||
}
|
||||
"""
|
||||
|
||||
_PYTORCH_CONV2D_GRAD_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
// CUDA forward declarations
|
||||
at::Tensor ${name}_kernel(
|
||||
std::tuple<int, int, int, int> result_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
||||
float alpha=1.f, float beta=0.f,
|
||||
std::string split_k_mode="serial", int split_k_slices=1);
|
||||
|
||||
// C++ interface
|
||||
at::Tensor ${name}(
|
||||
std::tuple<int, int, int, int> result_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
||||
float alpha=1.f, float beta=0.f,
|
||||
std::string split_k_mode="serial", int split_k_slices=1) {
|
||||
return ${name}_kernel(result_size, A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("run",
|
||||
py::overload_cast<
|
||||
std::tuple<int, int, int, int>, const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>,
|
||||
std::tuple<int, int>, std::tuple<int, int>, std::tuple<int, int>, float, float, std::string, int>(
|
||||
&${name}), py::arg("result_size"), py::arg("A"), py::arg("B"), py::arg("C") = nullptr,
|
||||
py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1),
|
||||
py::arg("alpha") = 1.f, py::arg("beta") = 0.f,
|
||||
py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1);
|
||||
}
|
||||
"""
|
||||
|
||||
_PYTORCH_GEMM_INCLUDES = {
|
||||
ApiVersion.v2x: """
|
||||
#include "cutlass/gemm/device/gemm_universal.h"
|
||||
""",
|
||||
ApiVersion.v3x: """
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
""",
|
||||
}
|
||||
|
||||
_PYTORCH_GROUPED_GEMM_INCLUDES = """
|
||||
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
|
||||
#include "cutlass/gemm/device/gemm_grouped.h"
|
||||
"""
|
||||
|
||||
_PYTORCH_CONV2D_INCLUDES = """
|
||||
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d_dgrad.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d_wgrad.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
"""
|
||||
|
||||
_CUTLASS_TYPE_TO_TORCH_TYPE = {
|
||||
DataType.f16: "torch::kF16",
|
||||
DataType.f32: "torch::kF32",
|
||||
DataType.f64: "torch::kF64",
|
||||
DataType.s8: "torch::kI8",
|
||||
DataType.s32: "torch::kI32",
|
||||
DataType.bf16: "torch::kBFloat16",
|
||||
}
|
||||
|
||||
_PYTORCH_GEMM_IMPL_TEMPLATE_2x = (
|
||||
common._CUTLASS_KERNEL_RUN_GEMM_2x
|
||||
+ """
|
||||
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C, float alpha, float beta) {
|
||||
int M = A.size(0);
|
||||
int N = B.size(1);
|
||||
int K = A.size(1);
|
||||
|
||||
typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ?
|
||||
nullptr :
|
||||
reinterpret_cast<typename DeviceKernel::ElementC*>(C->contiguous().data_ptr());
|
||||
at::Tensor D = B.new_empty({M, N}, ${torch_type_C});
|
||||
|
||||
cutlass::Status status = ${name}_kernel_run(M, N, K,
|
||||
reinterpret_cast<typename DeviceKernel::ElementA*>(A.contiguous().data_ptr()),
|
||||
reinterpret_cast<typename DeviceKernel::ElementB*>(B.contiguous().data_ptr()),
|
||||
ptrC,
|
||||
reinterpret_cast<typename DeviceKernel::ElementC*>(D.contiguous().data_ptr()),
|
||||
ElementCompute(alpha), ElementCompute(beta));
|
||||
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
|
||||
return D;
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
_PYTORCH_GEMM_IMPL_TEMPLATE_3x = (
|
||||
common._CUTLASS_KERNEL_RUN_GEMM_3x
|
||||
+ """
|
||||
bool hw_info_queried = false;
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
|
||||
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C, float alpha, float beta) {
|
||||
int M = A.size(0);
|
||||
int N = B.size(1);
|
||||
int K = A.size(1);
|
||||
int L = 1;
|
||||
|
||||
// Query hardware info if we haven't already
|
||||
if (!hw_info_queried) {
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
}
|
||||
|
||||
typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ?
|
||||
nullptr :
|
||||
reinterpret_cast<typename DeviceKernel::ElementC*>(C->contiguous().data_ptr());
|
||||
at::Tensor D = B.new_empty({M, N}, ${torch_type_C});
|
||||
|
||||
cutlass::Status status = ${name}_kernel_run(M, N, K, L,
|
||||
reinterpret_cast<typename DeviceKernel::ElementA*>(A.contiguous().data_ptr()),
|
||||
reinterpret_cast<typename DeviceKernel::ElementB*>(B.contiguous().data_ptr()),
|
||||
ptrC,
|
||||
reinterpret_cast<typename DeviceKernel::ElementC*>(D.contiguous().data_ptr()),
|
||||
ElementCompute(alpha), ElementCompute(beta),
|
||||
hw_info);
|
||||
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
|
||||
return D;
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
_PYTORCH_GROUPED_GEMM_IMPL_TEMPLATE = (
|
||||
common._CUTLASS_KERNEL_RUN_GROUPED_GEMM_2x
|
||||
+ """
|
||||
std::vector<at::Tensor> ${name}_kernel(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C, float alpha, float beta) {
|
||||
size_t num = A.size();
|
||||
|
||||
// To avoid performing many small cudaMallocs and host-to-device copies,
|
||||
// we serialize the grouped GEMM arguments on the host, allocate one
|
||||
// large chunk of device memory, and perform a single cudaMemcpy to
|
||||
// copy the host data to the device. Allocation overheads could be
|
||||
// avoided by using a memory pool.
|
||||
|
||||
// Calculate the total size of the data to be copied from host to device
|
||||
size_t total_size = sizeof(cutlass::gemm::GemmCoord) +
|
||||
sizeof(DeviceKernel::ElementA*) +
|
||||
sizeof(DeviceKernel::ElementB*) +
|
||||
sizeof(DeviceKernel::ElementC*) +
|
||||
sizeof(DeviceKernel::ElementC*) +
|
||||
sizeof(int64_t) +
|
||||
sizeof(int64_t) +
|
||||
sizeof(int64_t);
|
||||
total_size *= num;
|
||||
|
||||
// num * sizeof(cutlass::gemm::GemmCoord) may leave one at a non-multiple
|
||||
// of sizeof(DeviceKernel::ElementA*) (which will be 64 on a 64-bit system).
|
||||
// To ensure that we don't end up having misaligned loads in the kernel,
|
||||
// we pad to the nearest multiple of 8.
|
||||
//
|
||||
// Note that, even on a 32-bit system (for which sizeof(X*) will not equal
|
||||
// sizeof(int64_t)), only padding between the list of GemmCoords and the
|
||||
// list of ptr_As is sufficient because the set of four equal-length lists of pointers
|
||||
// (A*, B*, C*, D*) will ensure that the first list of int64_ts will always
|
||||
// start on a multiple of 8.
|
||||
int64_t padding = 8 - (total_size % 8);
|
||||
total_size += padding;
|
||||
|
||||
uint8_t* host_data = new uint8_t[total_size];
|
||||
cutlass::DeviceAllocation<uint8_t> device_data(total_size);
|
||||
|
||||
uint8_t* start = host_data;
|
||||
cutlass::gemm::GemmCoord* problem_sizes_host = reinterpret_cast<cutlass::gemm::GemmCoord*>(start);
|
||||
|
||||
// Apply the padding after the list of GemmCoords
|
||||
start += num * sizeof(cutlass::gemm::GemmCoord) + padding;
|
||||
|
||||
int64_t ptr_A_offset = start - host_data;
|
||||
DeviceKernel::ElementA** ptr_A_host = reinterpret_cast<DeviceKernel::ElementA**>(start);
|
||||
start += num * sizeof(DeviceKernel::ElementA*);
|
||||
|
||||
int64_t ptr_B_offset = start - host_data;
|
||||
DeviceKernel::ElementB** ptr_B_host = reinterpret_cast<DeviceKernel::ElementB**>(start);
|
||||
start += num * sizeof(DeviceKernel::ElementB*);
|
||||
|
||||
int64_t ptr_C_offset = start - host_data;
|
||||
DeviceKernel::ElementC** ptr_C_host = reinterpret_cast<DeviceKernel::ElementC**>(start);
|
||||
start += num * sizeof(DeviceKernel::ElementC*);
|
||||
|
||||
int64_t ptr_D_offset = start - host_data;
|
||||
DeviceKernel::ElementC** ptr_D_host = reinterpret_cast<DeviceKernel::ElementC**>(start);
|
||||
start += num * sizeof(DeviceKernel::ElementC*);
|
||||
|
||||
int64_t lda_offset = start - host_data;
|
||||
int64_t* lda_host = reinterpret_cast<int64_t*>(start);
|
||||
start += num * sizeof(int64_t);
|
||||
|
||||
int64_t ldb_offset = start - host_data;
|
||||
int64_t* ldb_host = reinterpret_cast<int64_t*>(start);
|
||||
start += num * sizeof(int64_t);
|
||||
|
||||
int64_t ldc_offset = start - host_data;
|
||||
int64_t* ldc_host = reinterpret_cast<int64_t*>(start);
|
||||
start += num * sizeof(int64_t);
|
||||
|
||||
std::vector<at::Tensor> D(num);
|
||||
|
||||
bool need_C = (C != at::nullopt) && (beta != 0.f);
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
int M = A[i].size(0);
|
||||
int N = B[i].size(1);
|
||||
int K = A[i].size(1);
|
||||
*(problem_sizes_host + i) = {M, N, K};
|
||||
*(ptr_A_host + i) = reinterpret_cast<typename DeviceKernel::ElementA*>(A[i].contiguous().data_ptr());
|
||||
*(ptr_B_host + i) = reinterpret_cast<typename DeviceKernel::ElementB*>(B[i].contiguous().data_ptr());
|
||||
|
||||
if (need_C) {
|
||||
*(ptr_C_host + i) = reinterpret_cast<typename DeviceKernel::ElementC*>(C->at(i).contiguous().data_ptr());
|
||||
}
|
||||
else {
|
||||
*(ptr_C_host + i) = nullptr;
|
||||
}
|
||||
|
||||
D[i] = B[i].new_empty({M, N}, ${torch_type_C});
|
||||
*(ptr_D_host + i) = reinterpret_cast<typename DeviceKernel::ElementC*>(D[i].contiguous().data_ptr());
|
||||
|
||||
*(lda_host + i) = DeviceKernel::LayoutA::packed({M, K}).stride(0);
|
||||
*(ldb_host + i) = DeviceKernel::LayoutB::packed({K, N}).stride(0);
|
||||
*(ldc_host + i) = DeviceKernel::LayoutC::packed({M, N}).stride(0);
|
||||
}
|
||||
|
||||
device_data.copy_from_host(host_data);
|
||||
|
||||
cutlass::Status status = ${name}_kernel_run(
|
||||
num,
|
||||
reinterpret_cast<cutlass::gemm::GemmCoord*>(device_data.get()),
|
||||
reinterpret_cast<DeviceKernel::ElementA**>(device_data.get() + ptr_A_offset),
|
||||
reinterpret_cast<DeviceKernel::ElementB**>(device_data.get() + ptr_B_offset),
|
||||
reinterpret_cast<DeviceKernel::ElementC**>(device_data.get() + ptr_C_offset),
|
||||
reinterpret_cast<DeviceKernel::ElementC**>(device_data.get() + ptr_D_offset),
|
||||
reinterpret_cast<int64_t*>(device_data.get() + lda_offset),
|
||||
reinterpret_cast<int64_t*>(device_data.get() + ldb_offset),
|
||||
reinterpret_cast<int64_t*>(device_data.get() + ldc_offset),
|
||||
reinterpret_cast<int64_t*>(device_data.get() + ldc_offset),
|
||||
ElementCompute(alpha), ElementCompute(beta));
|
||||
|
||||
delete[] host_data;
|
||||
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
|
||||
return D;
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
_PYTORCH_CONV2D_IMPL_TEMPLATE_2x = """
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
cutlass::Status status = ${name}_kernel_run(
|
||||
&problem_size,
|
||||
reinterpret_cast<typename UnderlyingKernel::ElementA*>(A.data_ptr()),
|
||||
reinterpret_cast<typename UnderlyingKernel::ElementB*>(B.data_ptr()),
|
||||
ptrC,
|
||||
reinterpret_cast<typename UnderlyingKernel::ElementC*>(D.data_ptr()),
|
||||
alpha, beta,
|
||||
split_k_mode, stream, B.device().index());
|
||||
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
|
||||
return D;
|
||||
}
|
||||
"""
|
||||
|
||||
_PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x = (
|
||||
common._CUTLASS_KERNEL_RUN_CONV2D_2x
|
||||
+ """
|
||||
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
||||
float alpha=1.f, float beta=0.f, std::string split_k_mode="serial", int split_k_slices=1) {
|
||||
int N, H, W, C_, K, R, S, P, Q;
|
||||
N = A.size(0);
|
||||
C_ = A.size(1);
|
||||
H = A.size(2);
|
||||
W = A.size(3);
|
||||
|
||||
K = B.size(0);
|
||||
R = B.size(2);
|
||||
S = B.size(3);
|
||||
|
||||
cutlass::conv::Conv2dProblemSize problem_size(
|
||||
cutlass::Tensor4DCoord(N, H, W, C_),
|
||||
cutlass::Tensor4DCoord(K, R, S, C_),
|
||||
cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)),
|
||||
cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)),
|
||||
cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)),
|
||||
cutlass::conv::Mode::kCrossCorrelation,
|
||||
split_k_slices
|
||||
);
|
||||
|
||||
P = problem_size.P;
|
||||
Q = problem_size.Q;
|
||||
|
||||
typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
|
||||
nullptr :
|
||||
reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
|
||||
|
||||
torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
|
||||
at::Tensor D = torch::zeros({N, K, P, Q}, options);
|
||||
""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
|
||||
)
|
||||
|
||||
|
||||
_PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x = (
|
||||
common._CUTLASS_KERNEL_RUN_CONV2D_2x
|
||||
+ """
|
||||
at::Tensor ${name}_kernel(std::tuple<int, int, int, int> input_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1}, float alpha=1.f, float beta=0.f,
|
||||
std::string split_k_mode="serial", int split_k_slices=1) {
|
||||
int N, H, W, C_, K, R, S;
|
||||
N = std::get<0>(input_size);
|
||||
C_ = std::get<1>(input_size);
|
||||
H = std::get<2>(input_size);
|
||||
W = std::get<3>(input_size);
|
||||
|
||||
K = B.size(0);
|
||||
R = B.size(2);
|
||||
S = B.size(3);
|
||||
|
||||
cutlass::conv::Conv2dProblemSize problem_size(
|
||||
cutlass::Tensor4DCoord(N, H, W, C_),
|
||||
cutlass::Tensor4DCoord(K, R, S, C_),
|
||||
cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)),
|
||||
cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)),
|
||||
cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)),
|
||||
cutlass::conv::Mode::kCrossCorrelation,
|
||||
split_k_slices
|
||||
);
|
||||
|
||||
typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
|
||||
nullptr :
|
||||
reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
|
||||
|
||||
torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
|
||||
at::Tensor D = torch::empty({N, C_, H, W}, options);
|
||||
""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
|
||||
)
|
||||
|
||||
|
||||
_PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x = (
|
||||
common._CUTLASS_KERNEL_RUN_CONV2D_2x
|
||||
+ """
|
||||
at::Tensor ${name}_kernel(std::tuple<int, int, int, int> weight_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1}, float alpha=1.f, float beta=0.f,
|
||||
std::string split_k_mode="serial", int split_k_slices=1) {
|
||||
int N, H, W, C_, K, R, S;
|
||||
K = std::get<0>(weight_size);
|
||||
C_ = std::get<1>(weight_size);
|
||||
R = std::get<2>(weight_size);
|
||||
S = std::get<3>(weight_size);
|
||||
|
||||
N = B.size(0);
|
||||
H = B.size(2);
|
||||
W = B.size(3);
|
||||
|
||||
cutlass::conv::Conv2dProblemSize problem_size(
|
||||
cutlass::Tensor4DCoord(N, H, W, C_),
|
||||
cutlass::Tensor4DCoord(K, R, S, C_),
|
||||
cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)),
|
||||
cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)),
|
||||
cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)),
|
||||
cutlass::conv::Mode::kCrossCorrelation,
|
||||
split_k_slices
|
||||
);
|
||||
|
||||
typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
|
||||
nullptr :
|
||||
reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
|
||||
|
||||
torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
|
||||
at::Tensor D = torch::empty({K, C_, R, S}, options);
|
||||
""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
|
||||
)
|
||||
|
||||
|
||||
_PYTORCH_SETUP_PY = common._PYSTYLE_AUTOGEN_COMMENT + """
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
setup(
|
||||
name='${name}',
|
||||
ext_modules=[
|
||||
CUDAExtension('${name}', [
|
||||
'${name}.cpp',
|
||||
'${name}_kernel.cu',
|
||||
],
|
||||
include_dirs=['${cutlass_path}/include', '${cutlass_path}/tools/util/include'],
|
||||
extra_compile_args={
|
||||
'cxx': ['-std=c++17'],
|
||||
'nvcc': ['-std=c++17', ${extra_compile_args}],
|
||||
},
|
||||
libraries=['cuda']
|
||||
),
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension
|
||||
})
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def _generate_setup(name: str, sourcedir: str, extra_compile_args: str=""):
|
||||
"""
|
||||
Generates a setup.py file for the extension
|
||||
|
||||
:param name: name of the module to generate
|
||||
:type name: str
|
||||
:param sourcedir: directory to which generated source files should be written
|
||||
:type sourcedir: str
|
||||
:param extra_compile_args: additional arguments to pass to setup.py
|
||||
:type extra_args: str
|
||||
"""
|
||||
setup_py_file = os.path.join(sourcedir, "setup.py")
|
||||
setup_source = SubstituteTemplate(
|
||||
_PYTORCH_SETUP_PY, {"name": name, "cutlass_path": CUTLASS_PATH, "extra_compile_args": extra_compile_args}
|
||||
)
|
||||
with open(setup_py_file, "w") as outfile:
|
||||
outfile.write(setup_source)
|
||||
|
||||
|
||||
class _ArchListSetter:
|
||||
"""
|
||||
Utility context manager for temporarily setting the value of the ``TORCH_CUDA_ARCH_LIST``
|
||||
environment variable when building a PyTorch CUDA module.
|
||||
|
||||
``TORCH_CUDA_ARCH_LIST`` is a space-delmited list of compute capabilites for which a PyTorch
|
||||
CUDA module should be compiled.
|
||||
|
||||
For example, ``TORCH_CUDA_ARCH_LIST="7.0 8.0"`` would result in the inclusion of
|
||||
``-gencode=arch=compute_70,code=sm_70`` and ``-gencode=arch=compute_80,code=sm_80`` in the
|
||||
compilation of the module.
|
||||
|
||||
This utility wraps the building of a PyTorch CUDA module with a setting of this environment
|
||||
variable according to the current compute capability being targetted.
|
||||
|
||||
Example usage:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
# Temporarily set TORCH_CUDA_ARCH_LIST="8.0"
|
||||
with _ArchListSetter(80):
|
||||
# Perform JIT compilation and loading of the module
|
||||
mod = torch.utils.cpp_extension.load(...)
|
||||
|
||||
:param cc: compute capability
|
||||
:type cc: int
|
||||
"""
|
||||
|
||||
_TORCH_CUDA_ARCH_LIST = "TORCH_CUDA_ARCH_LIST"
|
||||
|
||||
def __init__(self, cc: int):
|
||||
self.cc_str = ".".join(list(str(cc)))
|
||||
|
||||
def __enter__(self):
|
||||
"""
|
||||
Saves the old value of TORCH_CUDA_ARCH_LIST and reset it to the new value based on ``cc``
|
||||
"""
|
||||
self.old_arch_list = os.getenv(_ArchListSetter._TORCH_CUDA_ARCH_LIST)
|
||||
os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] = self.cc_str
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, traceback):
|
||||
"""
|
||||
Restores the old value of TORCH_CUDA_ARCH_LIST
|
||||
"""
|
||||
if self.old_arch_list is None:
|
||||
del os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST]
|
||||
else:
|
||||
os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] = self.old_arch_list
|
||||
|
||||
|
||||
def _jit(name: str, cc: int, cpp_file: str, cuda_file: str):
|
||||
"""
|
||||
JIT compiles and loads a PyTorch CUDA extension.
|
||||
|
||||
:param name: name of the module to generate
|
||||
:type name: str
|
||||
:param cc: compute capability of the device the module should target
|
||||
:type cc: int
|
||||
:param cpp_file: path to file containing extension's C++ interface
|
||||
:type cpp_file: str
|
||||
:param cuda_file: path to file containing extension's CUDA interface
|
||||
:type cuda_file: str
|
||||
|
||||
:return: loaded PyTorch module
|
||||
"""
|
||||
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
extra_cuda_cflags = ["-std=c++17"]
|
||||
if cc in [90, 100, 101, 103]:
|
||||
# PyTorch does not currently add the sm_90a target when compute capability
|
||||
# 9.0 is set within TORCH_CUDA_ARCH_LIST. Thus, we manually add the sm_90a target.
|
||||
extra_cuda_cflags.append(f"-gencode=arch=compute_{cc}a,code=sm_{cc}a")
|
||||
|
||||
with _ArchListSetter(cc):
|
||||
jitmodule = load(
|
||||
name,
|
||||
[cpp_file, cuda_file],
|
||||
extra_cuda_cflags=extra_cuda_cflags,
|
||||
extra_include_paths=[
|
||||
os.path.join(CUTLASS_PATH, "include"),
|
||||
os.path.join(CUTLASS_PATH, "tools/util/include"),
|
||||
],
|
||||
extra_ldflags=["-lcuda"],
|
||||
verbose=(logger.level == logging.DEBUG)
|
||||
)
|
||||
return jitmodule
|
||||
|
||||
|
||||
def _pytorch_gemm(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
|
||||
"""
|
||||
Generates source for building a PyTorch CUDA module that leverages the CUTLASS GEMM
|
||||
specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
|
||||
compiled, loaded, and returned.
|
||||
|
||||
:param op: operation to emit in the module
|
||||
:param name: name of the module to generate
|
||||
:type name: str
|
||||
:param cc: compute capability of the device the module should target
|
||||
:type cc: int
|
||||
:param jit: whether the module should be just-in-time compiled
|
||||
:type jit: bool
|
||||
:param sourcedir: directory to which generated source files should be written
|
||||
:type sourcedir: str
|
||||
|
||||
:return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise
|
||||
"""
|
||||
if sourcedir != "" and not os.path.isdir(sourcedir):
|
||||
os.makedirs(sourcedir)
|
||||
|
||||
cuda_file = os.path.join(sourcedir, name + "_kernel.cu")
|
||||
extra_kw = {}
|
||||
if op.api == ApiVersion.v3x:
|
||||
impl_template = _PYTORCH_GEMM_IMPL_TEMPLATE_3x
|
||||
else:
|
||||
impl_template = _PYTORCH_GEMM_IMPL_TEMPLATE_2x
|
||||
if op.swizzling_functor == swizzle.ThreadblockSwizzleStreamK:
|
||||
extra_kw["args"] = common._CUTLASS_KERNEL_ARGS_2x_STREAM_K
|
||||
else:
|
||||
extra_kw["args"] = common._CUTLASS_KERNEL_ARGS_2x
|
||||
impl_template = (
|
||||
_PYTORCH_GEMM_IMPL_TEMPLATE_3x
|
||||
if op.api == ApiVersion.v3x
|
||||
else _PYTORCH_GEMM_IMPL_TEMPLATE_2x
|
||||
)
|
||||
cuda_impl = SubstituteTemplate(impl_template, {"name": name, **extra_kw})
|
||||
cuda_source = SubstituteTemplate(
|
||||
_PYTORCH_CUDA_TEMPLATE,
|
||||
{
|
||||
"includes": _PYTORCH_GEMM_INCLUDES[op.api],
|
||||
"declaration": op.rt_module.emit(),
|
||||
"procedural_name": op.procedural_name(),
|
||||
"impl": cuda_impl,
|
||||
"torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element],
|
||||
},
|
||||
)
|
||||
with open(cuda_file, "w") as outfile:
|
||||
outfile.write(cuda_source)
|
||||
|
||||
cpp_file = os.path.join(sourcedir, name + ".cpp")
|
||||
cpp_source = SubstituteTemplate(
|
||||
_PYTORCH_GEMM_CPP_TEMPLATE,
|
||||
{"name": name, "description": f"CUTLASS {op.procedural_name()} GEMM"},
|
||||
)
|
||||
with open(cpp_file, "w") as outfile:
|
||||
outfile.write(cpp_source)
|
||||
|
||||
extra_compile_args = ""
|
||||
if cc in [90, 100, 101, 103]:
|
||||
extra_compile_args = f"'--generate-code=arch=compute_{cc}a,code=[sm_{cc}a]'"
|
||||
_generate_setup(name, sourcedir, extra_compile_args)
|
||||
|
||||
if jit:
|
||||
return _jit(name, cc, cpp_file, cuda_file)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _pytorch_grouped_gemm(
|
||||
op, name: str, cc: int, jit: bool = False, sourcedir: str = ""
|
||||
):
|
||||
"""
|
||||
Generates source for building a PyTorch CUDA module that leverages the CUTLASS grouped GEMM
|
||||
specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
|
||||
compiled, loaded, and returned.
|
||||
|
||||
:param op: operation to emit in the module
|
||||
:param name: name of the module to generate
|
||||
:type name: str
|
||||
:param cc: compute capability of the device the module should target
|
||||
:type cc: int
|
||||
:param jit: whether the module should be just-in-time compiled
|
||||
:type jit: bool
|
||||
:param sourcedir: directory to which generated source files should be written
|
||||
:type sourcedir: str
|
||||
|
||||
:return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise
|
||||
"""
|
||||
if op.api != ApiVersion.v2x:
|
||||
raise Exception("Grouped GEMM is currently only supported for CUTLASS 2.x")
|
||||
|
||||
if sourcedir != "" and not os.path.isdir(sourcedir):
|
||||
os.makedirs(sourcedir)
|
||||
|
||||
cuda_file = os.path.join(sourcedir, name + "_kernel.cu")
|
||||
cuda_impl = SubstituteTemplate(_PYTORCH_GROUPED_GEMM_IMPL_TEMPLATE, {"name": name})
|
||||
cuda_source = SubstituteTemplate(
|
||||
_PYTORCH_CUDA_TEMPLATE,
|
||||
{
|
||||
"includes": _PYTORCH_GROUPED_GEMM_INCLUDES,
|
||||
"declaration": op.rt_module.emit(),
|
||||
"procedural_name": op.procedural_name(),
|
||||
"impl": cuda_impl,
|
||||
"torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element],
|
||||
},
|
||||
)
|
||||
with open(cuda_file, "w") as outfile:
|
||||
outfile.write(cuda_source)
|
||||
|
||||
cpp_file = os.path.join(sourcedir, name + ".cpp")
|
||||
cpp_source = SubstituteTemplate(
|
||||
_PYTORCH_GROUPED_GEMM_CPP_TEMPLATE,
|
||||
{"name": name, "description": f"CUTLASS {op.procedural_name()} grouped GEMM"},
|
||||
)
|
||||
with open(cpp_file, "w") as outfile:
|
||||
outfile.write(cpp_source)
|
||||
|
||||
_generate_setup(name, sourcedir)
|
||||
|
||||
if jit:
|
||||
return _jit(name, cc, cpp_file, cuda_file)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _pytorch_conv2d(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
|
||||
"""
|
||||
Generates source for building a PyTorch CUDA module that leverages the CUTLASS Conv2d
|
||||
specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
|
||||
compiled, loaded, and returned.
|
||||
|
||||
:param op: operation to emit in the module
|
||||
:param name: name of the module to generate
|
||||
:type name: str
|
||||
:param cc: compute capability of the device the module should target
|
||||
:type cc: int
|
||||
:param jit: whether the module should be just-in-time compiled
|
||||
:type jit: bool
|
||||
:param sourcedir: directory to which generated source files should be written
|
||||
:type sourcedir: str
|
||||
|
||||
Note that the when conv kind is `dgrad` or `wgrad`, the size of the input `(N, C, H, W)` or
|
||||
weight `(K, C, R, S)` should be provided. This is because there are multiple valid solutions
|
||||
for H/W/R/S given the same P/Q.
|
||||
|
||||
:return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise
|
||||
"""
|
||||
if sourcedir != "" and not os.path.isdir(sourcedir):
|
||||
os.makedirs(sourcedir)
|
||||
cuda_file = os.path.join(sourcedir, name + "_kernel.cu")
|
||||
extra_kw = {}
|
||||
if op.conv_kind == ConvKind.Fprop:
|
||||
impl_template = _PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x
|
||||
cpp_template = _PYTORCH_CONV2D_FPROP_CPP_TEMPLATE
|
||||
elif op.conv_kind == ConvKind.Dgrad:
|
||||
impl_template = _PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x
|
||||
cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE
|
||||
elif op.conv_kind == ConvKind.Wgrad:
|
||||
impl_template = _PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x
|
||||
cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE
|
||||
extra_kw["conv_kind_name"] = ConvKindNames[op.conv_kind].capitalize()
|
||||
extra_kw["torch_type_C"] = _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element]
|
||||
cuda_impl = SubstituteTemplate(impl_template, {"name": name, **extra_kw})
|
||||
cuda_source = SubstituteTemplate(
|
||||
_PYTORCH_CUDA_TEMPLATE,
|
||||
{
|
||||
"includes": _PYTORCH_CONV2D_INCLUDES,
|
||||
"declaration": op.rt_module.emit(),
|
||||
"procedural_name": op.procedural_name(),
|
||||
"impl": cuda_impl,
|
||||
"torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element],
|
||||
},
|
||||
)
|
||||
with open(cuda_file, "w") as outfile:
|
||||
outfile.write(cuda_source)
|
||||
|
||||
cpp_file = os.path.join(sourcedir, name + ".cpp")
|
||||
cpp_source = SubstituteTemplate(
|
||||
cpp_template,
|
||||
{"name": name, "description": f"CUTLASS {op.procedural_name()} Conv2d"},
|
||||
)
|
||||
with open(cpp_file, "w") as outfile:
|
||||
outfile.write(cpp_source)
|
||||
|
||||
_generate_setup(name, sourcedir)
|
||||
|
||||
if jit:
|
||||
return _jit(name, cc, cpp_file, cuda_file)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def pytorch(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
|
||||
"""
|
||||
Generates source for building a PyTorch CUDA module that leverages the CUTLASS kernel
|
||||
specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
|
||||
compiled, loaded, and returned.
|
||||
|
||||
The result of this method is files within ``sourcedir`` that can be used for building
|
||||
a PyTorch module.
|
||||
|
||||
:param op: operation to emit in the module
|
||||
:param name: name of the module to generate
|
||||
:type name: str
|
||||
:param cc: compute capability of the device the module should target
|
||||
:type cc: int
|
||||
:param jit: whether the module should be just-in-time compiled
|
||||
:type jit: bool
|
||||
:param sourcedir: directory to which generated source files should be written
|
||||
:type sourcedir: str
|
||||
|
||||
:return: loaded PyTorch module (if ``jit=True``) or None
|
||||
"""
|
||||
device_op = op.device_op()
|
||||
if isinstance(op, GemmOperationUniversal):
|
||||
return _pytorch_gemm(device_op, name, cc, jit, sourcedir)
|
||||
elif isinstance(op, GemmOperationGrouped):
|
||||
return _pytorch_grouped_gemm(device_op, name, cc, jit, sourcedir)
|
||||
elif isinstance(op, Conv2dOperation):
|
||||
return _pytorch_conv2d(device_op, name, cc, jit, sourcedir)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Operation type {type(op)} is not currently supported for PyTorch emission."
|
||||
)
|
||||
56
python/cutlass_cppgen/epilogue/__init__.py
Normal file
56
python/cutlass_cppgen/epilogue/__init__.py
Normal file
@ -0,0 +1,56 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass_cppgen.epilogue.epilogue import (
|
||||
get_activations,
|
||||
get_activation_epilogue,
|
||||
gelu,
|
||||
hardswish,
|
||||
identity,
|
||||
leaky_relu,
|
||||
relu,
|
||||
sigmoid,
|
||||
silu,
|
||||
tanh,
|
||||
trace
|
||||
)
|
||||
|
||||
from cutlass_cppgen.epilogue.evt_ops import (
|
||||
max,
|
||||
multiply_add,
|
||||
sum,
|
||||
permute,
|
||||
reshape,
|
||||
maximum,
|
||||
minimum,
|
||||
exp
|
||||
)
|
||||
176
python/cutlass_cppgen/epilogue/epilogue.py
Normal file
176
python/cutlass_cppgen/epilogue/epilogue.py
Normal file
@ -0,0 +1,176 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Registry of elementwise epilogues
|
||||
|
||||
Elementwise epilogues can be added to many CUTLASS kernels in the CUTLAS Python interface via
|
||||
code like the following for GEMM:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
plan = cutlass_cppgen.op.Gemm(element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
plan.activation = cutlass_cppgen.epilogue.relu
|
||||
"""
|
||||
|
||||
from cutlass_cppgen.backend import epilogue, device_cc
|
||||
|
||||
|
||||
gelu = epilogue.gelu
|
||||
hardswish = epilogue.hardswish
|
||||
identity = epilogue.identity
|
||||
leaky_relu = epilogue.leaky_relu
|
||||
relu = epilogue.relu
|
||||
sigmoid = epilogue.sigmoid
|
||||
silu = epilogue.silu
|
||||
tanh = epilogue.tanh
|
||||
|
||||
|
||||
_activations = [gelu, hardswish, identity, leaky_relu, relu, sigmoid, silu, tanh]
|
||||
|
||||
|
||||
def get_activations() -> list:
|
||||
"""
|
||||
Returns a list of available activation functions
|
||||
|
||||
:return: list of available activation functions
|
||||
:rtype: list
|
||||
"""
|
||||
return _activations
|
||||
|
||||
|
||||
def get_activation_epilogue(
|
||||
activation,
|
||||
element_output,
|
||||
elements_per_access,
|
||||
element_accumulator,
|
||||
element_compute,
|
||||
):
|
||||
"""
|
||||
Return an epilogue corresponding to the activation function, data types, and alignment
|
||||
used in the kernel
|
||||
|
||||
:param activation: elementwise activation function to use
|
||||
:param element_output: data type of the output
|
||||
:param elements_per_access: alignment of operand C of the kernel
|
||||
:type elements_per_access: int
|
||||
:param element_accumulator: data type of the accumulated output C
|
||||
:param element_compute: data type in which compute operations should be performed
|
||||
|
||||
:return: epilogue functor
|
||||
"""
|
||||
if activation not in _activations:
|
||||
raise Exception(
|
||||
f"Unsupported activation type {activation}. Available activations are: {_activations}"
|
||||
)
|
||||
|
||||
if activation == identity:
|
||||
return epilogue.LinearCombination(
|
||||
element_output, elements_per_access, element_accumulator, element_compute
|
||||
)
|
||||
else:
|
||||
return epilogue.LinearCombinationGeneric(
|
||||
activation,
|
||||
element_output,
|
||||
elements_per_access,
|
||||
element_accumulator,
|
||||
element_compute,
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
Frontend for EVT that generates epilogue functor through tracing the input function
|
||||
"""
|
||||
from cutlass_cppgen.backend.evt.frontend import PythonASTFrontend
|
||||
|
||||
|
||||
def trace(fn, example_tensors, **kwargs):
|
||||
"""
|
||||
Trace `fn(**example_tensors)` and generates epilogue visitor
|
||||
|
||||
:param fn or str: Python callable or string of the epilogue function
|
||||
:param example_tensors: example inputs for fn
|
||||
:type example_tensors: dict
|
||||
|
||||
.. hightlight:: python
|
||||
.. code-block:: python
|
||||
import cutlass_cppgen.backend.evt
|
||||
|
||||
# Define epilogue function as Python callable
|
||||
def example_fn(accum, C, alpha, beta, gamma):
|
||||
D = ((accum + C) * alpha - gamma) / beta
|
||||
return D
|
||||
|
||||
# Define the example tensors
|
||||
example_inputs = {
|
||||
"accum": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda"),
|
||||
"C": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda"),
|
||||
"alpha": 1.5,
|
||||
"beta": 0.5,
|
||||
"gamma": 2.5,
|
||||
"D": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda")
|
||||
}
|
||||
|
||||
# Generate the epilogue functor
|
||||
epilogue_visitor = cutlass_cppgen.epilogue.trace(example_fn, example_inputs)
|
||||
"""
|
||||
if callable(fn):
|
||||
class EpilogueFunctor(PythonASTFrontend):
|
||||
def __init__(self, cc=None, **kwargs):
|
||||
if not cc:
|
||||
cc = device_cc()
|
||||
super().__init__(cc, **kwargs)
|
||||
pass
|
||||
setattr(EpilogueFunctor, "__call__", staticmethod(fn))
|
||||
|
||||
epilogue_functor = EpilogueFunctor(**kwargs)
|
||||
epilogue_functor.trace(example_tensors)
|
||||
return epilogue_functor
|
||||
elif isinstance(fn, str):
|
||||
class EpilogueFunctor(PythonASTFrontend):
|
||||
def __init__(self, cc=None, **kwargs):
|
||||
self.source = textwrap.dedent(fn)
|
||||
if not cc:
|
||||
cc = device_cc()
|
||||
super().__init__(cc, **kwargs)
|
||||
|
||||
def parse(self, example_inputs) -> None:
|
||||
self.example_inputs = example_inputs
|
||||
self.ast = ast.parse(self.source)
|
||||
self.visit(self.ast)
|
||||
|
||||
epilogue_functor = EpilogueFunctor(**kwargs)
|
||||
epilogue_functor.trace(example_tensors)
|
||||
return epilogue_functor
|
||||
else:
|
||||
raise NotImplementedError("Expect a callable Python function")
|
||||
98
python/cutlass_cppgen/epilogue/evt_ops.py
Normal file
98
python/cutlass_cppgen/epilogue/evt_ops.py
Normal file
@ -0,0 +1,98 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Collection of builtin functions used for host reference in EVT
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_available, is_torch_tensor
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
def multiply_add(x, y, z):
|
||||
return x * y + z
|
||||
|
||||
|
||||
def sum(x, dim):
|
||||
if is_numpy_tensor(x):
|
||||
return x.sum(axis=tuple(dim))
|
||||
elif is_torch_tensor(x):
|
||||
return torch.sum(x, dim)
|
||||
|
||||
|
||||
def max(x, dim):
|
||||
if is_numpy_tensor(x):
|
||||
return x.max(axis=tuple(dim))
|
||||
elif is_torch_tensor(x):
|
||||
return torch.amax(x, dim)
|
||||
|
||||
|
||||
def maximum(x, y):
|
||||
if is_numpy_tensor(x):
|
||||
return np.maximum(x, y)
|
||||
elif is_torch_tensor(x):
|
||||
return torch.maximum(x, torch.tensor(y))
|
||||
|
||||
|
||||
def minimum(x, y):
|
||||
if is_numpy_tensor(x):
|
||||
return np.minimum(x, y)
|
||||
elif is_torch_tensor(x):
|
||||
return torch.minimum(x, torch.tensor(y))
|
||||
|
||||
def exp(x):
|
||||
if is_numpy_tensor(x):
|
||||
return np.exp(x)
|
||||
elif is_torch_tensor(x):
|
||||
return torch.exp(x)
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Layout manipulate nodes
|
||||
##############################################################################
|
||||
|
||||
def permute(x, indices: tuple):
|
||||
if is_numpy_tensor(x):
|
||||
return np.transpose(x, axes=indices)
|
||||
elif is_torch_tensor(x):
|
||||
return x.permute(*indices)
|
||||
|
||||
|
||||
def reshape(x, new_shape: tuple):
|
||||
if is_numpy_tensor(x):
|
||||
return np.reshape(x, newshape=new_shape)
|
||||
elif is_torch_tensor(x):
|
||||
return x.view(new_shape)
|
||||
569
python/cutlass_cppgen/library_defaults.py
Normal file
569
python/cutlass_cppgen/library_defaults.py
Normal file
@ -0,0 +1,569 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Classes containing valid operations for a given compute capability and data types.
|
||||
"""
|
||||
|
||||
from itertools import combinations_with_replacement
|
||||
import logging
|
||||
|
||||
import cutlass_library
|
||||
from cutlass_library.library import ConvKind, IteratorAlgorithm, StrideSupport, GroupMode
|
||||
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen.utils.check import valid_stage_count
|
||||
from cutlass_cppgen.utils.datatypes import td_from_profiler_td, td_from_profiler_op
|
||||
|
||||
|
||||
_generator_ccs = [50, 60, 61, 70, 75, 80, 90, 100]
|
||||
|
||||
|
||||
class KernelsForDataType:
|
||||
"""
|
||||
Container class for keeping track of kernels that correspond to a particular combination
|
||||
of data types for operands A, B, and accumulator
|
||||
"""
|
||||
|
||||
def __init__(self, datatype_comb: tuple, layout_comb: tuple):
|
||||
self.datatype_comb = datatype_comb
|
||||
self.layout_comb = layout_comb
|
||||
self.math_operations = set()
|
||||
|
||||
# Dictionary mapping from alignment (int) to a list of kernels that fit the alignment
|
||||
# constraint for the data type combination
|
||||
self.kernels_by_alignment = {}
|
||||
|
||||
def add(self, operation):
|
||||
"""
|
||||
Add an operation to the list of supported kernels
|
||||
"""
|
||||
alignment_key = f"{operation.A.alignment} {operation.B.alignment} {operation.C.alignment}"
|
||||
if alignment_key not in self.kernels_by_alignment:
|
||||
self.kernels_by_alignment[alignment_key] = []
|
||||
self.kernels_by_alignment[alignment_key].append(operation)
|
||||
self.math_operations.add(operation.tile_description.math_instruction.math_operation)
|
||||
|
||||
def alignments(self, operand: str):
|
||||
"""
|
||||
Returns an unsorted list of alignments supported by this data type combination
|
||||
|
||||
:param operand: identifier of operand in question (e.g., A, B, C)
|
||||
:type operand: str
|
||||
|
||||
:return: unsorted list of alignments supported by this data type combination
|
||||
:rtype: list
|
||||
"""
|
||||
operand_idx = self._operand_idx(operand)
|
||||
return [int(key.split(" ")[operand_idx]) for key in self.kernels_by_alignment.keys()]
|
||||
|
||||
@property
|
||||
def all_operations(self):
|
||||
"""
|
||||
Returns a list of all operations supported by this data type combination
|
||||
|
||||
:return: list of all operations supported by this data type combination
|
||||
:rtype: list
|
||||
"""
|
||||
ops = []
|
||||
for _, alignment_ops in self.kernels_by_alignment.items():
|
||||
ops.extend(alignment_ops)
|
||||
return ops
|
||||
|
||||
def default_operation(self, math_operation: cutlass_cppgen.MathOperation):
|
||||
key = sorted(list(self.kernels_by_alignment.keys()))[0]
|
||||
kernels = self.kernels_by_alignment[key]
|
||||
if math_operation is not None:
|
||||
kernels = [x for x in kernels if x.tile_description.math_instruction.math_operation == math_operation]
|
||||
return kernels[0]
|
||||
|
||||
def operations(self, alignment_A: int, alignment_B: int, alignment_C: int, math_operation: cutlass_cppgen.MathOperation):
|
||||
"""
|
||||
Returns operations satisfying the alignment constraints
|
||||
|
||||
:param alignment_A: alignment constraint of operations to return
|
||||
:type alignment_A: int
|
||||
:param alignment_B: alignment constraint of operations to return
|
||||
:type alignment_B: int
|
||||
:param alignment_C: alignment constraint of operations to return
|
||||
:type alignment_C: int
|
||||
:param math_operation: math operation to consider
|
||||
:type math_operation: cutlass_cppgen.MathOperation
|
||||
|
||||
:return: list of operations
|
||||
:rtype: list
|
||||
"""
|
||||
key = f"{alignment_A} {alignment_B} {alignment_C}"
|
||||
|
||||
if key not in self.kernels_by_alignment:
|
||||
og_key = key
|
||||
# Reconcile A, B, and C alignments by trying to align to the minimum
|
||||
min_alignment = min(alignment_A, alignment_B, alignment_C)
|
||||
key = f"{min_alignment} {min_alignment} {min_alignment}"
|
||||
if key not in self.kernels_by_alignment:
|
||||
# Finally, go through all available alignment combinations and find
|
||||
# one for which all values are less than those passed in.
|
||||
key = None
|
||||
alignments = sorted([tuple(int(x) for x in k.split(" ")) for k in self.kernels_by_alignment.keys()], reverse=True)
|
||||
for align_A, align_B, align_C in alignments:
|
||||
if alignment_A % align_A == 0 and alignment_B % align_B == 0 and alignment_C % align_C == 0:
|
||||
key = f"{align_A} {align_B} {align_C}"
|
||||
break
|
||||
|
||||
if key is None:
|
||||
raise Exception(
|
||||
f"No operations of alignment {og_key} found for data type and layout "
|
||||
f"combination {self.datatype_comb} {self.layout_comb}. Compatible alignments "
|
||||
f"are {self.kernels_by_alignment.keys()}"
|
||||
)
|
||||
|
||||
ops = self.kernels_by_alignment[key]
|
||||
if math_operation is not None:
|
||||
ops = [op for op in ops if op.tile_description.math_instruction.math_operation == math_operation]
|
||||
return ops
|
||||
|
||||
def _operand_idx(self, key: str) -> int:
|
||||
operand_list = ["A", "B", "C"]
|
||||
if key not in operand_list:
|
||||
raise Exception(f"Unexpected operand {operand}")
|
||||
|
||||
return operand_list.index(key)
|
||||
|
||||
def find_alignment(self, shape: tuple, layout: cutlass_cppgen.LayoutType, operand=str) -> int:
|
||||
"""
|
||||
Returns the most preferable alignment for a given shape and layout
|
||||
|
||||
:param shape: extent of each dimension of the tensor
|
||||
:type shape: tuple
|
||||
:param layout: layout of the tensor
|
||||
:type layout: cutlass_cppgen.LayoutType
|
||||
:param operand: descriptor of the operand in question
|
||||
:type operand: str
|
||||
|
||||
:return: maximum alignment supported by the data type combination and tensor size
|
||||
:rtype: int
|
||||
"""
|
||||
operand_idx = self._operand_idx(operand)
|
||||
|
||||
# Determine the leading dimension of the shape
|
||||
if layout == cutlass_cppgen.LayoutType.ColumnMajor:
|
||||
ld = shape[-2]
|
||||
elif layout == cutlass_cppgen.LayoutType.RowMajor:
|
||||
ld = shape[-1]
|
||||
elif layout == cutlass_cppgen.LayoutType.TensorNHWC:
|
||||
ld = shape[-1]
|
||||
else:
|
||||
raise Exception(f"Unexpected or unsupported layout {layout}")
|
||||
|
||||
for alignments in sorted(list(self.kernels_by_alignment.keys()), reverse=True):
|
||||
alignment = int(alignments.split(" ")[operand_idx])
|
||||
if ld % alignment == 0:
|
||||
return alignment
|
||||
|
||||
# Default to alignment of 1 if no others match
|
||||
return 1
|
||||
|
||||
def sort(self):
|
||||
"""
|
||||
Sorts each list of kernels in `kernels_by_alignment` in descending order of threadblock shape
|
||||
"""
|
||||
key = lambda op: (
|
||||
op.tile_description.threadblock_shape[0]
|
||||
* op.tile_description.threadblock_shape[1]
|
||||
* op.tile_description.threadblock_shape[2]
|
||||
)
|
||||
for alignment in self.kernels_by_alignment.keys():
|
||||
self.kernels_by_alignment[alignment].sort(key=key, reverse=True)
|
||||
|
||||
def supports_math_operation(self, math_operation: cutlass_cppgen.MathOperation) -> bool:
|
||||
"""
|
||||
Returns whether `math_operation` is supported by at least one operation.
|
||||
|
||||
:param math_operation: math operation to consider
|
||||
:type math_operation: cutlass_cppgen.MathOperation
|
||||
|
||||
:return: whether math_operation is supported by at least one operation
|
||||
:rtype: bool
|
||||
"""
|
||||
return math_operation is None or math_operation in self.math_operations
|
||||
|
||||
|
||||
class ArchOptions:
|
||||
"""
|
||||
Structure for keeping track of kernels available on a given compute capability
|
||||
|
||||
:param target_cc: compute capability of the device on which kernels will be run
|
||||
:type target_cc: int
|
||||
:param kernel_cc: compute capability of the kernels to generate
|
||||
:type kernel_cc: int
|
||||
:param operation_kind: type of operation to register
|
||||
:type operation_kind: cutlass_library.OperationKind
|
||||
:param gemm_kinds: types of GEMM operations that can be included
|
||||
:type gemm_kinds: list
|
||||
:param allowed_math_operations: types of primitive math operations allowed
|
||||
:type allowed_math_operations: list
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_cc: int,
|
||||
kernel_cc: int,
|
||||
operation_kind: cutlass_library.OperationKind,
|
||||
gemm_kinds: list,
|
||||
allowed_math_operations: list = [
|
||||
cutlass_library.MathOperation.multiply_add,
|
||||
cutlass_library.MathOperation.multiply_add_saturate,
|
||||
cutlass_library.MathOperation.multiply_add_mixed_input_upcast,
|
||||
cutlass_library.MathOperation.multiply_add_fast_f32
|
||||
]
|
||||
):
|
||||
self.cc = kernel_cc
|
||||
|
||||
# Dictionary with following structure:
|
||||
# Key: OpcodeClass
|
||||
# Value: Dictionary with the following structure:
|
||||
# Key: tuple of ((DataType, DataType, DataType), (LayoutType, LayoutType, LayoutType),
|
||||
# representing ((element_a, element_b, element_accumulator), (layout_a, layout_b))
|
||||
# Value: KernelsForDataType
|
||||
self.operations_by_opclass = {}
|
||||
self.op_class = None
|
||||
self.allowed_math_operations = allowed_math_operations
|
||||
|
||||
if target_cc == 100 and kernel_cc == 90 or target_cc == 90 and kernel_cc == 100:
|
||||
return
|
||||
|
||||
# Identify the method within CUTLASS generator script that generates kernel
|
||||
# descriptions for the target CC
|
||||
generate_function_name = "GenerateSM" + str(kernel_cc)
|
||||
if not hasattr(cutlass_library.generator, generate_function_name):
|
||||
cutlass_cppgen.logger.warning(f"No generator found for architecture {kernel_cc}")
|
||||
return
|
||||
generate_function = getattr(cutlass_library.generator, generate_function_name)
|
||||
|
||||
# Initialize a default manifest and populate it with valid kernel descriptions
|
||||
# for the target CC
|
||||
args = [
|
||||
"--kernels=all",
|
||||
f"--log-level={logging.getLevelName(cutlass_cppgen.logger.level)}"
|
||||
]
|
||||
manifest_args = cutlass_library.generator.define_parser().parse_args(args)
|
||||
manifest = cutlass_library.manifest.Manifest(manifest_args)
|
||||
generate_function(manifest, cutlass_cppgen._nvcc_version)
|
||||
|
||||
if operation_kind not in manifest.operations:
|
||||
# No kernels generated for this architecture, this could be because the CUDA
|
||||
# toolkit is insufficient to support operations in this CC
|
||||
cutlass_cppgen.logger.warning(f"No operations of type {operation_kind} found for CC {kernel_cc}")
|
||||
return
|
||||
|
||||
# Only one CC should be returned, given the setup above of calling only the generation scripts
|
||||
# for a given CC
|
||||
if len(manifest.operations[operation_kind].keys()) != 1 or kernel_cc not in manifest.operations[operation_kind]:
|
||||
raise Exception(f"Error finding kernels for SM{kernel_cc}. Check that your CUDA toolkit version "
|
||||
"is sufficient for the architecture in question.")
|
||||
|
||||
# Iterate through the available operations for this operation kind and
|
||||
# find available opclasses and data types
|
||||
for name, op_list in manifest.operations[operation_kind][kernel_cc].items():
|
||||
for op in op_list:
|
||||
|
||||
if operation_kind == cutlass_library.OperationKind.Gemm:
|
||||
if op.gemm_kind not in gemm_kinds:
|
||||
continue
|
||||
|
||||
mi = op.tile_description.math_instruction
|
||||
if mi.math_operation not in self.allowed_math_operations:
|
||||
continue
|
||||
|
||||
# Prune operations that don't fit in shared memory
|
||||
td = td_from_profiler_op(op)
|
||||
if not valid_stage_count(target_cc, kernel_cc, td, verbose=False)[0]:
|
||||
continue
|
||||
|
||||
if mi.opcode_class not in self.operations_by_opclass:
|
||||
self.operations_by_opclass[mi.opcode_class] = {}
|
||||
|
||||
datatype_comb = (mi.element_a, mi.element_b, mi.element_accumulator)
|
||||
layout_comb = (op.A.layout, op.B.layout)
|
||||
|
||||
# Register TF32 kernels as F32 to enable F32 -> TF32 conversion + TF32 Tensor Core operations
|
||||
if datatype_comb == (cutlass_library.DataType.tf32, cutlass_library.DataType.tf32, cutlass_library.DataType.f32):
|
||||
# TF32 kernels only supported on SM80 and beyond
|
||||
if self.cc < 80:
|
||||
continue
|
||||
elif self.cc == 90 or self.cc == 100:
|
||||
if (op.A.element != cutlass_library.DataType.f32
|
||||
or op.B.element != cutlass_library.DataType.f32
|
||||
or op.C.element != cutlass_library.DataType.f32):
|
||||
continue
|
||||
|
||||
datatype_comb = (cutlass_library.DataType.f32, cutlass_library.DataType.f32, cutlass_library.DataType.f32)
|
||||
|
||||
opclass_dict = self.operations_by_opclass[mi.opcode_class]
|
||||
key = (datatype_comb, layout_comb)
|
||||
if key not in opclass_dict:
|
||||
opclass_dict[key] = KernelsForDataType(datatype_comb, layout_comb)
|
||||
opclass_dict[key].add(op)
|
||||
|
||||
# Set the default opclass to TensorOp, if available. Otherwise default to SIMT
|
||||
if cutlass_library.OpcodeClass.TensorOp in self.operations_by_opclass:
|
||||
self.op_class = cutlass_library.OpcodeClass.TensorOp
|
||||
else:
|
||||
self.op_class = cutlass_library.OpcodeClass.Simt
|
||||
|
||||
# The profiler's generator may generate only a limited set of combinations of operands for SIMT kernels.
|
||||
# Here, we generate additional versions via a generic TileDescription.
|
||||
if cutlass_library.OpcodeClass.Simt not in self.operations_by_opclass:
|
||||
self.operations_by_opclass[cutlass_library.OpcodeClass.Simt] = {}
|
||||
|
||||
if operation_kind == cutlass_library.OperationKind.Gemm:
|
||||
types = [
|
||||
(cutlass_library.DataType.s8, cutlass_library.DataType.s8, cutlass_library.DataType.s8),
|
||||
(cutlass_library.DataType.s8, cutlass_library.DataType.s8, cutlass_library.DataType.s32),
|
||||
(cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f16),
|
||||
(cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f32),
|
||||
(cutlass_library.DataType.f32, cutlass_library.DataType.f32, cutlass_library.DataType.f32),
|
||||
(cutlass_library.DataType.f64, cutlass_library.DataType.f64, cutlass_library.DataType.f64),
|
||||
]
|
||||
|
||||
# Add FP8 A/B/C
|
||||
fp8_types = [cutlass_library.DataType.e4m3, cutlass_library.DataType.e5m2]
|
||||
for type_comb in combinations_with_replacement(fp8_types, 3):
|
||||
types.append(type_comb)
|
||||
|
||||
# Add FP8 A/B with FP32 C
|
||||
for type_comb in combinations_with_replacement(fp8_types, 2):
|
||||
types.append(type_comb + (cutlass_cppgen.DataType.f32,))
|
||||
|
||||
layouts = [
|
||||
(cutlass_library.LayoutType.RowMajor, cutlass_library.LayoutType.RowMajor),
|
||||
(cutlass_library.LayoutType.RowMajor, cutlass_library.LayoutType.ColumnMajor),
|
||||
(cutlass_library.LayoutType.ColumnMajor, cutlass_library.LayoutType.RowMajor),
|
||||
(cutlass_library.LayoutType.ColumnMajor, cutlass_library.LayoutType.ColumnMajor),
|
||||
]
|
||||
elif operation_kind == cutlass_library.OperationKind.Conv2d:
|
||||
types = [
|
||||
(cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f16),
|
||||
(cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f32),
|
||||
(cutlass_library.DataType.f32, cutlass_library.DataType.f32, cutlass_library.DataType.f32),
|
||||
(cutlass_library.DataType.f64, cutlass_library.DataType.f64, cutlass_library.DataType.f64),
|
||||
]
|
||||
|
||||
layouts = [
|
||||
(cutlass_library.LayoutType.TensorNHWC, cutlass_library.LayoutType.TensorNHWC),
|
||||
]
|
||||
else:
|
||||
raise NotImplementedError(f"Operation kind {operation_kind} is currently unsupported.")
|
||||
|
||||
alignment = 1
|
||||
epilogue_functor = cutlass_library.EpilogueFunctor.LinearCombination
|
||||
swizzling_functor = cutlass_library.SwizzlingFunctor.Identity8
|
||||
for type_comb in types:
|
||||
for layout_comb in layouts:
|
||||
comb = (type_comb, layout_comb)
|
||||
if comb in self.operations_by_opclass[cutlass_library.OpcodeClass.Simt]:
|
||||
continue
|
||||
|
||||
A = cutlass_library.TensorDescription(type_comb[0], layout_comb[0], alignment)
|
||||
B = cutlass_library.TensorDescription(type_comb[1], layout_comb[1], alignment)
|
||||
C = cutlass_library.TensorDescription(type_comb[2], cutlass_library.LayoutType.ColumnMajor, alignment)
|
||||
math_inst = cutlass_library.MathInstruction(
|
||||
[1, 1, 1],
|
||||
type_comb[0],
|
||||
type_comb[1],
|
||||
type_comb[2],
|
||||
cutlass_library.OpcodeClass.Simt,
|
||||
cutlass_library.MathOperation.multiply_add
|
||||
)
|
||||
|
||||
td = cutlass_library.TileDescription(
|
||||
[128, 128, 8], 2, [4, 2, 1], math_inst, 50, 1024)
|
||||
|
||||
# Prune operations that don't fit in shared memory
|
||||
if not valid_stage_count(target_cc, kernel_cc, td_from_profiler_td(td), verbose=False)[0]:
|
||||
continue
|
||||
|
||||
new_kernels = KernelsForDataType(type_comb, layout_comb)
|
||||
|
||||
if operation_kind == cutlass_library.OperationKind.Gemm:
|
||||
new_operation = cutlass_library.manifest.GemmOperation(
|
||||
cutlass_library.GemmKind.Universal, td.minimum_compute_capability,
|
||||
td, A, B, C, type_comb[2], epilogue_functor, swizzling_functor)
|
||||
new_kernels.add(new_operation)
|
||||
elif operation_kind == cutlass_library.OperationKind.Conv2d:
|
||||
for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]:
|
||||
new_operation = cutlass_library.manifest.Conv2dOperation(
|
||||
conv_kind, IteratorAlgorithm.Analytic, td.minimum_compute_capability, td,
|
||||
A, B, C, type_comb[2], StrideSupport.Strided, epilogue_functor, swizzling_functor,
|
||||
group_mode=GroupMode.SingleGroup
|
||||
)
|
||||
new_kernels.add(new_operation)
|
||||
|
||||
self.operations_by_opclass[cutlass_library.OpcodeClass.Simt][comb] = new_kernels
|
||||
|
||||
# Sort all operations
|
||||
for oc in self.operations_by_opclass.keys():
|
||||
for comb in self.operations_by_opclass[oc].keys():
|
||||
self.operations_by_opclass[oc][comb].sort()
|
||||
|
||||
def opclass_supports_combination(
|
||||
self, op_class: cutlass_library.OpcodeClass, datatype_comb: tuple, layout_comb: tuple, math_operation: cutlass_library.MathOperation
|
||||
) -> bool:
|
||||
"""
|
||||
Returns whether the provided operation class supports the provided data type and layout combination
|
||||
|
||||
:param op_class: operation class to consider
|
||||
:type op_class: cutlass_library.OpcodeClass
|
||||
:param datatype_comb: tuple of data types for (element_A, element_B, element_accumulator)
|
||||
:type datatype_comb: tuple[cutlass_library.DataType]
|
||||
:param layout_comb: tuple of data types for (layout_A, layout_B)
|
||||
:type layout_comb: tuple[cutlass_library.LayoutType]
|
||||
:param math_operation: math operation to consider or None if any can be considered
|
||||
:type math_operation: cutlass_cppgen.MathOperation
|
||||
|
||||
:return: set of operation classes that support the provided data type and layout combination
|
||||
:rtype: set
|
||||
"""
|
||||
if op_class not in self.operations_by_opclass:
|
||||
raise Exception(f"Unexpected or unsupported operation class {op_class}")
|
||||
|
||||
if operations := self.operations_by_opclass[op_class].get((datatype_comb, layout_comb)):
|
||||
if math_operation is not None:
|
||||
return operations.supports_math_operation(math_operation)
|
||||
else:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def supporting_opclasses(
|
||||
self,
|
||||
element_a: cutlass_library.DataType,
|
||||
element_b: cutlass_library.DataType,
|
||||
element_accumulator: cutlass_library.DataType,
|
||||
layout_a: cutlass_library.LayoutType,
|
||||
layout_b: cutlass_library.LayoutType,
|
||||
math_operation: cutlass_library.MathOperation,
|
||||
) -> set:
|
||||
"""
|
||||
Returns a set of operation classes that support the provided data type combination
|
||||
|
||||
:param element_a: data type of operand A
|
||||
:type element_a: cutlass_library.DataType
|
||||
:param element_b: data type of operand B
|
||||
:type element_b: cutlass_library.DataType
|
||||
:param element_accumulator: data type of accumulator
|
||||
:type element_accumulator: cutlass_library.DataType
|
||||
:param layout_a: layout of operand A
|
||||
:type layout_a: cutlass_library.LayoutType
|
||||
:param layout_b: layout of operand B
|
||||
:type layout_b: cutlass_library.LayoutType
|
||||
:param math_operation: math operation to consider
|
||||
:type math_operation: cutlass_cppgen.MathOperation
|
||||
|
||||
:return: set of operation classes that support the provided data type combination
|
||||
:rtype: set
|
||||
"""
|
||||
supporting_op_classes = set()
|
||||
datatype_comb = (element_a, element_b, element_accumulator)
|
||||
layout_comb = (layout_a, layout_b)
|
||||
|
||||
for op_class in self.operations_by_opclass.keys():
|
||||
if self.opclass_supports_combination(op_class, datatype_comb, layout_comb, math_operation):
|
||||
supporting_op_classes.add(op_class)
|
||||
return supporting_op_classes
|
||||
|
||||
def operations(
|
||||
self,
|
||||
op_class: cutlass_library.OpcodeClass,
|
||||
element_a: cutlass_library.DataType,
|
||||
element_b: cutlass_library.DataType,
|
||||
element_accumulator: cutlass_library.DataType,
|
||||
layout_a: cutlass_library.LayoutType,
|
||||
layout_b: cutlass_library.LayoutType,
|
||||
math_operation: cutlass_library.MathOperation,
|
||||
) -> KernelsForDataType:
|
||||
"""
|
||||
Returns whether the provided operation class supports the provided data type combination
|
||||
|
||||
:param op_class: operation class to consider
|
||||
:type op_class: cutlass_library.OpcodeClass
|
||||
:param element_a: data type of operand A
|
||||
:type element_a: cutlass_library.DataType
|
||||
:param element_b: data type of operand B
|
||||
:type element_b: cutlass_library.DataType
|
||||
:param element_accumulator: data type of accumulator
|
||||
:type element_accumulator: cutlass_library.DataType
|
||||
:param layout_a: layout of operand A
|
||||
:type layout_a: cutlass_library.LayoutType
|
||||
:param layout_b: layout of operand B
|
||||
:type layout_b: cutlass_library.LayoutType
|
||||
:param math_operation: math operation to consider
|
||||
:type math_operation: cutlass_cppgen.MathOperation
|
||||
|
||||
:return: container of kernels by alignment supported by the provided combination of parameters
|
||||
:rtype: KernelsForDataType
|
||||
"""
|
||||
datatype_comb = (element_a, element_b, element_accumulator)
|
||||
layout_comb = (layout_a, layout_b)
|
||||
if not self.opclass_supports_combination(op_class, datatype_comb, layout_comb, math_operation):
|
||||
raise Exception(
|
||||
f"Data type layout combination {datatype_comb}, {layout_comb} "
|
||||
f"is not supported by opcode class {op_class} on CC {self.cc}."
|
||||
)
|
||||
return self.operations_by_opclass[op_class][(datatype_comb, layout_comb)]
|
||||
|
||||
|
||||
class OptionRegistry:
|
||||
"""
|
||||
Container of all architecture-specific options
|
||||
|
||||
:param target_cc: compute capability of the device on which operations will be run
|
||||
:type target_cc: int
|
||||
"""
|
||||
|
||||
def __init__(self, target_cc: int):
|
||||
self.registry = {}
|
||||
|
||||
if target_cc > 100 and (target_cc not in [101, 103, 120, 121]):
|
||||
raise Exception(f"Unsupported compute capability {target_cc}. The CUTLASS Python interface only supports compute capabilities up to the Blackwell architecture.")
|
||||
|
||||
gemm_kinds = [cutlass_library.GemmKind.Universal, cutlass_library.GemmKind.Universal3x]
|
||||
operation_kinds = [cutlass_library.OperationKind.Gemm, cutlass_library.OperationKind.Conv2d]
|
||||
# Construct options for each CC
|
||||
for kernel_cc in _generator_ccs:
|
||||
self.registry[kernel_cc] = {}
|
||||
for opkind in operation_kinds:
|
||||
self.registry[kernel_cc][opkind] = ArchOptions(target_cc, kernel_cc, opkind, gemm_kinds)
|
||||
|
||||
def options_for_cc(self, cc: int, op_kind=cutlass_library.OperationKind.Gemm) -> ArchOptions:
|
||||
return self.registry.get(cc, None)[op_kind]
|
||||
36
python/cutlass_cppgen/op/__init__.py
Normal file
36
python/cutlass_cppgen/op/__init__.py
Normal file
@ -0,0 +1,36 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass_cppgen.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad
|
||||
from cutlass_cppgen.op.gemm import Gemm
|
||||
from cutlass_cppgen.op.gemm_grouped import GroupedGemm
|
||||
from cutlass_cppgen.op.op import OperationBase
|
||||
997
python/cutlass_cppgen/op/conv.py
Normal file
997
python/cutlass_cppgen/op/conv.py
Normal file
@ -0,0 +1,997 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Ease-of-use interface for constructing, compiling, and running CONVs
|
||||
|
||||
The ``Conv2d`` interface is meant to allow one to easily instantiate, compile, and run
|
||||
CONV2D operations in CUTLASS via Python, without specifying many configuration parameters.
|
||||
Under the hood, the interface will select sensible default parameters for the many template
|
||||
parameters for CUTLASS CONVs.
|
||||
|
||||
Note: optimal performance is not to be expected from this interface. To achieve optimal
|
||||
performance, one should specify and tune each configuration parameter.
|
||||
|
||||
The simplest example of using this interface is the following:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
# A, B, C, and D are torch/numpy/cupy tensor objects
|
||||
plan = cutlass_cppgen.op.Conv(A, B, C, D)
|
||||
plan.run(stride=(1, 1), padding=(0, 0), dilation=(1, 1))
|
||||
|
||||
One can also use the interface by specifying data types of operands at construction
|
||||
and using different tensor objects with these data types at runtime:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
# The following is shorthand for:
|
||||
# cutlass_cppgen.op.Conv2d(kind="fprop",
|
||||
# element_A=torch.float32, element_B=torch.float32,
|
||||
# element_C=torch.float32, element_D=torch.float32,
|
||||
# element_accumulator=torch.float32)
|
||||
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=torch.float32)
|
||||
|
||||
A0 = torch.rand((128, 256), dtype=torch.float32, device='cuda')
|
||||
B0 = torch.rand((256, 64), dtype=torch.float32, device='cuda')
|
||||
C0 = torch.zeros((128, 64), dtype=torch.float32, device='cuda')
|
||||
D0 = torch.zeros((128, 64), dtype=torch.float32, device.'cuda')
|
||||
plan.run(A0, B0, C0, D0, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
|
||||
|
||||
A = torch.rand((32, 128), dtype=torch.float32, device='cuda')
|
||||
B = torch.rand((128, 256), dtype=torch.float32, device='cuda')
|
||||
C = torch.zeros((32, 256), dtype=torch.float32, device='cuda')
|
||||
D = torch.zeros((32, 256), dtype=torch.float32, device.'cuda')
|
||||
plan.run(A1, B1, C1, D1, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
|
||||
|
||||
The interface additionally enables one to decouple the compilation of the underlying CUTLASS
|
||||
kernel from its execution:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32)
|
||||
|
||||
# Do other work...
|
||||
|
||||
plan.run(A0, B0, C0, D0, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
|
||||
|
||||
# Do other work...
|
||||
|
||||
plan.run(A1, B1, C1, D1, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
|
||||
|
||||
Elementwise activation functions are easily fused to the GEMM via the interface:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32)
|
||||
plan.activation = cutlass_cppgen.epilogue.relu
|
||||
|
||||
Operations can also be run asynchronously:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32)
|
||||
args = plan.run()
|
||||
|
||||
# Do other work...
|
||||
|
||||
args.sync()
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Optional
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
cudart = lazy_import("cuda.cudart")
|
||||
from cutlass_library import (
|
||||
ConvKind,
|
||||
ConvMode,
|
||||
DataTypeSize,
|
||||
IteratorAlgorithm,
|
||||
OperationKind,
|
||||
SplitKMode,
|
||||
StrideSupport,
|
||||
)
|
||||
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen import epilogue
|
||||
from cutlass_cppgen.backend import compiler
|
||||
from cutlass_cppgen.backend.conv2d_operation import Conv2dArguments, Conv2dOperation
|
||||
from cutlass_cppgen.backend.reduction_operation import ReductionOperation, ReductionArguments
|
||||
from cutlass_cppgen.backend.library import TensorDescription, TileDescription
|
||||
from cutlass_cppgen.op.op import OperationBase
|
||||
from cutlass_cppgen.shape import Conv2DProblemSize, MatrixCoord
|
||||
from cutlass_cppgen.utils import check, datatypes
|
||||
|
||||
|
||||
class Conv2d(OperationBase):
|
||||
"""
|
||||
Constructs a ``Conv2d`` object.
|
||||
|
||||
The convolution kind (fprop, wgrad, degrad), the data types of operands A, B, and C,
|
||||
along with the data type of output D and that used for accumulation, are bound to the ``Conv``
|
||||
object throughout its lifetime -- these are not to be changed after a ``Conv2d`` has been constructed.
|
||||
|
||||
The constructor has optional parameters for flexibly setting these parameters. The following
|
||||
constructors are equivalent:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
# Use F32 for A, B, C, D, and accumulation in fprop
|
||||
|
||||
# Use the generic ``element`` parameter to concisely set all data types for operands to the same values.
|
||||
Conv2d(kind="fprop", element=cutlass_cppgen.DataType.f32)
|
||||
|
||||
# Explicitly specify the data types to use for A, B, C, and D.
|
||||
Conv2d(kind="fprop", element_A=cutlass_cppgen.DataType.f32, element_B=cutlass_cppgen.DataType.f32,
|
||||
element_C=cutlass_cppgen.DataType.f32, element_D=cutlass_cppgen.DataType.f32)
|
||||
|
||||
# Set the data types and elements from existing tensors. Note that one can use different tensors when
|
||||
# executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must
|
||||
# have the same data type as those passed in here).
|
||||
# A, B, C, and D are torch.Tensor objects of type torch.float32 under the channel-last layout
|
||||
Conv2d(kind="fprop", A=A, B=B, C=C, D=D)
|
||||
|
||||
# Explicitly specify the data type for only some of A, B, C, and D. Unspecified data types will inherit
|
||||
# those passed in via the generic ``element``
|
||||
Conv2d(kind="fprop", element_A=cutlass_cppgen.DataType.f32, element_accumulator=cutlass_cppgen.DataType.f32,
|
||||
element=cutlass_cppgen.DataType.f32)
|
||||
|
||||
The order of precedence for the setting of the data type for a given operand/output is as follows:
|
||||
1) If the tensor type is specified (e.g., ``A``), use the data type inferred from this tensor
|
||||
2) Otherwise, if the data type (e.g., ``element_A``) is specified, use those
|
||||
3) Otherwise, use the generic values (e.g., ``element``)
|
||||
|
||||
:param kind: the convolution kind (i.e. fprop, wgrad, and dgrad)
|
||||
:type kind: str
|
||||
:param A: tensor representing data type of operand A
|
||||
:param B: tensor representing data type of operand B
|
||||
:param C: tensor representing data type of operand C
|
||||
:param D: tensor representing data type of operand D
|
||||
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
||||
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
||||
:param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
|
||||
:type element: cutlass_cppgen.DataType
|
||||
:param element_A: data type to be used for operand A
|
||||
:type element_A: cutlass_cppgen.DataType
|
||||
:param element_B: data type to be used for operand B
|
||||
:type element_B: cutlass_cppgen.DataType
|
||||
:param element_C: data type to be used for operand C
|
||||
:type element_C: cutlass_cppgen.DataType
|
||||
:param element_D: data type to be used for operand D
|
||||
:type element_D: cutlass_cppgen.DataType
|
||||
:param element_accumulator: data type to be used in accumulation of the product of operands A and B
|
||||
:type element_accumulator: cutlass_cppgen.DataType
|
||||
:param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
|
||||
:type cc: int
|
||||
:param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
|
||||
:type kernel_cc: int
|
||||
"""
|
||||
def __init__(
|
||||
self, kind="fprop",
|
||||
A=None, B=None, C=None, D=None, alpha=1.0, beta=0.0,
|
||||
element=None,
|
||||
element_A=None, element_B=None, element_C=None, element_D=None,
|
||||
element_accumulator=None,
|
||||
cc: int = None, kernel_cc: int = None
|
||||
):
|
||||
super().__init__(cc=cc, kernel_cc=kernel_cc, operation_kind=OperationKind.Conv2d)
|
||||
# Verify the kernel cc
|
||||
if self.current_cc in [90, 100, 101, 103]:
|
||||
# The Conv2d kernel on Hopper (SM90) is currently unsupported
|
||||
# Revert to use SM80-tagged kernels
|
||||
cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
|
||||
self.specified_kernel_cc = 80
|
||||
self._reset_options(80)
|
||||
|
||||
# The arch is used in testing
|
||||
self.arch = self.current_cc
|
||||
self.name = "conv2d" + kind
|
||||
|
||||
# The convolution kind. (concept: cutlass_library.library.ConvKind)
|
||||
self.conv_kind = datatypes.getattr_enum(ConvKind, kind)
|
||||
|
||||
# The element types (concept: cutlass library types) of A, B, C, and D
|
||||
elements = []
|
||||
layouts = []
|
||||
|
||||
# Complete the data types based on user-provided arguments
|
||||
for elt, tens, name in zip([element_A, element_B, element_C, element_D],
|
||||
[A, B, C, D],
|
||||
["A", "B", "C", "D"]):
|
||||
if elt is not None and tens is not None:
|
||||
raise Exception(f'Must not specify both element_{name} and tensor {name}')
|
||||
if elt is None and tens is None and element is None:
|
||||
raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.')
|
||||
|
||||
elt_to_set = None
|
||||
lay_to_set = None
|
||||
|
||||
if tens is not None:
|
||||
elt_to_set, _ = datatypes.get_datatype_and_layout(tens)
|
||||
else:
|
||||
elt_to_set = elt if elt is not None else element
|
||||
|
||||
assert elt_to_set is not None
|
||||
|
||||
# Currently we only support layout TensorNHWC
|
||||
lay_to_set = cutlass_cppgen.LayoutType.TensorNHWC
|
||||
elements.append(datatypes.library_type(elt_to_set))
|
||||
layouts.append(lay_to_set)
|
||||
|
||||
self._element_a, self._element_b, self._element_c, self._element_d = elements
|
||||
self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts
|
||||
|
||||
self.A, self.B, self.C, self.D, self.alpha, self.beta = A, B, C, D, alpha, beta
|
||||
|
||||
if element_accumulator is None:
|
||||
self._element_accumulator = self._element_c
|
||||
else:
|
||||
self._element_accumulator = datatypes.library_type(element_accumulator)
|
||||
|
||||
# Default inputs if none is supplied in run()
|
||||
self.A = A
|
||||
self.B = B
|
||||
self.C = C
|
||||
self.D = D
|
||||
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
|
||||
# We only specify the stride of the swizzling functor here
|
||||
# The actual swizzling functor is determined in run based on conv_kind and stride
|
||||
self._swizzling_stride = 1
|
||||
|
||||
# Arguments that will be set to default value in _reset_operations
|
||||
# The default tile_description and op_class are fetched from manifest of cutlass library
|
||||
self._tile_description = None
|
||||
self.op_class = None
|
||||
# The default identity epilogue will be created
|
||||
self.epilogue_functor = None
|
||||
|
||||
self._reset_operations()
|
||||
|
||||
# Arguments that will be determined online based on arguments of "run"
|
||||
# based on stride, input/output channels, alignment, and conv_kind
|
||||
self._iterator_algorithm = None
|
||||
self._stride_support = None
|
||||
|
||||
def _reset_operations(self, reset_epilogue: bool = True):
|
||||
# Set the default op class
|
||||
datatype_comb = (self._element_a, self._element_b, self._element_accumulator)
|
||||
layout_comb = (self._layout_a, self._layout_b)
|
||||
|
||||
self.possible_op_classes = self.options.supporting_opclasses(
|
||||
self._element_a, self._element_b, self._element_accumulator,
|
||||
self._layout_a, self._layout_b, self._math_operation
|
||||
)
|
||||
|
||||
if cutlass_cppgen.OpcodeClass.TensorOp in self.possible_op_classes:
|
||||
self.opclass = cutlass_cppgen.OpcodeClass.TensorOp
|
||||
elif cutlass_cppgen.OpcodeClass.Simt in self.possible_op_classes:
|
||||
self.opclass = cutlass_cppgen.OpcodeClass.Simt
|
||||
else:
|
||||
if self._math_operation is not None:
|
||||
math_op_str = f' and math operation {self._math_operation}'
|
||||
else:
|
||||
math_op_str = ''
|
||||
|
||||
raise Exception(f'No kernel configuration found for supported data type and layout '
|
||||
f'combination {datatype_comb}x{layout_comb}{math_op_str}')
|
||||
|
||||
if reset_epilogue:
|
||||
self._reset_epilogue_functor_activation(epilogue.identity)
|
||||
|
||||
self.alignment_pref_A = min(
|
||||
128 // DataTypeSize[self._element_a], max(self.possible_operations.alignments("A")))
|
||||
self.alignment_pref_B = min(
|
||||
128 // DataTypeSize[self._element_b], max(self.possible_operations.alignments("B")))
|
||||
self.alignment_pref_C = min(
|
||||
128 // DataTypeSize[self._element_c], max(self.possible_operations.alignments("C")))
|
||||
|
||||
#
|
||||
# Tile description Related
|
||||
#
|
||||
|
||||
@property
|
||||
def tile_description(self) -> TileDescription:
|
||||
"""
|
||||
Returns the tile description
|
||||
"""
|
||||
return self._tile_description
|
||||
|
||||
@tile_description.setter
|
||||
def tile_description(
|
||||
self, td=None):
|
||||
"""
|
||||
Set the tile description
|
||||
|
||||
:param td: tile description
|
||||
:type td: cutlass_cppgen.backend.TileDescription, or a dict with keys
|
||||
{
|
||||
"threadblock_shape": [int, int, int],
|
||||
"warp_count": [int, int, int],
|
||||
"stages": int,
|
||||
"instruction_shape": [int, int, int] (optional),
|
||||
"cluster_shape": [int, int, int] (optional)
|
||||
}
|
||||
"""
|
||||
if td is None:
|
||||
return
|
||||
if isinstance(td, dict):
|
||||
if self._tile_description is None:
|
||||
op = self.possible_operations.default_operation(self._math_operation)
|
||||
self._tile_description = datatypes.td_from_profiler_op(op)
|
||||
if "cluster_shape" in td.keys():
|
||||
if td["cluster_shape"] != [1, 1, 1]:
|
||||
cutlass_cppgen.logger.warning("Conv2d currently only support 'cluster_shape'=[1, 1, 1]'.")
|
||||
td["cluster_shape"] = [1, 1, 1]
|
||||
td = self._tile_description.clone_and_update(td)
|
||||
|
||||
valid, msg = self._valid_tile_description(td)
|
||||
if valid:
|
||||
self._tile_description = td
|
||||
else:
|
||||
raise Exception(msg)
|
||||
|
||||
def _valid_tile_description(self, td: TileDescription) -> tuple:
|
||||
"""
|
||||
Checks whether the provided tile description is valid for the given compute capability. At present,
|
||||
this checks the following:
|
||||
|
||||
- Does the tile description use a number of stages supported by the compute capability in question?
|
||||
- Does the tile size requested fit within shared memory?
|
||||
- Are cluster dimensions outside the valid range requested for a given architecture (e.g.,
|
||||
more non-unit cluster dimensions for pre-SM90 architectures)?
|
||||
- Is the kernel schedule being used supported on the architecture in question?
|
||||
|
||||
:param td: tile description to validate
|
||||
:type td: cutlass_cppgen.backend.TileDescription
|
||||
:return: tuple in which the first element is a bool indicating that the tile description is valid
|
||||
and the second element is a string providing an optional error message.
|
||||
:rtype: tuple
|
||||
"""
|
||||
valid, msg = check.valid_stage_count(self.cc, self.current_cc, td)
|
||||
if not valid:
|
||||
return (valid, msg)
|
||||
|
||||
valid, msg = check.valid_cluster_shape(self.current_cc, td.cluster_shape)
|
||||
if not valid:
|
||||
return (valid, msg)
|
||||
|
||||
return valid, msg
|
||||
|
||||
def tile_descriptions(self) -> list:
|
||||
"""
|
||||
Returns a list of valid tile descriptions for the operations
|
||||
|
||||
:returns: list of valid tile descriptions for the operations
|
||||
:rtype: list
|
||||
"""
|
||||
descriptions = []
|
||||
description_str = []
|
||||
for op in self.possible_operations.all_operations:
|
||||
td = datatypes.td_from_profiler_op(op)
|
||||
|
||||
if self._math_operation is not None:
|
||||
if td.math_instruction.math_operation != self._math_operation:
|
||||
continue
|
||||
|
||||
if str(td) not in description_str:
|
||||
description_str.append(str(td))
|
||||
descriptions.append(td)
|
||||
return descriptions
|
||||
|
||||
#
|
||||
# Swizzling functor Related
|
||||
#
|
||||
|
||||
@property
|
||||
def swizzling_stride(self):
|
||||
"""
|
||||
Returns the stride of swizzling currently being used by the Conv2d
|
||||
|
||||
:return: swizzing stride
|
||||
"""
|
||||
return self._swizzling_stride
|
||||
|
||||
@swizzling_stride.setter
|
||||
def swizzling_stride(self, stride: int):
|
||||
"""
|
||||
Sets the swizzling functor to the type specified by `swizzling_functor`
|
||||
"""
|
||||
if not isinstance(stride, int):
|
||||
raise Exception(f"Expect integer (1, 2, 4, 8), got {stride}")
|
||||
self._swizzling_stride = stride
|
||||
|
||||
def _propose_swizzling_functor(self, stride):
|
||||
"""
|
||||
Automatically propose the swizzling functor based on the stride
|
||||
"""
|
||||
if self.conv_kind == ConvKind.Dgrad:
|
||||
if stride[0] != 1 or stride[1] != 1:
|
||||
return getattr(cutlass_cppgen.swizzle, f"StridedDgradIdentitySwizzle{self._swizzling_stride}")
|
||||
|
||||
return getattr(cutlass_cppgen.swizzle, f"IdentitySwizzle{self._swizzling_stride}")
|
||||
|
||||
#
|
||||
# Iterator Algorithm Related
|
||||
#
|
||||
|
||||
@property
|
||||
def iterator_algorithm(self) -> IteratorAlgorithm:
|
||||
"""
|
||||
Returns the iterator algorithm
|
||||
"""
|
||||
return self._iterator_algorithm
|
||||
|
||||
@iterator_algorithm.setter
|
||||
def iterator_algorithm(self, alg: str):
|
||||
"""
|
||||
Sets the iterator algorithm
|
||||
|
||||
:param alg: The iterator algorithm
|
||||
:type td: string, options: "analytic", "optimized", "few_channels", and "fixed_channels"
|
||||
"""
|
||||
iterator_alg = datatypes.getattr_enum(IteratorAlgorithm, alg)
|
||||
|
||||
# Check if the iterator algorithm is valid
|
||||
if iterator_alg in [IteratorAlgorithm.FewChannels, IteratorAlgorithm.FixedChannels] and self.conv_kind != ConvKind.Fprop:
|
||||
raise Exception(f"{self.conv_kind} does not support iterator algorithm {alg}.")
|
||||
|
||||
self._iterator_algorithm = iterator_alg
|
||||
|
||||
def _propose_iterator_algorithm(self, problem_size, alignment_a, alignment_b) -> IteratorAlgorithm:
|
||||
"""
|
||||
Propose a valid iterator algorithm based on problem size and alignment
|
||||
"""
|
||||
if self.conv_kind == ConvKind.Fprop:
|
||||
# Check whether the fixed channel is applicable
|
||||
if problem_size.C == alignment_a:
|
||||
return IteratorAlgorithm.FixedChannels
|
||||
elif (problem_size.C % alignment_a == 0 and
|
||||
problem_size.R <= 32 and problem_size.S <= 32):
|
||||
return IteratorAlgorithm.Optimized
|
||||
else:
|
||||
return IteratorAlgorithm.Analytic
|
||||
elif self.conv_kind == ConvKind.Dgrad:
|
||||
if (problem_size.K % alignment_a == 0 and
|
||||
problem_size.R <= 32 and problem_size.S <= 32 and
|
||||
problem_size.C % alignment_b == 0):
|
||||
return IteratorAlgorithm.Optimized
|
||||
else:
|
||||
return IteratorAlgorithm.Analytic
|
||||
elif self.conv_kind == ConvKind.Wgrad:
|
||||
if (problem_size.K % alignment_a == 0 and
|
||||
problem_size.C % alignment_b == 0):
|
||||
return IteratorAlgorithm.Optimized
|
||||
else:
|
||||
return IteratorAlgorithm.Analytic
|
||||
|
||||
def _validate_iterator_algorithm(self, iterator_algorithm, problem_size, alignment_a, alignment_b) -> bool:
|
||||
"""
|
||||
Validate whether the user provide iterator algorithm works for the given problem size
|
||||
"""
|
||||
if self.conv_kind == ConvKind.Fprop:
|
||||
if iterator_algorithm == IteratorAlgorithm.FixedChannels:
|
||||
return problem_size.C == alignment_a
|
||||
elif iterator_algorithm == IteratorAlgorithm.Optimized:
|
||||
return (problem_size.C % alignment_a == 0 and
|
||||
problem_size.R <= 32 and problem_size.S <= 32)
|
||||
elif iterator_algorithm == IteratorAlgorithm.FewChannels:
|
||||
return problem_size.C % alignment_a == 0
|
||||
elif self.conv_kind == ConvKind.Dgrad:
|
||||
if iterator_algorithm == IteratorAlgorithm.Optimized:
|
||||
return (problem_size.K % alignment_a == 0 and
|
||||
problem_size.R <= 32 and problem_size.S <= 32 and
|
||||
problem_size.C % alignment_b == 0)
|
||||
elif self.conv_kind == ConvKind.Wgrad:
|
||||
if iterator_algorithm == IteratorAlgorithm.Optimized:
|
||||
return (problem_size.K % alignment_a == 0 and
|
||||
problem_size.C % alignment_b == 0)
|
||||
|
||||
return True
|
||||
|
||||
#
|
||||
# Stride Support Related
|
||||
#
|
||||
|
||||
def _propose_stride_support(self, stride):
|
||||
if self.conv_kind == ConvKind.Dgrad:
|
||||
if stride[0] == 1 and stride[1] == 1:
|
||||
return StrideSupport.Unity
|
||||
|
||||
return StrideSupport.Strided
|
||||
|
||||
#
|
||||
# Construct and Compilation
|
||||
#
|
||||
|
||||
def construct(
|
||||
self, tile_description: TileDescription = None,
|
||||
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
|
||||
iterator_algorithm: IteratorAlgorithm = None,
|
||||
stride_support = None, swizzling_functor: cutlass_cppgen.swizzle = None,
|
||||
epilogue_functor=None) -> cutlass_cppgen.backend.Conv2dOperation:
|
||||
"""
|
||||
Constructs a ``cutlass_cppgen.backend.Conv2dOperation`` based on the input parameters and current
|
||||
kernel specification of the ``Conv2d`` object.
|
||||
|
||||
:param tile_description: tile description specifying shapes and operand types to use in the kernel
|
||||
:type tile_description: cutlass_cppgen.backend.TileDescription
|
||||
:param alignment_A: alignment of operand A
|
||||
:type alignment_A: int
|
||||
:param alignment_B: alignment of operand B
|
||||
:type alignment_B: int
|
||||
:param alignment_C: alignment of operand C
|
||||
:type alignment_C: int
|
||||
:param iterator_algorithm: the iterator algorithm used
|
||||
:type iterator_algorithm: cutlass_library.library.IteratorAlgorithm
|
||||
:param stride_support: the stride support of dgrad
|
||||
:type stride_support: cutlass_library.library.StrideSupport
|
||||
:param swizzling_functor: the swizzling functor
|
||||
:type swizzling_functor: cutlass_cppgen.swizzle
|
||||
:param epilogue_functor: the epilogue functor
|
||||
|
||||
:return: operation that was constructed
|
||||
:rtype: cutlass_cppgen.backend.Conv2dOperation
|
||||
"""
|
||||
# Get alignment
|
||||
alignment_A = check.alignment_or_default(alignment_A, self.alignment_pref_A)
|
||||
alignment_B = check.alignment_or_default(alignment_B, self.alignment_pref_B)
|
||||
alignment_C = check.alignment_or_default(alignment_C, self.alignment_pref_C)
|
||||
|
||||
tensor_A = TensorDescription(self._element_a, self._layout_b, alignment_A)
|
||||
tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
|
||||
tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
|
||||
|
||||
if tile_description is None:
|
||||
if self.tile_description is not None:
|
||||
tile_description = self.tile_description
|
||||
else:
|
||||
op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0]
|
||||
tile_description = datatypes.td_from_profiler_op(op)
|
||||
else:
|
||||
valid, err_str = self._valid_tile_description(tile_description)
|
||||
if not valid:
|
||||
raise Exception(f"Invalid tile description. {err_str}")
|
||||
self.tile_description = tile_description
|
||||
|
||||
if iterator_algorithm is None:
|
||||
# If the iterator algorithm is already set
|
||||
if self.iterator_algorithm is not None:
|
||||
iterator_algorithm = self.iterator_algorithm
|
||||
else:
|
||||
# Otherwise, we conservatively use the analytic iterator for correctness
|
||||
iterator_algorithm = IteratorAlgorithm.Analytic
|
||||
|
||||
if stride_support is None:
|
||||
# If the stride support is already set
|
||||
if self._stride_support is not None:
|
||||
stride_support = self._stride_support
|
||||
else:
|
||||
# Otherwise, we assume strided
|
||||
stride_support = StrideSupport.Strided
|
||||
|
||||
if swizzling_functor is None:
|
||||
# If the swizzling functor is already set
|
||||
swizzling_functor = self._propose_swizzling_functor(stride=(2, 2))
|
||||
|
||||
if epilogue_functor is None:
|
||||
if self.epilogue_functor is not None:
|
||||
epilogue_functor = self.epilogue_functor
|
||||
else:
|
||||
epilogue_functor = self._create_epilogue_functor_activation(self._activation)
|
||||
|
||||
# Reset the alignment of the epilogue functor
|
||||
epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, epilogue_functor)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=self.conv_kind,
|
||||
iterator_algorithm=iterator_algorithm,
|
||||
arch=self.current_cc,
|
||||
tile_description=tile_description,
|
||||
A=tensor_A, B=tensor_B, C=tensor_C,
|
||||
stride_support=stride_support,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=swizzling_functor,
|
||||
)
|
||||
|
||||
return operation
|
||||
|
||||
def compile(self, tile_description: TileDescription = None,
|
||||
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
|
||||
iterator_algorithm: IteratorAlgorithm = None,
|
||||
stride_support = None, swizzling_functor: cutlass_cppgen.swizzle = None,
|
||||
epilogue_functor = None, print_module: bool = False) -> cutlass_cppgen.backend.Conv2dOperation:
|
||||
"""
|
||||
Emits and compiles the kernel currently specified. If ``tile_description`` and any
|
||||
of the ``alignment`` parameters are set, the kernel will be chosen using this
|
||||
tile description and alignments. Otherwise, a default tile description and alignment
|
||||
will be used.
|
||||
|
||||
::param tile_description: tile description specifying shapes and operand types to use in the kernel
|
||||
:type tile_description: cutlass_cppgen.backend.TileDescription
|
||||
:param alignment_A: alignment of operand A
|
||||
:type alignment_A: int
|
||||
:param alignment_B: alignment of operand B
|
||||
:type alignment_B: int
|
||||
:param alignment_C: alignment of operand C
|
||||
:type alignment_C: int
|
||||
:param iterator_algorithm: the iterator algorithm used
|
||||
:type iterator_algorithm: cutlass_library.library.IteratorAlgorithm
|
||||
:param stride_support: the stride support of dgrad
|
||||
:type stride_support: cutlass_library.library.StrideSupport
|
||||
:param swizzling_functor: the swizzling functor
|
||||
:type swizzling_functor: cutlass_cppgen.swizzle
|
||||
:param epilogue_functor: the epilogue functor
|
||||
|
||||
:return: operation that was compiled
|
||||
:rtype: cutlass_cppgen.backend.Conv2dOperation
|
||||
"""
|
||||
|
||||
self.operation = self.construct(
|
||||
tile_description, alignment_A, alignment_B, alignment_C,
|
||||
iterator_algorithm, stride_support, swizzling_functor, epilogue_functor)
|
||||
|
||||
if print_module:
|
||||
print(self.operation.rt_module.emit())
|
||||
|
||||
compiler.add_module([self.operation,])
|
||||
return self.operation
|
||||
|
||||
#
|
||||
# Run Related
|
||||
#
|
||||
|
||||
def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name):
|
||||
"""
|
||||
Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception
|
||||
is raised if it does not.
|
||||
|
||||
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
|
||||
:type tensor: numpy/cupy/torch array/tensor object
|
||||
:param ref_dtype: data type for the tensor that this object was initialized to
|
||||
:param name: identifier of the tensor to verify. Used in raising exceptions
|
||||
:type name: str
|
||||
"""
|
||||
dtype, _ = datatypes.get_datatype_and_layout(tensor)
|
||||
if dtype != ref_type:
|
||||
raise Exception(f'Tensor {name} with type and layout {dtype} '
|
||||
f'does not match the expected type of {ref_type}.')
|
||||
|
||||
def _get_and_verify_conv_problem_size(self, A, B, C, stride, padding, dilation):
|
||||
if self.conv_kind == ConvKind.Fprop:
|
||||
input = A
|
||||
weight = B
|
||||
output = C
|
||||
output_tensor = "C"
|
||||
elif self.conv_kind == ConvKind.Dgrad:
|
||||
output = A
|
||||
weight = B
|
||||
input = C
|
||||
output_tensor = "A"
|
||||
elif self.conv_kind == ConvKind.Wgrad:
|
||||
output = A
|
||||
input = B
|
||||
weight = C
|
||||
output_tensor = "A"
|
||||
else:
|
||||
raise Exception(f"Convolution kind {self.conv_kind} is not supported")
|
||||
|
||||
N_, H_, W_, C_ = datatypes.get_tensor_shape(input, op="CONV")
|
||||
K_, R_, S_, _ = datatypes.get_tensor_shape(weight, op="CONV")
|
||||
_, P_, Q_, _ = datatypes.get_tensor_shape(output, op="CONV")
|
||||
|
||||
problem_size = Conv2DProblemSize(
|
||||
N_, H_, W_, C_,
|
||||
K_, R_, S_, C_,
|
||||
padding[0], padding[1],
|
||||
stride[0], stride[1],
|
||||
dilation[0], dilation[1],
|
||||
ConvMode.CrossCorrelation,
|
||||
1, 1
|
||||
)
|
||||
|
||||
if P_ != problem_size.P or Q_ != problem_size.Q:
|
||||
raise Exception(
|
||||
f"Tensor {output_tensor} size should be ({N_}, {problem_size.P}, {problem_size.Q}, {K_}), got ({N_}, {P_}, {Q_}, {K_})")
|
||||
|
||||
return problem_size
|
||||
|
||||
def run(self, A=None, B=None, C=None, D=None,
|
||||
stride=(1, 1), padding=(0, 0), dilation=(1, 1),
|
||||
alpha=None, beta=None,
|
||||
split_k=("serial", 1), sync: bool = True,
|
||||
print_module: bool = False,
|
||||
stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
|
||||
"""
|
||||
Runs the kernel currently specified. If it has not already been, the kernel is emitted and
|
||||
compiled. Tensors holding operands and outputs of the kernel are sourced either from the
|
||||
``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta``
|
||||
parameters provided in the call, or from those
|
||||
passed in on the construction of this object -- one of the two must be specified.
|
||||
|
||||
By default, this call returns only once the kernel has completed. To launch the kernel
|
||||
and immediately return, set ``sync=False``. In this case, it is the responsibility of the
|
||||
caller to syncrhonize the results of the kernel before attempting to access outputs
|
||||
by calling ``sync()`` on the arguments returned from this call.
|
||||
|
||||
:param A: tensor representing data type and layout of operand A
|
||||
:param B: tensor representing data type and layout of operand B
|
||||
:param C: tensor representing data type and layout of operand C
|
||||
:param D: tensor representing data type and layout of operand D
|
||||
:param stride: (stride_h, stride_w) describing the convolution stride. Default: (1, 1)
|
||||
:param padding: (pad_h, pad_w) describing the convolution padding. Default: (0, 0)
|
||||
:param dilation: (dilation_h, dilation_w) describing the dilation of convolution. Default: (1, 1)
|
||||
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
||||
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
||||
:param split_k: a tuple (split_k_mode, split_k_slices)
|
||||
:param sync: whether the call should wait for the kernel to complete before returning
|
||||
:type sync: bool
|
||||
:param print_module: whether to print the emitted C++ code
|
||||
:type print_module: bool
|
||||
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
||||
:type stream: :class:`cuda.cuda.CUstream`
|
||||
|
||||
:return: arguments passed in to the kernel
|
||||
:rtype: cutlass_cppgen.backend.Conv2dArguments
|
||||
"""
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
super().run_setup()
|
||||
|
||||
A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A")
|
||||
B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B")
|
||||
C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C")
|
||||
D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D")
|
||||
alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
|
||||
beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
|
||||
|
||||
# handle the case when there is no C
|
||||
if C is None:
|
||||
if beta != 0:
|
||||
raise Exception(f"With beta {beta} != 0, C has to be provided.")
|
||||
else:
|
||||
C = D
|
||||
|
||||
# Construct problem size based on input
|
||||
# It also verifies whether the A, B, C, D, stride, padding, and dilation are matching
|
||||
problem_size = self._get_and_verify_conv_problem_size(A, B, C, stride, padding, dilation)
|
||||
|
||||
# Propose stride support based on input
|
||||
stride_support = self._propose_stride_support(stride)
|
||||
|
||||
# Propose swizzling functor
|
||||
swizzling_functor = self._propose_swizzling_functor(stride)
|
||||
|
||||
shape_a = datatypes.get_tensor_shape(A, op="CONV")
|
||||
shape_b = datatypes.get_tensor_shape(B, op="CONV")
|
||||
shape_c = datatypes.get_tensor_shape(C, op="CONV")
|
||||
|
||||
# Get the alignment
|
||||
alignment_a = self.possible_operations.find_alignment(shape_a, self._layout_a, operand="A")
|
||||
alignment_b = self.possible_operations.find_alignment(shape_b, self._layout_b, operand="B")
|
||||
alignment_c = self.possible_operations.find_alignment(shape_c, self._layout_c, operand="C")
|
||||
|
||||
alignment_a = check.update_alignment(alignment_a, self.alignment_pref_A)
|
||||
alignment_b = check.update_alignment(alignment_b, self.alignment_pref_B)
|
||||
alignment_c = check.update_alignment(alignment_c, self.alignment_pref_C)
|
||||
|
||||
# Propose iterator algorithm based on input
|
||||
if self._iterator_algorithm is None:
|
||||
# Propose a default iterator algorithm based on the problem size
|
||||
iterator_algorithm = self._propose_iterator_algorithm(problem_size, alignment_a, alignment_b)
|
||||
else:
|
||||
if (self._validate_iterator_algorithm(self._iterator_algorithm, problem_size, alignment_a, alignment_b)):
|
||||
iterator_algorithm = self._iterator_algorithm
|
||||
else:
|
||||
raise Exception(f"Iterator algorithm {self._iterator_algorithm} is invalid for current problem.")
|
||||
|
||||
epilogue_args = [alpha, beta]
|
||||
|
||||
if hasattr(self, "_activation_args"):
|
||||
if isinstance(self._activation_args, list):
|
||||
epilogue_args += self._activation_args
|
||||
else:
|
||||
epilogue_args.append(self._activation_args)
|
||||
|
||||
if split_k[0] == "parallel" and split_k[1] > 1:
|
||||
epilogue_functor = self._create_epilogue_functor_activation(epilogue.identity)
|
||||
else:
|
||||
epilogue_functor = self.epilogue_functor
|
||||
|
||||
# The alignment is determined by the iterator function (I believe)
|
||||
self.compile(tile_description=self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
|
||||
alignment_C=alignment_c, iterator_algorithm=iterator_algorithm, stride_support=stride_support,
|
||||
swizzling_functor=swizzling_functor, epilogue_functor=epilogue_functor, print_module=print_module)
|
||||
|
||||
# Create reduction operation for parallel split-k
|
||||
if split_k[0] == "parallel" and split_k[1] > 1:
|
||||
epilogue_functor_reduction = self._reset_epilogue_functor_alignment(alignment_c, self.epilogue_functor)
|
||||
self.reduction_operation = ReductionOperation(
|
||||
shape=MatrixCoord(4, 32 * alignment_c), C=self.operation.C,
|
||||
element_accumulator=self._element_accumulator,
|
||||
element_compute=self._element_accumulator,
|
||||
epilogue_functor=epilogue_functor_reduction,
|
||||
count=alignment_c
|
||||
)
|
||||
if print_module:
|
||||
print(self.reduction_operation.rt_module.emit())
|
||||
compiler.add_module([self.reduction_operation,])
|
||||
|
||||
arguments = Conv2dArguments(
|
||||
operation=self.operation, problem_size=problem_size,
|
||||
A=A, B=B, C=C, D=D,
|
||||
output_op=self.operation.epilogue_type(*epilogue_args),
|
||||
split_k_mode=datatypes.getattr_enum(SplitKMode, split_k[0]),
|
||||
split_k_slices=split_k[1],
|
||||
stream=stream
|
||||
)
|
||||
|
||||
self.operation.run(arguments)
|
||||
|
||||
if split_k[0] == "parallel" and split_k[1] > 1:
|
||||
implicit_gemm_size = arguments.problem_size.implicit_gemm_size(self.conv_kind)
|
||||
reduction_arguments = ReductionArguments(
|
||||
self.reduction_operation,
|
||||
problem_size=[implicit_gemm_size.m, implicit_gemm_size.n],
|
||||
partitions=split_k[1],
|
||||
workspace=arguments.ptr_D,
|
||||
destination=D,
|
||||
source=C,
|
||||
output_op=self.reduction_operation.epilogue_type(*epilogue_args),
|
||||
stream=stream
|
||||
)
|
||||
self.reduction_operation.run(reduction_arguments)
|
||||
|
||||
if sync:
|
||||
if split_k[0] == "parallel" and split_k[1] > 1:
|
||||
reduction_arguments.sync()
|
||||
|
||||
# Free memory allocated by args because we are not
|
||||
# calling `arguments.sync()` in this case (which will free memory)
|
||||
arguments.free()
|
||||
else:
|
||||
arguments.sync()
|
||||
|
||||
return arguments
|
||||
|
||||
#
|
||||
# Helper functions
|
||||
#
|
||||
@staticmethod
|
||||
def output_size(input_size, weight_size, padding, stride, dilation):
|
||||
problem_size = Conv2DProblemSize(
|
||||
*input_size,
|
||||
*weight_size,
|
||||
padding[0], padding[1],
|
||||
stride[0], stride[1],
|
||||
dilation[0], dilation[1],
|
||||
ConvMode.CrossCorrelation,
|
||||
1, 1
|
||||
)
|
||||
return (problem_size.N, problem_size.P, problem_size.Q, problem_size.K)
|
||||
|
||||
|
||||
#
|
||||
# Easy to use interfaces for fprop, wgrad, and dgrad
|
||||
#
|
||||
|
||||
class Conv2dFprop(Conv2d):
|
||||
def __init__(
|
||||
self,
|
||||
input=None, weight=None, C=None, output=None, alpha=1, beta=0,
|
||||
element=None,
|
||||
element_input=None, element_weight=None, element_C=None, element_output=None,
|
||||
element_accumulator=None,
|
||||
cc: int = None, kernel_cc: int = None):
|
||||
A, B, D = input, weight, output
|
||||
element_A, element_B, element_D = element_input, element_weight, element_output
|
||||
super().__init__(
|
||||
"fprop", A, B, C, D, alpha, beta, element,
|
||||
element_A, element_B, element_C, element_D,
|
||||
element_accumulator, cc, kernel_cc)
|
||||
|
||||
def run(
|
||||
self, input=None, weight=None, C=None, output=None, alpha=None, beta=None,
|
||||
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
||||
sync: bool = True, print_module: bool = False,
|
||||
stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
|
||||
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
|
||||
A, B, D = input, weight, output
|
||||
return super().run(
|
||||
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)
|
||||
|
||||
|
||||
class Conv2dDgrad(Conv2d):
|
||||
def __init__(
|
||||
self,
|
||||
grad_output=None, weight=None, C=None, grad_input=None, alpha=1, beta=0,
|
||||
element=None,
|
||||
element_grad_output=None, element_weight=None, element_C=None, element_grad_input=None,
|
||||
element_accumulator=None,
|
||||
cc: int = None, kernel_cc: int = None):
|
||||
A, B, D = grad_output, weight, grad_input
|
||||
element_A, element_B, element_D = element_grad_output, element_weight, element_grad_input
|
||||
super().__init__(
|
||||
"dgrad", A, B, C, D, alpha, beta, element,
|
||||
element_A, element_B, element_C, element_D,
|
||||
element_accumulator, cc, kernel_cc)
|
||||
|
||||
def run(self, grad_output=None, weight=None, C=None, grad_input=None, alpha=None, beta=None,
|
||||
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
||||
sync: bool = True, print_module: bool = False,
|
||||
stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
|
||||
#
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
|
||||
A, B, D = grad_output, weight, grad_input
|
||||
return super().run(
|
||||
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)
|
||||
|
||||
|
||||
class Conv2dWgrad(Conv2d):
|
||||
def __init__(
|
||||
self,
|
||||
grad_output=None, input=None, C=None, grad_weight=None, alpha=1, beta=0,
|
||||
element=None,
|
||||
element_grad_output=None, element_input=None, element_C=None, element_grad_weight=None,
|
||||
element_accumulator=None,
|
||||
cc: int = None, kernel_cc: int = None):
|
||||
A, B, D = grad_output, input, grad_weight
|
||||
element_A, element_B, element_D = element_grad_output, element_input, element_grad_weight
|
||||
super().__init__(
|
||||
"wgrad", A, B, C, D, alpha, beta, element,
|
||||
element_A, element_B, element_C, element_D,
|
||||
element_accumulator, cc, kernel_cc)
|
||||
|
||||
def run(self, grad_output=None, input=None, C=None, grad_weight=None, alpha=None, beta=None,
|
||||
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
||||
sync: bool = True, print_module: bool = False,
|
||||
stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
|
||||
A, B, D = grad_output, input, grad_weight
|
||||
return super().run(
|
||||
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)
|
||||
725
python/cutlass_cppgen/op/gemm.py
Normal file
725
python/cutlass_cppgen/op/gemm.py
Normal file
@ -0,0 +1,725 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Ease-of-use interface for constructing, compiling, and running GEMMs.
|
||||
|
||||
The ``Gemm`` interface is meant to allow one to easily instantiate, compile, and run
|
||||
GEMM operations in CUTLASS via Python, without specifying many configuration parameters.
|
||||
Under the hood, the interface will select sensible default parameters for the many template
|
||||
parameters for CUTLASS GEMMs.
|
||||
|
||||
Note: optimal performance is not to be expected from this interface. To achieve optimal
|
||||
performance, one should specify and tune each configuration parameter.
|
||||
|
||||
The simplest example of using this interface is the following:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
# A, B, C, and D are torch/numpy/cupy tensor objects
|
||||
plan = cutlass_cppgen.op.Gemm(A, B, C, D)
|
||||
plan.run()
|
||||
|
||||
|
||||
One can also use the interface by specifying data types of operands at construction
|
||||
and using different tensor objects with these data types at runtime:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
# The following is shorthand for:
|
||||
# cutlass_cppgen.op.Gemm(element_A=torch.float32, element_B=torch.float32,
|
||||
# element_C=torch.float32, element_D=torch.float32,
|
||||
# element_accumulator=torch.float32,
|
||||
# layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
|
||||
A0 = torch.rand((128, 256), device='cuda')
|
||||
B0 = torch.rand((256, 64), device='cuda')
|
||||
C0 = torch.zeros((128, 64), device='cuda')
|
||||
D0 = torch.zeros((128, 64), device.'cuda')
|
||||
plan.run(A0, B0, C0, D0)
|
||||
|
||||
A = torch.rand((32, 128), device='cuda')
|
||||
B = torch.rand((128, 256), device='cuda')
|
||||
C = torch.zeros((32, 256), device='cuda')
|
||||
D = torch.zeros((32, 256), device.'cuda')
|
||||
plan.run(A1, B1, C1, D1)
|
||||
|
||||
The interface additionally enables one to decouple the compilation of the underlying CUTLASS
|
||||
kernel from its execution:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
plan.compile()
|
||||
|
||||
# Do other work...
|
||||
|
||||
plan.run(A0, B0, C0, D0)
|
||||
|
||||
# Do other work...
|
||||
|
||||
plan.run(A1, B1, C1, D1)
|
||||
|
||||
Elementwise activation functions are easily fused to the GEMM via the interface:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
plan.activation = cutlass_cppgen.epilogue.relu
|
||||
|
||||
Operations can also be run asynchronously:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
args = plan.run()
|
||||
|
||||
# Do other work...
|
||||
|
||||
args.sync()
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from typing import Optional
|
||||
from math import prod
|
||||
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
from cutlass_library import (
|
||||
DataType,
|
||||
DataTypeSize,
|
||||
GemmUniversalMode,
|
||||
KernelScheduleSuffixes,
|
||||
)
|
||||
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen import epilogue, swizzle
|
||||
from cutlass_cppgen.backend import compiler
|
||||
from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor
|
||||
from cutlass_cppgen.backend.gemm_operation import GemmArguments, GemmOperationUniversal
|
||||
from cutlass_cppgen.backend.library import TensorDescription, TileDescription
|
||||
from cutlass_cppgen.op.op import OperationBase
|
||||
from cutlass_cppgen.shape import GemmCoord
|
||||
from cutlass_cppgen.utils import check, datatypes
|
||||
|
||||
|
||||
class Gemm(OperationBase):
|
||||
"""
|
||||
Constructs a ``Gemm`` object.
|
||||
|
||||
The data types and layouts of operands A, B, and C, along with the data type of output D
|
||||
and that used for accumulation, are bound to the ``Gemm`` object throughout its lifetime --
|
||||
these are not to be changed after a ``Gemm`` has been constructed.
|
||||
|
||||
The constructor has optional parameters for flexibly setting these parameters. The following
|
||||
constructors are equivalent:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
# Use F32 for A, B, C, D, and accumulation. All operands are row major.
|
||||
|
||||
# Use the generic ``element`` and ``layout`` parameters to concisely set all data types and layouts
|
||||
# for operands to the same values.
|
||||
Gemm(element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
|
||||
# Explicitly specify the data types to use for A, B, C, and D. Use the generic ``layout``.
|
||||
Gemm(element_A=cutlass_cppgen.DataType.f32, element_B=cutlass_cppgen.DataType.f32, element_C=cutlass_cppgen.DataType.f32,
|
||||
element_D=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
|
||||
# Set the data types and elements from existing tensors. Note that one can use different tensors when
|
||||
# executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must
|
||||
# have the same data type and layout as those passed in here).
|
||||
# A, B, C, and D are row-major torch.Tensor objects of type torch.float32
|
||||
Gemm(A=A, B=B, C=C, D=D)
|
||||
|
||||
# Use the generic ``element`` and explicitly specify the layouts to use for A, B, and C (layout of D is
|
||||
# the same as that for D, at present)
|
||||
Gemm(element=cutlass_cppgen.DataType.f32, layout_A=cutlass_cppgen.LayoutType.RowMajor,
|
||||
layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor)
|
||||
|
||||
# Explicitly specify the data type and layout for only some of A, B, C, and D. Unspecified data types
|
||||
# and layouts will inherit those passed in via the generic ``element`` and ``layout``
|
||||
Gemm(element_A=cutlass_cppgen.DataType.f32, layout_B=cutlass_cppgen.LayoutType.RowMajor,
|
||||
element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
|
||||
The order of precedence for the setting of the data type and layout for a given operand/output is as follows:
|
||||
1) If the tensor type is specified (e.g., ``A``), use the data type and layout inferred from this tensor
|
||||
2) Otherwise, if the data type/layout (e.g., ``element_A``, ``layout_A``) is specified, use those
|
||||
3) Otherwise, use the generic values (e.g., ``element``, ``layout``)
|
||||
|
||||
:param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
|
||||
:type cc: int
|
||||
:param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
|
||||
:type kernel_cc: int
|
||||
:param A: tensor representing data type and layout of operand A
|
||||
:param B: tensor representing data type and layout of operand B
|
||||
:param C: tensor representing data type and layout of operand C
|
||||
:param D: tensor representing data type and layout of operand D
|
||||
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
||||
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
||||
:param element_accumulator: data type to be used in accumulation of the product of operands A and B
|
||||
:type element_accumulator: cutlass_cppgen.DataType
|
||||
:param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
|
||||
:type element: cutlass_cppgen.DataType
|
||||
:param layout: generic layout type to be used for operands A, B, C, and D
|
||||
:type layout: cutlass_cppgen.LayoutType
|
||||
:param element_A: data type to be used for operand A
|
||||
:type element_A: cutlass_cppgen.DataType
|
||||
:param element_B: data type to be used for operand B
|
||||
:type element_B: cutlass_cppgen.DataType
|
||||
:param element_C: data type to be used for operand C
|
||||
:type element_C: cutlass_cppgen.DataType
|
||||
:param element_D: data type to be used for operand D
|
||||
:type element_D: cutlass_cppgen.DataType
|
||||
:param layout_A: layout of operand A
|
||||
:type layout_A: cutlass_cppgen.LayoutType
|
||||
:param layout_B: layout of operand B
|
||||
:type layout_B: cutlass_cppgen.LayoutType
|
||||
:param layout_C: layout of operand C
|
||||
:type layout_C: cutlass_cppgen.LayoutType
|
||||
:param layout_D: layout of operand D
|
||||
:type layout_D: cutlass_cppgen.LayoutType
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, A=None, B=None, C=None, D=None,
|
||||
alpha=1.0, beta=0.0, element_accumulator=None,
|
||||
element=None, layout=None,
|
||||
element_A=None, element_B=None, element_C=None, element_D=None,
|
||||
layout_A=None, layout_B=None, layout_C=None,
|
||||
cc: int = None, kernel_cc: int = None
|
||||
):
|
||||
super().__init__(cc=cc, kernel_cc=kernel_cc)
|
||||
self.name = "gemm"
|
||||
self.compiled = False
|
||||
|
||||
elements = []
|
||||
layouts = []
|
||||
|
||||
# Check that at least one of the following is set for each tensor (illustrated assuming tensor A):
|
||||
# ``A``, ``element_A``, ``element`` and ``A``, ``layout_A``, ``layout``
|
||||
for elt, lay, tens, name in zip([element_A, element_B, element_C, element_D],
|
||||
[layout_A, layout_B, layout_C, layout_C],
|
||||
[A, B, C, D],
|
||||
["A", "B", "C", "D"]):
|
||||
if elt is not None and tens is not None:
|
||||
raise Exception(f'Must not specify both element_{name} and tensor {name}')
|
||||
if lay is not None and tens is not None:
|
||||
raise Exception(f'Must not specify both layout_{name} and tensor {name}')
|
||||
if elt is None and tens is None and element is None:
|
||||
raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.')
|
||||
if lay is None and tens is None and layout is None:
|
||||
raise Exception(f'Must specify one of layout_{name}, tensor {name}, or generic layout.')
|
||||
|
||||
elt_to_set = None
|
||||
lay_to_set = None
|
||||
if tens is not None:
|
||||
elt_to_set, lay_to_set = datatypes.get_datatype_and_layout(tens)
|
||||
else:
|
||||
elt_to_set = elt if elt is not None else element
|
||||
lay_to_set = lay if lay is not None else layout
|
||||
|
||||
elements.append(datatypes.library_type(elt_to_set))
|
||||
layouts.append(lay_to_set)
|
||||
|
||||
self._element_a, self._element_b, self._element_c, self._element_d = elements
|
||||
self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts
|
||||
|
||||
if element_accumulator is None:
|
||||
self._element_accumulator = self._element_c
|
||||
else:
|
||||
self._element_accumulator = datatypes.library_type(element_accumulator)
|
||||
|
||||
self.A = A
|
||||
self.B = B
|
||||
self.C = C
|
||||
self.D = D
|
||||
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
|
||||
self.epilogue_functor = None
|
||||
self.op_class = None
|
||||
self._tile_description = None
|
||||
|
||||
self._reset_operations()
|
||||
|
||||
self._swizzling_functor = cutlass_cppgen.swizzle.IdentitySwizzle1
|
||||
|
||||
def _reset_operations(self, reset_epilogue: bool = True):
|
||||
# Set the default op class
|
||||
datatype_comb = (self._element_a, self._element_b, self._element_accumulator)
|
||||
layout_comb = (self._layout_a, self._layout_b)
|
||||
|
||||
self.possible_op_classes = self.options.supporting_opclasses(
|
||||
self._element_a, self._element_b, self._element_accumulator,
|
||||
self._layout_a, self._layout_b, self._math_operation)
|
||||
|
||||
if cutlass_cppgen.OpcodeClass.TensorOp in self.possible_op_classes:
|
||||
self.opclass = cutlass_cppgen.OpcodeClass.TensorOp
|
||||
elif cutlass_cppgen.OpcodeClass.Simt in self.possible_op_classes:
|
||||
self.opclass = cutlass_cppgen.OpcodeClass.Simt
|
||||
else:
|
||||
if self._math_operation is not None:
|
||||
math_op_str = f' and math operation {self._math_operation}'
|
||||
else:
|
||||
math_op_str = ''
|
||||
|
||||
raise Exception(f'No kernel configuration found for supported data type and layout '
|
||||
f'combination {datatype_comb}x{layout_comb}{math_op_str}')
|
||||
|
||||
if reset_epilogue:
|
||||
self._reset_epilogue_functor_activation(cutlass_cppgen.epilogue.identity)
|
||||
|
||||
@property
|
||||
def swizzling_functor(self):
|
||||
"""
|
||||
Returns the type of the swizzling functor currently being used by the GEMM
|
||||
|
||||
:return: swizzing functor type
|
||||
"""
|
||||
return self._swizzling_functor
|
||||
|
||||
@swizzling_functor.setter
|
||||
def swizzling_functor(self, swizzling_functor):
|
||||
"""
|
||||
Sets the swizzling functor to the type specified by `swizzling_functor`
|
||||
"""
|
||||
if swizzling_functor == cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK:
|
||||
if self.op_class == cutlass_cppgen.OpcodeClass.Simt:
|
||||
raise Exception('ThreadblockSwizzleStreamK is currently only supported with opcode class TensorOp')
|
||||
|
||||
if self.current_cc in [90, 100, 101, 103]:
|
||||
raise Exception('ThreadblockSwizzleStreamK is currently unsupported on SM90+')
|
||||
self._swizzling_functor = swizzling_functor
|
||||
|
||||
#
|
||||
# Tile description Related
|
||||
#
|
||||
|
||||
@property
|
||||
def tile_description(self) -> TileDescription:
|
||||
"""
|
||||
Returns the tile description
|
||||
"""
|
||||
return self._tile_description
|
||||
|
||||
@tile_description.setter
|
||||
def tile_description(
|
||||
self, td=None):
|
||||
"""
|
||||
Set the tile description
|
||||
|
||||
:param td: tile description
|
||||
:type td: cutlass_cppgen.backend.TileDescription, or a dict with keys
|
||||
{
|
||||
"threadblock_shape": [int, int, int],
|
||||
"warp_count": [int, int, int],
|
||||
"stages": int,
|
||||
"instruction_shape": [int, int, int] (optional),
|
||||
"cluster_shape": [int, int, int] (optional)
|
||||
}
|
||||
"""
|
||||
if td is None:
|
||||
return
|
||||
if isinstance(td, dict):
|
||||
if self._tile_description is None:
|
||||
op = self.possible_operations.default_operation(self._math_operation)
|
||||
self._tile_description = datatypes.td_from_profiler_op(op)
|
||||
td = self._tile_description.clone_and_update(td)
|
||||
|
||||
valid, msg = self._valid_tile_description(td)
|
||||
if valid:
|
||||
self._tile_description = td
|
||||
else:
|
||||
raise Exception(msg)
|
||||
|
||||
def _valid_tile_description(self, td: TileDescription) -> tuple:
|
||||
"""
|
||||
Checks whether the provided tile description is valid for the given compute capability. At present,
|
||||
this checks the following:
|
||||
|
||||
- Does the tile description use a number of stages supported by the compute capability in question?
|
||||
- Does the tile size requested fit within shared memory?
|
||||
- Are cluster dimensions outside the valid range requested for a given architecture (e.g.,
|
||||
more non-unit cluster dimensions for pre-SM90 architectures)?
|
||||
- Is the kernel schedule being used supported on the architecture in question?
|
||||
|
||||
:param td: tile description to validate
|
||||
:type td: cutlass_cppgen.backend.TileDescription
|
||||
:return: tuple in which the first element is a bool indicating that the tile description is valid
|
||||
and the second element is a string providing an optional error message.
|
||||
:rtype: tuple
|
||||
"""
|
||||
valid, msg = check.valid_stage_count(self.cc, self.current_cc, td, self._element_c, self._element_d)
|
||||
if not valid:
|
||||
return (valid, msg)
|
||||
|
||||
valid, msg = check.valid_cluster_shape(self.current_cc, td.cluster_shape)
|
||||
if not valid:
|
||||
return (valid, msg)
|
||||
|
||||
valid, msg = check.valid_schedule(self.current_cc, td.kernel_schedule, td.epilogue_schedule, td.tile_scheduler)
|
||||
|
||||
if self.cc in [100, 101, 103] and td.kernel_schedule is not None and td.is_2sm and td.cluster_shape[0] % 2 != 0:
|
||||
valid = False
|
||||
msg = "Cluster shape must be divisible by 2 for 2SM kernels on SM100, SM101, and SM103"
|
||||
|
||||
return valid, msg
|
||||
|
||||
def tile_descriptions(self) -> list:
|
||||
"""
|
||||
Returns a list of valid tile descriptions for the operations
|
||||
|
||||
:returns: list of valid tile descriptions for the operations
|
||||
:rtype: list
|
||||
"""
|
||||
tds = [datatypes.td_from_profiler_op(op) for op in self.possible_operations.all_operations]
|
||||
if self._math_operation is not None:
|
||||
tds = [td for td in tds if td.math_instruction.math_operation == self._math_operation]
|
||||
return tds
|
||||
|
||||
def construct(
|
||||
self, tile_description: TileDescription = None,
|
||||
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None) -> GemmOperationUniversal:
|
||||
"""
|
||||
Constructs a ``cutlass_cppgen.backend.GemmUniversalOperation`` based on the input parameters and current
|
||||
kernel specification of the ``Gemm`` object.
|
||||
|
||||
:param tile_description: tile description specifying shapes and operand types to use in the kernel
|
||||
:type tile_description: cutlass_cppgen.backend.TileDescription
|
||||
:param alignment_A: alignment of operand A
|
||||
:type alignment_A: int
|
||||
:param alignment_B: alignment of operand B
|
||||
:type alignment_B: int
|
||||
:param alignment_C: alignment of operand C
|
||||
:type alignment_C: int
|
||||
|
||||
:return: operation that was constructed
|
||||
:rtype: cutlass_cppgen.backend.GemmOperationUniversal
|
||||
"""
|
||||
alignment_pref_A = min(128 // DataTypeSize[self._element_a], max(self.possible_operations.alignments("A")))
|
||||
alignment_pref_B = min(128 // DataTypeSize[self._element_b], max(self.possible_operations.alignments("B")))
|
||||
alignment_A = check.alignment_or_default(alignment_A, alignment_pref_A)
|
||||
alignment_B = check.alignment_or_default(alignment_B, alignment_pref_B)
|
||||
|
||||
tensor_A = TensorDescription(self._element_a, self._layout_a, alignment_A)
|
||||
tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
|
||||
|
||||
if alignment_C is None:
|
||||
alignment_C = max(self.possible_operations.alignments("C"))
|
||||
if self._element_c != DataType.void:
|
||||
alignment_C = min(128 // DataTypeSize[self._element_c], alignment_C)
|
||||
|
||||
if tile_description is None:
|
||||
if self._tile_description is None:
|
||||
op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0]
|
||||
tile_description = datatypes.td_from_profiler_op(op)
|
||||
|
||||
# The selected op may have lower alignment than that determined above, so we must
|
||||
# reset alignment here.
|
||||
alignment_C = op.C.alignment
|
||||
else:
|
||||
tile_description = self._tile_description
|
||||
else:
|
||||
valid, err_str = self._valid_tile_description(tile_description)
|
||||
if not valid:
|
||||
raise Exception(f"Invalid tile description. {err_str}")
|
||||
self._tile_description = tile_description
|
||||
|
||||
tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
|
||||
self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor)
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=self.current_cc,
|
||||
tile_description=tile_description,
|
||||
A=tensor_A, B=tensor_B, C=tensor_C,
|
||||
epilogue_functor=self.epilogue_functor,
|
||||
swizzling_functor=self._swizzling_functor,
|
||||
)
|
||||
|
||||
return operation
|
||||
|
||||
def compile(self, tile_description: TileDescription = None,
|
||||
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
|
||||
print_module: bool = False) -> cutlass_cppgen.backend.GemmOperationUniversal:
|
||||
"""
|
||||
Emits and compiles the kernel currently specified. If ``tile_description`` and any
|
||||
of the ``alignment`` parameters are set, the kernel will be chosen using this
|
||||
tile description and alignments. Otherwise, a default tile description and alignment
|
||||
will be used.
|
||||
|
||||
:param tile_description: tile description specifying shapes and operand types to use in the kernel
|
||||
:type tile_description: cutlass_cppgen.backend.TileDescription
|
||||
:param alignment_A: alignment of operand A
|
||||
:type alignment_A: int
|
||||
:param alignment_B: alignment of operand B
|
||||
:type alignment_B: int
|
||||
:param alignment_C: alignment of operand C
|
||||
:type alignment_C: int
|
||||
:param print_module: whether to print the emitted C++ code
|
||||
:type print_module: bool
|
||||
|
||||
:return: operation that was compiled
|
||||
:rtype: cutlass_cppgen.backend.GemmOperationUniversal
|
||||
"""
|
||||
self.operation = self.construct(tile_description, alignment_A, alignment_B, alignment_C)
|
||||
|
||||
if print_module:
|
||||
print(self.operation.rt_module.emit())
|
||||
|
||||
compiler.add_module([self.operation,])
|
||||
return self.operation
|
||||
|
||||
def _verify_rank(self, tensor):
|
||||
"""
|
||||
Verifies that ``tensor`` has rank greater than 1
|
||||
|
||||
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
|
||||
:type tensor: numpy/cupy/torch array/tensor object
|
||||
"""
|
||||
if len(tensor.shape) < 2:
|
||||
raise Exception(f"Tensors must be of rank greater than 1. Received tensor of shape: {tensor.shape}")
|
||||
|
||||
def _get_batch_count(self, A, B, C, D) -> int:
|
||||
"""
|
||||
Returns the batch count specified by the tensors A, B, C, and D and verifies that these
|
||||
tensors match in batch size. Presence of a batch dimension is detected by one of the
|
||||
tensors being rank 3. If a batch dimension is present, it must be present in one of
|
||||
operands A, B, or C (but need not be in all), and must be present in D.
|
||||
|
||||
:param A: tensor A
|
||||
:type A: numpy/cupy/torch array/tensor object
|
||||
:param B: tensor B
|
||||
:type B: numpy/cupy/torch array/tensor object
|
||||
:param C: tensor C
|
||||
:type C: numpy/cupy/torch array/tensor object
|
||||
:param D: tensor D
|
||||
:type D: numpy/cupy/torch array/tensor object
|
||||
|
||||
:return: tuple of batch count dimensions
|
||||
:rtype: tuple
|
||||
"""
|
||||
A_batch = prod(A.shape[:-2]) if len(A.shape) > 2 else 1
|
||||
B_batch = prod(B.shape[:-2]) if len(B.shape) > 2 else 1
|
||||
|
||||
if 1 not in [A_batch, B_batch]:
|
||||
if A_batch != B_batch:
|
||||
raise Exception(f"Get invalid batch counts: A={A_batch}, B={B_batch}")
|
||||
return max(A_batch, B_batch)
|
||||
|
||||
def _get_batch_stride(self, tensor) -> int:
|
||||
"""
|
||||
Returns the batch stride of ``tensor``. If ``tensor`` is only rank-2, batch stride is 0.
|
||||
|
||||
:param tensor: tensor object to process
|
||||
:type tensor: numpy/cupy/torch array/tensor object
|
||||
|
||||
:return: stride between each matrix in the batch
|
||||
:rtype: int
|
||||
"""
|
||||
if tensor is not None and len(tensor.shape) > 2:
|
||||
return tensor.shape[-2] * tensor.shape[-1]
|
||||
else:
|
||||
return 0
|
||||
|
||||
def _get_problem_args(self, A, B, C, D) -> tuple:
|
||||
"""
|
||||
Returns the problem size and GEMM universal mode to use for the
|
||||
given operands.
|
||||
|
||||
:param A: tensor A
|
||||
:type A: numpy/cupy/torch array/tensor object
|
||||
:param B: tensor B
|
||||
:type B: numpy/cupy/torch array/tensor object
|
||||
:param C: tensor C
|
||||
:type C: numpy/cupy/torch array/tensor object
|
||||
:param D: tensor D
|
||||
:type D: numpy/cupy/torch array/tensor object
|
||||
|
||||
:return: tuple containing the problem size (cutlass_cppgen.shape.GemmCoord), the GEMM mode (cutlass_cppgen.GemmUniversalMode), and the batch count (int)
|
||||
:rtype: tuple
|
||||
"""
|
||||
M, K = A.shape[-2:]
|
||||
N = B.shape[-1]
|
||||
mode = GemmUniversalMode.Gemm
|
||||
|
||||
batch_count = self._get_batch_count(A, B, C, D)
|
||||
returned_batch_count = batch_count
|
||||
|
||||
# If we are running a batched GEMM in which there is a nonzero batch stride
|
||||
# only for A, then we can fold the batched dimension of A into the M dimension
|
||||
# (i.e., (b, m, k) x (k, n) -> (m*b, k) x (k, n)). This works only if both A
|
||||
# and C are row major. A similar operation can be performed if only B has a nonzero
|
||||
# batch dimension
|
||||
if batch_count > 1:
|
||||
A_row = self._layout_a == cutlass_cppgen.LayoutType.RowMajor
|
||||
B_row = self._layout_b == cutlass_cppgen.LayoutType.RowMajor
|
||||
C_row = self._layout_c == cutlass_cppgen.LayoutType.RowMajor
|
||||
|
||||
# Consider a Tensor to be batched if its rank is > 2 and
|
||||
# the product of the modes beyond rank 2 equals our pre-determined batch size.
|
||||
batched = lambda x : x is None or (len(x.shape) > 2 and prod(x.shape[:-2]) == batch_count)
|
||||
|
||||
if batched(A) and not batched(B) and (C is None or batched(C)) and A_row and C_row:
|
||||
M *= batch_count
|
||||
returned_batch_count = 1
|
||||
elif not batched(A) and batched(B) and (C is None or batched(C)) and not B_row and not C_row:
|
||||
N *= batch_count
|
||||
returned_batch_count = 1
|
||||
else:
|
||||
mode = GemmUniversalMode.Batched
|
||||
|
||||
return GemmCoord(M, N, K), mode, returned_batch_count
|
||||
|
||||
def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name):
|
||||
"""
|
||||
Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception
|
||||
is raised if it does not.
|
||||
|
||||
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
|
||||
:type tensor: numpy/cupy/torch array/tensor object
|
||||
:param ref_dtype: data type for the tensor that this object was initialized to
|
||||
:param ref_layout: layout for the tensor that this object was initialized to
|
||||
:param name: identifier of the tensor to verify. Used in raising exceptions
|
||||
:type name: str
|
||||
"""
|
||||
dtype, layout = datatypes.get_datatype_and_layout(tensor)
|
||||
if dtype != ref_type or layout != ref_layout:
|
||||
try:
|
||||
# Attempt to transpose the tensor to fit the desired layout
|
||||
tensor = tensor.transpose(-1, -2)
|
||||
except:
|
||||
raise Exception(f'Tensor {name} with type and layout ({dtype}, {layout}) '
|
||||
f'does not match the expected type and '
|
||||
f'layout of ({ref_type}, {ref_layout}) and transpose failed.')
|
||||
|
||||
def run(self, A=None, B=None, C=None, D=None,
|
||||
alpha=None, beta=None, sync: bool = True, print_module: bool = False, visitor_args: dict = None,
|
||||
stream: Optional[cuda.CUstream] = None) -> GemmArguments:
|
||||
"""
|
||||
Runs the kernel currently specified. If it has not already been, the kernel is emitted and
|
||||
compiled. Tensors holding operands and outputs of the kernel are sourced either from the
|
||||
``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta``
|
||||
parameters provided in this call, or from those
|
||||
passed in on the construction of this object -- one of the two must be specified.
|
||||
|
||||
By default, this call returns only once the kernel has completed. To launch the kernel
|
||||
and immediately return, set ``sync=False``. In this case, it is the responsibility of the
|
||||
caller to syncrhonize the results of the kernel before attempting to access outputs
|
||||
by calling ``sync()`` on the arguments returned from this call.
|
||||
|
||||
:param A: tensor representing data type and layout of operand A
|
||||
:param B: tensor representing data type and layout of operand B
|
||||
:param C: tensor representing data type and layout of operand C
|
||||
:param D: tensor representing data type and layout of operand D
|
||||
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
||||
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
||||
:param sync: whether the call should wait for the kernel to complete before returning
|
||||
:type sync: bool
|
||||
:param print_module: whether to print the emitted C++ code
|
||||
:type print_module: bool
|
||||
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
||||
:type stream: :class:`cuda.cuda.CUstream`
|
||||
|
||||
:return: arguments passed in to the kernel
|
||||
:rtype: cutlass_cppgen.backend.GemmArguments
|
||||
"""
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
super().run_setup()
|
||||
A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A")
|
||||
B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B")
|
||||
C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C")
|
||||
D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D")
|
||||
alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
|
||||
beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
|
||||
|
||||
is_void_c = self._element_c == DataType.void
|
||||
|
||||
self._verify_rank(A)
|
||||
self._verify_rank(B)
|
||||
if not is_void_c:
|
||||
self._verify_rank(C)
|
||||
self._verify_rank(D)
|
||||
|
||||
alignment_a = self.possible_operations.find_alignment(A.shape, self._layout_a, operand="A")
|
||||
alignment_b = self.possible_operations.find_alignment(B.shape, self._layout_b, operand="B")
|
||||
|
||||
# Set C alignment based on D.shape so as to correctly get an alignment with void-C
|
||||
# kernels, for which `C` is None.
|
||||
alignment_c = self.possible_operations.find_alignment(D.shape, self._layout_c, operand="C")
|
||||
self.compile(self._tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
|
||||
alignment_C=alignment_c, print_module=print_module)
|
||||
|
||||
problem_size, mode, batch_count = self._get_problem_args(A, B, C, D)
|
||||
|
||||
if mode == GemmUniversalMode.Gemm or batch_count == 1:
|
||||
kwargs = {'split_k_slices': 1}
|
||||
else:
|
||||
kwargs = {
|
||||
'batch': batch_count,
|
||||
'batch_strides': {
|
||||
'A': self._get_batch_stride(A),
|
||||
'B': self._get_batch_stride(B),
|
||||
'C': self._get_batch_stride(C),
|
||||
'D': self._get_batch_stride(D)
|
||||
}
|
||||
}
|
||||
|
||||
kwargs['stream'] = stream
|
||||
|
||||
if isinstance(self.epilogue_functor, EpilogueFunctorVisitor):
|
||||
output_op = self.operation.epilogue_type(visitor_args)
|
||||
else:
|
||||
output_op = self.operation.epilogue_type(alpha, beta)
|
||||
|
||||
arguments = GemmArguments(
|
||||
operation=self.operation, problem_size=problem_size,
|
||||
A=A, B=B, C=C, D=D,
|
||||
output_op=output_op,
|
||||
gemm_mode=mode,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self.operation.run(arguments)
|
||||
|
||||
if sync:
|
||||
arguments.sync()
|
||||
|
||||
return arguments
|
||||
269
python/cutlass_cppgen/op/gemm_grouped.py
Normal file
269
python/cutlass_cppgen/op/gemm_grouped.py
Normal file
@ -0,0 +1,269 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Ease-of-use interface for constructing, compiling, and running GEMMs.
|
||||
|
||||
The ``GroupedGemm`` interface is meant to allow one to easily instantiate, compile, and run
|
||||
grouped GEMM operations in CUTLASS via Python, without specifying many configuration parameters.
|
||||
Under the hood, the interface will select sensible default parameters for the many template
|
||||
parameters for CUTLASS grouped GEMMs.
|
||||
|
||||
Note: optimal performance is not to be expected from this interface. To achieve optimal
|
||||
performance, one should specify and tune each configuration parameter.
|
||||
|
||||
The simplest example of using this interface is the following:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
# As, Bs, Cs, and Ds are torch/numpy/cupy tensor objects
|
||||
plan = cutlass_cppgen.op.GroupedGemm(element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
plan.run([A0, A1], [B0, B1], [C0, C1], [D0, D1])
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from typing import Optional
|
||||
from cutlass_library import DataTypeSize
|
||||
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
from cutlass_cppgen.backend.gemm_operation import (
|
||||
GemmGroupedArguments,
|
||||
GemmOperationGrouped,
|
||||
)
|
||||
from cutlass_cppgen.backend.library import (
|
||||
SchedulerMode,
|
||||
TensorDescription,
|
||||
TileDescription,
|
||||
)
|
||||
from cutlass_cppgen.op.gemm import Gemm
|
||||
from cutlass_cppgen.shape import GemmCoord
|
||||
from cutlass_cppgen.utils import check, datatypes
|
||||
|
||||
|
||||
class GroupedGemm(Gemm):
|
||||
"""
|
||||
Constructs a ``GroupedGemm`` object.
|
||||
|
||||
The data types and layouts of operands A, B, and C, along with the data type of output D
|
||||
and that used for accumulation, are bound to the ``GroupedGemm`` object throughout its lifetime --
|
||||
these are not to be changed after a ``GroupedGemm`` has been constructed.
|
||||
|
||||
The constructor has optional parameters for flexibly setting these parameters. Please see the constructor
|
||||
for ``Gemm`` for examples of these.
|
||||
|
||||
:param cc: compute capability of device to generate kernels for
|
||||
:type cc: int
|
||||
:param A: tensor representing data type and layout of operands A
|
||||
:param B: tensor representing data type and layout of operands B
|
||||
:param C: tensor representing data type and layout of operands C
|
||||
:param D: tensor representing data type and layout of operands D
|
||||
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
||||
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
||||
:param element_accumulator: data type to be used in accumulation of the product of operands A and B
|
||||
:type element_accumulator: cutlass_cppgen.DataType
|
||||
:param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
|
||||
:type element: cutlass_cppgen.DataType
|
||||
:param layout: generic layout type to be used for operands A, B, C, and D
|
||||
:type layout: cutlass_cppgen.LayoutType
|
||||
:param element_A: data type to be used for operand A
|
||||
:type element_A: cutlass_cppgen.DataType
|
||||
:param element_B: data type to be used for operand B
|
||||
:type element_B: cutlass_cppgen.DataType
|
||||
:param element_C: data type to be used for operand C
|
||||
:type element_C: cutlass_cppgen.DataType
|
||||
:param element_D: data type to be used for operand D
|
||||
:type element_D: cutlass_cppgen.DataType
|
||||
:type layout_A: layout of operand A
|
||||
:param layout_A: cutlass_cppgen.LayoutType
|
||||
:type layout_B: layout of operand B
|
||||
:param layout_B: cutlass_cppgen.LayoutType
|
||||
:type layout_C: layout of operand C
|
||||
:param layout_C: cutlass_cppgen.LayoutType
|
||||
:type layout_D: layout of operand D
|
||||
:param layout_D: cutlass_cppgen.LayoutType
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, A=None, B=None, C=None, D=None,
|
||||
alpha=1.0, beta=0.0, element_accumulator=None,
|
||||
element=None, layout=None,
|
||||
element_A=None, element_B=None, element_C=None, element_D=None,
|
||||
layout_A=None, layout_B=None, layout_C=None,
|
||||
cc: int = None,
|
||||
):
|
||||
super().__init__(
|
||||
A=A, B=B, C=C, D=D,
|
||||
alpha=alpha, beta=beta,
|
||||
element_accumulator=element_accumulator,
|
||||
element=element, layout=layout,
|
||||
element_A=element_A, element_B=element_B,
|
||||
element_C=element_C, element_D=element_D,
|
||||
layout_A=layout_A, layout_B=layout_B, layout_C=layout_C,
|
||||
cc=cc
|
||||
)
|
||||
|
||||
# Grouped GEMM specializations for SM90 are currently unavailable. Revert to using SM80
|
||||
if self.current_cc in [90, 100, 101, 103]:
|
||||
self._reset_options(80)
|
||||
self._reset_operations(reset_epilogue=False)
|
||||
|
||||
self.name = "grouped_gemm"
|
||||
|
||||
@Gemm.swizzling_functor.setter
|
||||
def swizzling_functor(self, swizzling_functor):
|
||||
"""
|
||||
Sets the swizzling functor to the type specified by `swizzling_functor`
|
||||
"""
|
||||
raise Exception('Grouped GEMM does not currently support different swizzling functors')
|
||||
|
||||
def construct(self, tile_description: TileDescription = None,
|
||||
alignment_A: int = None,
|
||||
alignment_B: int = None,
|
||||
alignment_C: int = None) -> GemmOperationGrouped:
|
||||
"""
|
||||
Constructs a ``cutlass_cppgen.backend.GemmOperationGrouped`` based on the input parameters and current
|
||||
kernel specification of the ``Gemm`` object.
|
||||
|
||||
:param tile_description: tile description specifying shapes and operand types to use in the kernel
|
||||
:type tile_description: cutlass_cppgen.backend.TileDescription
|
||||
:param alignment_A: alignment of operand A
|
||||
:type alignment_A: int
|
||||
:param alignment_B: alignment of operand B
|
||||
:type alignment_B: int
|
||||
:param alignment_C: alignment of operand C
|
||||
:type alignment_C: int
|
||||
|
||||
:return: operation that was constructed
|
||||
:rtype: cutlass_cppgen.backend.GemmOperationGrouped
|
||||
"""
|
||||
alignment_A = check.alignment_or_default(alignment_A, max(self.possible_operations.alignments("A")))
|
||||
alignment_B = check.alignment_or_default(alignment_B, max(self.possible_operations.alignments("B")))
|
||||
alignment_C = check.alignment_or_default(alignment_C, max(self.possible_operations.alignments("C")))
|
||||
|
||||
self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor)
|
||||
|
||||
tensor_A = TensorDescription(self._element_a, self._layout_b, alignment_A)
|
||||
tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
|
||||
tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
|
||||
|
||||
if tile_description is None:
|
||||
op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0]
|
||||
tile_description = datatypes.td_from_profiler_op(op)
|
||||
else:
|
||||
valid, err_str = self._valid_tile_description(tile_description)
|
||||
if not valid:
|
||||
raise Exception(f"Invalid tile description. {err_str}")
|
||||
self.tile_description = tile_description
|
||||
|
||||
operation = GemmOperationGrouped(
|
||||
arch=self.current_cc,
|
||||
tile_description=tile_description,
|
||||
A=tensor_A, B=tensor_B, C=tensor_C,
|
||||
epilogue_functor=self.epilogue_functor,
|
||||
swizzling_functor=self._swizzling_functor,
|
||||
precompute_mode=SchedulerMode.Device)
|
||||
|
||||
return operation
|
||||
|
||||
def run(self, A, B, C, D,
|
||||
alpha=None, beta=None, sync: bool = True,
|
||||
print_module: bool = False,
|
||||
stream: Optional[cuda.CUstream] = None) -> GemmGroupedArguments:
|
||||
"""
|
||||
Runs the kernel currently specified.
|
||||
|
||||
By default, this call returns only once the kernel has completed. To launch the kernel
|
||||
and immediately return, set ``sync=False``. In this case, it is the responsibility of the
|
||||
caller to syncrhonize the results of the kernel before attempting to access outputs
|
||||
by calling ``sync()`` on the arguments returned from this call.
|
||||
|
||||
:param A: list of tensors representing data type and layout of operand A
|
||||
:type A: list
|
||||
:param B: list of tensors representing data type and layout of operand B
|
||||
:type B: list
|
||||
:param C: list of tensors representing data type and layout of operand C
|
||||
:type C: list
|
||||
:param D: list of tensors representing data type and layout of operand D
|
||||
:type D: list
|
||||
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
||||
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
||||
:param sync: whether the call should wait for the kernel to complete before returning
|
||||
:type sync: bool
|
||||
:param print_module: whether to print the emitted C++ code
|
||||
:type print_module: bool
|
||||
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
||||
:type stream: :class:`cuda.cuda.CUstream`
|
||||
|
||||
:return: arguments passed in to the kernel
|
||||
:rtype: cutlass_cppgen.backend.GemmGroupedArguments
|
||||
"""
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
|
||||
super().run_setup()
|
||||
|
||||
if len(A) != len(B) or len(A) != len(C) or len(A) != len(D):
|
||||
raise Exception("Lengths of A, B, C, and D lists must be equal")
|
||||
|
||||
problem_sizes = []
|
||||
As, Bs, Cs, Ds = ([None] * len(A) for _ in range(4))
|
||||
for i in range(len(A)):
|
||||
As[i] = self._verify_tensor(A[i], self.A, self._element_a, self._layout_a, "A")
|
||||
Bs[i] = self._verify_tensor(B[i], self.B, self._element_b, self._layout_b, "B")
|
||||
Cs[i] = self._verify_tensor(C[i], self.C, self._element_c, self._layout_c, "C")
|
||||
Ds[i] = self._verify_tensor(D[i], self.D, self._element_d, self._layout_d, "D")
|
||||
problem_sizes.append(GemmCoord(A[i].shape[0], B[i].shape[1], A[i].shape[1]))
|
||||
|
||||
alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
|
||||
beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
|
||||
|
||||
alignment_a = min((self.possible_operations.find_alignment(A.shape, self._layout_a, operand="A") for A in As))
|
||||
alignment_b = min((self.possible_operations.find_alignment(B.shape, self._layout_b, operand="B") for B in Bs))
|
||||
alignment_c = min((self.possible_operations.find_alignment(C.shape, self._layout_c, operand="C") for C in Cs))
|
||||
self.compile(self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
|
||||
alignment_C=alignment_c, print_module=print_module)
|
||||
|
||||
arguments = GemmGroupedArguments(
|
||||
operation=self.operation,
|
||||
problem_sizes=problem_sizes,
|
||||
A=As, B=Bs, C=Cs, D=Ds,
|
||||
output_op=self.operation.epilogue_type(alpha, beta),
|
||||
stream=stream
|
||||
)
|
||||
|
||||
self.operation.run(arguments)
|
||||
|
||||
if sync:
|
||||
arguments.sync()
|
||||
|
||||
return arguments
|
||||
431
python/cutlass_cppgen/op/op.py
Normal file
431
python/cutlass_cppgen/op/op.py
Normal file
@ -0,0 +1,431 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d)
|
||||
"""
|
||||
|
||||
from bisect import bisect_left
|
||||
|
||||
from cutlass_library import (
|
||||
DataType,
|
||||
DataTypeSize,
|
||||
MathOperation,
|
||||
OperationKind,
|
||||
SharedMemPerCC
|
||||
)
|
||||
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen import get_option_registry
|
||||
from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor
|
||||
from cutlass_cppgen.backend.evt.passes.util import cc_map
|
||||
from cutlass_cppgen.backend.utils.device import device_cc
|
||||
from cutlass_cppgen.epilogue import get_activations, get_activation_epilogue, identity
|
||||
from cutlass_cppgen.library_defaults import KernelsForDataType, _generator_ccs
|
||||
from cutlass_cppgen.swizzle import get_swizzling_functors
|
||||
from cutlass_cppgen.utils import datatypes, check
|
||||
|
||||
|
||||
class OperationBase:
|
||||
"""
|
||||
Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d)
|
||||
"""
|
||||
|
||||
def __init__(self, cc: int = None, kernel_cc: int = None, operation_kind = OperationKind.Gemm):
|
||||
"""
|
||||
:param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
|
||||
:type cc: int
|
||||
:param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
|
||||
:type kernel_cc: int
|
||||
:param operation_kind: class of operation that will be performed (e.g., GEMM, Conv)
|
||||
:type operation_kind: cutlass_library.OperationKind
|
||||
"""
|
||||
self.operation_kind = operation_kind
|
||||
self.cc = cc if cc is not None else device_cc()
|
||||
self.specified_kernel_cc = kernel_cc is not None
|
||||
self.current_cc = kernel_cc if kernel_cc is not None else self._find_closest_cc(self.cc)
|
||||
self.tile_description = None
|
||||
self._math_operation = None
|
||||
|
||||
self.options = get_option_registry().options_for_cc(self.current_cc, operation_kind)
|
||||
|
||||
if self.options is None:
|
||||
raise Exception(f"Invalid or unsupported compute capability: {self.current_cc}")
|
||||
|
||||
# Default activation function: identity
|
||||
self._activation = identity
|
||||
|
||||
def _find_closest_cc(self, cc: int) -> int:
|
||||
"""
|
||||
Returns the closest CC in _generator_ccs less than or equal to `cc`
|
||||
|
||||
:param cc: compute capability to query
|
||||
:type cc: int
|
||||
|
||||
:returns: closest CC in _generator_ccs less than or equal to `cc`
|
||||
:rtype: int
|
||||
"""
|
||||
if cc in _generator_ccs:
|
||||
return cc
|
||||
|
||||
# Find closest CC lower than this CC
|
||||
idx = bisect_left(_generator_ccs, cc)
|
||||
if idx == 0:
|
||||
raise Exception(f'No valid CC to fall back to for {cc}')
|
||||
return _generator_ccs[idx-1]
|
||||
|
||||
def activations(self) -> list:
|
||||
"""
|
||||
Returns possible activation functions that can be used
|
||||
|
||||
:return: list of activation functions that can be used
|
||||
:rtype: list
|
||||
"""
|
||||
return get_activations()
|
||||
|
||||
def swizzling_functors(self) -> list:
|
||||
"""
|
||||
Returns possible swizzling functions that can be used
|
||||
|
||||
:return: list of swizzling functions that can be used
|
||||
:rtype: list
|
||||
"""
|
||||
return get_swizzling_functors()
|
||||
|
||||
def _reset_options(self, cc: int):
|
||||
"""
|
||||
Resets the kernel options based on cc
|
||||
|
||||
:param cc: compute capability to reset to
|
||||
:type cc: int
|
||||
"""
|
||||
if cc != self.current_cc:
|
||||
if cc not in _generator_ccs:
|
||||
raise Exception(f'Invalid CC for CUTLASS kernels: {cc}.')
|
||||
self.current_cc = cc
|
||||
self.options = get_option_registry().options_for_cc(self.current_cc, self.operation_kind)
|
||||
|
||||
def _verify_scalar(self, scalar, ref_scalar, ref_dtype, name):
|
||||
"""
|
||||
Verifies the following properties:
|
||||
1) Either ``scalar`` or ``ref_scakar`` must be set (i.e., not ``None``)
|
||||
2) If ``scalar`` is not ``None``, its datatype must match matches the current version
|
||||
set by the plan (i.e., those in ``ref_dtype``)
|
||||
|
||||
If either of these properties does not hold, an exception is raised. If these properties hold and
|
||||
``scalar`` is not ``None``, ``scalar`` is returned. Otherwise, ``ref_scalar`` is returned.
|
||||
|
||||
:param scalar: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
|
||||
:type scalar: numpy/cupy/torch scalar
|
||||
:param ref_scalar: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in
|
||||
:type ref_scalar: numpy/cupy/torch scalar
|
||||
:param ref_dtype: data type for the scalar that this object was initialized to
|
||||
:param name: identifier of the scalar to verify. Used in raising exceptions
|
||||
:type name: str
|
||||
|
||||
:return: valid scalar to use
|
||||
:rtype: numpy/cupy/torch scalar
|
||||
"""
|
||||
if scalar is None:
|
||||
if ref_scalar is None:
|
||||
raise Exception(f"Scalar {name} must be set.")
|
||||
return ref_scalar
|
||||
if hasattr(scalar, "dtype"):
|
||||
dtype = datatypes.library_type(scalar.dtype)
|
||||
if dtype != ref_dtype:
|
||||
raise Exception(
|
||||
f"Tensor {name} with type {dtype} does not match expected type {ref_dtype}."
|
||||
)
|
||||
return scalar
|
||||
|
||||
def _verify_tensor(self, tensor, ref_tensor, ref_dtype, ref_layout, name):
|
||||
"""
|
||||
Verifies the following properties:
|
||||
If ref_dtype is not void:
|
||||
1) Either ``tensor`` or ``ref_tensor`` must be set (i.e., not ``None``)
|
||||
2) If ``tensor`` is not ``None``, its datatype and layout must match matches the current versions
|
||||
set by the plan (i.e., those in ``ref_dtype`` and ``ref_layout``)
|
||||
If ref_dtype is void:
|
||||
Neither ``tensor`` nor ``ref_tensor`` are set
|
||||
|
||||
If either of these properties does not hold, an exception is raised. If these properties hold and
|
||||
``tensor`` is not ``None``, ``tensor`` is returned. Otherwise, ``ref_tensor`` is returned.
|
||||
|
||||
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
|
||||
:type tensor: numpy/cupy/torch array/tensor object
|
||||
:param ref_tensor: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in
|
||||
:type ref_tensor: numpy/cupy/torch array/tensor object
|
||||
:param ref_dtype: data type for the tensor that this object was initialized to
|
||||
:param ref_layout: layout for the tensor that this object was initialized to
|
||||
:param name: identifier of the tensor to verify. Used in raising exceptions
|
||||
:type name: str
|
||||
|
||||
:return: valid tensor object to use
|
||||
:rtype: numpy/cupy/torch array/tensor object
|
||||
"""
|
||||
if ref_dtype == DataType.void:
|
||||
if tensor is not None or ref_tensor is not None:
|
||||
raise Exception("Operands with element DataType.void must not be provided a tensor")
|
||||
return None
|
||||
|
||||
if tensor is None:
|
||||
if ref_tensor is None:
|
||||
raise Exception(f"Tensor {name} must be set.")
|
||||
return ref_tensor
|
||||
|
||||
self._verify_type_and_layout(tensor, ref_dtype, ref_layout, name)
|
||||
return tensor
|
||||
|
||||
@property
|
||||
def opclass(self) -> cutlass_cppgen.OpcodeClass:
|
||||
"""
|
||||
Returns the opcode class currently in use
|
||||
|
||||
:return: opcode class currently in use
|
||||
:rtype: cutlass_cppgen.OpcodeClass
|
||||
"""
|
||||
return self.op_class
|
||||
|
||||
@opclass.setter
|
||||
def opclass(self, oc: cutlass_cppgen.OpcodeClass):
|
||||
if isinstance(oc, str):
|
||||
oc = datatypes.getattr_enum(cutlass_cppgen.OpcodeClass, oc)
|
||||
if oc in self.possible_op_classes:
|
||||
self.op_class = oc
|
||||
else:
|
||||
raise Exception(
|
||||
f'Unsupported operation class {oc} for CC {self.cc} and data type combination '
|
||||
f'({self._element_a}, {self._element_b}, {self._element_accumulator}) and '
|
||||
f'layout combination ({self._layout_a}, {self._layout_b}).')
|
||||
|
||||
# Changing the op class also changes the possible operations available. Reset these.
|
||||
self.possible_operations = self.options.operations(
|
||||
self.op_class, self._element_a, self._element_b,
|
||||
self._element_accumulator, self._layout_a, self._layout_b, self._math_operation)
|
||||
|
||||
# Changing the op class changes the elements per access in the epilogue. Reset this.
|
||||
if self.epilogue_functor is not None:
|
||||
self.epilogue_functor = self._reset_epilogue_functor_alignment(self._elements_per_access(), self.epilogue_functor)
|
||||
|
||||
@property
|
||||
def math_operation(self) -> cutlass_cppgen.MathOperation:
|
||||
"""
|
||||
Returns the math operation currently in use
|
||||
|
||||
:return: math operation currently in use
|
||||
:rtype: cutlass_cppgen.MathOperation
|
||||
"""
|
||||
return self._math_operation
|
||||
|
||||
@math_operation.setter
|
||||
def math_operation(self, mo: cutlass_cppgen.MathOperation):
|
||||
if isinstance(mo, str):
|
||||
mo = datatypes.getattr_enum(cutlass_cppgen.MathOperation, mo)
|
||||
|
||||
if not self.specified_kernel_cc:
|
||||
if self.current_cc in [90, 100, 101, 103]:
|
||||
# CUTLASS 3.0 kernels do not use different math operations. If one is specified, we
|
||||
# revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels.
|
||||
cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
|
||||
self._reset_options(80)
|
||||
self._reset_operations(reset_epilogue=False)
|
||||
elif self.current_cc in [90, 100, 101, 103]:
|
||||
raise Exception("CUTLASS 3.0 kernels do not use different math operations. "
|
||||
"To use 2.x kernels with a specific math operation, do not set the `kernel_cc`"
|
||||
"parameter when constructing the plan.")
|
||||
|
||||
self._math_operation = mo
|
||||
self._reset_operations()
|
||||
|
||||
def _elements_per_access(self):
|
||||
if self.op_class == cutlass_cppgen.OpcodeClass.Simt:
|
||||
return 1
|
||||
elif self._element_c != DataType.void:
|
||||
return 128 // DataTypeSize[self._element_c]
|
||||
else:
|
||||
return 128 // max(self.possible_operations.alignments("C"))
|
||||
|
||||
def _create_epilogue_functor_activation(self, activation):
|
||||
"""
|
||||
Returns the epilogue functor with given activation function
|
||||
"""
|
||||
if self.epilogue_functor is None:
|
||||
elements_per_access = self._elements_per_access()
|
||||
else:
|
||||
elements_per_access = self.epilogue_functor.epilogue_vector_length
|
||||
|
||||
if not self.specified_kernel_cc:
|
||||
if self.current_cc in [90, 100, 101, 103] and activation != identity:
|
||||
# CUTLASS 3.0 kernels in Python currently only support identity activation. If one requests a non-identity activation,
|
||||
# revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels.
|
||||
cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
|
||||
if self._element_c != self._element_d:
|
||||
raise Exception("CUTLASS 2.x kernels require element C to be the same as element D")
|
||||
self._reset_options(80)
|
||||
self._reset_operations(reset_epilogue=False)
|
||||
elif (self.cc in [90, 100, 101, 103] and self.current_cc not in [90, 100, 101, 103] and activation == identity and self._math_operation is None):
|
||||
# SM80 fallback kernels are currently used. Since an identity activation is requested,
|
||||
# we can switch back to using SM90 kernels.
|
||||
self._reset_options(self.cc)
|
||||
self._reset_operations(reset_epilogue=False)
|
||||
else:
|
||||
if self.current_cc in [90, 100, 101, 103] and activation != identity:
|
||||
raise Exception("Epilogues with elementwise fusion are not currently supported "
|
||||
"in the Python interface for 3.x kernels. To use 2.x kernels "
|
||||
"with fused elementwise epilogues, do not set the `kernel_cc` "
|
||||
"parameter when constructing the plan.")
|
||||
|
||||
return get_activation_epilogue(
|
||||
activation,
|
||||
self._element_d,
|
||||
elements_per_access,
|
||||
self._element_accumulator,
|
||||
self._element_accumulator,
|
||||
)
|
||||
|
||||
def _reset_epilogue_functor_activation(self, activation):
|
||||
"""
|
||||
Set the epilogue functor based on the provided activation function
|
||||
"""
|
||||
self.epilogue_functor = self._create_epilogue_functor_activation(activation)
|
||||
|
||||
def _reset_epilogue_functor_alignment(self, alignment, epilogue_functor):
|
||||
"""
|
||||
Reset the alignment of the current epilogue functor based on alignment C
|
||||
"""
|
||||
if isinstance(epilogue_functor, EpilogueFunctorVisitor):
|
||||
return epilogue_functor
|
||||
|
||||
if epilogue_functor is None or not hasattr(epilogue_functor, 'activation_functor'):
|
||||
# Identity epilogue does not have 'activation_functor'
|
||||
activation = identity
|
||||
else:
|
||||
activation = epilogue_functor.activation_functor
|
||||
|
||||
epilogue_functor = get_activation_epilogue(
|
||||
activation,
|
||||
self._element_d,
|
||||
alignment,
|
||||
self._element_accumulator,
|
||||
self._element_accumulator,
|
||||
)
|
||||
return epilogue_functor
|
||||
|
||||
@property
|
||||
def activation(self):
|
||||
"""
|
||||
Returns the type of the current activation function used
|
||||
"""
|
||||
if hasattr(self.epilogue_functor, "activation_functor"):
|
||||
return self.epilogue_functor.activation_functor
|
||||
else:
|
||||
return identity
|
||||
|
||||
@activation.setter
|
||||
def activation(self, act):
|
||||
"""
|
||||
Sets the type of the activation function to use
|
||||
Activation can come with a set of arguments
|
||||
|
||||
:param act: type of activation function to use
|
||||
:type act: str or tuple. e.g. "relu", ("leaky_relu", 0.01)
|
||||
|
||||
"""
|
||||
if isinstance(act, tuple):
|
||||
if isinstance(act[0], str):
|
||||
act_fn = getattr(cutlass_cppgen.backend.epilogue, act[0])
|
||||
else:
|
||||
act_fn = act[0]
|
||||
self._reset_epilogue_functor_activation(act_fn)
|
||||
self._activation_args = act[1]
|
||||
self._activation = act[0]
|
||||
else:
|
||||
if isinstance(act, str):
|
||||
act = getattr(cutlass_cppgen.backend.epilogue, act)
|
||||
self._reset_epilogue_functor_activation(act)
|
||||
self._activation = act
|
||||
|
||||
@property
|
||||
def epilogue_visitor(self):
|
||||
"""
|
||||
Return the epilogue functor
|
||||
"""
|
||||
return self.epilogue_functor
|
||||
|
||||
@epilogue_visitor.setter
|
||||
def epilogue_visitor(self, visitor):
|
||||
"""
|
||||
Create the epilogue visitor
|
||||
"""
|
||||
self.epilogue_functor = EpilogueFunctorVisitor(cc_map[self.cc], visitor)
|
||||
|
||||
# The epilogue_functor may consume too much shared memory
|
||||
# Reset the possible operations
|
||||
if self.cc not in [90, 100, 101, 103]:
|
||||
# The shared memory is only a concern for sm90+ epilogue
|
||||
# In sm80, the epilogue and mainloop share the shared memory
|
||||
return
|
||||
|
||||
datatype_comb = self.possible_operations.datatype_comb
|
||||
layout_comb = self.possible_operations.layout_comb
|
||||
new_possible_operations = KernelsForDataType(datatype_comb, layout_comb)
|
||||
for operation in self.possible_operations.all_operations:
|
||||
td = datatypes.td_from_profiler_op(operation)
|
||||
# Filter invalid epilogue schedules
|
||||
if cc_map[self.cc] == 90 and td.epilogue_schedule not in [
|
||||
cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized,
|
||||
cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative]:
|
||||
continue
|
||||
epilogue_smem_bytes = self.epilogue_functor.get_smem_size(td)
|
||||
|
||||
# Verify the maximum number of mainloop stages
|
||||
mainloop_smem_per_stage = check.calculate_smem_usage_per_stage(td, OperationKind.Gemm)
|
||||
smem_capacity_bytes = SharedMemPerCC[self.cc] << 10
|
||||
mainloop_stages = (smem_capacity_bytes - epilogue_smem_bytes) // mainloop_smem_per_stage
|
||||
if mainloop_stages < 2:
|
||||
# Mainloop stages must >= 2
|
||||
continue
|
||||
|
||||
new_possible_operations.add(operation)
|
||||
if len(new_possible_operations.all_operations) == 0:
|
||||
raise RuntimeError(
|
||||
"The epilogue consumes too much shared memory. "
|
||||
"No valid tile description is found in the generator.")
|
||||
self.possible_operations = new_possible_operations
|
||||
|
||||
|
||||
def run_setup(self):
|
||||
"""
|
||||
Steps that must be taken before caling `plan.run()`
|
||||
"""
|
||||
# Initialize the memory pool if, if not already done
|
||||
cutlass_cppgen.get_memory_pool()
|
||||
184
python/cutlass_cppgen/shape.py
Normal file
184
python/cutlass_cppgen/shape.py
Normal file
@ -0,0 +1,184 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Utilities for expressing shapes
|
||||
"""
|
||||
|
||||
from cutlass_library import (
|
||||
ConvMode,
|
||||
ConvKind,
|
||||
LayoutType
|
||||
)
|
||||
from cutlass_cppgen.backend.c_types import (
|
||||
Conv2DProblemSize_,
|
||||
GemmCoord_,
|
||||
GemmCoordBatched_
|
||||
)
|
||||
|
||||
|
||||
class MatrixCoord:
|
||||
def __init__(self, row, col):
|
||||
self._row = row
|
||||
self._col = col
|
||||
|
||||
@property
|
||||
def row(self):
|
||||
return self._row
|
||||
|
||||
@property
|
||||
def column(self):
|
||||
return self._col
|
||||
|
||||
def leading_dimension(self, layout: LayoutType) -> int:
|
||||
"""
|
||||
Returns the leading dimension for a matrix with layout ``layout`` and shape provided by the MatrixCoord.
|
||||
|
||||
:param layout: layout of matrix
|
||||
:type layout: cutlass_library.LayoutType
|
||||
|
||||
:returns: leading dimension
|
||||
:rtype: int
|
||||
"""
|
||||
if layout == LayoutType.RowMajor:
|
||||
return self._col
|
||||
elif layout == LayoutType.ColumnMajor:
|
||||
return self._row
|
||||
else:
|
||||
raise Exception(f'Unsupported layout for leading dimension calculation: {layout}')
|
||||
|
||||
|
||||
class GemmCoord:
|
||||
def __init__(self, m: int, n: int, k: int):
|
||||
self._m = m
|
||||
self._n = n
|
||||
self._k = k
|
||||
|
||||
@property
|
||||
def m(self) -> int:
|
||||
return self._m
|
||||
|
||||
@property
|
||||
def n(self) -> int:
|
||||
return self._n
|
||||
|
||||
@property
|
||||
def k(self) -> int:
|
||||
return self._k
|
||||
|
||||
@property
|
||||
def mk(self) -> MatrixCoord:
|
||||
return MatrixCoord(self._m, self._k)
|
||||
|
||||
@property
|
||||
def mn(self) -> MatrixCoord:
|
||||
return MatrixCoord(self._m, self._n)
|
||||
|
||||
@property
|
||||
def kn(self) -> MatrixCoord:
|
||||
return MatrixCoord(self._k, self._n)
|
||||
|
||||
@property
|
||||
def ctype(self) -> GemmCoord_:
|
||||
return GemmCoord_(self._m, self._n, self._k)
|
||||
|
||||
def batched_ctype(self, batch_count: int) -> GemmCoordBatched_:
|
||||
return GemmCoordBatched_(self._m, self._n, self._k, batch_count)
|
||||
|
||||
|
||||
class Conv2DProblemSize:
|
||||
def __init__(
|
||||
self, n: int, h: int, w: int, c: int,
|
||||
k: int, r: int, s: int, c_: int,
|
||||
pad_h: int, pad_w: int, stride_h: int, stride_w: int,
|
||||
dilation_h: int, dilation_w: int, mode: ConvMode=ConvMode.CrossCorrelation,
|
||||
split_k_slices: int=1, groups: int=1):
|
||||
|
||||
self.N = n
|
||||
self.H = h
|
||||
self.W = w
|
||||
self.C = c
|
||||
self.K = k
|
||||
self.R = r
|
||||
self.S = s
|
||||
self.pad_h = pad_h
|
||||
self.pad_w = pad_w
|
||||
self.stride_h = stride_h
|
||||
self.stride_w = stride_w
|
||||
self.dilation_h = dilation_h
|
||||
self.dilation_w = dilation_w
|
||||
self.mode = int(mode)
|
||||
self.split_k_slices = split_k_slices
|
||||
self.groups = groups
|
||||
self.P = ((h + pad_h * 2 - r * dilation_h) // stride_h) + 1
|
||||
self.Q = ((w + pad_w * 2 - s * dilation_w) // stride_w) + 1
|
||||
|
||||
@property
|
||||
def ctype(self) -> Conv2DProblemSize_:
|
||||
return Conv2DProblemSize_(self)
|
||||
|
||||
def implicit_gemm_size(self, kind: ConvKind):
|
||||
if kind == ConvKind.Fprop:
|
||||
return GemmCoord(
|
||||
self.N * self.P * self.Q,
|
||||
self.K,
|
||||
self.R * self.S * self.C // self.groups
|
||||
)
|
||||
elif kind == ConvKind.Dgrad:
|
||||
return GemmCoord(
|
||||
self.N * self.H * self.W,
|
||||
self.C,
|
||||
self.R * self.S * self.K
|
||||
)
|
||||
elif kind == ConvKind.Wgrad:
|
||||
return GemmCoord(
|
||||
self.K,
|
||||
self.R * self.S * self.C,
|
||||
self.N * self.P * self.Q
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_sizes(input_size, weight_size):
|
||||
K, R, S, _ = weight_size
|
||||
pad_h = R // 2
|
||||
pad_w = S // 2
|
||||
stride_h = 1
|
||||
stride_w = 1
|
||||
dilation_h = 1
|
||||
dilation_w = 1
|
||||
return Conv2DProblemSize(
|
||||
*input_size,
|
||||
*weight_size,
|
||||
pad_h, pad_w,
|
||||
stride_h, stride_w,
|
||||
dilation_h, dilation_w
|
||||
)
|
||||
65
python/cutlass_cppgen/swizzle.py
Normal file
65
python/cutlass_cppgen/swizzle.py
Normal file
@ -0,0 +1,65 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Registry of swizzling functions
|
||||
"""
|
||||
|
||||
from cutlass_library import SwizzlingFunctor
|
||||
|
||||
|
||||
IdentitySwizzle1 = SwizzlingFunctor.Identity1
|
||||
IdentitySwizzle2 = SwizzlingFunctor.Identity2
|
||||
IdentitySwizzle4 = SwizzlingFunctor.Identity4
|
||||
IdentitySwizzle8 = SwizzlingFunctor.Identity8
|
||||
HorizontalSwizzle = SwizzlingFunctor.Horizontal
|
||||
ThreadblockSwizzleStreamK = SwizzlingFunctor.StreamK
|
||||
StridedDgradIdentitySwizzle1 = SwizzlingFunctor.StridedDgradIdentity1
|
||||
StridedDgradIdentitySwizzle4 = SwizzlingFunctor.StridedDgradIdentity4
|
||||
StridedDgradHorizontalSwizzle = SwizzlingFunctor.StridedDgradHorizontal
|
||||
|
||||
|
||||
_swizzling_functors = [
|
||||
IdentitySwizzle1,
|
||||
IdentitySwizzle2,
|
||||
IdentitySwizzle4,
|
||||
IdentitySwizzle8,
|
||||
HorizontalSwizzle,
|
||||
ThreadblockSwizzleStreamK,
|
||||
StridedDgradIdentitySwizzle1,
|
||||
StridedDgradIdentitySwizzle4,
|
||||
StridedDgradHorizontalSwizzle,
|
||||
]
|
||||
|
||||
|
||||
def get_swizzling_functors():
|
||||
return _swizzling_functors
|
||||
41
python/cutlass_cppgen/utils/__init__.py
Normal file
41
python/cutlass_cppgen/utils/__init__.py
Normal file
@ -0,0 +1,41 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass_cppgen.utils.check import (
|
||||
alignment_or_default,
|
||||
calculate_smem_usage,
|
||||
calculate_smem_usage_per_stage,
|
||||
valid_cluster_shape,
|
||||
valid_schedule,
|
||||
valid_stage_count,
|
||||
update_alignment,
|
||||
)
|
||||
262
python/cutlass_cppgen/utils/check.py
Normal file
262
python/cutlass_cppgen/utils/check.py
Normal file
@ -0,0 +1,262 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Utility functions for checking constraints on kernels and calculating kernel attributes
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
|
||||
from cutlass_library import DataTypeSize, KernelScheduleSuffixes, OperationKind, SharedMemPerCC
|
||||
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen.backend.library import TileDescription
|
||||
|
||||
|
||||
def calculate_smem_usage_per_stage(td: TileDescription, operation_kind: OperationKind) -> int:
|
||||
"""
|
||||
Returns the amount of shared memory in bytes consumed in a single stage of a kernel.
|
||||
|
||||
:param td: tile description to compute shared memory of
|
||||
:type td: TileDescription
|
||||
:param operation_kind: identifier for the type of operation being performed
|
||||
:type operation_kind: cutlass_library.OperationKind
|
||||
|
||||
:return: number of bytes of shared memory consumed by a single stage
|
||||
:rtype: int
|
||||
"""
|
||||
m, n, k = td.blackwell_threadblock_shape
|
||||
if td.is_2sm:
|
||||
m //= 2
|
||||
|
||||
if operation_kind == OperationKind.Gemm:
|
||||
stage_barrier_bytes = 32
|
||||
return (
|
||||
(DataTypeSize[td.math_instruction.element_a] * m * k // 8)
|
||||
+ (DataTypeSize[td.math_instruction.element_b] * k * n // 8)
|
||||
+ stage_barrier_bytes
|
||||
)
|
||||
else:
|
||||
raise Exception(f"No available shared memory calculation for operation kind {operation.operation_kind}")
|
||||
|
||||
|
||||
def calculate_smem_usage(operation) -> int:
|
||||
"""
|
||||
Returns the amount of shared memory in bytes consumed by a kernel.
|
||||
|
||||
:return: number of bytes of shared memory consumed by the operation
|
||||
:return: int
|
||||
"""
|
||||
_per_stage = calculate_smem_usage_per_stage(operation.tile_description, operation.operation_kind)
|
||||
return _per_stage * operation.tile_description.stages
|
||||
|
||||
|
||||
def valid_stage_count(
|
||||
cc: int,
|
||||
kernel_cc: int,
|
||||
td: TileDescription,
|
||||
element_C: cutlass_cppgen.DataType = None,
|
||||
element_D: cutlass_cppgen.DataType = None,
|
||||
verbose: bool = True) -> tuple:
|
||||
"""
|
||||
Checks whether a device with `cc` supports the number of stages within `tile_description`, both
|
||||
based on raw limits on the number of stages and based on shared memory capacity
|
||||
|
||||
:param cc: compute capability of device in question
|
||||
:type cc: int
|
||||
:param kernel_cc: compute capability that the kernel targets (corresponding to the arch::SMxy tag in CUTLASS)
|
||||
:type kernel_cc: int
|
||||
:param td: tile description to check
|
||||
:type td: TileDescription
|
||||
:param element_C: data type of operand C
|
||||
:type element_C: cutlass_cppgen.DataType
|
||||
:param element_D: data type of operand D
|
||||
:type element_D: cutlass_cppgen.DataType
|
||||
:param verbose: whether to log warnings
|
||||
:type verbose: bool
|
||||
|
||||
:return: tuple with the first element indicating whether the provided tile description is
|
||||
valid for the provided device and the second element being an error message
|
||||
:rtype: tuple
|
||||
"""
|
||||
if kernel_cc in [90, 100, 101, 103]:
|
||||
if (td.stages is None or td.stages == 0):
|
||||
# Stage count of None or 0 for SM90 indicates that the CollectiveBuilder automatically
|
||||
# determines the stage count to use. Thus, all settings are valid in these scenarios.
|
||||
return (True, "")
|
||||
elif verbose:
|
||||
cutlass_cppgen.logger.warning(
|
||||
"Setting an explicit stage count for SM90 kernels currently may "
|
||||
"result in compilation errors if the combination of tile shape, "
|
||||
"stage count, and shared memory requirement of the epilogue exceeds "
|
||||
"the available shared memory per SM.")
|
||||
|
||||
if td.stages <= 0:
|
||||
return (False, f"Stage counts must be positive integers. Tile description has stage count of {td.stages}.")
|
||||
|
||||
if cc < 80 and td.stages != 2:
|
||||
return (False, f"Tile description has stage count of {td.stages}, "
|
||||
f"but only 2 stages are supported on SM{cc}.")
|
||||
|
||||
# The calculation below does not consider shared memory used by the epilogue and, thus,
|
||||
# only catches cases in which the mainloop exceeds the device's shared memory capacity.
|
||||
# This is not a concern for CUTLASS 2.x kernels, for which the shared memory of the
|
||||
# mainloop and epilogue is shared.
|
||||
smem_per_stage = calculate_smem_usage_per_stage(td, OperationKind.Gemm)
|
||||
smem_usage_mainloop = (smem_per_stage * td.stages)
|
||||
smem_arch = SharedMemPerCC[cc] << 10
|
||||
if smem_usage_mainloop > smem_arch:
|
||||
return ( False,
|
||||
"Configuration uses too much shared memory. Consider reducing stage count or tile shape.\n"
|
||||
f"Details:\n"
|
||||
f"Mainloop uses {smem_per_stage} bytes of shared memory per stage, and "
|
||||
f"{td.stages} stages for a total of {smem_usage_mainloop} bytes.\n"
|
||||
f"The maxmium amount of shared memory that can be used per block on CC {cc} is {smem_arch}.")
|
||||
|
||||
return (True, "")
|
||||
|
||||
|
||||
def valid_cluster_shape(cc: int, cluster_shape: list) -> tuple:
|
||||
"""
|
||||
Checks whether a device with `cc` supports a thread block cluster of shape `cluster_shape`.
|
||||
|
||||
:param cc: compute capability of device in question
|
||||
:type cc: int
|
||||
:param cluster_shape: dimensions of thread block cluster shape to check
|
||||
:type cluster_shape: list
|
||||
|
||||
:return: tuple with the first element indicating whether the provided cluster shape is
|
||||
valid for the provided device and the second element being an error message
|
||||
:rtype: tuple
|
||||
"""
|
||||
|
||||
if cc < 90 or cc in [120, 121]:
|
||||
if cluster_shape != [1, 1, 1]:
|
||||
return (False,
|
||||
f"Cluster shape for pre-SM90 architectures and SM 120 and 121 must be [1, 1, 1]. Received cluster shape of "
|
||||
f"{cluster_shape} for SM{cc}.")
|
||||
else:
|
||||
return (True, "")
|
||||
|
||||
if len(cluster_shape) != 3:
|
||||
return (False,
|
||||
f"Cluster shapes must be rank-3. Received {cluster_shape} (rank {len(cluster_shape)}")
|
||||
|
||||
if cluster_shape[2] != 1:
|
||||
return (False,
|
||||
"CUTLASS kernels currently require the third dimension of cluster shape to be 1. "
|
||||
f"Received cluster shape of {cluster_shape}.")
|
||||
|
||||
return (True, "")
|
||||
|
||||
|
||||
def valid_schedule(
|
||||
cc: int,
|
||||
kernel_schedule: cutlass_cppgen.KernelScheduleType,
|
||||
epilogue_schedule: cutlass_cppgen.EpilogueScheduleType,
|
||||
tile_scheduler: cutlass_cppgen.TileSchedulerType) -> tuple:
|
||||
"""
|
||||
Checks that the kernel and epilogue schedules passed in are a valid combination for
|
||||
a device of compute capability ``cc``.
|
||||
|
||||
:param cc: compute capability of device in question
|
||||
:type cc: int
|
||||
:param kernel_schedule: kernel schedule type
|
||||
:type kernel_schedule: cutlass_cppgen.KernelScheduleType
|
||||
:param epilogue_schedule: epilogue schedule type
|
||||
:type epilogue_schedule: cutlass_cppgen.EpilogueScheduleType
|
||||
:param tile_scheduler: tile scheduler type
|
||||
:type tile_scheduler: cutlass_cppgen.TileSchedulerType
|
||||
|
||||
:return: tuple with the first element indicating whether the provided schedules are
|
||||
valid for the provided device and the second element being an error message
|
||||
:rtype: tuple
|
||||
"""
|
||||
kernel_auto = (kernel_schedule == cutlass_cppgen.KernelScheduleType.ScheduleAuto)
|
||||
epilogue_auto = (epilogue_schedule == cutlass_cppgen.EpilogueScheduleType.ScheduleAuto)
|
||||
tile_scheduler_default = (tile_scheduler == cutlass_cppgen.TileSchedulerType.Default)
|
||||
if (cc < 90 or cc in [120, 121]) and not (kernel_auto and epilogue_auto and tile_scheduler_default):
|
||||
return (False, "Non-default schedules are only supported on SM90 and beyond (excluding SM120 and SM121)")
|
||||
|
||||
if cc == 90 and ((kernel_auto and not epilogue_auto) or (not kernel_auto and epilogue_auto)):
|
||||
return (False, "Kernel and epilogue schedules must either both be auto or neither be auto")
|
||||
|
||||
if not tile_scheduler_default:
|
||||
cooperative_kernels = [cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
cutlass_cppgen.KernelScheduleType.CpAsyncWarpSpecializedCooperative]
|
||||
if cc == 90 and (tile_scheduler == cutlass_cppgen.TileSchedulerType.StreamK) and (kernel_schedule not in cooperative_kernels):
|
||||
return (False, "Stream-K tile scheduler is currently only supported with the cooperative kernel schedule")
|
||||
return (True, "")
|
||||
|
||||
|
||||
def alignment_or_default(alignment_provided: int, default_alignment: int) -> int:
|
||||
"""
|
||||
Returns `alignment_provided` if it is set, otherwise `default_alignment` and checks
|
||||
that `alignment_provided` does not exceed `default_alignment`.
|
||||
|
||||
:param alignment_provided: alignment preference specified. Can be None.
|
||||
:type alignment_provided: int
|
||||
:param default_alignment: alignment to use if `alignment_provided` is None
|
||||
:type default_alignment: int
|
||||
|
||||
:return: alignment to use
|
||||
:rtype: int
|
||||
"""
|
||||
if alignment_provided is not None:
|
||||
if alignment_provided > default_alignment:
|
||||
raise Exception(f"Alignment {alignment_provided} exceeds the maximum supported of {default_alignment}.")
|
||||
return alignment_provided
|
||||
|
||||
return default_alignment
|
||||
|
||||
|
||||
def update_alignment(alignment_provided:int, default_alignment: int) -> int:
|
||||
"""
|
||||
Returns `alignment_provided` if it is set, otherwise `default_alignment` and checks
|
||||
that `alignment_provided` does not exceed `default_alignment`.
|
||||
|
||||
:param alignment_provided: alignment preference specified. Can be None.
|
||||
:type alignment_provided: int
|
||||
:param default_alignment: alignment to use if `alignment_provided` is None
|
||||
:type default_alignment: int
|
||||
|
||||
:return: alignment to use
|
||||
:rtype: int
|
||||
"""
|
||||
if alignment_provided is not None:
|
||||
if alignment_provided > default_alignment:
|
||||
if alignment_provided % default_alignment == 0:
|
||||
return default_alignment
|
||||
raise Exception(f"Alignment {alignment_provided} exceeds the maximum supported of {default_alignment}.")
|
||||
return alignment_provided
|
||||
|
||||
return default_alignment
|
||||
362
python/cutlass_cppgen/utils/datatypes.py
Normal file
362
python/cutlass_cppgen/utils/datatypes.py
Normal file
@ -0,0 +1,362 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Utility functions for converting between frontend datatypes and CUTLASS datatypes
|
||||
"""
|
||||
|
||||
import cutlass_cppgen
|
||||
from cutlass_library import (
|
||||
DataTypeSize,
|
||||
MathOperation,
|
||||
MathInstruction
|
||||
)
|
||||
from cutlass_cppgen.backend.library import (
|
||||
TileDescription,
|
||||
)
|
||||
|
||||
bfloat16_available = None
|
||||
cupy_available = None
|
||||
numpy_available = None
|
||||
torch_available = None
|
||||
_library_to_cupy_dict = None
|
||||
_library_to_numpy_dict = None
|
||||
_library_to_torch_dict = None
|
||||
_torch_to_library_dict = None
|
||||
|
||||
|
||||
def is_numpy_available():
|
||||
global numpy_available, _library_to_numpy_dict
|
||||
if numpy_available is None:
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
numpy_available = True
|
||||
_library_to_numpy_dict = {
|
||||
cutlass_cppgen.DataType.f16: np.float16,
|
||||
cutlass_cppgen.DataType.f32: np.float32,
|
||||
cutlass_cppgen.DataType.f64: np.float64,
|
||||
cutlass_cppgen.DataType.s8: np.int8,
|
||||
cutlass_cppgen.DataType.s32: np.int32,
|
||||
}
|
||||
except ImportError:
|
||||
numpy_available = False
|
||||
_library_to_numpy_dict = {}
|
||||
return numpy_available
|
||||
|
||||
|
||||
def is_numpy_tensor(inp) -> bool:
|
||||
if is_numpy_available():
|
||||
import numpy as np
|
||||
return isinstance(inp, np.ndarray)
|
||||
return False
|
||||
|
||||
|
||||
def numpy_library_type(inp) -> cutlass_cppgen.DataType:
|
||||
if is_numpy_available():
|
||||
import numpy as np
|
||||
if inp == np.float16:
|
||||
return cutlass_cppgen.DataType.f16
|
||||
elif inp == np.float32:
|
||||
return cutlass_cppgen.DataType.f32
|
||||
elif inp == np.float64:
|
||||
return cutlass_cppgen.DataType.f64
|
||||
elif inp == np.int8:
|
||||
return cutlass_cppgen.DataType.s8
|
||||
elif inp == np.int32:
|
||||
return cutlass_cppgen.DataType.s32
|
||||
return None
|
||||
|
||||
|
||||
def numpy_type(inp):
|
||||
return _library_to_numpy_dict.get(inp, None)
|
||||
|
||||
|
||||
def is_cupy_available():
|
||||
global cupy_available
|
||||
if cupy_available is None:
|
||||
try:
|
||||
import cupy as cp
|
||||
|
||||
cupy_available = True
|
||||
_library_to_cupy_dict = {
|
||||
cutlass_cppgen.DataType.f16: cp.float16,
|
||||
cutlass_cppgen.DataType.f32: cp.float32,
|
||||
cutlass_cppgen.DataType.f64: cp.float64,
|
||||
cutlass_cppgen.DataType.s8: cp.int8,
|
||||
cutlass_cppgen.DataType.s32: cp.int32,
|
||||
}
|
||||
except ImportError:
|
||||
cupy_available = False
|
||||
_library_to_cupy_dict = {}
|
||||
return cupy_available
|
||||
|
||||
|
||||
def is_cupy_tensor(inp) -> bool:
|
||||
if is_cupy_available():
|
||||
import cupy as cp
|
||||
return isinstance(inp, cp.ndarray)
|
||||
return False
|
||||
|
||||
|
||||
def cupy_library_type(inp) -> cutlass_cppgen.DataType:
|
||||
if is_cupy_available():
|
||||
import cupy as cp
|
||||
if inp == cp.float16:
|
||||
return cutlass_cppgen.DataType.f16
|
||||
elif inp == cp.float32:
|
||||
return cutlass_cppgen.DataType.f32
|
||||
elif inp == cp.float64:
|
||||
return cutlass_cppgen.DataType.f64
|
||||
return None
|
||||
|
||||
|
||||
def cupy_type(inp):
|
||||
return _library_to_cupy_dict.get(inp, None)
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
global torch_available, _library_to_torch_dict, _torch_to_library_dict
|
||||
if torch_available is None:
|
||||
try:
|
||||
import torch
|
||||
|
||||
torch_available = True
|
||||
_torch_to_library_dict = {
|
||||
torch.half: cutlass_cppgen.DataType.f16,
|
||||
torch.float16: cutlass_cppgen.DataType.f16,
|
||||
torch.bfloat16: cutlass_cppgen.DataType.bf16,
|
||||
torch.float: cutlass_cppgen.DataType.f32,
|
||||
torch.float32: cutlass_cppgen.DataType.f32,
|
||||
torch.double: cutlass_cppgen.DataType.f64,
|
||||
torch.float64: cutlass_cppgen.DataType.f64,
|
||||
torch.int8: cutlass_cppgen.DataType.s8,
|
||||
torch.int32: cutlass_cppgen.DataType.s32,
|
||||
torch.uint8: cutlass_cppgen.DataType.u8,
|
||||
}
|
||||
|
||||
_library_to_torch_dict = {
|
||||
cutlass_cppgen.DataType.f16: torch.half,
|
||||
cutlass_cppgen.DataType.f16: torch.float16,
|
||||
cutlass_cppgen.DataType.bf16: torch.bfloat16,
|
||||
cutlass_cppgen.DataType.f32: torch.float,
|
||||
cutlass_cppgen.DataType.f32: torch.float32,
|
||||
cutlass_cppgen.DataType.f64: torch.double,
|
||||
cutlass_cppgen.DataType.f64: torch.float64,
|
||||
cutlass_cppgen.DataType.s8: torch.int8,
|
||||
cutlass_cppgen.DataType.s32: torch.int32,
|
||||
cutlass_cppgen.DataType.u8: torch.uint8,
|
||||
}
|
||||
|
||||
def possibly_add_type(torch_type_name, cutlass_type):
|
||||
# Only try adding the type if the version of torch being used supports it
|
||||
if hasattr(torch, torch_type_name):
|
||||
torch_type = getattr(torch, torch_type_name)
|
||||
_torch_to_library_dict[torch_type] = cutlass_type
|
||||
_library_to_torch_dict[cutlass_type] = torch_type
|
||||
|
||||
possibly_add_type("float8_e4m3fn", cutlass_cppgen.DataType.e4m3)
|
||||
possibly_add_type("float8_e5m2", cutlass_cppgen.DataType.e5m2)
|
||||
|
||||
except ImportError:
|
||||
torch_available = False
|
||||
_torch_to_library_dict = {}
|
||||
_library_to_torch_dict = {}
|
||||
return torch_available
|
||||
|
||||
|
||||
def is_torch_tensor(inp) -> bool:
|
||||
if is_torch_available():
|
||||
import torch
|
||||
return isinstance(inp, torch.Tensor)
|
||||
return False
|
||||
|
||||
|
||||
def torch_library_type(inp) -> cutlass_cppgen.DataType:
|
||||
return _torch_to_library_dict.get(inp, None)
|
||||
|
||||
|
||||
def torch_type(inp):
|
||||
return _library_to_torch_dict.get(inp, None)
|
||||
|
||||
|
||||
def is_bfloat16_available():
|
||||
global bfloat16_available
|
||||
|
||||
if bfloat16_available is None:
|
||||
try:
|
||||
import bfloat16
|
||||
|
||||
bfloat16_available = True
|
||||
except ImportError:
|
||||
bfloat16_available = False
|
||||
return bfloat16_available
|
||||
|
||||
|
||||
def bfloat16_library_type(inp) -> cutlass_cppgen.DataType:
|
||||
if is_bfloat16_available():
|
||||
import bfloat16
|
||||
if inp == bfloat16.bfloat16:
|
||||
return cutlass_cppgen.DataType.bf16
|
||||
|
||||
|
||||
def bfloat16_type(inp):
|
||||
if is_bfloat16_available():
|
||||
import bfloat16
|
||||
if inp == cutlass_cppgen.DataType.bf16:
|
||||
return bfloat16.bfloat16
|
||||
|
||||
|
||||
def library_type(inp):
|
||||
if inp in DataTypeSize:
|
||||
return inp
|
||||
|
||||
for cvt_fn in [
|
||||
bfloat16_library_type,
|
||||
cupy_library_type,
|
||||
numpy_library_type,
|
||||
torch_library_type,
|
||||
]:
|
||||
out = cvt_fn(inp)
|
||||
if out is not None:
|
||||
return out
|
||||
|
||||
raise Exception(f"No available conversion from type {inp} to a library type.")
|
||||
|
||||
|
||||
def _tensor_from_numpy(np_tensor):
|
||||
dtype = library_type(np_tensor.dtype)
|
||||
if np_tensor.flags.c_contiguous:
|
||||
layout = cutlass_cppgen.LayoutType.RowMajor
|
||||
elif np_tensor.flags.f_contiguous:
|
||||
layout = cutlass_cppgen.LayoutType.ColumnMajor
|
||||
return (dtype, layout)
|
||||
|
||||
|
||||
def _tensor_from_torch(pt_tensor):
|
||||
dtype = library_type(pt_tensor.dtype)
|
||||
return (dtype, cutlass_cppgen.LayoutType.RowMajor)
|
||||
|
||||
|
||||
def get_datatype_and_layout(tensor):
|
||||
if (is_numpy_tensor(tensor) or is_cupy_tensor(tensor)):
|
||||
return _tensor_from_numpy(tensor)
|
||||
elif is_torch_tensor(tensor):
|
||||
return _tensor_from_torch(tensor)
|
||||
elif isinstance(tensor, float) or isinstance(tensor, int):
|
||||
return (cutlass_cppgen.DataType.f32, cutlass_cppgen.LayoutType.RowMajor)
|
||||
else:
|
||||
raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.")
|
||||
|
||||
|
||||
def get_tensor_shape(tensor, op="GEMM"):
|
||||
if (is_numpy_tensor(tensor) or is_cupy_tensor(tensor)):
|
||||
return tensor.shape
|
||||
elif is_torch_tensor(tensor):
|
||||
size = tensor.size()
|
||||
if op == "CONV":
|
||||
# PyTorch Tensors have shape NCHW
|
||||
return (size[0], size[2], size[3], size[1])
|
||||
else:
|
||||
return tuple(tensor.size())
|
||||
elif isinstance(tensor, float) or isinstance(tensor, int):
|
||||
return (1,)
|
||||
else:
|
||||
raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.")
|
||||
|
||||
|
||||
_math_operation_value_map = {x.value: x for x in MathOperation}
|
||||
|
||||
|
||||
def backend_math_operation(math_op: MathOperation):
|
||||
if math_op.value not in _math_operation_value_map.keys():
|
||||
raise Exception(f"Unable to convert math operation of type {math_op} to backend math operation.")
|
||||
return _math_operation_value_map[math_op.value]
|
||||
|
||||
|
||||
def construct_backend_td(td: cutlass_cppgen.TileDescription,
|
||||
kernel_schedule: cutlass_cppgen.KernelScheduleType,
|
||||
epilogue_schedule: cutlass_cppgen.EpilogueScheduleType,
|
||||
tile_scheduler: cutlass_cppgen.TileSchedulerType) -> TileDescription:
|
||||
mi = td.math_instruction
|
||||
backend_mi = MathInstruction(
|
||||
mi.instruction_shape,
|
||||
mi.element_a,
|
||||
mi.element_b,
|
||||
mi.element_accumulator,
|
||||
mi.opcode_class,
|
||||
backend_math_operation(mi.math_operation)
|
||||
)
|
||||
cluster_shape = td.cluster_shape if hasattr(td, "cluster_shape") else [1, 1, 1]
|
||||
return TileDescription(td.threadblock_shape, td.stages, td.warp_count,
|
||||
backend_mi, cluster_shape, kernel_schedule, epilogue_schedule, tile_scheduler)
|
||||
|
||||
|
||||
def td_from_profiler_op(op) -> TileDescription:
|
||||
"""
|
||||
Converts the profiler's TileDescription in ``op`` into the backend TileDescription
|
||||
|
||||
:param op: profiler Operation
|
||||
|
||||
:returns: backend TileDescription
|
||||
:rtype: cutlass_cppgen.backend.TileDescription
|
||||
"""
|
||||
kschedule = op.kernel_schedule if hasattr(op, 'kernel_schedule') else None
|
||||
eschedule = op.epilogue_schedule if hasattr(op, 'epilogue_schedule') else None
|
||||
tschedule = op.tile_scheduler if hasattr(op, 'tile_scheduler') else None
|
||||
return construct_backend_td(op.tile_description, kschedule, eschedule, tschedule)
|
||||
|
||||
|
||||
def td_from_profiler_td(td: TileDescription) -> TileDescription:
|
||||
"""
|
||||
Converts the profiler's TileDescription into the backend TileDescription
|
||||
|
||||
:param td: profiler TileDescription
|
||||
:type td: cutlass_cppgen.TileDescription
|
||||
|
||||
:returns: backend TileDescription
|
||||
:rtype: cutlass_cppgen.backend.TileDescription
|
||||
"""
|
||||
return construct_backend_td(td, kernel_schedule=None, epilogue_schedule=None, tile_scheduler=None)
|
||||
|
||||
|
||||
def to_camel_case(snake_str):
|
||||
return "".join(x.capitalize() for x in snake_str.lower().split("_"))
|
||||
|
||||
|
||||
def getattr_enum(obj, attr_name):
|
||||
# The attr_name is under the snake_case
|
||||
camel_attr = to_camel_case(attr_name)
|
||||
if hasattr(obj, camel_attr):
|
||||
return getattr(obj, camel_attr)
|
||||
else:
|
||||
raise Exception(f"Invalid option: {attr_name}")
|
||||
41
python/cutlass_cppgen/utils/lazy_import.py
Normal file
41
python/cutlass_cppgen/utils/lazy_import.py
Normal file
@ -0,0 +1,41 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
import importlib
|
||||
from typing import Any
|
||||
|
||||
def lazy_import(mod_name: str) -> Any:
|
||||
class Lazy:
|
||||
def __getattr__(self, name:str) -> Any:
|
||||
module = importlib.import_module(mod_name)
|
||||
return getattr(module, name)
|
||||
|
||||
return Lazy()
|
||||
196
python/cutlass_cppgen/utils/profiler.py
Normal file
196
python/cutlass_cppgen/utils/profiler.py
Normal file
@ -0,0 +1,196 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Profiler based on the cuda events
|
||||
"""
|
||||
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
cudart = lazy_import("cuda.cudart")
|
||||
import numpy as np
|
||||
|
||||
from cutlass_cppgen import CUTLASS_PATH
|
||||
from cutlass_cppgen.backend.library import DataTypeSize
|
||||
from cutlass_cppgen.op.op import OperationBase
|
||||
from cutlass_cppgen.shape import GemmCoord
|
||||
from cutlass_cppgen.utils.datatypes import is_numpy_tensor
|
||||
|
||||
|
||||
class GpuTimer:
|
||||
def __init__(self) -> None:
|
||||
self.events = [
|
||||
cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1],
|
||||
cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1],
|
||||
]
|
||||
|
||||
def start(self, stream=None):
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
|
||||
(err,) = cuda.cuEventRecord(self.events[0], stream)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError(f"CUDA Error {str(err)}")
|
||||
|
||||
def stop(self, stream=None):
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
|
||||
(err,) = cuda.cuEventRecord(self.events[1], stream)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError(f"CUDA Error {str(err)}")
|
||||
pass
|
||||
|
||||
def stop_and_wait(self, stream=None):
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
|
||||
self.stop(stream)
|
||||
if stream:
|
||||
(err,) = cuda.cuStreamSynchronize(stream)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError(f"CUDA Error {str(err)}")
|
||||
else:
|
||||
(err,) = cudart.cudaDeviceSynchronize()
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError(f"CUDA Error {str(err)}")
|
||||
|
||||
def duration(self, iterations=1):
|
||||
err, duration = cuda.cuEventElapsedTime(self.events[0], self.events[1])
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError(f"CUDA Error {str(err)}")
|
||||
return duration / float(iterations)
|
||||
|
||||
|
||||
class CUDAEventProfiler:
|
||||
def __init__(self, op: OperationBase, warmup_iterations: int=500, iterations: int=500, *args, **kwargs) -> None:
|
||||
self.arguments = op.run(*args, **kwargs)
|
||||
self.operation = op.operation
|
||||
self.warmup_iterations = warmup_iterations
|
||||
self.iterations = iterations
|
||||
self.timer = GpuTimer()
|
||||
|
||||
#
|
||||
# Cutlass Python Interface Profiler
|
||||
#
|
||||
|
||||
def __call__(self):
|
||||
for _ in range(self.warmup_iterations):
|
||||
self.operation.run(self.arguments)
|
||||
|
||||
self.timer.start()
|
||||
for _ in range(self.iterations):
|
||||
self.operation.run(self.arguments)
|
||||
|
||||
self.timer.stop_and_wait()
|
||||
runtime = self.timer.duration(self.iterations)
|
||||
return runtime
|
||||
|
||||
#
|
||||
# CUTLASS Profiler
|
||||
#
|
||||
|
||||
def run_cutlass_profiler(self):
|
||||
alpha = 1.0
|
||||
beta = 1.0
|
||||
|
||||
profiler_path = CUTLASS_PATH + "/build/tools/profiler/cutlass_profiler"
|
||||
kernel_name = self.operation.procedural_name()
|
||||
verification_providers = "device"
|
||||
provider = "cutlass"
|
||||
problem_size = self.arguments.problem_size
|
||||
|
||||
if "cutlass3x" in kernel_name:
|
||||
# cutlass3x generator only have column-major output
|
||||
layout_name = self.operation.layout_name_3x()
|
||||
if layout_name[-1] == "t":
|
||||
new_layout_name = "".join(["n" for l in layout_name if l == "t" or "t"])
|
||||
problem_size = GemmCoord(problem_size.n, problem_size.m, problem_size.k)
|
||||
kernel_name = kernel_name.replace(layout_name, new_layout_name)
|
||||
|
||||
batch_count = self.arguments.batch_count
|
||||
|
||||
cmd = f"{profiler_path} --kernels={kernel_name} --verification-providers={verification_providers} " \
|
||||
f"--providers={provider} --m={problem_size.m()} --n={problem_size.n()} --k={problem_size.k()} " \
|
||||
f"--batch_count={batch_count} --alpha={alpha} --beta={beta} "\
|
||||
f"--warmup-iterations={self.warmup_iterations} --profiling-iterations={self.iterations}"
|
||||
|
||||
result = subprocess.getoutput(cmd)
|
||||
|
||||
m = re.search(r"Runtime:\s+(?P<runtime>\d+.\d+)", result)
|
||||
runtime = float(m.group("runtime"))
|
||||
|
||||
m = re.search(r"Bytes:\s+(?P<bytes>\d+)", result)
|
||||
bytes = int(m.group("bytes"))
|
||||
|
||||
m = re.search(r"FLOPs:\s+(?P<flops>\d+)", result)
|
||||
flops = int(m.group("flops"))
|
||||
|
||||
# check if the problem size matches
|
||||
assert bytes == self.bytes(problem_size, batch_count, beta)
|
||||
assert flops == self.flops(problem_size, batch_count, beta)
|
||||
|
||||
return runtime
|
||||
|
||||
def bytes(self, problem_size, batch_count=1, beta=0.0):
|
||||
m = problem_size.m()
|
||||
n = problem_size.n()
|
||||
k = problem_size.k()
|
||||
|
||||
bytes = (
|
||||
(DataTypeSize[self.operation.A.element] * m // 8) * k
|
||||
+ (DataTypeSize[self.operation.B.element] * n // 8) * k
|
||||
+ (DataTypeSize[self.operation.C.element] * m // 8) * n
|
||||
)
|
||||
|
||||
if beta != 0:
|
||||
bytes += (DataTypeSize[self.operation.C.element] * m // 8) * n
|
||||
|
||||
bytes *= batch_count
|
||||
|
||||
return bytes
|
||||
|
||||
def flops(self, problem_size, batch_count=1, beta=0.0):
|
||||
m = problem_size.m()
|
||||
n = problem_size.n()
|
||||
k = problem_size.k()
|
||||
|
||||
flops_ = (m * n * k) * 2 * batch_count
|
||||
|
||||
if beta != 0:
|
||||
flops_ += m * n * batch_count * 2
|
||||
|
||||
return flops_
|
||||
|
||||
Reference in New Issue
Block a user