CUTLASS 3.1 (#915)

Co-authored-by: Aniket Shivam <ashivam@nvidia.com>
This commit is contained in:
ANIKET SHIVAM
2023-04-14 20:19:34 -07:00
committed by GitHub
parent 9b8166e3f0
commit d572cc1aab
482 changed files with 37184 additions and 16419 deletions

180
python/README.md Normal file
View File

@ -0,0 +1,180 @@
![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")
# CUTLASS Python Interface
The CUTLASS Python interface enables one to compile and run CUTLASS operations from within Python.
```python
import cutlass
import numpy as np
plan = cutlass.op.Gemm(element=np.float16, layout=cutlass.LayoutType.RowMajor)
A, B, C, D = [np.ones((4096, 4096), dtype=np.float16) for i in range(4)]
plan.run(A, B, C, D)
```
**NOTE** The CUTLASS Python interface is currently an experimental release. The API may change in the future.
We welcome feedback from the community.
## Overview
The CUTLASS Python interface aims to provide an ease-of-use interface for using CUTLASS via Python. Toward this goal,
the CUTLASS Python interface attempts to:
* Present high-level interfaces for operators that require only few parameters
* Select sensible default configurations for an operator given the parameters that have been specified
* Enumerate configurations for users that are known to work in a given setting
* Reduce the occurrence of C++ compile-time errors in favor of descriptive Python exceptions
* Make it easy to export CUTLASS kernels to framework extensions (e.g., PyTorch CUDA extensions)
### Non-goals
The CUTLASS Python interface does not intended to:
**Select optimal kernel configurations.**
As an ease-of-use interface, the default selections for operator parameters made by the CUTLASS Python interface may
not achieve the highest possible performance in all scenarios. Users wishing to achieve the highest performance possible
should consider profile different combinations of configuration parameters, or use a library such as [cuBLAS](https://developer.nvidia.com/cublas)
that contains heuristics for selecting kernels.
**Act as a fast container for CUTLASS kernels.**
The CUTLASS Python interface does not strive to minimize overhead in its Python functions surrounding the running of a kernel.
Those wishing to deploy a CUTLASS kernel should consider either using the C++ emitted by the Python interface directly, or using
one of the CUTLASS emitters for automatically creating a framework extension for the kernel (e.g., a PyTorch CUDA extension).
**Act as a Python-to-CUDA-kernel JIT compilation engine.**
The CUTLASS Python interface intends to enable one to use CUTLASS via Python. It can be used by frameworks for JIT compiling
Python to CUDA kernels, but does not set out to be such a framework.
### Comparison to PyCUTLASS
The CUTLASS Python interface builds atop CUTLASS's [PyCUTLASS](https://github.com/NVIDIA/cutlass/tree/v3.0.0/tools/library/scripts/pycutlass) library. PyCUTLASS enables
one to declare, compile, and run GEMMs, convolutions, and grouped GEMM operators with nearly the same configuration
space as CUTLASS's C++ interface. While this flexibility enables one to achieve the similar levels of functionality
as available in CUTLASS's C++ interface, it comes with the burden of needing to specify many configuration parameters
to operators -- similar to what one must do in specifying template parameters to operations in CUTLASS's C++ interface.
In contrast, the CUTLASS Python interface aims to provide a higher-level API for declaring, emitting, and compiling
kernels that does not require exhaustively defining template parameters.
#### Transitioning from PyCUTLASS
At present, existing PyCUTLASS functionality remains available via the CUTLASS Python interface. One can
continue to use PyCUTLASS by replacing references to the PyCUTLASS `cutlass` module with `cutlass_bindings`
and the PyCUTLASS `pycutlass` module with `cutlass.backend`.
For example, the following code using PyCUTLASS:
```python
import pycutlass
import cutlass
math_inst = pycutlass.MathInstruction(
[1, 1, 1], cutlass.float32, cutlass.float32, cutlass.float32,
cutlass.OpClass.Simt, pycutlass.MathOperation.multiply_add
)
```
can work with the Python interface via:
```python
import cutlass.backend as pycutlass
import cutlass_bindings
math_inst = pycutlass.MathInstruction(
[1, 1, 1], cutlass_bindings.float32, cutlass_bindings.float32, cutlass_bindings.float32,
cutlass_bindings.OpClass.Simt, pycutlass.MathOperation.multiply_add
)
```
**NOTE:** backwards compatibility of `cutlass.backend` with `pycutlass` will not be maintained moving forward.
## Current functionality
The CUTLASS Python interface currently supports the following operations:
* GEMMs
* GEMMs with fused elementwise epilogues (e.g., ReLU) (for pre-SM90 kernels)
* Stream K swizzling (for pre-SM90 kernels)
* Grouped GEMM (for pre-SM90 kernels)
## Getting started
We recommend using the CUTLASS Python interface via one of the Docker images located in the [docker](/python/docker) directory.
```bash
docker build -t cutlass-cuda12.0:latest -f docker/Dockerfile-cuda12.0-pytorch .
docker run --gpus all -it --rm cutlass-cuda12.0:latest
```
The CUTLASS Python interface has been tested with CUDA 11.8 and CUDA 12.0 on Python 3.8.10 and 3.9.7.
### Optional environment variables
Prior to installing the CUTLASS Python interface, one may optionally set the following environment variables:
* `CUTLASS_PATH`: the path to the cloned CUTLASS repository
* `CUDA_INSTALL_PATH`: the path to the installation of CUDA
If these environment variables are not set, the installation process will infer them to be the following:
* `CUTLASS_PATH`: one directory level above the current directory (i.e., `$(pwd)/..`)
* `CUDA_INSTALL_PATH`: the directory holding `/bin/nvcc` for the first version of `nvcc` on `$PATH` (i.e., `which nvcc | awk -F'/bin/nvcc' '{print $1}'`)
**NOTE:** The version of `cuda-python` installed must match the CUDA version in `CUDA_INSTALL_PATH`.
### Installation
The CUTLASS Python interface can currently be installed via:
```bash
python setup.py develop --user
```
This will allow changes to the Python interface source to be reflected when using the Python interface.
We plan to add support for installing via `python setup.py install` in a future release.
## Examples
Jupyter notebook examples of using the CUTLASS Python interface are located in [examples/python](/examples/python).
To launch these notebooks from this directory, run:
```bash
jupyter-lab ../examples/python
```
## Building documentation
The CUTLASS Python interface uses [Sphinx](https://www.sphinx-doc.org/en/master/) for documentation.
Building the documentation requires additional packages. These can be installed via:
```bash
sudo apt-get install pandoc
pip install --upgrade Sphinx furo pandoc myst-parser sphinx-copybutton nbsphinx nbsphinx-link sphinx-inline-tabs
```
To build documentation, you must first have installed the CUTLASS Python interface via the
[installation instructions](#installation).
Documentation can then be built via the following commands:
```bash
sphinx-apidoc -o docs_src/source/ cutlass/ cutlass/backend*
cd docs_src
make html
mv _build/* ../docs
```
# Copyright
Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
```
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

117
python/cutlass/__init__.py Normal file
View File

@ -0,0 +1,117 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import logging
import os
import sys
def _cutlass_path_from_dir() -> str:
cutlass_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../')
if not os.path.isdir(cutlass_path):
raise Exception(f'Environment variable "CUTLASS_PATH" is not defined, '
f'and default path of {cutlass_path} does not exist.')
return cutlass_path
def _cuda_install_path_from_nvcc() -> str:
import subprocess
# Attempt to detect CUDA_INSTALL_PATH based on location of NVCC
result = subprocess.run(['which', 'nvcc'], capture_output=True)
if result.returncode != 0:
raise Exception(f'Unable to find nvcc via `which` utility.')
cuda_install_path = result.stdout.decode('utf-8').split('/bin/nvcc')[0]
if not os.path.isdir(cuda_install_path):
raise Exception(f'Environment variable "CUDA_INSTALL_PATH" is not defined, '
f'and default path of {cuda_install_path} does not exist.')
return cuda_install_path
CUTLASS_PATH = os.getenv("CUTLASS_PATH", _cutlass_path_from_dir())
CUDA_INSTALL_PATH = os.getenv("CUDA_INSTALL_PATH", _cuda_install_path_from_nvcc())
CACHE_FILE = "compiled_cache.db"
# Add the path to the CUTLASS profiler generation/manifest scripts to PYTHONPATH
sys.path.insert(0, os.path.join(CUTLASS_PATH, "tools/library/scripts/"))
# Import types/methods from the CUTLASS utility libraries for profiler generation/emission under
from library import (
ArchitectureNames,
DataType,
DataTypeSize,
EpilogueFunctor,
GemmKind,
LayoutTag,
LayoutType,
KernelScheduleSuffixes,
KernelScheduleType,
KernelScheduleTag,
MathInstruction,
MathOperation,
OpcodeClass,
OperationKind,
SharedMemPerCC,
SwizzlingFunctor,
TensorDescription,
TileDescription,
)
this = sys.modules[__name__]
this.logger = logging.getLogger(__name__)
def set_log_level(level: int):
"""
Sets the log level
:param log_level: severity of logging level to use. See https://docs.python.org/3/library/logging.html#logging-levels for options
:type log_level: int
"""
this.logger.setLevel(level)
set_log_level(logging.ERROR)
from cutlass.library_defaults import OptionRegistry
from cutlass.backend.utils.device import device_cc
this.option_registry = OptionRegistry(device_cc())
this.__version__ = '3.1.0'
from cutlass.backend import get_memory_pool
from cutlass.emit.pytorch import pytorch
from cutlass.op.gemm import Gemm
from cutlass.op.gemm_grouped import GroupedGemm
from cutlass.op.op import OperationBase
get_memory_pool(init_pool_size=2 ** 30, max_pool_size=2 ** 32)

View File

@ -0,0 +1,27 @@
# module-wide variables
import os
from cutlass.backend.arguments import *
from cutlass.backend.c_types import *
from cutlass.backend.compiler import ArtifactManager
from cutlass.backend.conv2d_operation import *
from cutlass.backend.epilogue import *
from cutlass.backend.frontend import *
from cutlass.backend.gemm_operation import *
from cutlass.backend.library import *
from cutlass.backend.memory_manager import PoolMemoryManager
from cutlass.backend.operation import *
from cutlass.backend.parser import *
from cutlass.backend.reduction_operation import *
from cutlass.backend.tensor_ref import *
from cutlass.backend.type_hint import *
from cutlass.backend.utils import *
from cutlass.backend.utils.device import device_cc
from cutlass.backend.utils.software import (
CheckPackages,
SubstituteTemplate,
device_sm_count,
get_memory_pool,
)
compiler = ArtifactManager()

View File

@ -0,0 +1,119 @@
#################################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from typing import Union
from cuda import cuda, cudart
import numpy as np
from cutlass.backend.frontend import CupyFrontend, NumpyFrontend, TorchFrontend
from cutlass.backend.utils.software import CheckPackages
torch_available = CheckPackages().check_torch()
if torch_available:
import torch
cupy_available = CheckPackages().check_cupy()
if cupy_available:
import cupy as cp
class ArgumentBase:
"""
Base class for operation arguments
"""
def __init__(
self,
A: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]",
B: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]",
C: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]",
D: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]",
**kwargs,
) -> None:
# tensor_C can be interpreted as the bias with bias=True in keyword args
if "bias" in kwargs.keys():
self.bias = kwargs["bias"]
else:
# by default, tensor_C is not bias
self.bias = False
# preprocessing input tensors
if isinstance(A, np.ndarray):
self.host_D = D
self.buffer_A = NumpyFrontend.argument(A, False)
self.buffer_B = NumpyFrontend.argument(B, False)
self.buffer_C = NumpyFrontend.argument(C, False)
self.buffer_D = NumpyFrontend.argument(D, True)
self.ptr_A = self.buffer_A.ptr
self.ptr_B = self.buffer_B.ptr
self.ptr_C = self.buffer_C.ptr
self.ptr_D = self.buffer_D.ptr
# number of elements in C
self.tensor_c_numel = C.size
elif torch_available and isinstance(A, torch.Tensor):
self.ptr_A = TorchFrontend.argument(A)
self.ptr_B = TorchFrontend.argument(B)
self.ptr_C = TorchFrontend.argument(C)
self.ptr_D = TorchFrontend.argument(D)
# number of elements in C
self.tensor_c_numel = C.numel()
elif isinstance(A, cuda.CUdeviceptr):
self.ptr_A = A
self.ptr_B = B
self.ptr_C = C
self.ptr_D = D
elif cupy_available and isinstance(A, cp.ndarray):
self.ptr_A = CupyFrontend.argument(A)
self.ptr_B = CupyFrontend.argument(B)
self.ptr_C = CupyFrontend.argument(C)
self.ptr_D = CupyFrontend.argument(D)
# number of elements in C
self.tensor_c_numel = C.size
else:
raise TypeError("Unsupported Frontend. Only support numpy and torch")
def sync(self, stream_sync=True):
if stream_sync:
(err,) = cudart.cudaDeviceSynchronize()
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
if hasattr(self, "host_D"):
(err,) = cuda.cuMemcpyDtoH(
self.host_D,
self.ptr_D,
self.host_D.size * self.host_D.itemsize,
)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))

View File

@ -0,0 +1,405 @@
#################################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import ctypes
import cutlass_bindings
from cutlass import (
DataType,
KernelScheduleType
)
from cutlass.backend.library import DataTypeSizeBytes
class GemmCoord_(ctypes.Structure):
_fields_ = [
("m", ctypes.c_int),
("n", ctypes.c_int),
("k", ctypes.c_int)
]
def __init__(self, gemm_coord) -> None:
for field_name, _ in self._fields_:
setattr(self, field_name, getattr(gemm_coord, field_name)())
class GemmCoordBatched_(ctypes.Structure):
"""
Wrapper around a GemmCoord that also contains batch count. This is used for encoding
batched GEMM inputs to CUTLASS 3 GEMMs.
"""
_fields_ = [
("m", ctypes.c_int),
("n", ctypes.c_int),
("k", ctypes.c_int),
("batch_count", ctypes.c_int)
]
def __init__(self, gemm_coord, batch_count) -> None:
for field_name, _ in self._fields_[:-1]:
setattr(self, field_name, getattr(gemm_coord, field_name)())
setattr(self, "batch_count", batch_count)
class MatrixCoord_(ctypes.Structure):
_fields_ = [
("row", ctypes.c_int),
("column", ctypes.c_int)
]
class dim3_(ctypes.Structure):
_fields_ = [
("x", ctypes.c_int),
("y", ctypes.c_int),
("z", ctypes.c_int)
]
class StrideBatched_(ctypes.Structure):
"""
CUTLASS 3.0 strides for operands contain one static dimension and two variable dimensions. The
variable dimensions represent the stride along non-unit-stride dimension of the row/column major
layout, and the batch stride. This structure encodes the two variable dimensions.
"""
_fields_ = [
("major_stride", ctypes.c_int64),
("batch_stride", ctypes.c_int64)
]
dtype2ctype = {
cutlass_bindings.float16: ctypes.c_uint16,
cutlass_bindings.float32: ctypes.c_float,
cutlass_bindings.float64: ctypes.c_double,
cutlass_bindings.int32: ctypes.c_int32,
}
class GenericMainloopArguments3x_(ctypes.Structure):
"""
Structure representing the superset of possible mainloop arguments.
This structure should not be passed to kernels directly, but, rather,
be used as an input to one of the more specific schedule arguments, which
will each select those arguments relevant to the particular schedule.
"""
_fields_ = [
("ptr_A", ctypes.c_void_p),
("stride_A", StrideBatched_),
("ptr_B", ctypes.c_void_p),
("stride_B", StrideBatched_),
]
def get_mainloop_arguments_3x(
kernel_schedule: KernelScheduleType,
element_A,
element_B,
alignment_A: int,
alignment_B: int) -> ctypes.Structure:
"""
Returns the ctypes structure to be used for the 3.x kernel's mainloop parameters.
:param kernel_schedule: type of kernel schedule to be used in the mainloop
:type kerel_schedule: cutlass.KernelScheduleType
:param element_A: data type of operand A
:param element_B: data type of operand B
:param alignment_A: alignment of operand A
:type alignment_A: int
:param alignment_B: alignment of operand B
:type alignment_B: int
:returns: ctypes structure to be used for the 3.x kernel's mainloop parameters
:rtype: ctypes.Structure
"""
class _MainloopArgumentsTma(ctypes.Structure):
_fields_ = [
("ptr_A", ctypes.c_void_p),
("stride_A", StrideBatched_),
("ptr_B", ctypes.c_void_p),
("stride_B", StrideBatched_),
]
@staticmethod
def from_generic_mainloop_args(args: GenericMainloopArguments3x_):
return _MainloopArgumentsTma(
args.ptr_A, args.stride_A, args.ptr_B, args.stride_B,
)
class _MainloopArgumentsMultistage(ctypes.Structure):
_fields_ = [
("ptr_A", ctypes.c_void_p),
("stride_A", StrideBatched_),
("ptr_B", ctypes.c_void_p),
("stride_B", StrideBatched_),
]
@staticmethod
def from_generic_mainloop_args(args: GenericMainloopArguments3x_):
return _MainloopArgumentsMultistage(
args.ptr_A, args.stride_A, args.ptr_B, args.stride_B,
)
tma_alignment_bytes = 16
is_tma_aligned_A = ((DataTypeSizeBytes[element_A] * alignment_A) % tma_alignment_bytes) == 0
is_tma_aligned_B = ((DataTypeSizeBytes[element_B] * alignment_B) % tma_alignment_bytes) == 0
is_tma_aligned = is_tma_aligned_A and is_tma_aligned_B
if kernel_schedule == KernelScheduleType.Multistage:
return _MainloopArgumentsMultistage
elif kernel_schedule == KernelScheduleType.ScheduleAuto:
if is_tma_aligned:
return _MainloopArgumentsTma
else:
return _MainloopArgumentsMultistage
else:
if is_tma_aligned:
return _MainloopArgumentsTma
else:
raise Exception(f"Specified a kernel schedule using TMA ({kernel_schedule}), but "
"the provided data types and alignments are not properly aligned for "
"using TMA.")
def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor):
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
class _EpilogueArguments(ctypes.Structure):
_fields_ = [
("epilogue", _EpilogueOutputOpParams),
("ptr_C", ctypes.c_void_p),
("stride_C", StrideBatched_),
("ptr_D", ctypes.c_void_p),
("stride_D", StrideBatched_),
]
class _GemmArguments(ctypes.Structure):
_fields_ = [
("mode", ctypes.c_int),
("problem_size", GemmCoordBatched_),
("mainloop", mainloop_arguments),
("epilogue", _EpilogueArguments)
]
return _GemmArguments, _EpilogueArguments, _EpilogueOutputOpParams
def get_gemm_arguments(epilogue_functor):
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
class _GemmArguments(ctypes.Structure):
_fields_ = [
# Arguments from UniversalArgumentsBase
("mode", ctypes.c_int),
("problem_size", GemmCoord_),
("batch_count", ctypes.c_int),
("batch_stride_D", ctypes.c_longlong),
# Remaining arguments
("epilogue", _EpilogueOutputOpParams),
("ptr_A", ctypes.c_void_p),
("ptr_B", ctypes.c_void_p),
("ptr_C", ctypes.c_void_p),
("ptr_D", ctypes.c_void_p),
("batch_stride_A", ctypes.c_longlong),
("batch_stride_B", ctypes.c_longlong),
("batch_stride_C", ctypes.c_longlong),
("stride_a", ctypes.c_longlong),
("stride_b", ctypes.c_longlong),
("stride_c", ctypes.c_longlong),
("stride_d", ctypes.c_longlong),
("lda", ctypes.c_longlong),
("ldb", ctypes.c_longlong),
("ldc", ctypes.c_longlong),
("ldd", ctypes.c_longlong),
("ptr_gather_A_indices", ctypes.c_void_p),
("ptr_gather_B_indices", ctypes.c_void_p),
("ptr_scatter_D_indices", ctypes.c_void_p)
]
return _GemmArguments, _EpilogueOutputOpParams
def get_gemm_arguments_streamk(epilogue_functor):
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
class _GemmArguments(ctypes.Structure):
_fields_ = [
("mode", ctypes.c_int),
("problem_size", GemmCoord_),
("batch_count", ctypes.c_int),
("epilogue", _EpilogueOutputOpParams),
("ptr_A", ctypes.c_void_p),
("ptr_B", ctypes.c_void_p),
("ptr_C", ctypes.c_void_p),
("ptr_D", ctypes.c_void_p),
("batch_stride_A", ctypes.c_longlong),
("batch_stride_B", ctypes.c_longlong),
("batch_stride_C", ctypes.c_longlong),
("batch_stride_D", ctypes.c_longlong),
("stride_a", ctypes.c_longlong),
("stride_b", ctypes.c_longlong),
("stride_c", ctypes.c_longlong),
("stride_d", ctypes.c_longlong),
("lda", ctypes.c_longlong),
("ldb", ctypes.c_longlong),
("ldc", ctypes.c_longlong),
("ldd", ctypes.c_longlong),
("avail_sms", ctypes.c_int)
]
return _GemmArguments, _EpilogueOutputOpParams
###########################################################################################
# GEMM Grouped
###########################################################################################
def get_gemm_grouped_arguments(epilogue_functor):
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
class _GEMMGroupedArguments(ctypes.Structure):
_fields_ = [
("problem_sizes", ctypes.c_void_p),
("problem_count", ctypes.c_int),
("threadblock_count", ctypes.c_int),
("output_op", _EpilogueOutputOpParams),
("ptr_A", ctypes.c_void_p),
("ptr_B", ctypes.c_void_p),
("ptr_C", ctypes.c_void_p),
("ptr_D", ctypes.c_void_p),
("lda", ctypes.c_void_p),
("ldb", ctypes.c_void_p),
("ldc", ctypes.c_void_p),
("ldd", ctypes.c_void_p),
("host_problem_sizes", ctypes.c_void_p)
]
return _GEMMGroupedArguments, _EpilogueOutputOpParams
############################################################################################
# Convolution2D
############################################################################################
class Conv2DProblemSize(ctypes.Structure):
_fields_ = [
("N", ctypes.c_int),
("H", ctypes.c_int),
("W", ctypes.c_int),
("C", ctypes.c_int),
("P", ctypes.c_int),
("Q", ctypes.c_int),
("K", ctypes.c_int),
("R", ctypes.c_int),
("S", ctypes.c_int),
("pad_h", ctypes.c_int),
("pad_w", ctypes.c_int),
("stride_h", ctypes.c_int),
("stride_w", ctypes.c_int),
("dilation_h", ctypes.c_int),
("dilation_w", ctypes.c_int),
("mode", ctypes.c_int), # kCrossCorrelation: 0, kConvolution: 1
("split_k_slices", ctypes.c_int),
("groups", ctypes.c_int)
]
def __init__(self, problem_size) -> None:
for field_name, _ in self._fields_:
setattr(self, field_name, getattr(problem_size, field_name))
class Layout4D(ctypes.Structure):
_fields_ = [("stride", ctypes.c_int * 3)]
def __init__(self, tensor_ref):
stride = tensor_ref.stride()
setattr(self, "stride", (stride.at(0), stride.at(1), stride.at(2)))
class TensorRef_(ctypes.Structure):
_fields_ = [
("ptr", ctypes.c_void_p),
("layout", Layout4D)
]
def __init__(self, tensor_ref):
setattr(self, "ptr", tensor_ref.data())
setattr(self, "layout", Layout4D(tensor_ref.layout()))
class TensorRef2D_(ctypes.Structure):
_fields_ = [
("ptr", ctypes.c_void_p),
("stride", ctypes.c_int)
]
def get_conv2d_arguments(epilogue_functor):
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
class _Conv2dArguments(ctypes.Structure):
_fields_ = [
("problem_size", Conv2DProblemSize),
("ref_A", TensorRef_),
("ref_B", TensorRef_),
("ref_C", TensorRef_),
("ref_D", TensorRef_),
("output_op", _EpilogueOutputOpParams),
("split_k_mode", ctypes.c_int)
]
return _Conv2dArguments, _EpilogueOutputOpParams
############################################################################################
# Reduction
############################################################################################
def get_reduction_params(epilogue_functor):
_EpilogueOutputParams = epilogue_functor.epilogue_type
class _ReductionParams(ctypes.Structure):
_fields_ = [
("problem_size", MatrixCoord_),
("partitions", ctypes.c_int),
("partition_stride", ctypes.c_longlong),
("workspace", TensorRef2D_),
("destination", TensorRef2D_),
("source", TensorRef2D_),
("output_op", _EpilogueOutputParams),
]
return _ReductionParams, _EpilogueOutputParams

View File

@ -0,0 +1,469 @@
#################################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import ctypes
import json
import os
import sqlite3
import tempfile
from cuda import cuda, nvrtc
import cutlass_bindings
from cutlass import CACHE_FILE, CUDA_INSTALL_PATH, CUTLASS_PATH
from cutlass.backend.gemm_operation import GemmOperationUniversal
from cutlass.backend.library import ApiVersion
from cutlass.backend.utils.device import device_cc
from cutlass.backend.utils.software import SubstituteTemplate
IncludeTemplate = r"""#include "${include}"
"""
class CompilationOptions:
"""
Compilation options.
"""
def __init__(self, flags, arch, include_paths=[]):
self.includes = []
self.include_paths = include_paths
self.flags = flags
self.arch = arch
def get_str(self):
options = ""
for flag in self.flags:
options += " " + flag
for incl in self.include_paths:
options += " --include-path=%s" % incl
arch_flag = " -arch=sm_%d" % self.arch
if self.arch == 90:
arch_flag += "a"
options += arch_flag
return options
def get(self):
options = []
for flag in self.flags:
options.append(bytes(str.encode(flag)))
for incl in self.include_paths:
options.append(bytes(str.encode("--include-path=%s" % incl)))
arch_flag = " -arch=sm_%d" % self.arch
if self.arch == 90:
arch_flag += "a"
options.append(bytes(str.encode(arch_flag)))
return options
def convertToBinaryData(filename):
with open(filename, "rb") as file:
blobData = file.read()
return blobData
def CDLLBin(host_binary):
tempfile.tempdir = "./"
temp_so = tempfile.NamedTemporaryFile(prefix="host_func", suffix=".so", delete=True)
with open(temp_so.name, "wb") as file:
file.write(host_binary)
host_lib = ctypes.CDLL(temp_so.name)
return host_lib
class ArtifactManager:
"""
Artifact manager
"""
def __init__(self) -> None:
connection = sqlite3.connect(CACHE_FILE)
cursor = connection.cursor()
# Create the table if it does not already exist
sqlite_create_table_query = """
CREATE TABLE IF NOT EXISTS compiled_operations(op_key TEXT NOT NULL UNIQUE,
cubin BLOB NOT NULL,
hostbin BLOB NOT NULL,
op_name TEXT NOT NULL,
op_attrs TEXT NOT NULL)
"""
cursor.execute(sqlite_create_table_query)
connection.commit()
cursor.close()
self.nvcc()
self.compiled_cache_device = cutlass_bindings.CompileCache()
self.compiled_cache_host = cutlass_bindings.CompileCache()
def nvrtc(self):
self.backend = "nvrtc"
self.default_compile_options = ["-std=c++17", "-default-device"]
def nvcc(self):
self.backend = "nvcc"
self.default_compile_options = [
"-std=c++17",
"--expt-relaxed-constexpr",
"-Xcudafe --diag_suppress=esa_on_defaulted_function_ignored",
]
def insert_operation(self, op_key, cubin, hostfile, op_name, op_attrs):
connection = sqlite3.connect(CACHE_FILE)
cursor = connection.cursor()
sqlite_insert_blob_query = """ INSERT OR IGNORE INTO compiled_operations (op_key, cubin, hostbin, op_name, op_attrs) VALUES (?, ?, ?, ?, ?)"""
hostbin = convertToBinaryData(hostfile)
data_tuple = (op_key, cubin, hostbin, op_name, json.dumps(op_attrs))
cursor.execute(sqlite_insert_blob_query, data_tuple)
connection.commit()
cursor.close()
def load_operation(self, op_key, extra_funcs):
connection = sqlite3.connect(CACHE_FILE)
cursor = connection.cursor()
sqlite_fetch_blob_query = """SELECT * from compiled_operations where op_key = ?"""
cursor.execute(sqlite_fetch_blob_query, (op_key,))
record = cursor.fetchall()
if len(record) == 0:
return False
for row in record:
key, cubin_image, host_binary, operation_name, op_attr = row
op_attr = json.loads(op_attr)
err, module = cuda.cuModuleLoadData(cubin_image)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("Cuda Error: {}".format(err))
err, kernel = cuda.cuModuleGetFunction(module, bytes(str.encode(operation_name)))
self.compiled_cache_device.insert(key, kernel)
compiled_host_fns = {}
host_lib = CDLLBin(host_binary)
func_name = operation_name + "_get_params"
func = getattr(host_lib, func_name)
func.restype = ctypes.POINTER(ctypes.c_char * op_attr[0])
compiled_host_fns["get_args"] = func
func_name = operation_name + "_shared_memory_size"
func = getattr(host_lib, func_name)
compiled_host_fns["shared_memory_capacity"] = func()
for attr in op_attr:
if isinstance(attr, str):
func_name = operation_name + "_" + attr
func = getattr(host_lib, func_name)
# Set the return type of the function
if attr in extra_funcs and extra_funcs[attr] != None:
func.restype = extra_funcs[attr]
compiled_host_fns[attr] = func
self.compiled_cache_host.insert(key, compiled_host_fns)
return True
def emit_compile_(self, operation_list, compilation_options, requires_nvcc_hostlib_compilation):
"""
Compile a list of kernels and store them into database
"""
source_buffer_device = ""
source_buffer_host = ""
# 1. include
includes = []
for operation in operation_list:
for incl in operation.emitter.includes:
if incl not in includes:
includes.append(incl)
includes_host = ["builtin_types.h", "device_launch_parameters.h", "stddef.h"] + includes
for incl in includes:
source_buffer_device += SubstituteTemplate(
IncludeTemplate,
{"include": incl},
)
for incl in includes_host:
if "/device/" not in incl:
source_buffer_host += SubstituteTemplate(
IncludeTemplate,
{"include": incl},
)
# 2. Operations
for operation in operation_list:
source_buffer_device += operation.emit()
source_buffer_host += operation.emit()
values = {
"operation_name": operation.name(),
"operation_suffix": operation.emitter.operation_suffix,
}
source_buffer_device += SubstituteTemplate(
operation.KernelTemplate,
values,
)
source_buffer_host += SubstituteTemplate(operation.HostTemplate, values)
if self.backend == "nvrtc":
# 3. compile
err, program = nvrtc.nvrtcCreateProgram(
str.encode(source_buffer_device),
bytes(str.encode("module.cu")),
0, [], [])
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError("NVRTC Error: {}".format(err))
# Compile program
options = compilation_options.get()
err, = nvrtc.nvrtcCompileProgram(program, len(options), options)
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
error_string = "NVRTC Error: {}\n".format(err)
# Get log from compilation
err, logSize = nvrtc.nvrtcGetProgramLogSize(program)
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError("NVRTC Error: {}".format(err))
log = b" " * logSize
err, = nvrtc.nvrtcGetProgramLog(program, log)
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError("NVRTC Error: {}".format(err))
raise RuntimeError(error_string + log.decode() + source_buffer_device)
# Get data from compilation
err, dataSize = nvrtc.nvrtcGetCUBINSize(program)
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError("NVRTC Error: {}".format(err))
cubin_image = b" " * dataSize
(err,) = nvrtc.nvrtcGetCUBIN(program, cubin_image)
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError("NVRTC Error: {}".format(err))
else: # with nvcc backend
# emit code
tempfile.tempdir = "./"
temp_cu = tempfile.NamedTemporaryFile(
prefix="kernel", suffix=".cu", delete=True)
temp_cubin = tempfile.NamedTemporaryFile(
prefix="kernel", suffix=".cubin", delete=True)
with open(temp_cu.name, "w") as file:
file.write(source_buffer_device)
# compile with nvcc
cmd_template = "${cuda_install_path}/bin/nvcc ${options} -cubin ${srcfile} -o ${tarfile}"
values = {
"cuda_install_path": CUDA_INSTALL_PATH,
"options": compilation_options.get_str(),
"srcfile": temp_cu.name,
"tarfile": temp_cubin.name,
}
cmd = SubstituteTemplate(cmd_template, values)
os.system(cmd)
# load the cubin image
with open(temp_cubin.name, "rb") as file:
cubin_image = file.read()
# Set up the host-side library code
if requires_nvcc_hostlib_compilation:
cmd_template = (
"echo '%s'|${cuda_install_path}/bin/nvcc -x cu -Xcompiler=\"-fpermissive -w -fPIC\" ${options}"
% source_buffer_host
)
cmd = SubstituteTemplate(
cmd_template,
{
"cuda_install_path": CUDA_INSTALL_PATH,
"options": compilation_options.get_str(),
},
)
else:
options = compilation_options.get()
cmd = (
"echo '%s'|g++ -x c++ -fpermissive -w -fPIC -DCUTLASS_PYTHON_HOST_CC=1"
% source_buffer_host
)
filtered_opts = [
"-default-device",
"-Xcicc",
"-Xllc",
"--expt-relaxed-constexpr",
"-Xcudafe --diag_suppress=esa_on_defaulted_function_ignored",
]
for opt in options:
opt = opt.decode("utf-8")
if opt not in filtered_opts and "-arch=sm_" not in opt:
if "--include-path=" in opt:
cmd += " " + opt.replace(
"--include-path=",
"-I",
)
else:
cmd += " " + opt
tempfile.tempdir = "./"
temp = tempfile.NamedTemporaryFile(
prefix="host_func", suffix=".so", delete=True)
cmd += " - -shared -o %s -lcudart -lcuda" % temp.name
os.system(cmd)
host_lib = ctypes.CDLL(temp.name)
return cubin_image, host_lib, temp
def add_module(self, operations, compile_options=None):
"""
Insert a new compiled device module
"""
if compile_options is None:
include_paths = [
CUDA_INSTALL_PATH + "/include",
CUTLASS_PATH + "/include",
CUTLASS_PATH + "/tools/util/include",
CUTLASS_PATH + "/python/cutlass/cpp/include",
]
if device_cc() is not None:
arch = device_cc()
else:
# Find the maximum arch tag among the provided operations and compile for that target.
# Since we are compiling to .cubin files, only one architecture may be specified.
arch = max([op.arch for op in operations])
compile_options = CompilationOptions(
self.default_compile_options, arch, include_paths)
# save the cubin
operation_key = []
operation_list = []
requires_nvcc_hostlib_compilation = False
for operation in operations:
# step 1: get kernel string as key
key = operation.rt_module.emit() + operation.procedural_name() + self.backend
# step 1: check if the operation is in cache
compiled_kernel = self.compiled_cache_device.at(key)
if compiled_kernel is None:
hit = self.load_operation(key, getattr( operation.rt_module, "extra_funcs", {}))
if hit:
compiled_kernel = self.compiled_cache_device.at(key)
assert compiled_kernel is not None
if compiled_kernel is not None:
operation.rt_module.kernel = compiled_kernel
compiled_host_fns = self.compiled_cache_host.at(key)
assert compiled_host_fns is not None
for key in compiled_host_fns.keys():
setattr(operation.rt_module, key, compiled_host_fns[key])
operation.rt_module.initialize()
else:
operation_list.append(operation.rt_module)
operation_key.append(key)
# Creating the Params structures for certain 3.0 kernels currently requires CUDA. For these cases, use NVCC to generate
# the PyCUTLASS host-side library. Otherwise, g++ will be used.
if isinstance(operation, GemmOperationUniversal) and operation.api == ApiVersion.v3x:
if self.backend == "nvrtc":
raise RuntimeError("CUTLASS 3 kernels currently require NVCC for compilation.")
requires_nvcc_hostlib_compilation = True
if len(operation_list) > 0:
cubin_image, host_lib, host_file = self.emit_compile_(
operation_list, compile_options, requires_nvcc_hostlib_compilation)
err, module = cuda.cuModuleLoadData(cubin_image)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("Cuda Error: {}".format(err))
operation_name = []
operation_attr = []
for operation, key in zip(operation_list, operation_key):
# get device kernels
err, operation.kernel = cuda.cuModuleGetFunction(
module,
bytes(str.encode(operation.name()))
)
operation_name.append(operation.name())
self.compiled_cache_device.insert(key, operation.kernel)
# get host functions
compiled_host_fns = {}
op_attr = []
# get param size
func_name = operation.name() + "_get_param_size"
func = getattr(host_lib, func_name)
param_size = func()
func_name = operation.name() + "_get_params"
func = getattr(host_lib, func_name)
func.argtype = operation.argtype
func.restype = ctypes.POINTER(ctypes.c_char * param_size)
setattr(operation, "get_args", func)
compiled_host_fns["get_args"] = func
# set shared memory size
func_name = operation.name() + "_shared_memory_size"
func = getattr(host_lib, func_name)
setattr(operation, "shared_memory_capacity", func())
compiled_host_fns["shared_memory_capacity"] = func()
# set the maximum dynamic shared size
operation.initialize()
# get extra functions
op_attr.append(param_size)
if hasattr(operation, "extra_funcs"):
for suffix, ret_type in operation.extra_funcs.items():
func_name = operation.name() + "_" + suffix
func = getattr(host_lib, func_name)
if ret_type is not None:
func.restype = ret_type
setattr(operation, suffix, func)
compiled_host_fns[suffix] = func
op_attr.append(suffix)
operation_attr.append(op_attr)
self.compiled_cache_host.insert(key, compiled_host_fns)
for (key, operation_name, operation_attr,) in zip(operation_key, operation_name, operation_attr):
self.insert_operation(
key, cubin_image, host_file.name, operation_name, operation_attr)

View File

@ -0,0 +1,655 @@
################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
# from typeguard import typechecked
import ctypes
from typing import Union
from cuda import cuda
import cutlass_bindings
import numpy as np
from cutlass.backend.arguments import ArgumentBase
from cutlass.backend.c_types import Conv2DProblemSize, TensorRef_, get_conv2d_arguments
from cutlass.backend.library import (
ConvKindNames,
ConvKindTag,
DataTypeNames,
DataTypeSize,
DataTypeTag,
IteratorAlgorithmNames,
IteratorAlgorithmTag,
LayoutTag,
MathOperation,
MathOperationTag,
OpcodeClassNames,
OpcodeClassTag,
OperationKind,
ShortDataTypeNames,
ShortLayoutTypeNames,
StrideSupport,
StrideSupportTag,
TensorDescription,
TileDescription,
get_complex_from_real,
)
from cutlass.backend.memory_manager import device_mem_alloc
from cutlass.backend.operation import ExecutableOperation, LaunchConfiguration
from cutlass.backend.tensor_ref import TensorRef
from cutlass.backend.utils.software import CheckPackages, SubstituteTemplate
if CheckPackages().check_torch():
import torch
# @typechecked
class Conv2dArguments(ArgumentBase):
"""
Argument wrapper for Conv2d. It encodes problem information and
user-provide tensors into the kernel's argument.
:param operation: the Conv2d operation to take the argument
:type operation: :class:`cutlass.backend.Conv2dOperation`
:param problem_size: the Conv2d problem size
:type problem_size: :class:`cutlass_bindings.conv.Conv2dProblemSize`
:param A: tensor A
:type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
:param B: tensor B
:type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
:param C: tensor C
:type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
:param D: tensor D
:type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
:param split_k_mode: conv2d split K mode, defaults to cutlass_bindings.conv.SplitKMode.Serial
:type split_k_mode: cutlass_bindings.conv.SplitKMode, optional
:param output_op: output operator, optional
:type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments`
"""
def __init__(
self,
operation: "Conv2dOperation",
problem_size: "cutlass_bindings.conv.Conv2dProblemSize",
A: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]",
B: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]",
C: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]",
D: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]",
split_k_mode: "cutlass_bindings.conv.SplitKMode" = cutlass_bindings.conv.SplitKMode.Serial,
**kwargs,
) -> None:
self.operation = operation
#: convolution kind
self.conv_kind: cutlass_bindings.conv.Operator = operation.conv_kind
self.layout_A: cutlass_bindings.layout = operation.A.layout
self.layout_B: cutlass_bindings.layout = operation.B.layout
self.layout_C: cutlass_bindings.layout = operation.C.layout
self.element_A = operation.A.element
self.element_B = operation.B.element
self.element_C = operation.C.element
if self.layout_C == cutlass_bindings.TensorNC32HW32:
B = self.reorder_tensor_B(B, problem_size)
super().__init__(A, B, C, D, **kwargs)
# preprocessing output ops
if "output_op" in kwargs.keys() and split_k_mode != cutlass_bindings.conv.SplitKMode.Parallel:
self.output_op = kwargs["output_op"]
else:
self.output_op = self.operation.epilogue_type(1.0, 0.0)
if "split_k_slices" in kwargs.keys():
self.split_k_mode = split_k_mode
self.split_k_slices = kwargs["split_k_slices"]
else:
self.split_k_mode = cutlass_bindings.conv.SplitKMode.Serial
self.split_k_slices = 1
#: problem_size
self.problem_size: cutlass_bindings.conv.Conv2dProblemSize = problem_size
self.problem_size.split_k_slices = self.split_k_slices
if hasattr(self, "tensor_c_numel"):
c_coord = cutlass_bindings.conv.implicit_gemm_tensor_c_extent(
self.conv_kind, problem_size)
if self.tensor_c_numel == c_coord.at(3) and self.tensor_c_numel < c_coord.size():
self.bias = True
#
# initialize the argument
#
self.initialize()
# @typechecked
def reorder_tensor_B(self, tensor_B: "np.ndarray",
problem_size: "cutlass_bindings.conv.Conv2dProblemSize"):
"""
Reorder tensor_B for interleaved layout
:param tensor_B: input tensor B
:type tensor_B: numpy.ndarray
:param problem_size: Conv2d problem size
:type problem_size: :class:`cutlass_bindings.conv.Conv2dProblemSize`
:return: reordered tensor B
:rtype: numpy.ndarray
"""
reordered_tensor_B = np.empty_like(tensor_B)
tensor_ref_B = self.get_tensor_ref(
tensor_B, self.element_B, self.layout_B, problem_size, "b")
reordered_tensor_ref_B = self.get_tensor_ref(
reordered_tensor_B, self.element_B, self.layout_B, problem_size, "b")
cutlass_bindings.conv.host.reorder_convK(
reordered_tensor_ref_B, tensor_ref_B, self.conv_kind, problem_size)
return reordered_tensor_B
def get_tensor_ref(
self, tensor, dtype, tensor_layout, problem_size, operand):
if operand == "a":
tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_a_extent(
self.conv_kind, problem_size)
elif operand == "b":
tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_b_extent(
self.conv_kind, problem_size)
elif operand in ["c", "d"]:
tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_c_extent(
self.conv_kind, problem_size)
else:
raise ValueError("unknown operand: " + operand)
# Zero stride trick
if operand == "c" and self.bias:
tensor_coord = cutlass_bindings.Tensor4DCoord(0, 0, 0, 0)
layout = tensor_layout.packed(tensor_coord)
return TensorRef(tensor, dtype, layout).tensor_ref
def get_arguments(self, semaphore):
ref_A = TensorRef_(self.get_tensor_ref(
self.ptr_A, self.element_A, self.layout_A, self.problem_size, "a"))
ref_B = TensorRef_(self.get_tensor_ref(
self.ptr_B, self.element_B, self.layout_B, self.problem_size, "b"))
ref_C = TensorRef_(self.get_tensor_ref(
self.ptr_C, self.element_C, self.layout_C, self.problem_size, "c"))
ref_D = TensorRef_(self.get_tensor_ref(
self.ptr_D, self.element_C, self.layout_C, self.problem_size, "d"))
self.c_arguments = self.operation.argument_type(
Conv2DProblemSize(self.problem_size),
ref_A, ref_B, ref_C, ref_D, self.output_op, self.split_k_mode)
self.semaphore = semaphore
def initialize(self):
# Get launch configuration
self.launch_config = self.operation.rt_module.plan(self)
# Allocate and initialize device workspace
device_workspace_size = self.operation.rt_module.get_device_workspace_size(self)
if device_workspace_size > 0:
self.workspace_buffer = device_mem_alloc(device_workspace_size)
workspace_ptr = self.workspace_buffer.ptr
err, = cuda.cuMemsetD32(
workspace_ptr, 0, device_workspace_size // 4)
else:
workspace_ptr = None
# Get kernel params as a bytearray
semaphore = 0
if (workspace_ptr is not None
and self.split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel):
self.ptr_D = workspace_ptr
elif (workspace_ptr is not None
and self.split_k_mode == cutlass_bindings.conv.SplitKMode.Serial):
semaphore = workspace_ptr
self.get_arguments(semaphore)
params_ = self.operation.rt_module.get_args(
ctypes.byref(self.c_arguments), ctypes.c_void_p(int(self.semaphore)))
self.host_workspace = bytearray(params_.contents)
self.device_workspace = None
def sync(self):
"""
Synchronize the arguments. If the input tensor is in host,
copy it from device to host.
"""
return super().sync()
# @typechecked
class Conv2dRT(ExecutableOperation):
"""
Conv2dRT manages the CUTLASS runtime components
"""
KernelTemplate = r"""
extern "C"
__global__ void
${operation_name}(${operation_name}${operation_suffix}::Params params) {
// Dynamic shared memory base pointer
extern __shared__ int SharedStorageBase[];
// Declare pointer to dynamic shared memory.
${operation_name}${operation_suffix}::SharedStorage *shared_storage =
reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase);
${operation_name}${operation_suffix} op;
op(params, *shared_storage);
}
"""
HostTemplate = r"""
extern "C" {
// Get the size of params in bytes
int ${operation_name}_get_param_size(){
return sizeof(${operation_name}${operation_suffix}::Params);
}
// Get the size of dynamic shared memory in bytes
int ${operation_name}_shared_memory_size() {
return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
}
// Get the params as byte array
char* ${operation_name}_get_params(${operation_name}${operation_suffix}::Arguments* arguments, int *semaphore=nullptr){
typename ${operation_name}${operation_suffix}::Params* params;
params = new ${operation_name}${operation_suffix}::Params(*arguments, semaphore);
char *bytes = ((char*)(params));
char *output = new char[sizeof(${operation_name}${operation_suffix}::Params)];
for (unsigned int i = 0; i < sizeof(${operation_name}${operation_suffix}::Params); i ++)
output[i] = bytes[i];
return output;
}
}
"""
def __init__(self, operation: "Conv2dOperation"):
super().__init__(operation)
self.argument_type, self.epilogue_type = get_conv2d_arguments(operation.epilogue_functor)
self.argtype = [ctypes.POINTER(self.argument_type), ctypes.c_void_p]
self.conv_kind = operation.conv_kind
self.operation: Conv2dOperation = operation
self.emitter = EmitConv2dInstance("_type")
self.threads: int = operation.tile_description.num_threads
self.swizzle_functor = operation.swizzling_functor
def emit(self):
return self.emitter.emit(self.operation)
def get_device_workspace_size(self, arguments: Conv2dArguments):
workspace_bytes = 0
launch_config = arguments.launch_config
self.conv_kind = self.operation.conv_kind
if arguments.split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel:
problem_size = arguments.problem_size
workspace_bytes = DataTypeSize[self.operation.C.element] \
* launch_config.grid[2] * cutlass_bindings.conv.implicit_gemm_tensor_c_size(
self.conv_kind, problem_size
) // 8
elif arguments.split_k_mode == cutlass_bindings.conv.SplitKMode.Serial and \
arguments.split_k_slices > 1:
workspace_bytes = launch_config.grid[0] * launch_config.grid[1] * 4
return workspace_bytes
# @typechecked
def plan(self, arguments: Conv2dArguments):
tile_size = cutlass_bindings.gemm.GemmCoord(
self.operation.tile_description.threadblock_shape[0],
self.operation.tile_description.threadblock_shape[1],
self.operation.tile_description.threadblock_shape[2],
)
grid = self.swizzle_functor.get_grid_shape(
self.swizzle_functor.get_tiled_shape(
self.conv_kind, arguments.problem_size,
tile_size, arguments.split_k_slices
)
)
return LaunchConfiguration(
[grid.x, grid.y, grid.z], [self.threads, 1, 1],
self.shared_memory_capacity)
def initialize(self):
err, = cuda.cuFuncSetAttribute(
self.kernel,
attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
value=self.shared_memory_capacity)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("Cuda Error: {}".format(err))
class Conv2dOperation:
"""
CUTLASS Conv2d operation description.
:param conv_kind: convolution operator
:type conv_kind: :class:`cutlass_bindings.conv.Operator`
:param iterator_algorithm: Selects among several implementation
variants trading off performance with simplicity
:type iterator_algorithm: :class:`cutlass_bindings.conv.IteratorAlgorithm`
:param arch: GPU compute capability (sm_xx)
:type arch: int
:param tile_description: tile description
:type tile_description: :class:`cutlass.backend.TileDescription`
:param A: tensor A description
:type A: :class:`cutlass.backend.TensorDescription`
:param B: tensor B description
:type B: :class:`cutlass.backend.TensorDescription`
:param C: tensor C description
:type C: :class:`cutlass.backend.TensorDescription`
:param D: tensor D description
:type D: :class:`cutlass.backend.TensorDescription`
:param element_epilogue: element type for computation in epilogue \
:type element_epilogue: cutlass_bindings.int8 | cutlass_bindings.int32 | cutlass_bindings.float16 | \
cutlass_bindings.bfloat16 | cutlass_bindings.float32 | cutlass_bindings.float64
:param stride_support: distinguish among partial specializations that \
accelerate certain problems where convolution stride is unit \
:type stride_support: :class:`cutlass_bindings.conv.StrideSupport`
:param epilogue_functor: convolution epilogue functor
:type epilogue_functor: :class:`EpilogueFunctor`
:param swizzling_functor: threadblock swizzling functor
"""
def __init__(
self,
conv_kind: cutlass_bindings.conv.Operator,
iterator_algorithm: cutlass_bindings.conv.IteratorAlgorithm,
arch: int,
tile_description: TileDescription,
A: TensorDescription,
B: TensorDescription,
C: TensorDescription,
stride_support,
epilogue_functor,
swizzling_functor=cutlass_bindings.IdentitySwizzle1
):
self.operation_kind: OperationKind = OperationKind.Conv2d
self.arch: int = arch
self.tile_description: TileDescription = tile_description
self.conv_kind = conv_kind
self.A: TensorDescription = A
self.B: TensorDescription = B
self.C: TensorDescription = C
self.epilogue_functor = epilogue_functor
self.iterator_algorithm = iterator_algorithm
self.stride_support = stride_support
self.swizzling_functor = swizzling_functor()
self.rt_module: Conv2dRT = Conv2dRT(self)
self.argument_type = self.rt_module.argument_type
self.epilogue_type = self.rt_module.epilogue_type
def run(self, arguments: Conv2dArguments) -> cuda.CUresult:
"""
Launch the cuda kernel with input arguments
:param arguments: conv2d arguments
:type arguments: :class:`cutlass.backend.Conv2dArguments`
"""
# launch the kernel
err = self.rt_module.run(
arguments.host_workspace,
arguments.device_workspace,
arguments.launch_config,
)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
return err
#
# Get function name
#
def procedural_name(self):
"""The full procedural name indicates architecture, extended name, tile size, and layout."""
return self.configuration_name()
#
def configuration_name(self):
"""The full procedural name indicates architecture, extended name, tile size, and layout."""
opcode_class_name = OpcodeClassNames[
self.tile_description.math_instruction.opcode_class
]
threadblock = "%dx%d_%dx%d" % (
self.tile_description.threadblock_shape[0],
self.tile_description.threadblock_shape[1],
self.tile_description.threadblock_shape[2],
self.tile_description.stages,
)
if self.stride_support == StrideSupport.Unity:
configuration_name = "cutlass_sm${arch}_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_align${alignment}"
else:
configuration_name = "cutlass_sm${arch}_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}"
return SubstituteTemplate(
configuration_name,
{
"arch": str(self.arch),
"opcode_class": opcode_class_name,
"extended_name": self.extended_name(),
"threadblock": threadblock,
"layout": self.layout_name(),
"alignment": "%d" % self.A.alignment
},
)
#
def extended_name(self):
"""Append data types if they differ from compute type."""
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
self.A.element != self.tile_description.math_instruction.element_accumulator:
extended_name = "${element_c}_${core_name}_${element_a}"
elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
self.A.element != self.tile_description.math_instruction.element_accumulator:
extended_name = "${core_name}_${element_a}"
else:
extended_name = "${core_name}"
extended_name = SubstituteTemplate(extended_name, {
"element_a": DataTypeNames[self.A.element],
"element_c": DataTypeNames[self.C.element],
"core_name": self.core_name(),
})
return extended_name
#
def layout_name(self):
return "%s" % (ShortLayoutTypeNames[self.A.layout])
#
def core_name(self):
"""The basic operation kind is prefixed with a letter indicating the accumulation type."""
intermediate_type = ""
if self.tile_description.math_instruction.opcode_class == cutlass_bindings.OpClass.TensorOp:
inst_shape = "%dx%dx%d" % tuple(
self.tile_description.math_instruction.instruction_shape)
if self.tile_description.math_instruction.element_a != self.A.element and \
self.tile_description.math_instruction.element_a != self.accumulator_type():
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
else:
inst_shape = ""
return "%s%s%s%s_%s" % (
ShortDataTypeNames[self.accumulator_type()],
inst_shape,
intermediate_type,
ConvKindNames[self.conv_kind],
IteratorAlgorithmNames[self.iterator_algorithm]
)
#
def is_complex(self):
complex_operators = [
MathOperation.multiply_add_complex,
MathOperation.multiply_add_complex_gaussian,
]
return self.tile_description.math_instruction.math_operation in complex_operators
#
def accumulator_type(self):
accum = self.tile_description.math_instruction.element_accumulator
if self.is_complex():
return get_complex_from_real(accum)
return accum
###################################################################################################
#
# Emits single instances of a CUTLASS device-wide operator
#
###################################################################################################
class EmitConv2dInstance:
def __init__(self, operation_suffix=""):
self.operation_suffix = operation_suffix
self.includes = [
"cutlass/cutlass.h",
"cutlass/conv/kernel/default_conv2d_fprop.h",
"cutlass/conv/kernel/default_conv2d_dgrad.h",
"cutlass/conv/kernel/default_conv2d_wgrad.h"
]
self.template = """
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
using ${operation_name}_base =
typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
${element_a},
${layout_a},
${element_b},
${layout_b},
${element_c},
${layout_c},
${element_accumulator},
${opcode_class},
${arch},
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
${epilogue_functor},
${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
${stages},
${math_operator},
${iterator_algorithm},
${stride_support},
${align_a},
${align_b}
>::Kernel;
struct ${operation_name}${operation_suffix}:
public ${operation_name}_base { };
"""
def emit(self, operation):
warp_shape = [int(operation.tile_description.threadblock_shape[idx] /
operation.tile_description.warp_count[idx]) for idx in range(3)]
epilogue_vector_length = int(min(
operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
values = {
"operation_name": operation.procedural_name(),
"operation_suffix": self.operation_suffix,
"conv_kind": ConvKindTag[operation.conv_kind],
"conv_kind_name": ConvKindNames[operation.conv_kind].capitalize(),
"element_a": DataTypeTag[operation.A.element],
"layout_a": LayoutTag[operation.A.layout],
"element_b": DataTypeTag[operation.B.element],
"layout_b": LayoutTag[operation.B.layout],
"element_c": DataTypeTag[operation.C.element],
"layout_c": LayoutTag[operation.C.layout],
"element_accumulator": DataTypeTag[operation.accumulator_type()],
"opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
"arch": "cutlass::arch::Sm%d" % operation.arch,
"threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
"threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
"threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
"warp_shape_m": str(warp_shape[0]),
"warp_shape_n": str(warp_shape[1]),
"warp_shape_k": str(warp_shape[2]),
"instruction_shape_m": str(operation.tile_description.math_instruction.instruction_shape[0]),
"instruction_shape_n": str(operation.tile_description.math_instruction.instruction_shape[1]),
"instruction_shape_k": str(operation.tile_description.math_instruction.instruction_shape[2]),
"epilogue_vector_length": str(epilogue_vector_length),
"epilogue_functor": operation.epilogue_functor.emit(),
"swizzling_functor": operation.swizzling_functor.tag(),
"stages": str(operation.tile_description.stages),
"iterator_algorithm": IteratorAlgorithmTag[operation.iterator_algorithm],
"iterator_algorithm_name": IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(),
"stride_support": StrideSupportTag[operation.stride_support],
"math_operator": "cutlass::arch::OpMultiplyAddComplex" if operation.is_complex() else MathOperationTag[operation.tile_description.math_instruction.math_operation],
"align_a": str(operation.A.alignment),
"align_b": str(operation.B.alignment),
}
return SubstituteTemplate(self.template, values)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,96 @@
################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
from cuda import cuda
import numpy as np
from cutlass.backend.memory_manager import device_mem_alloc, todevice
from cutlass.backend.utils.software import CheckPackages
if CheckPackages().check_torch():
import torch
if CheckPackages().check_cupy():
import cupy as cp
class NumpyFrontend:
"""
Frontend node for numpy
"""
@staticmethod
def argument(np_tensor: "np.ndarray", is_output: "bool") -> cuda.CUdeviceptr:
"""Convert the input numpy tensor to CUDA device pointer
:param np_tensor: input numpy nd array
:param is_output: whether the tensor is output
:return: CUDA device pointer
"""
# copy the data to device
if is_output:
return device_mem_alloc(np_tensor.size * np_tensor.itemsize)
else:
return todevice(np_tensor)
class TorchFrontend:
"""
Frontend node for torch
"""
@staticmethod
def argument(torch_tensor: "torch.Tensor") -> cuda.CUdeviceptr:
"""Convert the input torch tensor to CUDA device pointer
:param torch_tensor: input torch tensor
:param is_output: whether the tensor is output
:return: CUDA device pointer
"""
# check the device of torch_tensor
if not torch_tensor.is_cuda:
torch_tensor = torch_tensor.to("cuda")
return cuda.CUdeviceptr(torch_tensor.data_ptr())
class CupyFrontend:
"""
Frontend node for cupy
"""
@staticmethod
def argument(cupy_ndarray: "cp.ndarray"):
return cuda.CUdeviceptr(int(cupy_ndarray.data.ptr))

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,714 @@
#################################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Common data types and string names for them. This file is similar to /tools/library/scripts/library.py,
but uses the Pybind-bound CUTLASS data types as many keys to the dictionary.
"""
import enum
import cutlass_bindings
from cutlass import KernelScheduleType
# The following block implements enum.auto() for Python 3.5 variants that don't include it such
# as the default 3.5.2 on Ubuntu 16.04.
#
# https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility
try:
from enum import auto as enum_auto
except ImportError:
__cutlass_library_auto_enum = 0
def enum_auto() -> int:
global __cutlass_library_auto_enum
i = __cutlass_library_auto_enum
__cutlass_library_auto_enum += 1
return i
ShortDataTypeNames = {
cutlass_bindings.int32: "i",
cutlass_bindings.float16: "h",
cutlass_bindings.float32: "s",
cutlass_bindings.float64: "d",
cutlass_bindings.dtype.cf32: "c",
cutlass_bindings.dtype.cf64: "z",
}
DataTypeNames = {
cutlass_bindings.dtype.b1: "b1",
cutlass_bindings.dtype.u4: "u4",
cutlass_bindings.dtype.u8: "u8",
cutlass_bindings.dtype.u16: "u16",
cutlass_bindings.dtype.u32: "u32",
cutlass_bindings.dtype.u64: "u64",
cutlass_bindings.dtype.s4: "s4",
cutlass_bindings.int8: "s8",
cutlass_bindings.dtype.s16: "s16",
cutlass_bindings.int32: "s32",
cutlass_bindings.dtype.s64: "s64",
cutlass_bindings.float16: "f16",
cutlass_bindings.bfloat16: "bf16",
cutlass_bindings.float32: "f32",
cutlass_bindings.tfloat32: "tf32",
cutlass_bindings.float64: "f64",
cutlass_bindings.dtype.cf16: "cf16",
cutlass_bindings.dtype.cbf16: "cbf16",
cutlass_bindings.dtype.cf32: "cf32",
cutlass_bindings.dtype.ctf32: "ctf32",
cutlass_bindings.dtype.cf64: "cf64",
cutlass_bindings.dtype.cu4: "cu4",
cutlass_bindings.dtype.cu8: "cu8",
cutlass_bindings.dtype.cu16: "cu16",
cutlass_bindings.dtype.cu32: "cu32",
cutlass_bindings.dtype.cu64: "cu64",
cutlass_bindings.dtype.cs4: "cs4",
cutlass_bindings.dtype.cs8: "cs8",
cutlass_bindings.dtype.cs16: "cs16",
cutlass_bindings.dtype.cs32: "cs32",
cutlass_bindings.dtype.cs64: "cs64",
}
DataTypeTag = {
cutlass_bindings.dtype.b1: "cutlass::uint1b_t",
cutlass_bindings.dtype.u4: "cutlass::uint4b_t",
cutlass_bindings.dtype.u8: "uint8_t",
cutlass_bindings.dtype.u16: "uint16_t",
cutlass_bindings.dtype.u32: "uint32_t",
cutlass_bindings.dtype.u64: "uint64_t",
cutlass_bindings.dtype.s4: "cutlass::int4b_t",
cutlass_bindings.int8: "int8_t",
cutlass_bindings.dtype.s16: "int16_t",
cutlass_bindings.int32: "int32_t",
cutlass_bindings.dtype.s64: "int64_t",
cutlass_bindings.float16: "cutlass::half_t",
cutlass_bindings.bfloat16: "cutlass::bfloat16_t",
cutlass_bindings.float32: "float",
cutlass_bindings.tfloat32: "cutlass::tfloat32_t",
cutlass_bindings.float64: "double",
cutlass_bindings.dtype.cf16: "cutlass::complex<cutlass::half_t>",
cutlass_bindings.dtype.cbf16: "cutlass::complex<cutlass::bfloat16_t>",
cutlass_bindings.dtype.cf32: "cutlass::complex<float>",
cutlass_bindings.dtype.ctf32: "cutlass::complex<cutlass::tfloat32_t>",
cutlass_bindings.dtype.cf64: "cutlass::complex<double>",
cutlass_bindings.dtype.cu4: "cutlass::complex<cutlass::uint4b_t>",
cutlass_bindings.dtype.cu8: "cutlass::complex<cutlass::uint8_t>",
cutlass_bindings.dtype.cu16: "cutlass::complex<cutlass::uint16_t>",
cutlass_bindings.dtype.cu32: "cutlass::complex<cutlass::uint32_t>",
cutlass_bindings.dtype.cu64: "cutlass::complex<cutlass::uint64_t>",
cutlass_bindings.dtype.cs4: "cutlass::complex<cutlass::int4b_t>",
cutlass_bindings.dtype.cs8: "cutlass::complex<cutlass::int8_t>",
cutlass_bindings.dtype.cs16: "cutlass::complex<cutlass::int16_t>",
cutlass_bindings.dtype.cs32: "cutlass::complex<cutlass::int32_t>",
cutlass_bindings.dtype.cs64: "cutlass::complex<cutlass::int64_t>",
}
DataTypeSize = {
cutlass_bindings.dtype.b1: 1,
cutlass_bindings.dtype.u4: 4,
cutlass_bindings.dtype.u8: 8,
cutlass_bindings.dtype.u16: 16,
cutlass_bindings.dtype.u32: 32,
cutlass_bindings.dtype.u64: 64,
cutlass_bindings.dtype.s4: 4,
cutlass_bindings.int8: 8,
cutlass_bindings.dtype.s16: 16,
cutlass_bindings.int32: 32,
cutlass_bindings.dtype.s64: 64,
cutlass_bindings.float16: 16,
cutlass_bindings.bfloat16: 16,
cutlass_bindings.float32: 32,
cutlass_bindings.tfloat32: 32,
cutlass_bindings.float64: 64,
cutlass_bindings.dtype.cf16: 32,
cutlass_bindings.dtype.cbf16: 32,
cutlass_bindings.dtype.cf32: 64,
cutlass_bindings.dtype.ctf32: 32,
cutlass_bindings.dtype.cf64: 128,
cutlass_bindings.dtype.cu4: 8,
cutlass_bindings.dtype.cu8: 16,
cutlass_bindings.dtype.cu16: 32,
cutlass_bindings.dtype.cu32: 64,
cutlass_bindings.dtype.cu64: 128,
cutlass_bindings.dtype.cs4: 8,
cutlass_bindings.dtype.cs8: 16,
cutlass_bindings.dtype.cs16: 32,
cutlass_bindings.dtype.cs32: 64,
cutlass_bindings.dtype.cs64: 128,
}
class DataTypeSizeBytes:
"""
Static class to mimic the `DataTypeSize` dictionary, but with checks for whether the
data type key is less than a full byte or a non-integer number of bytes.
"""
@staticmethod
def __class_getitem__(datatype):
"""
Returns the number of bytes in size the data type is. Raises an exception if the data type
is either less than a full byte or a non-integer number of bytes in size.
:param datatype: data type to query
:return: number of bytes the data type occupies
:rtype: int
"""
bits = DataTypeSize[datatype]
if bits < 8:
raise Exception(
"Data type {} is less than one byte in size.".format(datatype)
)
elif bits % 8 != 0:
raise Exception(
"Data type {} is not an integer number of bytes.".format(datatype)
)
return bits // 8
ComplexTransformTag = {
cutlass_bindings.complex_transform.none: "cutlass::ComplexTransform::kNone",
cutlass_bindings.complex_transform.conj: "cutlass::ComplexTransform::kConjugate",
}
RealComplexBijection = [
(cutlass_bindings.float16, cutlass_bindings.dtype.cf16),
(cutlass_bindings.float32, cutlass_bindings.dtype.cf32),
(cutlass_bindings.float64, cutlass_bindings.dtype.cf64),
]
def is_complex(data_type):
for r, c in RealComplexBijection:
if data_type == c:
return True
return False
def get_complex_from_real(real_type):
for r, c in RealComplexBijection:
if real_type == r:
return c
return cutlass_bindings.dtype.invalid
def get_real_from_complex(complex_type):
for r, c in RealComplexBijection:
if complex_type == c:
return r
return cutlass_bindings.dtype.invalid
class ComplexMultiplyOp(enum.Enum):
multiply_add = enum_auto()
gaussian = enum_auto()
class MathOperation(enum.Enum):
multiply_add = enum_auto()
multiply_add_saturate = enum_auto()
xor_popc = enum_auto()
multiply_add_fast_bf16 = enum_auto()
multiply_add_fast_f16 = enum_auto()
multiply_add_fast_f32 = enum_auto()
multiply_add_complex_fast_f32 = enum_auto()
multiply_add_complex = enum_auto()
multiply_add_complex_gaussian = enum_auto()
MathOperationNames = {
MathOperation.multiply_add: "multiply_add",
MathOperation.multiply_add_saturate: "multiply_add_saturate",
MathOperation.xor_popc: "xor_popc",
MathOperation.multiply_add_fast_bf16: "multiply_add_fast_bf16",
MathOperation.multiply_add_fast_f16: "multiply_add_fast_f16",
MathOperation.multiply_add_fast_f32: "multiply_add_fast_f32",
MathOperation.multiply_add_complex_fast_f32: "multiply_add_complex_fast_f32",
MathOperation.multiply_add_complex: "multiply_add_complex",
MathOperation.multiply_add_complex_gaussian: "multiply_add_complex_gaussian",
}
MathOperationTag = {
MathOperation.multiply_add: "cutlass::arch::OpMultiplyAdd",
MathOperation.multiply_add_saturate: "cutlass::arch::OpMultiplyAddSaturate",
MathOperation.xor_popc: "cutlass::arch::OpXorPopc",
MathOperation.multiply_add_fast_bf16: "cutlass::arch::OpMultiplyAddFastBF16",
MathOperation.multiply_add_fast_f16: "cutlass::arch::OpMultiplyAddFastF16",
MathOperation.multiply_add_fast_f32: "cutlass::arch::OpMultiplyAddFastF32",
MathOperation.multiply_add_complex_fast_f32: "cutlass::arch::OpMultiplyAddComplexFastF32",
MathOperation.multiply_add_complex: "cutlass::arch::OpMultiplyAddComplex",
MathOperation.multiply_add_complex_gaussian: "cutlass::arch::OpMultiplyAddGaussianComplex",
}
LayoutTag = {
cutlass_bindings.ColumnMajor: "cutlass::layout::ColumnMajor",
cutlass_bindings.RowMajor: "cutlass::layout::RowMajor",
cutlass_bindings.layout.ColumnMajorInterleaved2: "cutlass::layout::ColumnMajorInterleaved<2>",
cutlass_bindings.layout.RowMajorInterleaved2: "cutlass::layout::RowMajorInterleaved<2>",
cutlass_bindings.ColumnMajorInterleaved32: "cutlass::layout::ColumnMajorInterleaved<32>",
cutlass_bindings.RowMajorInterleaved32: "cutlass::layout::RowMajorInterleaved<32>",
cutlass_bindings.layout.ColumnMajorInterleaved64: "cutlass::layout::ColumnMajorInterleaved<64>",
cutlass_bindings.layout.RowMajorInterleaved64: "cutlass::layout::RowMajorInterleaved<64>",
cutlass_bindings.TensorNHWC: "cutlass::layout::TensorNHWC",
cutlass_bindings.layout.TensorNDHWC: "cutlass::layout::TensorNDHWC",
cutlass_bindings.layout.TensorNCHW: "cutlass::layout::TensorNCHW",
cutlass_bindings.layout.TensorNGHWC: "cutlass::layout::TensorNGHWC",
cutlass_bindings.TensorNC32HW32: "cutlass::layout::TensorNCxHWx<32>",
cutlass_bindings.TensorC32RSK32: "cutlass::layout::TensorCxRSKx<32>",
cutlass_bindings.layout.TensorNC64HW64: "cutlass::layout::TensorNCxHWx<64>",
cutlass_bindings.layout.TensorC64RSK64: "cutlass::layout::TensorCxRSKx<64>",
}
TransposedLayout = {
cutlass_bindings.ColumnMajor: cutlass_bindings.RowMajor,
cutlass_bindings.RowMajor: cutlass_bindings.ColumnMajor,
cutlass_bindings.layout.ColumnMajorInterleaved2: cutlass_bindings.layout.RowMajorInterleaved2,
cutlass_bindings.layout.RowMajorInterleaved2: cutlass_bindings.layout.ColumnMajorInterleaved2,
cutlass_bindings.ColumnMajorInterleaved32: cutlass_bindings.RowMajorInterleaved32,
cutlass_bindings.RowMajorInterleaved32: cutlass_bindings.ColumnMajorInterleaved32,
cutlass_bindings.layout.ColumnMajorInterleaved64: cutlass_bindings.layout.RowMajorInterleaved64,
cutlass_bindings.layout.RowMajorInterleaved64: cutlass_bindings.layout.ColumnMajorInterleaved64,
cutlass_bindings.TensorNHWC: cutlass_bindings.TensorNHWC,
}
ShortLayoutTypeNames = {
cutlass_bindings.ColumnMajor: "n",
cutlass_bindings.layout.ColumnMajorInterleaved2: "n2",
cutlass_bindings.ColumnMajorInterleaved32: "n32",
cutlass_bindings.layout.ColumnMajorInterleaved64: "n64",
cutlass_bindings.RowMajor: "t",
cutlass_bindings.layout.RowMajorInterleaved2: "t2",
cutlass_bindings.RowMajorInterleaved32: "t32",
cutlass_bindings.layout.RowMajorInterleaved64: "t64",
cutlass_bindings.TensorNHWC: "nhwc",
cutlass_bindings.layout.TensorNDHWC: "ndhwc",
cutlass_bindings.layout.TensorNCHW: "nchw",
cutlass_bindings.layout.TensorNGHWC: "nghwc",
cutlass_bindings.TensorNC32HW32: "nc32hw32",
cutlass_bindings.layout.TensorNC64HW64: "nc64hw64",
cutlass_bindings.TensorC32RSK32: "c32rsk32",
cutlass_bindings.layout.TensorC64RSK64: "c64rsk64",
}
ShortComplexLayoutNames = {
(cutlass_bindings.ColumnMajor, cutlass_bindings.complex_transform.none): "n",
(cutlass_bindings.ColumnMajor, cutlass_bindings.complex_transform.conj): "c",
(cutlass_bindings.RowMajor, cutlass_bindings.complex_transform.none): "t",
(cutlass_bindings.RowMajor, cutlass_bindings.complex_transform.conj): "h",
}
OpcodeClassNames = {
cutlass_bindings.OpClass.Simt: "simt",
cutlass_bindings.OpClass.TensorOp: "tensorop",
cutlass_bindings.OpClass.WmmaTensorOp: "wmma_tensorop",
cutlass_bindings.OpClass.SparseTensorOp: "sptensorop",
}
OpcodeClassTag = {
cutlass_bindings.OpClass.Simt: "cutlass::arch::OpClassSimt",
cutlass_bindings.OpClass.TensorOp: "cutlass::arch::OpClassTensorOp",
cutlass_bindings.OpClass.WmmaTensorOp: "cutlass::arch::OpClassWmmaTensorOp",
cutlass_bindings.OpClass.SparseTensorOp: "cutlass::arch::OpClassSparseTensorOp",
}
class OperationKind(enum.Enum):
Gemm = enum_auto()
Conv2d = enum_auto()
Conv3d = enum_auto()
OperationKindNames = {
OperationKind.Gemm: "gemm",
OperationKind.Conv2d: "conv2d",
OperationKind.Conv3d: "conv3d",
}
ArchitectureNames = {
50: "maxwell",
60: "pascal",
61: "pascal",
70: "volta",
75: "turing",
80: "ampere",
90: "hopper",
}
SharedMemPerCC = {
70: 96 << 10, # 96KB of SMEM
72: 96 << 10, # 96KB of SMEM
75: 64 << 10, # 64KB of SMEM
80: 160 << 10, # 164KB of SMEM - 4KB reserved for the driver
86: 100 << 10, # 100KB of SMEM
87: 160 << 10, # 164KB of SMEM - 4KB reserved for the driver
89: 100 << 10, # 100KB of SMEM
90: 227 << 10, # 228KB of SMEM - 1KB reserved for the driver
}
class GemmKind(enum.Enum):
Gemm = enum_auto()
Sparse = enum_auto()
Universal = enum_auto()
PlanarComplex = enum_auto()
PlanarComplexArray = enum_auto()
Grouped = enum_auto()
GemmKindNames = {
GemmKind.Gemm: "gemm",
GemmKind.Sparse: "spgemm",
GemmKind.Universal: "gemm",
GemmKind.PlanarComplex: "gemm_planar_complex",
GemmKind.PlanarComplexArray: "gemm_planar_complex_array",
GemmKind.Grouped: "gemm_grouped",
}
class SwizzlingFunctor(enum.Enum):
Identity1 = enum_auto()
Identity2 = enum_auto()
Identity4 = enum_auto()
Identity8 = enum_auto()
Horizontal = enum_auto()
BatchedIdentity1 = enum_auto()
StridedDgradIdentity1 = enum_auto()
StridedDgradIdentity4 = enum_auto()
StridedDgradHorizontal = enum_auto()
SwizzlingFunctorTag = {
cutlass_bindings.IdentitySwizzle1: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>",
SwizzlingFunctor.Identity2: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>",
SwizzlingFunctor.Identity4: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>",
SwizzlingFunctor.Identity8: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>",
SwizzlingFunctor.Horizontal: "cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle",
SwizzlingFunctor.BatchedIdentity1: "cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle",
SwizzlingFunctor.StridedDgradIdentity1: "cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>",
SwizzlingFunctor.StridedDgradIdentity4: "cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>",
SwizzlingFunctor.StridedDgradHorizontal: "cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle",
}
class SchedulerMode(enum.Enum):
Device = (enum_auto(),)
Host = enum_auto()
SchedulerModeTag = {
SchedulerMode.Device: "cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly",
SchedulerMode.Host: "cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute",
}
ShortSchedulerModeNames = {SchedulerMode.Device: "Device", SchedulerMode.Host: "Host"}
ConvKindTag = {
cutlass_bindings.conv.Operator.fprop: "cutlass::conv::Operator::kFprop",
cutlass_bindings.conv.Operator.dgrad: "cutlass::conv::Operator::kDgrad",
cutlass_bindings.conv.Operator.wgrad: "cutlass::conv::Operator::kWgrad",
}
ConvKindNames = {
cutlass_bindings.conv.Operator.fprop: "fprop",
cutlass_bindings.conv.Operator.dgrad: "dgrad",
cutlass_bindings.conv.Operator.wgrad: "wgrad",
}
IteratorAlgorithmTag = {
cutlass_bindings.conv.IteratorAlgorithm.analytic: "cutlass::conv::IteratorAlgorithm::kAnalytic",
cutlass_bindings.conv.IteratorAlgorithm.optimized: "cutlass::conv::IteratorAlgorithm::kOptimized",
cutlass_bindings.conv.IteratorAlgorithm.fixed_channels: "cutlass::conv::IteratorAlgorithm::kFixedChannels",
cutlass_bindings.conv.IteratorAlgorithm.few_channels: "cutlass::conv::IteratorAlgorithm::kFewChannels",
}
IteratorAlgorithmNames = {
cutlass_bindings.conv.IteratorAlgorithm.analytic: "analytic",
cutlass_bindings.conv.IteratorAlgorithm.optimized: "optimized",
cutlass_bindings.conv.IteratorAlgorithm.fixed_channels: "fixed_channels",
cutlass_bindings.conv.IteratorAlgorithm.few_channels: "few_channels",
}
class StrideSupport(enum.Enum):
Strided = enum_auto()
Unity = enum_auto()
StrideSupportTag = {
StrideSupport.Strided: "cutlass::conv::StrideSupport::kStrided",
StrideSupport.Unity: "cutlass::conv::StrideSupport::kUnity",
}
StrideSupportNames = {
StrideSupport.Strided: "",
StrideSupport.Unity: "unity_stride",
}
class ConvMode(enum.Enum):
CrossCorrelation = enum_auto()
Convolution = enum_auto()
ConvModeTag = {
ConvMode.CrossCorrelation: "cutlass::conv::Mode::kCrossCorrelation",
ConvMode.Convolution: "cutlass::conv::Mode::kConvolution",
}
class MathInstruction:
"""
Description of a the lowest-level matrix-multiply-accumulate operation to be used in a kernel
"""
def __init__(
self,
instruction_shape,
element_a,
element_b,
element_accumulator,
opcode_class=cutlass_bindings.OpClass.Simt,
math_operation=MathOperation.multiply_add,
):
"""
:param instruction_shape: size of the [M, N, K] dimensions of the instruction
:type instruction_shape: list or tuple
:param element_a: data type of operand A
:param element_b: data type of operand B
:param element_accumulator: data type used in accumulation
:param opcode_class: higher-level class of the instruction (e.g., SIMT or Tensor Core)
:type opcode_class: cutlass_bindings.OpClass
:param math_operation: the type of low-level operation to be performed (e.g., multiply accumulate)
:type math_operation: MathOperation
"""
self.instruction_shape = instruction_shape
self.element_a = element_a
self.element_b = element_b
self.element_accumulator = element_accumulator
self.opcode_class = opcode_class
self.math_operation = math_operation
class TileDescription:
"""
Description of a tile of computation to be performed in the kernel, encompassing threadblock, cluster, and warp shapes,
stage count, and math instruction specification
"""
def __init__(
self,
threadblock_shape,
stages,
warp_count,
math_instruction,
cluster_shape=[1, 1, 1],
kernel_schedule: KernelScheduleType = None
):
"""
:param threadblock_shape: shape of a threadblock tyle
:type threadblock_shape: list or tuple
:param stages: number of pipline stages in the operation. For SM90 kernels, this can be set to `None` and the maximum
number of stages that can be supported for an operation on a given architecture will be computed at a later time
:type stages: int or None
:param warp_count: number of warps in each [M, N, K] dimension of a threadblock tile
:type warp_count: list, tuple, or None
:param math_instruction: specification of the instruction type and shape to be performed and the types of its operands
:type math_instruction: MathInstruction
:param cluster_shape: number of threadblocks in the [X, Y, Z] dimensions of a threadblock cluster
:param kernel_schedule: type of kernel schedule to use (only available for SM90+)
:type kernel_schedule: cutlass.backend.KernelScheduleType
"""
self.threadblock_shape = threadblock_shape
self.cluster_shape = cluster_shape
self.kernel_schedule = kernel_schedule
self.stages: int = stages
self.math_instruction = math_instruction
# Number of warps along x, y, z directions
self.warp_count = warp_count
@property
def num_threads(self):
"""
Returns the number of threads in the threadblock
:return: number of threads in the threadblock
:rtype: int or None (if warp count is None)
"""
if self.warp_count is not None:
threads = 32
for cnt in self.warp_count:
threads *= cnt
return threads
return None
def procedural_name(self):
"""
Returns a name identifying the tile description
:return: name identifying the tile description
:rtype: int
"""
emit_stages = 0 if self.stages is None else self.stages
name = "%dx%dx%d_%dx%d_%dx%d" % (
self.cluster_shape[0],
self.cluster_shape[1],
self.cluster_shape[2],
self.threadblock_shape[0],
self.threadblock_shape[1],
self.threadblock_shape[2],
emit_stages
)
return name
def __str__(self):
"""
Returns a string with containing each of the tile description's values
:return: contents of tile description
:rtype: str
"""
schedule = KernelScheduleType.ScheduleAuto
if self.kernel_schedule is not None:
schedule = self.kernel_schedule
return f"""
{{
ClusterShape: {self.cluster_shape}
ThreadblockShape: {self.threadblock_shape}
WarpCount: {self.warp_count}
Stages: {self.stages if self.stages is not None else 'Auto'}
Kernel schedule: {schedule.name}
}}"""
class TensorDescription:
def __init__(self, element, layout, alignment=1,
complex_transform=cutlass_bindings.complex_transform.none):
self.element = element
self.layout = layout
self.alignment = min(128 // DataTypeSize[self.element], alignment)
self.complex_transform = complex_transform
def CalculateSmemUsagePerStage(operation):
"""
Returns the amount of shared memory in bytes consumed in a single stage of a kernel.
:param op: operation for which the maximum stages should be computed. If stages are
set via the `op.tile_description.stages` parameter, this setting is ignored
in the present calculation
:type op: cutlass.backend.Operation
:return: number of bytes of shared memory consumed by a single stage
:rtype: int
"""
m, n, k = operation.tile_description.threadblock_shape
if operation.operation_kind == OperationKind.Gemm:
stage_barrier_bytes = 32
return (
(DataTypeSize[operation.A.element] * m * k // 8)
+ (DataTypeSize[operation.B.element] * k * n // 8)
+ stage_barrier_bytes
)
else:
raise Exception("Unsupported operation kind {}.".format(operation.operation_kind))
def CalculateSmemUsage(operation):
"""
Returns the amount of shared memory in bytes consumed by a kernel.
:param op: operation for which the maximum stages should be computed. If stages are
set via the `op.tile_description.stages` parameter, this setting is ignored
in the present calculation
:type op: cutlass.backend.Operation
:return: int
"""
return operation.tile_description.stages * CalculateSmemUsagePerStage(operation)
class ApiVersion(enum.Enum):
"""
Differentiate between CUTLASS 2.x and 3.x API versions
"""
v2x = enum_auto()
v3x = enum_auto()
def api_version(arch, opclass, datatype):
"""
Returns whether the architecture, opcode class, and datatype in question require using CUTLASS 2.x
or 3.x for code emission.
:param arch: compute capability of device on which to run
:type arch: int
:param opclass: class of the operation being performed
:type opclass: cutlass_bindings.OpClass
:param datatype: data type to be used in operation (assumes that ElementA and ElementB are the same)
:return: API version to be used in code emission
:rtype: ApiVersion
"""
if (arch >= 90 and
opclass == cutlass_bindings.OpClass.TensorOp and
(datatype != cutlass_bindings.float64)):
return ApiVersion.v3x
else:
return ApiVersion.v2x

View File

@ -0,0 +1,74 @@
#################################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import numpy as np
import rmm
class PoolMemoryManager:
def __init__(self, init_pool_size: int, max_pool_size: int) -> None:
self.pool = rmm.mr.PoolMemoryResource(
rmm.mr.CudaMemoryResource(),
initial_pool_size=init_pool_size,
maximum_pool_size=max_pool_size
)
self.mr = rmm.mr.TrackingResourceAdaptor(self.pool)
rmm.mr.set_current_device_resource(self.mr)
def get_allocated_size(self):
return self.mr.get_allocated_bytes()
def pool_size(self):
return self.pool.pool_size()
def todevice(host_data, dtype=np.float32):
"""
Pass the host_data to device memory
"""
if isinstance(host_data, list):
return rmm.DeviceBuffer.to_device(np.array(host_data, dtype=dtype).tobytes())
elif isinstance(host_data, np.ndarray):
return rmm.DeviceBuffer.to_device(host_data.tobytes())
def device_mem_alloc(size):
return rmm.DeviceBuffer(size=size)
def align_size(size, alignment=256):
return ((size + alignment - 1) // alignment) * alignment
def get_allocated_size():
device_resource = rmm.mr.get_current_device_resource()
return device_resource.get_allocated_bytes()

View File

@ -0,0 +1,127 @@
################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
import ctypes
from cuda import __version__, cuda
from cutlass.backend.utils.device import device_cc
_version_splits = [int(x) for x in __version__.split("rc")[0].split(".")]
supports_cluster_launch = device_cc() >= 90 and (
_version_splits[0] > 11 or (_version_splits[0] == 11 and _version_splits[1] >= 8)
)
class LaunchConfiguration:
def __init__(self, grid=[1, 1, 1], block=[1, 1, 1], smem=0):
self.grid = grid
self.block = block
self.shared_memory_capacity = smem
class ExecutableOperation:
def __init__(self, operation):
self.operation = operation
self.module = None
self.kernel = None
def name(self):
return self.operation.procedural_name()
def emit(self):
return ""
def can_implement(self, configuration, arguments):
raise NotImplementedError()
def get_host_workspace_size(self, arguments):
raise NotImplementedError()
def get_device_workspace_size(self, arguments):
raise NotImplementedError()
def plan(self, arguments):
raise NotImplementedError()
def initialize(self, host_workspace, device_workspace, launch_config, arguments, stream=cuda.CUstream(0)):
raise NotImplementedError()
def run_with_clusters(self, launch_config, kernel_params, stream=cuda.CUstream(0)):
if hasattr(self.operation, "tile_description") and hasattr(self.operation.tile_description, "cluster_shape"):
attr = cuda.CUlaunchAttribute()
attr.value.clusterDim.x, attr.value.clusterDim.y, attr.value.clusterDim.z = self.operation.tile_description.cluster_shape
attr.id = cuda.CUstreamAttrID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
attrs = [attr]
# Allow for non-portable cluster sizes
err, = cuda.cuFuncSetAttribute(
self.kernel, cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1)
if err != cuda.CUresult.CUDA_SUCCESS:
return err
else:
attrs = []
config = cuda.CUlaunchConfig()
config.gridDimX, config.gridDimY, config.gridDimZ = launch_config.grid
config.blockDimX, config.blockDimY, config.blockDimZ = launch_config.block
config.blockDimZ = launch_config.block[2]
config.sharedMemBytes = launch_config.shared_memory_capacity
config.hStream = stream
config.attrs = attrs
config.numAttrs = len(attrs)
err, = cuda.cuLaunchKernelEx(
config, f=self.kernel, kernelParams=kernel_params, extra=0)
return err
def run_without_clusters(self, launch_config, kernel_params, stream=cuda.CUstream(0)):
err, = cuda.cuLaunchKernel(
self.kernel,
launch_config.grid[0], launch_config.grid[1], launch_config.grid[2],
launch_config.block[0], launch_config.block[1], launch_config.block[2],
launch_config.shared_memory_capacity,
stream,
kernel_params,
0)
return err
def run(self, host_workspace, device_workspace, launch_config, stream=cuda.CUstream(0)):
cArg = (ctypes.c_char * len(host_workspace)).from_buffer(host_workspace)
packed = (ctypes.c_void_p * 1)()
packed[0] = ctypes.addressof(cArg)
if supports_cluster_launch:
return self.run_with_clusters(launch_config, packed, stream)
else:
return self.run_without_clusters(launch_config, packed, stream)

View File

@ -0,0 +1,877 @@
################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
import ast
import ctypes
import inspect
import textwrap
from typing import Generic, TypeVar
from cuda import cuda, cudart
import numpy as np
from treelib import Tree
from cutlass.backend.epilogue import (
AccumulatorOp,
BinaryOp,
ColumnBroadcastOp,
ColumnReductionOp,
RowBroadcastOp,
RowReductionOp,
TensorInputOp,
TensorOutputOp,
UnaryOp,
)
from cutlass.backend.frontend import NumpyFrontend
from cutlass.backend.utils.software import SubstituteTemplate
import cutlass.backend as backend
################################################################################
# Type annotation for input arguments
################################################################################
Ttype = TypeVar("Ttype")
Dtype = TypeVar("Dtype")
class NDArray(np.ndarray, Generic[Ttype, Dtype]):
pass
################################################################################
# Operations
################################################################################
operators = {
ast.Add: "Add",
ast.Div: "Div",
ast.Eq: "Equal",
ast.Mult: "Mult",
}
################################################################################
# AST Node abstractions
################################################################################
class UnaryNode:
cnt = 0
# Concept: this is created by the BinOp Node in python ast
def __init__(
self,
element_accumulator,
element_compute,
elements_per_access,
node,
args,
) -> None:
if isinstance(node, BinOpNode):
self.op = node.op
elif isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
self.op = node.func.id
elif isinstance(node.func, ast.Attribute):
self.op = node.func.value.id
else:
raise TypeError
else:
raise TypeError
self.tag = "Unary" + self.op + str(UnaryNode.cnt)
self.id = self.op + str(UnaryNode.cnt)
self.args = args
UnaryNode.cnt += 1
self.type = "tensor"
self.epilogue_op = getattr(backend, self.op)(element_compute)
# data types
self.element_accumulator = element_accumulator
self.element_compute = element_compute
self.elements_per_access = elements_per_access
def get_epilogue_node(self, visitors):
self.epilogue_node = UnaryOp(
self.element_accumulator,
self.element_compute,
self.elements_per_access,
*visitors,
self.epilogue_op,
)
def get_argument(self, visitor_args, kwargs):
epilogue_ops = []
for arg in self.args:
try:
epilogue_ops.append(kwargs[arg])
except:
epilogue_ops.append(arg) # direct arguments like constant
self.argument = self.epilogue_node.argument_type(
self.epilogue_op.argument_type(*epilogue_ops),
*visitor_args,
)
class BinOpNode:
cnt = 0
# Concept: this is created by the BinOp Node in python ast
def __init__(
self,
element_accumulator,
element_compute,
elements_per_access,
node,
) -> None:
self.op = operators[type(node.op)]
self.tag = "Binary" + self.op + str(BinOpNode.cnt)
self.id = self.op + str(BinOpNode.cnt)
self.args = None
BinOpNode.cnt += 1
self.type = "tensor"
self.epilogue_op = getattr(backend, "Vector" + self.op)(element_compute)
# data types
self.element_accumulator = element_accumulator
self.element_compute = element_compute
self.elements_per_access = elements_per_access
def get_epilogue_node(self, visitors):
self.epilogue_node = BinaryOp(
self.element_accumulator,
self.element_compute,
self.elements_per_access,
*visitors,
self.epilogue_op,
)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type(
self.epilogue_op.argument_type(self.args),
*visitor_args,
)
class NameNode:
# Concept: this is created by the Name Node in python ast
def __init__(self, node) -> None:
try:
self.id = node.id
except:
self.id = node.targets[0].id
self.tag = self.id
class ScalarInputNode(NameNode):
# Concept: scalar
def __init__(self, node) -> None:
super().__init__(node)
self.tag = "Scalar:" + self.tag
self.type = "scalar"
class AccumulatorNode(NameNode):
# Concept: VisitorOpAccumulator
def __init__(
self,
element_accumulator,
elements_per_access,
node,
) -> None:
super().__init__(node)
self.tag = "Accum:" + self.tag
self.type = "tensor"
self.element_accumulator = element_accumulator
self.elements_per_access = elements_per_access
def get_epilogue_node(self, visitors):
self.epilogue_node = AccumulatorOp(
self.element_accumulator,
self.elements_per_access,
)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type()
class TensorInputNode(NameNode):
# Concept: VisitorOpTensorInput
def __init__(self, element_accumulator, node) -> None:
super().__init__(node)
self.tag = "TensorInput:" + self.tag
self.type = "tensor"
self.element_accumulator = element_accumulator
def get_epilogue_node(self, *args):
self.epilogue_node = TensorInputOp(self.element_accumulator)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type(
kwargs[self.id + "_ptr"],
kwargs["problem_size"][1],
kwargs["problem_size"][0] * kwargs["problem_size"][1],
)
class RowBroadcastNode(NameNode):
# Concept: VisitorOpRowBroadcast
def __init__(
self,
element_accumulator,
element_fragment,
node,
) -> None:
super().__init__(node)
#
self.tag = "RowBroadcast:" + self.tag
self.type = "tensor"
self.element_accumulator = element_accumulator
self.element_fragment = element_fragment
def get_epilogue_node(self, *args):
self.epilogue_node = RowBroadcastOp(
self.element_accumulator,
self.element_fragment,
)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type(
kwargs[self.id + "_ptr"],
kwargs["problem_size"][1],
)
class ColumnBroadcastNode(NameNode):
# Concept: VisitorOpColumnBroadcast
def __init__(
self,
element_accumulator,
element_fragment,
node,
) -> None:
super().__init__(node)
self.tag = "ColumnBroadcast:" + self.tag
self.type = "tensor"
self.element_accumulator = element_accumulator
self.element_fragment = element_fragment
def get_epilogue_node(self, *args):
self.epilogue_node = ColumnBroadcastOp(
self.element_accumulator,
self.element_fragment,
)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type(
kwargs[self.id + "_ptr"],
kwargs["problem_size"][0],
)
class TensorOutputNode(NameNode):
# Concept: VisitorOpTensorOutput
def __init__(self, element_accumulator, node) -> None:
super().__init__(node)
self.tag = "TensorOutput:" + self.tag
self.type = "tensor"
self.element_accumulator = element_accumulator
def get_epilogue_node(self, visitors):
self.epilogue_node = TensorOutputOp(self.element_accumulator, *visitors)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type(
kwargs[self.id + "_ptr"],
kwargs["problem_size"][1],
*visitor_args,
kwargs["problem_size"][0] * kwargs["problem_size"][1],
)
class RowReductionNode:
# Concept: RowReductionOp
def __init__(
self,
element_accumulator,
element_reduction,
element_reduction_accumulator,
id,
factor,
) -> None:
#
self.id = id
self.tag = "RowReduction:" + self.id
self.type = "tensor"
self.element_accumulator = element_accumulator
self.element_reduction = element_reduction
self.element_reduction_accumulator = element_reduction_accumulator
self.factor = factor
def get_epilogue_node(self, visitors):
self.epilogue_node = RowReductionOp(
self.element_accumulator,
self.element_reduction,
self.element_reduction_accumulator,
*visitors,
)
def get_batch_stride(self, problem_size):
return problem_size[0] * ((problem_size[1] + self.factor - 1) // self.factor)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type(
kwargs[self.id + "_ptr"],
*visitor_args,
self.get_batch_stride(kwargs["problem_size"]),
)
class ColumnReductionNode:
# Concept: ColumnReductionOp
def __init__(
self,
element_accumulator,
element_reduction,
element_reduction_accumulator,
id,
factor,
) -> None:
#
self.id = id
self.tag = "ColumnReduction:" + self.id
self.type = "tensor"
self.element_accumulator = element_accumulator
self.element_reduction = element_reduction
self.element_reduction_accumulator = element_reduction_accumulator
self.factor = factor
def get_epilogue_node(self, visitors):
self.epilogue_node = ColumnReductionOp(
self.element_accumulator,
self.element_reduction,
self.element_reduction_accumulator,
*visitors,
)
def get_batch_stride(self, problem_size):
return problem_size[1] * ((problem_size[0] + self.factor - 1) // self.factor)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type(
kwargs[self.id + "_ptr"],
*visitor_args,
self.get_batch_stride(kwargs["problem_size"]),
)
################################################################################
# Epilogue parser function
################################################################################
class EpilogueAST(ast.NodeVisitor):
def __init__(
self,
epilogue,
tile_description,
element_accumulator,
elements_per_access,
element_compute,
element_output,
) -> None:
#
self.tile_description = tile_description
self.element_accumulator = element_accumulator
self.elements_per_access = elements_per_access
self.element_compute = element_compute
self.element_output = element_output
self.epilogue = epilogue
self.source = textwrap.dedent(inspect.getsource(epilogue.__call__))
self.ast_tree = ast.parse(self.source)
self.epilogue_tree = Tree()
# print(ast.dump(self.ast_tree, indent=4)) # For Debug purpose
# input arguments
self.input_args = {}
# return nodes
self.returns = []
# reduction source nodes
self.reduction_source = {}
# stack used to keep the parent node id
self.stack = []
# visit the AST
self.visit(self.ast_tree)
# visit the name node
def visit_Name(self, node):
# append the return ids into self.returns
if self.stack[-1] == "return":
self.returns.append(node.id)
else:
# accum is produced from accumulator node
if node.id == "accum":
name_node = AccumulatorNode(
self.element_accumulator,
self.elements_per_access,
node,
)
else:
# for input nodes
if node.id in self.input_args.keys():
type = self.input_args[node.id][0]
if type == "tensor":
name_node = TensorInputNode(
self.element_accumulator,
node,
)
elif type == "row":
name_node = RowBroadcastNode(
self.element_accumulator,
self.element_compute,
node,
)
elif type == "column":
name_node = ColumnBroadcastNode(
self.element_accumulator,
self.element_compute,
node,
)
elif type == "scalar":
name_node = ScalarInputNode(node)
else:
raise ValueError(type)
# for output nodes
else:
name_node = TensorOutputNode(
self.element_accumulator,
node,
)
self.epilogue_tree.create_node(
name_node.tag,
name_node.id,
data=name_node,
parent=self.stack[-1],
)
def visit_Assign(self, node):
pre_assign_node = self.epilogue_tree.get_node(node.targets[0].id)
if pre_assign_node is None:
# The assign is to a root node
# skip the reduction nodes
if isinstance(node.value, ast.Call):
if isinstance(node.value.func, ast.Name):
func_type = node.value.func.id
elif isinstance(node.value.func, ast.Attribute):
func_type = node.value.func.value.id
else:
raise TypeError
if func_type == "reduction_op":
self.reduction_source[node.value.args[0].id] = [
node.value.args[1].value,
node.value.args[2].value,
node.targets[0].id,
]
return
name_node = TensorOutputNode(self.element_accumulator, node)
self.epilogue_tree.create_node(
name_node.tag,
name_node.id,
data=name_node,
)
self.stack.append(name_node.id)
else:
if (
node.targets[0].id in self.returns
or node.targets[0].id in self.reduction_source.keys()
):
self.stack.append(node.targets[0].id)
else:
self.stack.append(
pre_assign_node.predecessor(self.epilogue_tree.identifier)
)
self.epilogue_tree.remove_node(node.targets[0].id)
# get child tag
self.visit(node.value)
self.stack.pop()
def visit_Call(self, node):
if isinstance(node.func, ast.Name):
func_type = node.func.id
elif isinstance(node.func, ast.Attribute):
func_type = node.func.value.id
else:
raise TypeError
if func_type == "reduction_op":
self.visit(node.args[0])
else:
arg_list = []
for idx, arg in enumerate(node.args):
if idx == 0:
continue
if isinstance(arg, ast.Constant):
arg_list.append(arg.value)
elif isinstance(arg, ast.Name):
arg_list.append(arg.id)
else:
raise TypeError
unary_node = UnaryNode(
self.element_accumulator,
self.element_compute,
self.elements_per_access,
node,
arg_list,
)
self.epilogue_tree.create_node(
unary_node.tag,
unary_node.id,
parent=self.stack[-1],
data=unary_node,
)
self.stack.append(unary_node.id)
self.visit(node.args[0])
self.stack.pop()
def visit_BinOp(self, node):
binop = BinOpNode(
self.element_accumulator,
self.element_compute,
self.elements_per_access,
node,
)
self.epilogue_tree.create_node(
binop.tag,
binop.id,
data=binop,
parent=self.stack[-1],
)
self.stack.append(binop.id)
self.visit(node.left)
self.visit(node.right)
self.stack.pop()
def visit_Return(self, node):
self.stack.append("return")
self.visit(node.value)
self.stack.pop()
# # A function definition
def visit_FunctionDef(self, node: ast.FunctionDef):
# visit args
for arg in node.args.args:
if arg.arg == "self":
continue
if isinstance(arg.annotation, ast.Constant):
self.input_args[arg.arg] = [
arg.annotation.value,
]
# visit the assign in the reverse order
for idx in range(len(node.body)):
self.visit(node.body[-1 - idx])
#
# Tree optimization pass
#
# pass 1: lower Binary to Unary
def pass_binary_2_unary(self, tree, nid):
node = tree.get_node(nid)
if isinstance(node.data, BinOpNode):
lhs_node = tree.get_node(node.successors(tree.identifier)[0])
left_type = lhs_node.data.type
rhs_node = tree.get_node(node.successors(tree.identifier)[1])
right_type = rhs_node.data.type
if left_type == "scalar" and right_type == "tensor":
node.data = UnaryNode(
self.element_accumulator,
self.element_compute,
self.elements_per_access,
node.data,
[
lhs_node.data.id,
],
)
node.tag = node.data.tag
tree.remove_node(lhs_node.data.id)
self.pass_binary_2_unary(tree, rhs_node.data.id)
elif left_type == "tensor" and right_type == "scalar":
node.data = UnaryNode(
self.element_accumulator,
self.element_compute,
self.elements_per_access,
node.data,
[
rhs_node.id,
],
)
node.tag = node.data.tag
tree.remove_node(rhs_node.data.id)
self.pass_binary_2_unary(tree, lhs_node.data.id)
else:
self.pass_binary_2_unary(tree, lhs_node.data.id)
self.pass_binary_2_unary(tree, rhs_node.data.id)
else:
for child in node.successors(tree.identifier):
self.pass_binary_2_unary(tree, child)
# pass 2: inject reduction nodes
def pass_inject_reduction(self, tree, nid):
node = tree.get_node(nid)
if isinstance(node.data, TensorOutputNode):
if node.data.id in self.reduction_source.keys():
direction = self.reduction_source[node.data.id][0]
target = self.reduction_source[node.data.id][-1]
if direction == "row":
reduction_node = RowReductionNode(
self.element_accumulator,
self.element_output,
self.element_accumulator,
target,
self.tile_description.threadblock_shape[1],
)
elif direction == "column":
reduction_node = ColumnReductionNode(
self.element_accumulator,
self.element_output,
self.element_accumulator,
target,
self.tile_description.threadblock_shape[0],
)
else:
raise ValueError(direction)
child_nid = node.successors(tree.identifier)[0]
# if this output node is injected only for reduction
if node.data.id not in self.returns:
# get reduction config from disc
node.data = reduction_node
node.tag = reduction_node.tag
self.pass_inject_reduction(tree, child_nid)
# if this output node is also a tensor output, inject reduction as its children
else:
# get child node
tree.create_node(
reduction_node.tag,
reduction_node.id,
data=reduction_node,
parent=node.data.id,
)
tree.move_node(
child_nid,
reduction_node.id,
)
child = tree.get_node(child_nid)
for grand_child in child.successors(tree.identifier):
self.pass_inject_reduction(tree, grand_child)
else:
for child in node.successors(tree.identifier):
self.pass_inject_reduction(tree, child)
else:
for child in node.successors(tree.identifier):
self.pass_inject_reduction(tree, child)
def pass_inject_epilogue_op(self, tree, nid):
node = tree.get_node(nid)
visitors = []
for child in node.successors(tree.identifier):
visitors.append(self.pass_inject_epilogue_op(tree, child))
node.data.get_epilogue_node(visitors)
return node.data.epilogue_node
def get_arguments(self, tree, nid, kwargs):
node = tree.get_node(nid)
visitor_args = []
for child in node.successors(tree.identifier):
visitor_args.append(self.get_arguments(tree, child, kwargs))
node.data.get_argument(visitor_args, kwargs)
return node.data.argument
class EpilogueVisitTree:
KernelTemplate = """
${visitor}
using ${operation_name}_EpilogueVisitor = cutlass::epilogue::threadblock::EpilogueVisitorGeneric<${visitor_name}>;
"""
def __init__(
self,
elementwise_functor,
tile_description,
element_accumulator,
elements_per_access,
element_compute,
element_output,
) -> None:
#
# data types
self.tile_description = tile_description
self.element_accumulator = element_accumulator
self.elements_per_access = elements_per_access
self.element_compute = element_compute
self.element_output = element_output
self.elementwise_functor = elementwise_functor
pass
def initialize(self):
function = EpilogueAST(
self,
self.tile_description,
self.element_accumulator,
self.elements_per_access,
self.element_compute,
self.element_output,
)
#
tree = function.epilogue_tree
self.tree = tree
function.pass_binary_2_unary(self.tree, self.tree.root)
function.pass_inject_reduction(self.tree, self.tree.root)
function.pass_inject_epilogue_op(self.tree, self.tree.root)
visitor = self.tree.get_node(self.tree.root).data.epilogue_node
self.visitor = visitor
class _Argument(ctypes.Structure):
_fields_ = [
(
"visitor_arg",
visitor.argument_type,
)
]
def __init__(self, **kwargs) -> None:
# process input args
_kwargs = {}
for input_key in function.input_args.keys():
if input_key == "accum":
continue
if function.input_args[input_key][0] == "scalar":
continue
# tensor input
else:
setattr(
self,
"buffer_tensor_" + input_key,
NumpyFrontend.argument(
kwargs[input_key],
False,
),
)
setattr(
self,
input_key + "_ptr",
int(
getattr(
self,
"buffer_tensor_" + input_key,
).ptr
),
)
_kwargs[input_key + "_ptr"] = getattr(
self,
input_key + "_ptr",
)
# process the return args
for ret in function.returns:
setattr(
self,
"buffer_tensor_" + ret,
NumpyFrontend.argument(kwargs[ret], True),
)
setattr(
self,
ret + "_ptr",
int(
getattr(
self,
"buffer_tensor_" + ret,
).ptr
),
)
_kwargs[ret + "_ptr"] = getattr(self, ret + "_ptr")
setattr(
self,
"host_tensor_" + ret,
kwargs[ret],
)
_kwargs.update(kwargs)
function.get_arguments(tree, tree.root, _kwargs)
self.visitor_arg = tree.get_node(tree.root).data.argument
def sync(self, stream_sync=True):
if stream_sync:
(err,) = cudart.cudaDeviceSynchronize()
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
for ret in function.returns:
(err,) = cuda.cuMemcpyDtoH(
getattr(
self,
"host_tensor_" + ret,
),
cuda.CUdeviceptr(getattr(self, ret + "_ptr")),
getattr(
self,
"host_tensor_" + ret,
).size
* getattr(
self,
"host_tensor_" + ret,
).itemsize,
)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
pass
self.epilogue_type = _Argument
def emit(self, operation):
values = {
"visitor": self.visitor.emit(operation),
"operation_name": operation.procedural_name(),
"visitor_name": self.visitor.instance_name,
}
return SubstituteTemplate(self.KernelTemplate, values)

View File

@ -0,0 +1,462 @@
################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
from typing import Union
import ctypes
from cuda import cuda, cudart
import cutlass_bindings
import numpy as np
from cutlass.backend.c_types import MatrixCoord_, TensorRef2D_, get_reduction_params
from cutlass.backend.frontend import NumpyFrontend, TorchFrontend
from cutlass.backend.library import (
DataTypeNames,
DataTypeSize,
DataTypeTag,
TensorDescription,
)
from cutlass.backend.operation import ExecutableOperation, LaunchConfiguration
from cutlass.backend.utils.software import CheckPackages, SubstituteTemplate
if CheckPackages().check_torch():
import torch
class ReductionOperation:
pass
class ReductionArguments:
"""
Arguments of reduction
"""
def __init__(
self,
operation: ReductionOperation,
problem_size: "list[int]",
partitions: int,
workspace: cuda.CUdeviceptr,
destination: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]",
source: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]",
**kwargs,
) -> None:
# tensor_C can be interpreted as the bias with bias=True in keyword args
if "bias" in kwargs.keys():
self.bias = kwargs["bias"]
else:
# by default, tensor_C is not bias
self.bias = False
self.operation = operation
#: pointer to the workspace
self.ptr_workspace = workspace
#: number of split-k partitions
self.partitions = partitions
if isinstance(destination, np.ndarray):
self.host_D = destination
self.destination_buffer = NumpyFrontend.argument(destination, True)
self.source_buffer = NumpyFrontend.argument(source, False)
self.ptr_destination = cuda.CUdeviceptr(self.destination_buffer.ptr)
self.ptr_source = cuda.CUdeviceptr(self.source_buffer.ptr)
elif CheckPackages().check_torch() and isinstance(destination, torch.Tensor):
self.ptr_destination = TorchFrontend.argument(destination)
self.ptr_source = TorchFrontend.argument(source)
elif isinstance(destination, cuda.CUdeviceptr):
self.ptr_destination = destination
self.ptr_source = source
else:
raise TypeError("unknown Type")
self.problem_size = MatrixCoord_(problem_size[0], problem_size[1])
self.partition_stride = (
problem_size[0] * problem_size[1] * DataTypeSize[operation.C.element] // 8
)
if "output_op" in kwargs.keys():
self.output_op = kwargs["output_op"]
else:
self.output_op = self.operation.epilogue_type(1.0, 0.0)
# get arguments
self.get_arguments()
@staticmethod
def get_tensor_ref(
extent: "tuple[int]",
device_ptr: cuda.CUdeviceptr,
layout: cutlass_bindings.layout,
):
if layout == cutlass_bindings.RowMajor:
return TensorRef2D_(int(device_ptr), extent[1])
else:
raise ValueError("unknown layout type")
def get_arguments(self):
ref_workspace = ReductionArguments.get_tensor_ref(
extent=[
self.problem_size.row,
self.problem_size.column,
],
device_ptr=self.ptr_workspace,
layout=cutlass_bindings.RowMajor,
)
if self.bias:
ref_source = ReductionArguments.get_tensor_ref(
extent=[0, 0],
device_ptr=self.ptr_source,
layout=cutlass_bindings.RowMajor,
)
else:
ref_source = ReductionArguments.get_tensor_ref(
extent=[
self.problem_size.row,
self.problem_size.column,
],
device_ptr=self.ptr_source,
layout=cutlass_bindings.RowMajor,
)
ref_destination = ReductionArguments.get_tensor_ref(
extent=[
self.problem_size.row,
self.problem_size.column,
],
device_ptr=self.ptr_destination,
layout=cutlass_bindings.RowMajor,
)
self.c_arguments = self.operation.argument_type(
self.problem_size,
self.partitions,
self.partition_stride,
ref_workspace,
ref_destination,
ref_source,
self.output_op,
)
params_ = self.operation.rt_module.get_args(ctypes.byref(self.c_arguments))
self.host_workspace = bytearray(params_.contents)
def sync(self):
(err,) = cudart.cudaDeviceSynchronize()
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
if hasattr(self, "host_D"):
(err,) = cuda.cuMemcpyDtoH(
self.host_D,
self.ptr_destination,
self.host_D.size * self.host_D.itemsize,
)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
def free(self):
if hasattr(self, "destination_buffer"):
del self.destination_buffer
if hasattr(self, "source_buffer"):
del self.source_buffer
class ReductionRT(ExecutableOperation):
"""
ReductionRT manages the CUTLASS runtime components for reduction
"""
KernelTemplate = r"""
extern "C"
__global__ void
${operation_name}(${operation_name}${operation_suffix}::Params params) {
// Dynamic shared memory base pointer
extern __shared__ int SharedStorageBase[];
// Declare pointer to dynamic shared memory.
${operation_name}${operation_suffix}::SharedStorage *shared_storage =
reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase);
${operation_name}${operation_suffix} op;
op(params, *shared_storage);
}
"""
HostTemplate = r"""
extern "C" {
// Get the size of params in bytes
int ${operation_name}_get_param_size(){
return sizeof(${operation_name}${operation_suffix}::Params);
}
// Get the size of dynamic shared memory in bytes
int ${operation_name}_shared_memory_size() {
return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
}
// Get the params as byte array
char* ${operation_name}_get_params(${operation_name}${operation_suffix}::Params* params){
char *bytes = ((char*)(params));
char *output = new char[sizeof(${operation_name}${operation_suffix}::Params)];
for (unsigned int i = 0; i < sizeof(${operation_name}${operation_suffix}::Params); i ++)
output[i] = bytes[i];
return output;
}
}
"""
def __init__(self, operation: ReductionOperation):
super().__init__(operation)
self.operation: ReductionOperation = operation
self.emitter = EmitReductionInstance("_type")
self.elements_per_access = self.operation.count
(
self.argument_type,
self.epilogue_type,
) = get_reduction_params(operation.epilogue_functor)
self.argtype = [ctypes.POINTER(self.argument_type)]
def emit(self):
return self.emitter.emit(self.operation)
def plan(self, arguments: ReductionArguments):
block_shape = [
self.operation.shape.column() // self.elements_per_access,
self.operation.shape.row(),
1,
]
grid_shape = [
(arguments.problem_size.row + self.operation.shape.row() - 1)
// self.operation.shape.row(),
(arguments.problem_size.column + self.operation.shape.column() - 1)
// self.operation.shape.column(),
1,
]
return LaunchConfiguration(
grid_shape,
block_shape,
self.shared_memory_capacity,
)
def initialize(self):
(err,) = cuda.cuFuncSetAttribute(
self.kernel,
attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
value=self.shared_memory_capacity,
)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("Cuda Error: {}".format(err))
class ReductionOperation:
"""
CUTLASS Reduction Operation
shape: shape of CTA
outputop: output operator
r
"""
def __init__(
self,
shape: cutlass_bindings.MatrixCoord,
C: TensorDescription,
element_accumulator,
element_workspace=None,
element_compute=None,
epilogue_functor=None,
count: int = 1,
partitions_per_stage: int = 4,
) -> None:
"""Constructor"""
self.shape = shape
#: epilogue functor (default: LinearCombination)
self.epilogue_functor = epilogue_functor
#: datatype of accumulator
self.element_accumulator = element_accumulator
if element_workspace is None:
#: datatype of workspace
self.element_workspace = element_accumulator
else:
#: datatype of workspace
self.element_workspace = element_workspace
if element_compute is None:
#: datatype of workspace
self.element_compute = element_accumulator
else:
#: datatype of workspace
self.element_compute = element_compute
#: datatype of output
self.element_output = C.element
#: operand C
self.C: TensorDescription = C
#: reduce op processing size
self.count: int = count
#: number of partitions to reduce per stage
self.partitions_per_stage: int = partitions_per_stage
self.rt_module: ReductionRT = ReductionRT(self)
self.argument_type = self.rt_module.argument_type
self.epilogue_type = self.rt_module.epilogue_type
#
def extended_name(self):
extend_name = "${element_workspace}_${element_accumulator}_${element_compute}_${element_output}"
return SubstituteTemplate(
extend_name,
{
"element_workspace": DataTypeNames[self.element_workspace],
"element_accumulator": DataTypeNames[self.element_accumulator],
"element_compute": DataTypeNames[self.element_compute],
"element_output": DataTypeNames[self.element_output],
},
)
#
def configuration_name(self):
"""The full procedural name indicates architecture, extended name, tile size"""
configuration_name = "cutlass_reduce_split_k_${extended_name}_${threadblock}"
threadblock = "%dx%d" % (
self.shape.row(),
self.shape.column(),
)
return SubstituteTemplate(
configuration_name,
{
"extended_name": self.extended_name(),
"threadblock": threadblock,
},
)
#
def procedural_name(self):
"""The full procedural name indicates architeture, extended name, tile size"""
return self.configuration_name()
def run(self, arguments: ReductionArguments) -> cuda.CUresult:
"""
Configure and launch the cuda kernel with input arguments
"""
# get launch configuration
launch_config = self.rt_module.plan(arguments)
# get the host and device workspace
host_workspace = arguments.host_workspace
device_workspace = None
# launch the kernel
err = self.rt_module.run(
host_workspace,
device_workspace,
launch_config,
)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
return err
class EmitReductionInstance:
def __init__(self, operation_suffix="") -> None:
self.operation_suffix = operation_suffix
self.includes = [
"cutlass/cutlass.h",
"cutlass/numeric_types.h",
"cutlass/arch/arch.h",
"cutlass/arch/mma.h",
"cutlass/layout/matrix.h",
"cutlass/gemm/device/gemm.h",
"cutlass/gemm/device/gemm_universal_adapter.h",
"cutlass/gemm/kernel/default_gemm_universal.h",
"cutlass/reduction/kernel/reduce_split_k.h",
"cutlass/reduction/thread/reduction_operators.h",
]
self.template = """
// Reduction kernel instance
using ${operation_name}_base =
typename cutlass::reduction::kernel::ReduceSplitK<
cutlass::MatrixShape<${shape_row}, ${shape_column}>,
${epilogue_functor},
cutlass::reduction::thread::ReduceAdd<
${element_accumulator},
${element_output},
${count}>,
${partition_per_stage}>;
struct ${operation_name}${operation_suffix}:
public ${operation_name}_base { };
"""
def emit(self, operation: ReductionOperation):
epilogue_vector_length = int(
min(
operation.C.alignment * DataTypeSize[operation.C.element],
128,
)
/ DataTypeSize[operation.C.element]
)
values = {
"operation_name": operation.configuration_name(),
"operation_suffix": self.operation_suffix,
"shape_row": str(operation.shape.row()),
"shape_column": str(operation.shape.column()),
"epilogue_functor": operation.epilogue_functor.emit(),
"element_output": DataTypeTag[operation.element_output],
"epilogue_vector_length": str(epilogue_vector_length),
"element_accumulator": DataTypeTag[operation.element_accumulator],
"element_compute": DataTypeTag[operation.element_compute],
"element_workspace": DataTypeTag[operation.element_workspace],
"count": str(operation.count),
"partition_per_stage": str(operation.partitions_per_stage),
}
return SubstituteTemplate(self.template, values)

View File

@ -0,0 +1,69 @@
################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
from cuda import cuda
import cutlass_bindings
import numpy as np
from cutlass.backend.utils.software import CheckPackages
cupy_available = CheckPackages().check_cupy()
if cupy_available:
import cupy as cp
torch_available = CheckPackages().check_torch()
if torch_available:
import torch
class TensorRef:
"""
Python Wrapper for cutlass_bindings.TensorRef
"""
def __init__(self, tensor, dtype, layout) -> None:
if isinstance(tensor, np.ndarray):
ptr = cuda.CUdeviceptr(tensor.__array_interface__["data"][0])
elif torch_available and isinstance(tensor, torch.Tensor):
ptr = cuda.CUdeviceptr(tensor.data_ptr())
elif torch_available and isinstance(tensor, cp.ndarray):
ptr = cuda.CUdeviceptr(int(tensor.data.ptr))
elif isinstance(tensor, cuda.CUdeviceptr):
ptr = tensor
elif isinstance(tensor, int):
ptr = cuda.CUdeviceptr(tensor)
else:
raise NotImplementedError(tensor)
# the dtype(0) is used to overload between different data types
# with the same layout
self.tensor_ref = cutlass_bindings.get_tensor_ref(int(ptr), dtype(0), layout)

View File

@ -0,0 +1,36 @@
################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
from cutlass.backend.test.conv2d_testbed import *
from cutlass.backend.test.gemm_grouped_testbed import *
from cutlass.backend.test.gemm_testbed import *
from cutlass.backend.test.profiler import *

View File

@ -0,0 +1,783 @@
#################################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import re
import subprocess
from time import sleep
from bfloat16 import bfloat16
import cutlass_bindings
import numpy as np
from cutlass.backend.compiler import ArtifactManager
from cutlass.backend.conv2d_operation import Conv2dArguments, Conv2dOperation
from cutlass.backend.library import DataTypeSize, ShortDataTypeNames, StrideSupport
from cutlass.backend.memory_manager import get_allocated_size
from cutlass.backend.reduction_operation import ReductionArguments, ReductionOperation
from cutlass.backend.test.profiler import GpuTimer
from cutlass.backend.utils.software import SubstituteTemplate
def getTensorRef(tensor, tensor_layout, conv_kind, problem_size, operand):
ptr = tensor.__array_interface__["data"][0]
if operand == "a":
tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_a_extent(
conv_kind, problem_size
)
elif operand == "b":
tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_b_extent(
conv_kind, problem_size
)
elif operand in ["c", "d"]:
tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_c_extent(
conv_kind, problem_size
)
else:
raise ValueError("unknown operand: " + operand)
layout = tensor_layout.packed(tensor_coord)
if tensor.dtype == np.float64:
return cutlass_bindings.TensorRefF64NHWC(ptr, layout)
elif tensor.dtype == np.float32:
return cutlass_bindings.TensorRefF32NHWC(ptr, layout)
elif tensor.dtype == np.float16:
return cutlass_bindings.TensorRefF16NHWC(ptr, layout)
if tensor.dtype == bfloat16:
return cutlass_bindings.TensorRefBF16NHWC(ptr, layout)
elif tensor.dtype == np.int32:
return cutlass_bindings.TensorRefS32NHWC(ptr, layout)
elif tensor.dtype == np.int8:
if tensor_layout == cutlass_bindings.TensorNC32HW32:
return cutlass_bindings.TensorRefS8NC32HW32(ptr, layout)
elif tensor_layout == cutlass_bindings.TensorC32RSK32:
return cutlass_bindings.TensorRefS8C32RSK32(ptr, layout)
else:
return cutlass_bindings.TensorRefS8NHWC(ptr, layout)
else:
raise ValueError("unsupported data type")
def getTensorView(tensor, tensor_layout, conv_kind, problem_size, operand):
tensor_ref = getTensorRef(tensor, tensor_layout, conv_kind, problem_size, operand)
if operand == "a":
tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_a_extent(
conv_kind, problem_size
)
elif operand == "b":
tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_b_extent(
conv_kind, problem_size
)
elif operand in ["c", "d"]:
tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_c_extent(
conv_kind, problem_size
)
else:
raise ValueError("unknown operand: " + operand)
if tensor.dtype == np.float64:
return cutlass_bindings.TensorViewF64NHWC(tensor_ref, tensor_coord)
elif tensor.dtype == np.float32:
return cutlass_bindings.TensorViewF32NHWC(tensor_ref, tensor_coord)
elif tensor.dtype == np.float16:
return cutlass_bindings.TensorViewF16NHWC(tensor_ref, tensor_coord)
elif tensor.dtype == bfloat16:
return cutlass_bindings.TensorViewBF16NHWC(tensor_ref, tensor_coord)
elif tensor.dtype == np.int32:
return cutlass_bindings.TensorViewS32NHWC(tensor_ref, tensor_coord)
elif tensor.dtype == np.int8:
if tensor_layout == cutlass_bindings.TensorNC32HW32:
return cutlass_bindings.TensorViewS8NC32HW32(tensor_ref, tensor_coord)
elif tensor_layout == cutlass_bindings.TensorC32RSK32:
return cutlass_bindings.TensorViewS8C32RSK32(tensor_ref, tensor_coord)
else:
return cutlass_bindings.TensorViewS8NHWC(tensor_ref, tensor_coord)
else:
raise ValueError("unsupported data type")
# @typechecked
class Conv2dLauncher:
"""
Launcher that runs the operation on given problem size
"""
def __init__(
self,
operation: "Conv2dOperation",
seed: int = 2080,
interleaved=False,
verification=True,
profiling=False,
warmup_iterations=500,
iterations=500,
**kwargs,
) -> None:
self.enable_cached_results = True
self.interleaved = interleaved
# create the reduction kernel
self.reduction_operation = ReductionOperation(
shape=cutlass_bindings.MatrixCoord(4, 32 * operation.C.alignment),
C=operation.C,
element_accumulator=operation.tile_description.math_instruction.element_accumulator,
element_compute=operation.epilogue_functor.element_epilogue,
epilogue_functor=operation.epilogue_functor,
count=operation.C.alignment,
)
#: verify the output result
self.verification = verification
#: profile the kernel's runtime
self.profiling = profiling
self.timer = GpuTimer()
self.warmup_iterations = warmup_iterations
self.iterations = iterations
if "sleep" in kwargs.keys():
self.sleep_time = kwargs["sleep"]
else:
self.sleep_time = 0
#
# Compile the operator
#
ArtifactManager().add_module([operation, self.reduction_operation])
self.operation = operation
self.dtype_A = Conv2dLauncher.numpy_type(operation.A.element)
self.layout_A = operation.A.layout
self.dtype_B = Conv2dLauncher.numpy_type(operation.B.element)
self.layout_B = operation.B.layout
self.dtype_C = Conv2dLauncher.numpy_type(operation.C.element)
self.layout_C = operation.C.layout
self.dtype_D = Conv2dLauncher.numpy_type(operation.C.element)
self.layout_D = operation.C.layout
accumulator_size = DataTypeSize[
operation.tile_description.math_instruction.element_accumulator
]
element_size = DataTypeSize[operation.A.element]
if element_size <= 8:
self.scope = 1
elif element_size == 16:
if accumulator_size <= 16:
self.scope = 2
else:
self.scope = 4
else:
self.scope = 7
# Seed
self.seed = seed
self.conv_kind = operation.conv_kind
#
# Get the host reference function
#
self.element_compute = operation.epilogue_functor.element_epilogue
self.host_conv2d = cutlass_bindings.test.conv.host.conv2d
self.timer = GpuTimer()
@staticmethod
def numpy_type(type):
if type == cutlass_bindings.float64:
return np.float64
elif type == cutlass_bindings.float32:
return np.float32
elif type == cutlass_bindings.float16:
return np.float16
elif type == cutlass_bindings.bfloat16:
return bfloat16
elif type == cutlass_bindings.int32:
return np.int32
elif type == cutlass_bindings.int8:
return np.int8
else:
raise ValueError("unsupported type: %s" % ShortDataTypeNames[type])
def print_problem_size(self, p, split_k_mode=1):
print(
"nhwc_%dx%dx%dx%d_krsc_%dx%dx%dx%d_padding_%dx%d_stride_%dx%d_dilation_%dx%d_splitkslices_%d_splitkmode_%d"
% (
p.N,
p.H,
p.W,
p.C,
p.K,
p.R,
p.S,
p.C,
p.pad_h,
p.pad_w,
p.stride_h,
p.stride_w,
p.dilation_h,
p.dilation_w,
p.split_k_slices,
split_k_mode,
)
)
def uniform_init(self, size, dtype):
if dtype in [np.float32, np.float16, bfloat16, np.float64]:
return np.ceil(
np.random.uniform(
low=-self.scope - 0.5, high=self.scope - 0.5, size=size
).astype(dtype)
)
else:
return np.random.uniform(
low=-self.scope - 1, high=self.scope + 1, size=size
).astype(dtype)
def eq_gemm_size(self, problem_size):
n = problem_size.N
p = problem_size.P
q = problem_size.Q
k = problem_size.K
r = problem_size.R
s = problem_size.S
c = problem_size.C
h = problem_size.H
w = problem_size.W
if self.conv_kind == cutlass_bindings.conv.Operator.fprop:
return cutlass_bindings.gemm.GemmCoord(n * p * q, k, r * s * c)
elif self.conv_kind == cutlass_bindings.conv.Operator.dgrad:
return cutlass_bindings.gemm.GemmCoord(n * h * w, c, k * r * s)
else:
return cutlass_bindings.gemm.GemmCoord(k, r * s * c, n * p * q)
def bytes(self, problem_size, alpha, beta):
mnk = self.eq_gemm_size(problem_size)
bytes_ = (
(DataTypeSize[self.operation.A.element] * mnk.m() // 8) * mnk.k()
+ (DataTypeSize[self.operation.B.element] * mnk.n() // 8) * mnk.k()
+ (DataTypeSize[self.operation.C.element] * mnk.m() // 8) * mnk.n()
)
if beta != 0:
bytes_ += (DataTypeSize[self.operation.C.element] * mnk.m() // 8) * mnk.n()
return bytes_
def flops(self, problem_size):
mnk = self.eq_gemm_size(problem_size)
flops_mainloop_ = mnk.m() * mnk.n() * mnk.k() * 2
flops_epilogue_ = mnk.m() * mnk.n() * 2
# Adjust mainloop flop for dgrad stride
if self.conv_kind == cutlass_bindings.conv.Operator.dgrad:
flops_mainloop_ = flops_mainloop_ // (
problem_size.stride_h * problem_size.stride_w
)
flops_total_ = flops_mainloop_ + flops_epilogue_
return flops_total_
def host_reference(self, problem_size, tensor_A, tensor_B, tensor_C, alpha, beta):
if self.element_compute == cutlass_bindings.float16:
alpha = cutlass_bindings.float16(alpha)
beta = cutlass_bindings.float16(beta)
elif self.element_compute == cutlass_bindings.int32:
alpha = int(alpha)
beta = int(beta)
else:
alpha = alpha
beta = beta
# if cached result is loaded
cached_result_loaded = False
if self.enable_cached_results:
# get problem key
cached_test_key = cutlass_bindings.test.conv.host.CreateCachedConv2dTestKey(
self.conv_kind,
problem_size,
alpha,
beta,
getTensorView(
tensor_A, self.layout_A, self.conv_kind, problem_size, "a"
),
getTensorView(
tensor_B, self.layout_B, self.conv_kind, problem_size, "b"
),
getTensorView(
tensor_C, self.layout_C, self.conv_kind, problem_size, "c"
),
)
cached_test_result = cutlass_bindings.test.conv.host.CachedTestResult()
conv2d_result_cache_name = "cached_results_SM%d_%d.txt" % (
self.operation.arch,
self.seed,
)
cached_results = cutlass_bindings.test.conv.host.CachedTestResultListing(
conv2d_result_cache_name
)
# CachedTestResultListing cached_results(conv2d_result_cache_name);
cached = cached_results.find(cached_test_key)
cached_result_loaded = cached[0]
if cached_result_loaded:
cached_test_result = cached[1]
if not cached_result_loaded:
# compute the conv2d on host
tensor_D_ref = np.ones_like(tensor_C)
tensor_ref_A = getTensorRef(
tensor_A, self.layout_A, self.conv_kind, problem_size, "a"
)
tensor_ref_B = getTensorRef(
tensor_B, self.layout_B, self.conv_kind, problem_size, "b"
)
tensor_ref_C = getTensorRef(
tensor_C, self.layout_C, self.conv_kind, problem_size, "c"
)
tensor_ref_D_ref = getTensorRef(
tensor_D_ref, self.layout_D, self.conv_kind, problem_size, "d"
)
self.host_conv2d(
self.conv_kind,
problem_size,
tensor_ref_A,
tensor_ref_B,
tensor_ref_C,
tensor_ref_D_ref,
alpha,
beta,
)
tensor_view_D_ref = getTensorView(
tensor_D_ref, self.layout_D, self.conv_kind, problem_size, "d"
)
if self.enable_cached_results:
cached_test_result.D = cutlass_bindings.test.conv.host.TensorHash(
tensor_view_D_ref
)
cached_results = (
cutlass_bindings.test.conv.host.CachedTestResultListing(
conv2d_result_cache_name
)
)
cached_results.append(cached_test_key, cached_test_result)
cached_results.write(conv2d_result_cache_name)
else:
return tensor_D_ref
return cached_test_result.D
def equal(self, tensor_D, tensor_D_ref, problem_size):
if self.enable_cached_results:
tensor_view_D = getTensorView(
tensor_D, self.layout_D, self.conv_kind, problem_size, "d"
)
tensor_D_hash = cutlass_bindings.test.conv.host.TensorHash(tensor_view_D)
return tensor_D_hash == tensor_D_ref
else:
tensor_view_D = getTensorView(
tensor_D, self.layout_D, self.conv_kind, problem_size, "d"
)
tensor_view_D_ref = getTensorView(
tensor_D_ref, self.layout_D, self.conv_kind, problem_size, "d"
)
return cutlass_bindings.test.conv.host.equals(
tensor_view_D, tensor_view_D_ref
)
def run_cutlass_profiler(
self,
problem_size,
split_k_mode=cutlass_bindings.conv.SplitKMode.Serial,
alpha=1.0,
beta=0.0,
):
if split_k_mode == cutlass_bindings.conv.SplitKMode.Serial:
split_k_mode_ = "serial"
else:
split_k_mode_ = "parallel"
cutlass_path = os.getenv("CUTLASS_PATH")
assert (
cutlass_path is not None
), "Environment variable 'CUTLASS_PATH' is not defined."
values = {
"profiler_path": cutlass_path + "/build/tools/profiler/cutlass_profiler",
"kernel_name": self.operation.procedural_name(),
"verification_providers": "device",
"provider": "cutlass",
"n": str(problem_size.N),
"h": str(problem_size.H),
"w": str(problem_size.W),
"c": str(problem_size.C),
"k": str(problem_size.K),
"r": str(problem_size.R),
"s": str(problem_size.S),
"p": str(problem_size.P),
"q": str(problem_size.Q),
"pad_h": str(problem_size.pad_h),
"pad_w": str(problem_size.pad_w),
"stride_h": str(problem_size.stride_h),
"stride_w": str(problem_size.stride_w),
"dilation_h": str(problem_size.dilation_h),
"dilation_w": str(problem_size.dilation_w),
"split_k_slices": str(problem_size.split_k_slices),
"split_k_mode": split_k_mode_,
"alpha": str(alpha),
"beta": str(beta),
"warmup": str(self.warmup_iterations),
"profile": str(self.iterations),
}
cmd_template = (
"${profiler_path} --kernels=${kernel_name} --verification-providers=${verification_providers}"
" --providers=${provider} --n=${n} --h=${h} --w=${w} --c=${c} --k=${k} --r=${r} --s=${s} --p=${p}"
" --q=${q} --pad_h=${pad_h} --pad_w=${pad_w} --stride_h={stride_h} --stride_w=${stride_w}"
" --dilation_h=${dilation_h} --dilation_w=${dilation_w} --warmup-iterations=${warmup} --profiling-iterations=${profile}"
" --split_k_slices=${split_k_slices} --alpha=${alpha} --beta=${beta} --split_k_mode=${split_k_mode}"
)
cmd = SubstituteTemplate(cmd_template, values)
result = subprocess.getoutput(cmd)
m = re.search(r"Runtime:\s+(?P<runtime>\d+.\d+)", result)
runtime = float(m.group("runtime"))
m = re.search(r"Bytes:\s+(?P<bytes>\d+)", result)
bytes = int(m.group("bytes"))
m = re.search(r"FLOPs:\s+(?P<flops>\d+)", result)
flops = int(m.group("flops"))
# check if the problem size matches
assert bytes == self.bytes(problem_size, alpha, beta)
assert flops == self.flops(problem_size)
return runtime
def run(
self,
problem_size,
split_k_mode=cutlass_bindings.conv.SplitKMode.Serial,
alpha=1.0,
beta=0.0,
):
assert get_allocated_size() == 0, (
"%d byte of pool memory is not released in previous run"
% get_allocated_size()
)
#
# Initialize input and output tensors
#
tensor_A_size = cutlass_bindings.conv.implicit_gemm_tensor_a_size(
self.conv_kind, problem_size
)
tensor_B_size = cutlass_bindings.conv.implicit_gemm_tensor_b_size(
self.conv_kind, problem_size
)
tensor_C_size = cutlass_bindings.conv.implicit_gemm_tensor_c_size(
self.conv_kind, problem_size
)
np.random.seed(self.seed)
tensor_A = self.uniform_init(size=(tensor_A_size,), dtype=self.dtype_A)
tensor_B = self.uniform_init(size=(tensor_B_size,), dtype=self.dtype_B)
tensor_C = self.uniform_init(size=(tensor_C_size,), dtype=self.dtype_C)
tensor_D = np.zeros(shape=(tensor_C_size,), dtype=self.dtype_D)
#
# Launch kernel
#
arguments = Conv2dArguments(
operation=self.operation,
problem_size=problem_size,
A=tensor_A,
B=tensor_B,
C=tensor_C,
D=tensor_D,
output_op=self.operation.epilogue_type(alpha, beta),
split_k_slices=problem_size.split_k_slices,
split_k_mode=split_k_mode,
)
if split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel:
implicit_gemm_size = cutlass_bindings.conv.implicit_gemm_problem_size(
self.operation.conv_kind, arguments.problem_size
)
reduction_arguments = ReductionArguments(
self.reduction_operation,
problem_size=[implicit_gemm_size.m(), implicit_gemm_size.n()],
partitions=problem_size.split_k_slices,
workspace=arguments.ptr_D,
destination=tensor_D,
source=tensor_C,
output_op=self.reduction_operation.epilogue_type(alpha, beta),
)
self.operation.run(arguments)
if split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel:
self.reduction_operation.run(reduction_arguments)
passed = True
if self.verification:
if split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel:
reduction_arguments.sync()
else:
arguments.sync()
tensor_D_ref = self.host_reference(
problem_size, tensor_A, tensor_B, tensor_C, alpha, beta
)
passed = self.equal(tensor_D, tensor_D_ref, problem_size)
try:
assert passed
except AssertionError:
self.print_problem_size(problem_size, split_k_mode)
if self.profiling:
sleep(self.sleep_time)
for _ in range(self.warmup_iterations):
self.operation.run(arguments)
if split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel:
self.reduction_operation.run(reduction_arguments)
self.timer.start()
for _ in range(self.warmup_iterations):
self.operation.run(arguments)
if split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel:
self.reduction_operation.run(reduction_arguments)
self.timer.stop_and_wait()
runtime = self.timer.duration(self.iterations)
# free memory
del arguments
if split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel:
del reduction_arguments
assert get_allocated_size() == 0, (
"%d byte of pool memory is not released after current run"
% get_allocated_size()
)
if self.profiling:
return runtime
return passed
########################################################################################################
# TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference
# TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes
# Additionaly, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes
# (conv_blacklist_sizes)
############################################################################################################
def test_all_conv2d(operation: Conv2dOperation, conv_test_sizes=[], interleaved=False):
passed = True
#
# Testbed object
#
testbed = Conv2dLauncher(operation, interleaved=interleaved)
#
# Get conv problem sizes to run conv operator
#
conv_problems = cutlass_bindings.test.conv.TestbedConv2dProblemSizes(64)
# Vector of conv2d problem sizes to avoid duplicate runs
conv_tested_sizes = []
# Flatten 2D problem_vectors into a 1D problem sizes
problem_sizes = conv_problems.conv2d_default_sizes
problem_sizes = [conv_problem for conv_problem in problem_sizes] + conv_test_sizes
# Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slices=1, alpha=1.0, beta=0.0)
for conv_problem in problem_sizes:
if conv_problem in conv_tested_sizes:
continue
# skip channel dimension % 32 != 0 for interleaved case
if interleaved:
if conv_problem.K % 32 != 0 or conv_problem.C % 32 != 0:
continue
#
# Procedurally disable certain cases
#
# CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1}
if (
operation.conv_kind == cutlass_bindings.conv.Operator.dgrad
and operation.stride_support == StrideSupport.Unity
):
if not ((conv_problem.stride_h == 1) and (conv_problem.stride_w == 1)):
continue
if not interleaved:
# Fixed channels algorithm requires channel count to match access size
if (
operation.iterator_algorithm
== cutlass_bindings.conv.IteratorAlgorithm.fixed_channels
):
if conv_problem.C != operation.A.alignment:
continue
# Few channels algorithm requires channel count to match access size
if (
operation.iterator_algorithm
== cutlass_bindings.conv.IteratorAlgorithm.few_channels
):
if conv_problem.C % operation.A.alignment:
continue
# CUTLASS DGRAD's *strided* stride specialization supports all stride {stride_h, stride_w}
# Although strided dgrad works for all stride combinations, we are only going
# to run strided dgrad for non-unity strides
if (
operation.conv_kind == cutlass_bindings.conv.Operator.dgrad
and operation.stride_support == StrideSupport.Strided
):
if (conv_problem.stride_h == 1) and (conv_problem.stride_w == 1):
continue
#
# Test
#
# push back tested problem size to avoid re-running duplicates
conv_tested_sizes.append(conv_problem)
passed = testbed.run(conv_problem)
if not passed:
return False
if interleaved:
return True
#
# filter the cases for split K
#
# Small-channels convolution can't run here.
if operation.iterator_algorithm in [
cutlass_bindings.conv.IteratorAlgorithm.fixed_channels,
cutlass_bindings.conv.IteratorAlgorithm.few_channels,
]:
return True
# CUTLASS DGRAD's *stride* specialization does not support split-k mode
if (
operation.conv_kind == cutlass_bindings.conv.Operator.dgrad
and operation.stride_support == StrideSupport.Strided
):
conv_problem = cutlass_bindings.conv.Conv2dProblemSize(
cutlass_bindings.Tensor4DCoord(1, 56, 56, 8),
cutlass_bindings.Tensor4DCoord(8, 1, 1, 8),
cutlass_bindings.Tensor4DCoord(0, 0, 0, 0),
cutlass_bindings.MatrixCoord(2, 2),
cutlass_bindings.MatrixCoord(1, 1),
cutlass_bindings.conv.Mode.cross_correlation,
1,
1,
)
passed = testbed.run(conv_problem)
return passed
# Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for
# a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters
# which are abolutely neccessary to catch functional bugs. The below code does provide option to sweep
# alpha and beta for local testing, but only runs one value for alpha and beta.
conv2d_split_k_test_size = cutlass_bindings.conv.Conv2dProblemSize(
cutlass_bindings.Tensor4DCoord(1, 17, 11, 288),
cutlass_bindings.Tensor4DCoord(160, 3, 3, 288),
cutlass_bindings.Tensor4DCoord(1, 1, 1, 1),
cutlass_bindings.MatrixCoord(1, 1),
cutlass_bindings.MatrixCoord(1, 1),
cutlass_bindings.conv.Mode.cross_correlation,
1,
1,
)
split_k_modes = [
cutlass_bindings.conv.SplitKMode.Parallel,
cutlass_bindings.conv.SplitKMode.Serial,
]
split_k_slices = [1, 2, 3, 4, 201]
problem_alpha = [
2.0,
]
problem_beta = [
2.0,
]
for split_k_mode in split_k_modes:
for split_k_slice in split_k_slices:
for alpha in problem_alpha:
for beta in problem_beta:
passed = testbed.run(
conv2d_split_k_test_size.reset_split_k_slices(split_k_slice),
split_k_mode,
alpha,
beta,
)
return passed

View File

@ -0,0 +1,276 @@
#################################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from bfloat16 import bfloat16
import cutlass_bindings
import numpy as np
from cutlass.backend import compiler
from cutlass.backend.gemm_operation import GemmGroupedArguments, GemmOperationGrouped
from cutlass.backend.library import DataTypeSize, ShortDataTypeNames
from cutlass.backend.memory_manager import get_allocated_size
from cutlass.backend.test.gemm_testbed import getTensorRef, getTensorView, transpose
class TestbedGrouped:
def __init__(self, operation: GemmOperationGrouped, seed: int = 2080) -> None:
compiler.add_module([operation])
self.seed = seed
self.operation = operation
element_size = DataTypeSize[operation.A.element]
self.dtype_A = self.numpy_type(operation.A.element)
self.dtype_B = self.numpy_type(operation.B.element)
self.dtype_C = self.numpy_type(operation.C.element)
self.dtype_D = self.numpy_type(operation.C.element)
if element_size == 1:
self.scope_max = 1
self.scope_min = 0
elif element_size <= 8:
self.scope_max = 1
self.scope_min = -1
elif element_size == 16:
self.scope_max = 4
self.scope_min = -4
else:
self.scope_max = 8
self.scope_min = -8
#: compute type
self.compute_type = operation.epilogue_functor.element_epilogue
self.accumulator_type = (
operation.tile_description.math_instruction.element_accumulator
)
@staticmethod
def numpy_type(type):
if type == cutlass_bindings.float64:
return np.float64
elif type == cutlass_bindings.float32:
return np.float32
elif type == cutlass_bindings.float16:
return np.float16
elif type == cutlass_bindings.bfloat16:
return bfloat16
elif type == cutlass_bindings.int32:
return np.int32
elif type == cutlass_bindings.int8:
return np.int8
else:
raise ValueError("unsupported type: %s" % ShortDataTypeNames[type])
def uniform_init(self, size, dtype):
if dtype in [np.float32, np.float16, bfloat16, np.float64]:
return np.ceil(
np.random.uniform(
low=self.scope_min - 0.5, high=self.scope_max - 0.5, size=size
).astype(dtype)
)
else:
return np.random.uniform(
low=self.scope_min - 1, high=self.scope_max + 1, size=size
).astype(dtype)
def print_problem_size(self, p):
problem_size = "problem: %d, %d, %d\n" % (p.m(), p.n(), p.k())
print(problem_size)
def run(self, problem_count: int, alpha: float = 1.0, beta: float = 0.0) -> bool:
assert get_allocated_size() == 0, (
"%d byte of pool memory is not released in previous run"
% get_allocated_size()
)
# initialize
passed = False
np.random.seed(self.seed)
# generate the problem sizes
problem_sizes = []
tensor_As = []
tensor_Bs = []
tensor_Cs = []
tensor_Ds = []
tensor_D_refs = []
for i in range(problem_count):
if self.dtype_A == np.int8:
if i == 0:
problem_size = cutlass_bindings.gemm.GemmCoord(48, 16, 32)
else:
problem_size = cutlass_bindings.gemm.GemmCoord(
16 * np.random.randint(0, 64) + 48,
16 * np.random.randint(0, 64) + 48,
16 * np.random.randint(0, 64) + 48,
)
else:
if i == 0:
problem_size = cutlass_bindings.gemm.GemmCoord(48, 16, 8)
else:
problem_size = cutlass_bindings.gemm.GemmCoord(
8 * np.random.randint(0, 64) + 24,
8 * np.random.randint(0, 64) + 24,
8 * np.random.randint(0, 64) + 24,
)
tensor_As.append(
self.uniform_init(
size=(problem_size.m() * problem_size.k(),), dtype=self.dtype_A
)
)
tensor_Bs.append(
self.uniform_init(
size=(problem_size.n() * problem_size.k(),), dtype=self.dtype_B
)
)
tensor_Cs.append(
self.uniform_init(
size=(problem_size.m() * problem_size.n(),), dtype=self.dtype_C
)
)
tensor_Ds.append(
np.zeros(
shape=(problem_size.m() * problem_size.n(),), dtype=self.dtype_D
)
)
tensor_D_refs.append(
np.ones(
shape=(problem_size.m() * problem_size.n(),), dtype=self.dtype_D
)
)
problem_sizes.append(problem_size)
arguments = GemmGroupedArguments(
operation=self.operation,
problem_sizes=problem_sizes,
A=tensor_As,
B=tensor_Bs,
C=tensor_Cs,
D=tensor_Ds,
output_op=self.operation.epilogue_type(alpha, beta),
)
self.operation.run(arguments)
arguments.sync()
#
# Reference check
#
alpha = self.compute_type(alpha).value()
beta = self.compute_type(beta).value()
init_acc = self.accumulator_type(0).value()
for idx, problem_size in enumerate(problem_sizes):
if self.operation.switched:
tensor_ref_A = getTensorRef(
tensor_As[idx],
problem_size,
"a",
transpose(self.operation.B.layout),
)
tensor_ref_B = getTensorRef(
tensor_Bs[idx],
problem_size,
"b",
transpose(self.operation.A.layout),
)
tensor_ref_C = getTensorRef(
tensor_Cs[idx],
problem_size,
"c",
transpose(self.operation.C.layout),
)
tensor_ref_D_ref = getTensorRef(
tensor_D_refs[idx],
problem_size,
"d",
transpose(self.operation.C.layout),
)
else:
tensor_ref_A = getTensorRef(
tensor_As[idx], problem_size, "a", self.operation.A.layout
)
tensor_ref_B = getTensorRef(
tensor_Bs[idx], problem_size, "b", self.operation.B.layout
)
tensor_ref_C = getTensorRef(
tensor_Cs[idx], problem_size, "c", self.operation.C.layout
)
tensor_ref_D_ref = getTensorRef(
tensor_D_refs[idx], problem_size, "d", self.operation.C.layout
)
tensor_view_D_ref = getTensorView(
tensor_D_refs[idx], problem_size, "d", self.operation.C.layout
)
cutlass_bindings.test.gemm.host.gemm(
problem_size,
alpha,
tensor_ref_A,
tensor_ref_B,
beta,
tensor_ref_C,
tensor_ref_D_ref,
init_acc,
)
tensor_view_D = getTensorView(
tensor_Ds[idx], problem_size, "d", self.operation.C.layout
)
passed = cutlass_bindings.test.gemm.host.equals(
tensor_view_D, tensor_view_D_ref
)
try:
assert passed
except AssertionError:
self.print_problem_size(problem_size)
del arguments
assert get_allocated_size() == 0, (
"%d byte of pool memory is not released after current run"
% get_allocated_size()
)
return passed

View File

@ -0,0 +1,758 @@
#################################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import os
import re
import subprocess
from time import sleep
from bfloat16 import bfloat16
from cuda import cuda, cudart
import cutlass_bindings
import numpy as np
from cutlass.backend import compiler
from cutlass.backend.gemm_operation import GemmArguments, GemmOperationUniversal
from cutlass.backend.library import (
DataTypeSize,
DataTypeSizeBytes,
MathOperation,
ShortDataTypeNames,
)
from cutlass.backend.memory_manager import get_allocated_size
from cutlass.backend.reduction_operation import ReductionArguments, ReductionOperation
from cutlass.backend.test.profiler import GpuTimer
from cutlass.backend.utils.datatypes import to_cutlass
from cutlass.backend.utils.software import SubstituteTemplate
def transpose(layout):
if layout == cutlass_bindings.RowMajor:
return cutlass_bindings.ColumnMajor
elif layout == cutlass_bindings.ColumnMajor:
return cutlass_bindings.RowMajor
elif layout == cutlass_bindings.ColumnMajorInterleaved32:
return cutlass_bindings.RowMajorInterleaved32
elif layout == cutlass_bindings.RowMajorInterleaved32:
return cutlass_bindings.ColumnMajorInterleaved32
def getTensorRef(
tensor: np.ndarray,
problem_size: cutlass_bindings.gemm.GemmCoord,
operand: str,
layout: cutlass_bindings.layout,
batch_offset: int = 0,
):
ptr = tensor.__array_interface__["data"][0]
if operand == "a":
tensor_coord = problem_size.mk()
batch_stride = problem_size.m() * problem_size.k()
elif operand == "b":
tensor_coord = problem_size.kn()
batch_stride = problem_size.k() * problem_size.n()
elif operand in ["c", "d"]:
tensor_coord = problem_size.mn()
batch_stride = problem_size.m() * problem_size.n()
else:
raise ValueError("Unknown operand: " + operand)
elt_size = DataTypeSizeBytes[to_cutlass(tensor.dtype)]
ptr += batch_offset * batch_stride * elt_size
if layout == cutlass_bindings.RowMajor:
layout = cutlass_bindings.RowMajor.packed(tensor_coord)
layout_tag = "RowMajor"
elif layout == cutlass_bindings.ColumnMajor:
layout = cutlass_bindings.ColumnMajor.packed(tensor_coord)
layout_tag = "ColumnMajor"
elif layout == cutlass_bindings.ColumnMajorInterleaved32:
layout = cutlass_bindings.ColumnMajorInterleaved32.packed(tensor_coord)
layout_tag = "ColumnMajorInterleaved32"
elif layout == cutlass_bindings.RowMajorInterleaved32:
layout = cutlass_bindings.RowMajorInterleaved32.packed(tensor_coord)
layout_tag = "RowMajorInterleaved32"
else:
raise ValueError("unsupported layout")
if tensor.dtype == np.float32:
ref_name = "TensorRefF32" + layout_tag
elif tensor.dtype == np.float64:
ref_name = "TensorRefF64" + layout_tag
elif tensor.dtype == np.float16:
ref_name = "TensorRefF16" + layout_tag
elif tensor.dtype == bfloat16:
ref_name = "TensorRefBF16" + layout_tag
elif tensor.dtype == np.int8:
ref_name = "TensorRefS8" + layout_tag
elif tensor.dtype == np.int32:
ref_name = "TensorRefS32" + layout_tag
else:
raise ValueError("unsupported datatype %s" % ShortDataTypeNames[tensor.dtype])
return getattr(cutlass_bindings, ref_name)(ptr, layout)
def getTensorView(
tensor: np.ndarray,
problem_size: cutlass_bindings.gemm.GemmCoord,
operand: str,
layout: str,
batch_offset: int = 0,
):
tensor_ref = getTensorRef(tensor, problem_size, operand, layout, batch_offset)
if operand == "a":
tensor_coord = problem_size.mk()
elif operand == "b":
tensor_coord = problem_size.kn()
elif operand in ["c", "d"]:
tensor_coord = problem_size.mn()
else:
raise ValueError("Unknown operand: " + operand)
if layout == cutlass_bindings.RowMajor:
layout_tag = "RowMajor"
elif layout == cutlass_bindings.ColumnMajor:
layout_tag = "ColumnMajor"
elif layout == cutlass_bindings.ColumnMajorInterleaved32:
layout_tag = "ColumnMajorInterleaved32"
elif layout == cutlass_bindings.RowMajorInterleaved32:
layout_tag = "RowMajorInterleaved32"
else:
raise ValueError("unsupported layout")
if tensor.dtype == np.float32:
ref_name = "TensorViewF32" + layout_tag
elif tensor.dtype == np.float64:
ref_name = "TensorViewF64" + layout_tag
elif tensor.dtype == np.float16:
ref_name = "TensorViewF16" + layout_tag
elif tensor.dtype == bfloat16:
ref_name = "TensorViewBF16" + layout_tag
elif tensor.dtype == np.int32:
ref_name = "TensorViewS32" + layout_tag
elif tensor.dtype == np.int8:
ref_name = "TensorViewS8" + layout_tag
else:
raise ValueError("unsupported datatype")
return getattr(cutlass_bindings, ref_name)(tensor_ref, tensor_coord)
class GemmUniversalLauncher:
def __init__(
self,
operation: "GemmOperationUniversal",
seed: int = 2080,
interleaved=False,
verification=True,
profiling=False,
warmup_iterations=500,
iterations=500,
**kwargs,
) -> None:
# create the reduction kernel
self.reduction_operation: ReductionOperation = ReductionOperation(
shape=cutlass_bindings.MatrixCoord(4, 32 * operation.C.alignment),
C=operation.C,
element_accumulator=operation.tile_description.math_instruction.element_accumulator,
element_compute=operation.epilogue_functor.element_epilogue,
epilogue_functor=operation.epilogue_functor,
count=operation.C.alignment,
)
self.math_operation = operation.tile_description.math_instruction.math_operation
#: verify the output result
self.verification = verification
#: profile the kernel's runtime
self.profiling = profiling
self.timer = GpuTimer()
self.warmup_iterations = warmup_iterations
self.iterations = iterations
if "sleep" in kwargs.keys():
self.sleep_time = kwargs["sleep"]
else:
self.sleep_time = 0
#
# Compile the operator
#
op_list = [operation]
if operation.arch < 90:
# Split K via Python is currently only supported for pre-SM90 kernels
op_list.append(self.reduction_operation)
compiler.add_module(op_list)
self.operation = operation
self.dtype_A = GemmUniversalLauncher.numpy_type(operation.A.element)
self.dtype_B = GemmUniversalLauncher.numpy_type(operation.B.element)
self.dtype_C = GemmUniversalLauncher.numpy_type(operation.C.element)
self.dtype_D = GemmUniversalLauncher.numpy_type(operation.C.element)
accumulator_size = DataTypeSize[
operation.tile_description.math_instruction.element_accumulator
]
element_size = DataTypeSize[operation.A.element]
if element_size == 1:
self.scope_max = 1
self.scope_min = 0
elif element_size <= 8:
self.scope_max = 1
self.scope_min = -1
elif element_size == 16:
self.scope_max = 4
self.scope_min = -4
else:
self.scope_max = 8
self.scope_min = -8
#: seed
self.seed: int = seed
#: whether the layout is interleaved
self.interleaved = interleaved
#: compute type
self.compute_type = operation.epilogue_functor.element_epilogue
self.accumulator_type = (
operation.tile_description.math_instruction.element_accumulator
)
def print_problem_size(self, p, mode, batch_count):
if mode == cutlass_bindings.gemm.Mode.Gemm:
mode = "Gemm"
elif mode == cutlass_bindings.gemm.Mode.Batched:
mode = "GemmBatched"
elif mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel:
mode = "GemmSplitKParallel"
problem_size = "problem: %d, %d, %d\n batch_count: %d\n mode: %s" % (
p.m(),
p.n(),
p.k(),
batch_count,
mode,
)
print(problem_size)
@staticmethod
def numpy_type(type):
if type == cutlass_bindings.float64:
return np.float64
elif type == cutlass_bindings.float32:
return np.float32
elif type == cutlass_bindings.float16:
return np.float16
elif type == cutlass_bindings.bfloat16:
return bfloat16
elif type == cutlass_bindings.int32:
return np.int32
elif type == cutlass_bindings.int8:
return np.int8
else:
raise ValueError("unsupported type: %s" % ShortDataTypeNames[type])
def uniform_init(self, size, dtype):
if dtype in [np.float32, np.float16, bfloat16, np.float64]:
return np.ceil(
np.random.uniform(
low=self.scope_min - 0.5, high=self.scope_max - 0.5, size=size
).astype(dtype)
)
else:
return np.random.uniform(
low=self.scope_min - 1, high=self.scope_max + 1, size=size
).astype(dtype)
def reorder_tensor_B(self, tensor_B, problem_size):
reordered_tensor_B = np.empty_like(tensor_B)
tensor_ref_B = getTensorRef(
tensor_B, problem_size, "b", self.operation.B.layout
)
reordered_tensor_ref_B = getTensorRef(
reordered_tensor_B, problem_size, "b", self.operation.B.layout
)
cutlass_bindings.gemm.host.reorder_column(
tensor_ref_B, reordered_tensor_ref_B, problem_size
)
return reordered_tensor_B
def host_reference(self, problem_size, batch_count, tensor_A, tensor_B, tensor_C, alpha, beta):
tensor_D_ref = np.ones_like(tensor_C)
alpha = self.numpy_type(self.compute_type)(alpha)
beta = self.numpy_type(self.compute_type)(beta)
init_acc = 0
alpha = self.compute_type(alpha).value()
beta = self.compute_type(beta).value()
init_acc = self.accumulator_type(init_acc).value()
for i in range(batch_count):
if self.operation.switched:
tensor_ref_A = getTensorRef(
tensor_A,
problem_size,
"a",
transpose(self.operation.B.layout),
batch_offset=i,
)
tensor_ref_B = getTensorRef(
tensor_B,
problem_size,
"b",
transpose(self.operation.A.layout),
batch_offset=i,
)
tensor_ref_C = getTensorRef(
tensor_C,
problem_size,
"c",
transpose(self.operation.C.layout),
batch_offset=i,
)
tensor_ref_D_ref = getTensorRef(
tensor_D_ref,
problem_size,
"d",
transpose(self.operation.C.layout),
batch_offset=i,
)
else:
tensor_ref_A = getTensorRef(
tensor_A, problem_size, "a", self.operation.A.layout, batch_offset=i
)
tensor_ref_B = getTensorRef(
tensor_B, problem_size, "b", self.operation.B.layout, batch_offset=i
)
tensor_ref_C = getTensorRef(
tensor_C, problem_size, "c", self.operation.C.layout, batch_offset=i
)
tensor_ref_D_ref = getTensorRef(
tensor_D_ref,
problem_size,
"d",
self.operation.C.layout,
batch_offset=i,
)
if self.math_operation in [MathOperation.multiply_add_saturate]:
cutlass_bindings.test.gemm.host.gemm_saturate(
problem_size,
alpha,
tensor_ref_A,
tensor_ref_B,
beta,
tensor_ref_C,
tensor_ref_D_ref,
init_acc,
)
else:
cutlass_bindings.test.gemm.host.gemm(
problem_size,
alpha,
tensor_ref_A,
tensor_ref_B,
beta,
tensor_ref_C,
tensor_ref_D_ref,
init_acc,
)
return tensor_D_ref
def equal(self, tensor_D, tensor_D_ref, problem_size, batch_count):
for i in range(batch_count):
tensor_view_D = getTensorView(
tensor_D, problem_size, "d", self.operation.C.layout, batch_offset=i
)
tensor_view_D_ref = getTensorView(
tensor_D_ref, problem_size, "d", self.operation.C.layout, batch_offset=i
)
if not cutlass_bindings.test.gemm.host.equals(
tensor_view_D, tensor_view_D_ref
):
return False
return True
def bytes(self, problem_size, batch_count=1, alpha=1.0, beta=0.0):
m = problem_size.m()
n = problem_size.n()
k = problem_size.k()
bytes = (
(DataTypeSize[self.operation.A.element] * m // 8) * k
+ (DataTypeSize[self.operation.B.element] * n // 8) * k
+ (DataTypeSize[self.operation.C.element] * m // 8) * n
)
if beta != 0:
bytes += (DataTypeSize[self.operation.C.element] * m // 8) * n
bytes *= batch_count
return bytes
def flops(self, problem_size, batch_count=1):
m = problem_size.m()
n = problem_size.n()
k = problem_size.k()
flops_ = (m * n * k) * 2 * batch_count
return flops_
def run_cutlass_profiler(
self, mode, problem_size, batch_count=1, alpha=1.0, beta=0.0
):
cutlass_path = os.getenv("CUTLASS_PATH")
assert (
cutlass_path is not None
), "Environment variable 'CUTLASS_PATH' is not defined."
values = {
"profiler_path": cutlass_path + "/build/tools/profiler/cutlass_profiler",
"kernel_name": self.operation.procedural_name(),
"verification_providers": "device",
"provider": "cutlass",
"m": str(problem_size.m()),
"n": str(problem_size.n()),
"k": str(problem_size.k()),
"split_k_slices": str(batch_count),
"alpha": str(alpha),
"beta": str(beta),
"warmup": str(self.warmup_iterations),
"profile": str(self.iterations),
}
cmd_template = (
"${profiler_path} --kernels=${kernel_name} --verification-providers=${verification_providers}"
" --providers=${provider} --m=${m} --n=${n} --k=${k}"
)
cmd = SubstituteTemplate(cmd_template, values)
result = subprocess.getoutput(cmd)
m = re.search(r"Runtime:\s+(?P<runtime>\d+.\d+)", result)
runtime = float(m.group("runtime"))
m = re.search(r"Bytes:\s+(?P<bytes>\d+)", result)
bytes = int(m.group("bytes"))
m = re.search(r"FLOPs:\s+(?P<flops>\d+)", result)
flops = int(m.group("flops"))
# check if the problem size matches
assert bytes == self.bytes(problem_size, alpha, beta)
assert flops == self.flops(problem_size)
return runtime
def run(self, mode, problem_size, batch_count=1, split_k_slices=1, alpha=1.0, beta=0.0):
assert get_allocated_size() == 0, (
"%d byte of pool memory is not released in previous run"
% get_allocated_size()
)
np.random.seed(self.seed)
# Assign an actual batch count in cases where we are not running in batched mode.
# This is to differentiate between the number of split K slices and the batch count,
# which are overloaded within the single `batch_count` variable.
true_batch_count = (
batch_count if mode == cutlass_bindings.gemm.Mode.Batched else 1
)
tensor_A = self.uniform_init(
size=(problem_size.m() * problem_size.k() * true_batch_count,),
dtype=self.dtype_A,
)
tensor_B = self.uniform_init(
size=(problem_size.n() * problem_size.k() * true_batch_count,),
dtype=self.dtype_B,
)
tensor_C = self.uniform_init(
size=(problem_size.m() * problem_size.n() * true_batch_count,),
dtype=self.dtype_C,
)
tensor_D = np.zeros(
shape=(problem_size.m() * problem_size.n() * true_batch_count,),
dtype=self.dtype_D,
)
#
# Launch kernel
#
arguments = GemmArguments(
operation=self.operation,
problem_size=problem_size,
A=tensor_A,
B=tensor_B,
C=tensor_C,
D=tensor_D,
output_op=self.operation.epilogue_type(alpha, beta),
gemm_mode=mode,
split_k_slices=split_k_slices,
batch=batch_count,
)
if mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel:
reduction_arguments = ReductionArguments(
self.reduction_operation,
problem_size=[problem_size.m(), problem_size.n()],
partitions=split_k_slices,
workspace=arguments.ptr_D,
destination=tensor_D,
source=tensor_C,
output_op=self.reduction_operation.epilogue_type(alpha, beta),
)
self.operation.run(arguments)
if mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel:
self.reduction_operation.run(reduction_arguments)
passed = True
if self.verification:
if mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel:
reduction_arguments.sync()
else:
arguments.sync()
tensor_D_ref = self.host_reference(
problem_size,
true_batch_count,
tensor_A,
tensor_B,
tensor_C,
alpha,
beta,
)
passed = self.equal(tensor_D, tensor_D_ref, problem_size, true_batch_count)
try:
assert passed
except AssertionError:
self.print_problem_size(problem_size, mode, batch_count)
if self.profiling:
sleep(self.sleep_time)
for _ in range(self.warmup_iterations):
self.operation.run(arguments)
if mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel:
self.reduction_operation.run(reduction_arguments)
self.timer.start()
for _ in range(self.iterations):
self.operation.run(arguments)
if mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel:
self.reduction_operation.run(reduction_arguments)
self.timer.stop_and_wait()
runtime = self.timer.duration(self.iterations)
# free memory and clear buffers
del arguments
if mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel:
del reduction_arguments
assert get_allocated_size() == 0, (
"%d byte of pool memory is not released after current run"
% get_allocated_size()
)
if self.profiling:
return runtime
return passed
def test_all_gemm(operation: "GemmOperationUniversal", testcase="universal"):
passed = True
minimum_operand_element_size = min(
DataTypeSize[operation.A.element], DataTypeSize[operation.B.element]
)
opcode_class = operation.tile_description.math_instruction.opcode_class
if opcode_class == cutlass_bindings.OpClass.Simt:
alignment = 1
else:
alignment = 128 // minimum_operand_element_size
# int8_t gemm alignment constraints
if opcode_class == cutlass_bindings.OpClass.Simt and operation.A.element == cutlass_bindings.int8 and operation.A.layout == cutlass_bindings.ColumnMajor:
alignment_m = 4
else:
alignment_m = alignment
if (
opcode_class == cutlass_bindings.OpClass.Simt
and operation.B.element == cutlass_bindings.int8
and operation.A.layout == cutlass_bindings.RowMajor
):
alignment_n = 4
else:
alignment_n = alignment
if (
opcode_class == cutlass_bindings.OpClass.Simt
and operation.A.element == cutlass_bindings.int8
and operation.B.element == cutlass_bindings.int8
and (
operation.A.layout == cutlass_bindings.RowMajor
or operation.B.layout == cutlass_bindings.ColumnMajor
)
):
alignment_k = 4
else:
alignment_k = alignment
threadblock_k = operation.tile_description.threadblock_shape[2]
if testcase == "interleaved":
if operation.A.layout in [
cutlass_bindings.ColumnMajorInterleaved32,
cutlass_bindings.RowMajorInterleaved32,
]:
interleavedk = 32
else:
raise ValueError("Unknown layout")
# Split K mode via Python is currently only supported pre-SM90, and when stream K is not used.
# Stream K enables split-k functionality with mode `Gemm` and a non-unit batch count.
supports_split_k = operation.arch < 90 and not isinstance(
operation.swizzling_functor, cutlass_bindings.ThreadblockSwizzleStreamK
)
if testcase == "interleaved":
modes = [
cutlass_bindings.gemm.Mode.Gemm,
]
problem_size_m = [interleavedk, 512 + interleavedk]
problem_size_n = [interleavedk, 512 + interleavedk]
problem_size_k = [
interleavedk,
threadblock_k * operation.tile_description.stages + interleavedk,
]
problem_alpha = [1.0]
problem_beta = [0.0]
batch_counts = [
1,
]
elif testcase == "multistage":
modes = [
cutlass_bindings.gemm.Mode.Gemm,
]
problem_size_m = [16, 528]
problem_size_n = [16, 528]
problem_size_k = [
threadblock_k,
threadblock_k * operation.tile_description.stages
+ operation.tile_description.math_instruction.instruction_shape[2],
]
problem_alpha = [1.0]
problem_beta = [0.0]
batch_counts = [
1,
]
else: # universal
modes = [cutlass_bindings.gemm.Mode.Gemm]
batch_counts = [1, 2, 3, 5, 7]
if supports_split_k:
modes.append(cutlass_bindings.gemm.Mode.GemmSplitKParallel)
problem_size_m = [alignment_m, 512 - 3 * alignment_m]
problem_size_n = [alignment_n, 512 - 2 * alignment_n]
if operation.tile_description.stages is None:
stages_for_k_calc = 7
else:
stages_for_k_calc = operation.tile_description.stages
problem_size_k = [
alignment_k,
threadblock_k * stages_for_k_calc - alignment_k,
threadblock_k * stages_for_k_calc * 3 - alignment_k,
]
problem_alpha = [1.0]
problem_beta = [2.0]
testbed = GemmUniversalLauncher(operation, interleaved=(testcase == "interleaved"))
for mode in modes:
for m in problem_size_m:
for n in problem_size_n:
for k in problem_size_k:
for batch_count in batch_counts:
for alpha in problem_alpha:
for beta in problem_beta:
# skip very small K problems
if testcase == "universal":
if k // batch_count < 2 * threadblock_k:
continue
problem_size = cutlass_bindings.gemm.GemmCoord(m, n, k)
if supports_split_k:
split_k_slices = batch_count
else:
split_k_slices = 1
overridden_mode = mode
if (
mode == cutlass_bindings.gemm.Mode.Gemm
and batch_count > 1
):
overridden_mode = cutlass_bindings.gemm.Mode.Batched
passed = testbed.run(
overridden_mode,
problem_size,
batch_count,
split_k_slices,
alpha,
beta,
)
(err,) = cudart.cudaDeviceSynchronize()
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
if not passed:
return False
return passed

View File

@ -0,0 +1,69 @@
#################################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cuda import cuda, cudart
class GpuTimer:
def __init__(self) -> None:
self.events = [
cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1],
cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1],
]
def start(self, stream=cuda.CUstream(0)):
(err,) = cuda.cuEventRecord(self.events[0], stream)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
def stop(self, stream=cuda.CUstream(0)):
(err,) = cuda.cuEventRecord(self.events[1], stream)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
pass
def stop_and_wait(self, stream=cuda.CUstream(0)):
self.stop(stream)
if stream:
(err,) = cuda.cuStreamSynchronize(stream)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
else:
(err,) = cudart.cudaDeviceSynchronize()
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
def duration(self, iterations=1):
err, duration = cuda.cuEventElapsedTime(self.events[0], self.events[1])
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
return duration / float(iterations)

View File

@ -0,0 +1,131 @@
#################################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import cutlass_bindings
from cutlass import KernelScheduleSuffixes
from cutlass.backend import library
from cutlass.backend.utils.software import SubstituteTemplate
class Layout:
"""
Utility class to map transpose and non-transpose terminology to row- and column-major terminology
"""
T = cutlass_bindings.RowMajor
N = cutlass_bindings.ColumnMajor
class LayoutCombination:
"""
Utility class defining all combinations of row- and column-major layouts for operands to a GEMMs
"""
NNN = (Layout.N, Layout.N, Layout.N)
NNT = (Layout.N, Layout.N, Layout.T)
NTN = (Layout.N, Layout.T, Layout.N)
NTT = (Layout.N, Layout.T, Layout.T)
TNN = (Layout.T, Layout.N, Layout.N)
TNT = (Layout.T, Layout.N, Layout.T)
TTN = (Layout.T, Layout.T, Layout.N)
TTT = (Layout.T, Layout.T, Layout.T)
def get_name(
layouts,
alignments,
element_output,
element_accumulator,
element_epilogue,
cluster_shape,
threadblock_shape,
stages,
element_a,
element_b,
arch,
opclass,
kernel_schedule=None,
suffix="",
):
"""
Generates a procedural name for a test case.
:param layouts: indexable container of layouts of A, B, and C operands
:param alignments: indexable container of alignments of A, B, and C operands
:param element_output: data type of the output element
:param element_accumulator: data type used in accumulation
:param element_epilogue: data type used in computing the epilogue
:param cluster_shape: indexable container of dimensions of threadblock cluster to be launched
:param threadblock_shape: indexable container of dimensions of threadblock tiles
:param stages: number of pipeline stages to use in the kernel
:type stages: int
:param element_a: data type of operand A
:param element_b: data type of operand B
:param arch: compute capability of kernel being generated
:type arch: int
:param opclass: class of operation being performed (e.g., SIMT, Tensor Core)
:type opclass: cutlass_bindings.OpClass
:param kernel_schedule: kernel_schedule type
:type kernel_schedule: cutlass.KernelScheduleType
:param suffix: additional string to add to the suffix of the name
:type suffix: str
:return: str
"""
name_format = "test_SM${arch}_Device_Gemm_${eA}${lA}_${eB}${lB}_${eC}${lC}_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${cM}x${cN}x${cK}_${stages}_align${aA}-${aB}-${aC}${k}${suffix}"
return SubstituteTemplate(
name_format,
{
"arch": str(arch),
"eA": library.DataTypeNames[element_a],
"eB": library.DataTypeNames[element_b],
"eC": library.DataTypeNames[element_output],
"lA": library.ShortLayoutTypeNames[layouts[0]],
"lB": library.ShortLayoutTypeNames[layouts[1]],
"lC": library.ShortLayoutTypeNames[layouts[2]],
"opclass": library.OpcodeClassNames[opclass],
"acc": library.DataTypeNames[element_accumulator],
"cM": str(cluster_shape[0]),
"cN": str(cluster_shape[1]),
"cK": str(cluster_shape[2]),
"tbM": str(threadblock_shape[0]),
"tbN": str(threadblock_shape[1]),
"tbK": str(threadblock_shape[2]),
"stages": str(stages) if stages is not None else "auto",
"aA": str(alignments[0]),
"aB": str(alignments[1]),
"aC": str(alignments[2]),
"k": "" if kernel_schedule is None else KernelScheduleSuffixes[kernel_schedule],
"suffix": "" if suffix is None else suffix,
},
)

View File

@ -0,0 +1,35 @@
################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
GemmOperation = "Union[GemmOperationUniversal, GemmOperationGrouped]"
Tensor = "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]"

View File

@ -0,0 +1,41 @@
################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
from cutlass.backend.utils.datatypes import *
from cutlass.backend.utils.device import check_cuda_errors, device_cc
from cutlass.backend.utils.reference_model import ReferenceModule
from cutlass.backend.utils.software import (
CheckPackages,
SubstituteTemplate,
device_sm_count,
get_memory_pool,
)

View File

@ -0,0 +1,129 @@
#################################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Utility functions for converting between frontend datatypes and CUTLASS datatypes
"""
import cutlass_bindings
from cutlass.backend.utils.software import CheckPackages
numpy_available = CheckPackages().check_numpy()
if numpy_available:
import numpy as np
numpy_to_cutlass_dict = {
np.float16: cutlass_bindings.float16,
np.float32: cutlass_bindings.float32,
np.float64: cutlass_bindings.float64,
np.int8: cutlass_bindings.int8,
np.int32: cutlass_bindings.int32,
np.dtype('float16'): cutlass_bindings.float16,
np.dtype('float32'): cutlass_bindings.float32,
np.dtype('float64'): cutlass_bindings.float64,
np.dtype('int8'): cutlass_bindings.int8,
np.dtype('int32'): cutlass_bindings.int32,
}
def numpy_to_cutlass(inp):
numpy_available = CheckPackages().check_numpy()
if numpy_available:
return numpy_to_cutlass_dict.get(inp, None)
cupy_available = CheckPackages().check_cupy()
if cupy_available:
import cupy as cp
cupy_to_cutlass_dict = {
cp.float16: cutlass_bindings.float16,
cp.float32: cutlass_bindings.float32,
cp.float64: cutlass_bindings.float64,
}
def cupy_to_cutlass(inp):
cupy_available = CheckPackages().check_cupy()
if cupy_available:
return cupy_to_cutlass_dict.get(inp, None)
torch_available = CheckPackages().check_torch()
if torch_available:
import torch
torch_to_cutlass_dict = {
torch.half: cutlass_bindings.float16,
torch.float16: cutlass_bindings.float16,
torch.float: cutlass_bindings.float32,
torch.float32: cutlass_bindings.float32,
torch.double: cutlass_bindings.float64,
torch.float64: cutlass_bindings.float64,
}
def torch_to_cutlass(inp):
if torch_available:
return torch_to_cutlass_dict.get(inp, None)
try:
import bfloat16
bfloat16_available = True
numpy_to_cutlass_dict[np.dtype(bfloat16.bfloat16)] = cutlass_bindings.bfloat16
except ImportError:
bfloat16_available = False
def bfloat16_to_cutlass(inp):
if bfloat16_available:
if inp == bfloat16.bfloat16:
return cutlass_bindings.bfloat16
def to_cutlass(inp):
for cvt_fn in [
bfloat16_to_cutlass,
cupy_to_cutlass,
numpy_to_cutlass,
torch_to_cutlass,
]:
out = cvt_fn(inp)
if out is not None:
return out
raise Exception(
"No available conversion from type {} to a CUTLASS type.".format(inp)
)

View File

@ -0,0 +1,76 @@
#################################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Utility functions for interacting with the device
"""
from cuda import cudart
def check_cuda_errors(result: list):
"""
Checks whether `result` contains a CUDA error raises the error as an exception, if so. Otherwise,
returns the result contained in the remaining fields of `result`.
:param result: the results of the `cudart` method, consisting of an error code and any method results
:type result: list
:return: non-error-code results from the `results` parameter
"""
# `result` is of the format : (cudaError_t, result...)
err = result[0]
if err.value:
raise RuntimeError("CUDA error: {}".format(cudart.cudaGetErrorName(err)))
if len(result) == 1:
return None
elif len(result) == 2:
return result[1]
else:
return result[1:]
def device_cc(device: int = 0) -> int:
"""
Returns the compute capability of the device with ID `device`.
:param device: ID of the device to query
:type device: int
:return: compute capability of the queried device (e.g., 80 for SM80)
:rtype: int
"""
deviceProp = check_cuda_errors(cudart.cudaGetDeviceProperties(device))
major = str(deviceProp.major)
minor = str(deviceProp.minor)
return int(major + minor)

View File

@ -0,0 +1,317 @@
#################################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from typing import Union
from bfloat16 import bfloat16
import cutlass_bindings
import numpy as np
from cutlass.backend.library import TensorDescription
from cutlass.backend.utils.software import CheckPackages
torch_available = CheckPackages().check_torch()
if torch_available:
import torch
class ReferenceModule:
def __init__(
self, A: TensorDescription, B: TensorDescription, C: TensorDescription
) -> None:
self.layout_A = A.layout
self.layout_B = B.layout
self.layout_C = C.layout
def run(
self,
A: np.ndarray,
B: np.ndarray,
C: np.ndarray,
problem_size: cutlass_bindings.gemm.GemmCoord,
alpha: float = 1.0,
beta: float = 0.0,
bias=False,
batch=1,
):
"""
Compute the reference result on CPU
Args:
A: dense operator with shape (M, K) in row-major and (K, M) in column-major
B: dense operator with shape (K, N) in row-major and (N, K) in column-major
C: dense operator with shape (M, N) in row-major and (N, M) in column-major
"""
M, N, K = problem_size.m(), problem_size.n(), problem_size.k()
if isinstance(A, np.ndarray):
if self.layout_A == cutlass_bindings.RowMajor:
A_row = np.reshape(A, newshape=(batch, M, K))
else:
A_col = np.reshape(A, newshape=(batch, K, M))
A_row = np.transpose(A_col, axes=(0, 2, 1))
if self.layout_B == cutlass_bindings.RowMajor:
B_row = np.reshape(B, newshape=(batch, K, N))
else:
B_col = np.reshape(B, newshape=(batch, N, K))
B_row = np.transpose(B_col, axes=(0, 2, 1))
if self.layout_C == cutlass_bindings.RowMajor:
if bias:
C_row = np.reshape(C, newshape=(batch, 1, N))
else:
C_row = np.reshape(C, newshape=(batch, M, N))
else:
if bias:
C_row = np.reshape(C, newshape=(batch, M, 1))
else:
C_col = np.reshape(C, newshape=(batch, N, M))
C_row = np.transpose(C_col, axes=(0, 2, 1))
if A_row.dtype == bfloat16:
# numpy's einsum doesn't support bfloat16
out_row = (
np.einsum(
"bik,bkj->bij",
A_row.astype(np.float32),
B_row.astype(np.float32),
)
* alpha
+ C_row * beta
)
out_row = out_row.astype(C_row.dtype)
else:
out_row = np.einsum("bik,bkj->bij", A_row, B_row) * alpha + C_row * beta
if self.layout_C == cutlass_bindings.ColumnMajor:
out = np.transpose(out_row, axes=(0, 2, 1))
else:
out = out_row
return out.ravel()
elif isinstance(A, torch.Tensor):
if self.layout_A == cutlass_bindings.RowMajor:
A_row = A.view((M, K))
else:
A_col = A.view((K, M))
A_row = torch.permute(A_col, (1, 0))
if self.layout_B == cutlass_bindings.RowMajor:
B_row = B.view((K, N))
else:
B_col = B.view((N, K))
B_row = torch.permute(B_col, (1, 0))
if self.layout_C == cutlass_bindings.RowMajor:
C_row = C.view((M, N))
else:
C_col = C.view((N, M))
C_row = torch.permute(C_col, (1, 0))
out_row = torch.matmul(A_row, B_row) * alpha + C_row * beta
if self.layout_C == cutlass_bindings.ColumnMajor:
out = torch.permute(out_row, (1, 0))
else:
out = out_row
return torch.flatten(out)
#####################################################################################################
# Conv2d
#####################################################################################################
if torch_available:
import torch
class Conv2dReferenceModule:
def __init__(
self,
A: TensorDescription,
B: TensorDescription,
C: TensorDescription,
kind: cutlass_bindings.conv.Operator.fprop,
) -> None:
self.layout_A = A.layout
self.layout_B = B.layout
self.layout_C = C.layout
self.kind = kind
def run(
self,
A: Union[np.ndarray, torch.Tensor],
B: Union[np.ndarray, torch.Tensor],
C: Union[np.ndarray, torch.Tensor],
problem_size,
alpha=1.0,
beta=0.0,
bias=False,
) -> np.ndarray:
"""
Compute the reference result on CPU
"""
n = problem_size.N
h = problem_size.H
w = problem_size.W
c = problem_size.C
k = problem_size.K
r = problem_size.R
s = problem_size.S
p = problem_size.P
q = problem_size.Q
stride_h = problem_size.stride_h
stride_w = problem_size.stride_w
pad_h = problem_size.pad_h
pad_w = problem_size.pad_w
dilation_h = problem_size.dilation_h
dilation_w = problem_size.dilation_w
groups = problem_size.groups
if isinstance(A, np.ndarray):
# the pytorch activation layout is NCHW
# weight layout is Cout Cin Kh Kw (also NCHW)
if self.layout_A == cutlass_bindings.TensorNHWC:
A_nhwc = np.reshape(A, newshape=(n, h, w, c))
A_torch_nhwc = torch.from_numpy(A_nhwc).to("cuda")
A_torch_nchw = torch.permute(A_torch_nhwc, (0, 3, 1, 2))
if self.layout_B == cutlass_bindings.TensorNHWC:
B_nhwc = np.reshape(B, newshape=(k, r, s, c))
B_torch_nhwc = torch.from_numpy(B_nhwc).to("cuda")
B_torch_nchw = torch.permute(B_torch_nhwc, (0, 3, 1, 2))
if self.layout_C == cutlass_bindings.TensorNHWC:
C_nhwc = np.reshape(C, newshape=(n, p, q, k))
C_torch_nhwc = torch.from_numpy(C_nhwc).to("cuda")
C_torch_nchw = torch.permute(C_torch_nhwc, (0, 3, 1, 2))
elif isinstance(A, torch.Tensor):
if self.kind == cutlass_bindings.conv.Operator.wgrad:
if self.layout_A == cutlass_bindings.TensorNHWC:
A_nhwc = A.view((n, p, q, k))
A_torch_nchw = torch.permute(A_nhwc, (0, 3, 1, 2))
if self.layout_B == cutlass_bindings.TensorNHWC:
B_nhwc = B.view((n, h, w, c))
B_torch_nchw = torch.permute(B_nhwc, (0, 3, 1, 2))
if self.layout_C == cutlass_bindings.TensorNHWC:
if bias:
C_nhwc = C.view((1, 1, 1, c))
else:
C_nhwc = C.view((k, r, s, c))
C_torch_nchw = torch.permute(C_nhwc, (0, 3, 1, 2))
elif self.kind == cutlass_bindings.conv.Operator.dgrad:
if self.layout_A == cutlass_bindings.TensorNHWC:
A_nhwc = A.view((n, p, q, k))
A_torch_nchw = torch.permute(A_nhwc, (0, 3, 1, 2))
if self.layout_B == cutlass_bindings.TensorNHWC:
B_nhwc = B.view((k, r, s, c))
B_torch_nchw = torch.permute(B_nhwc, (0, 3, 1, 2))
if self.layout_C == cutlass_bindings.TensorNHWC:
if bias:
C_nhwc = C.view((1, 1, 1, c))
else:
C_nhwc = C.view((n, h, w, c))
C_torch_nchw = torch.permute(C_nhwc, (0, 3, 1, 2))
else:
if self.layout_A == cutlass_bindings.TensorNHWC:
A_nhwc = A.view((n, h, w, c))
A_torch_nchw = torch.permute(A_nhwc, (0, 3, 1, 2))
if self.layout_B == cutlass_bindings.TensorNHWC:
B_nhwc = B.view((k, r, s, c))
B_torch_nchw = torch.permute(B_nhwc, (0, 3, 1, 2))
if self.layout_C == cutlass_bindings.TensorNHWC:
if bias:
C_nhwc = C.view((1, 1, 1, k))
else:
C_nhwc = C.view((n, p, q, k))
C_torch_nchw = torch.permute(C_nhwc, (0, 3, 1, 2))
if self.kind == cutlass_bindings.conv.Operator.fprop:
D_torch_nchw = (
alpha
* torch.nn.functional.conv2d(
A_torch_nchw,
B_torch_nchw,
stride=(stride_h, stride_w),
padding=(pad_h, pad_w),
dilation=(dilation_h, dilation_w),
groups=groups,
)
+ beta * C_torch_nchw
)
elif self.kind == cutlass_bindings.conv.Operator.dgrad:
D_torch_nchw = (
alpha
* torch.nn.grad.conv2d_input(
(n, c, h, w),
B_torch_nchw,
A_torch_nchw,
padding=(pad_h, pad_w),
stride=(stride_h, stride_w),
).to(torch.float32)
+ beta * C_torch_nchw
)
elif self.kind == cutlass_bindings.conv.Operator.wgrad:
D_torch_nchw = (
alpha
* torch.nn.grad.conv2d_weight(
B_torch_nchw,
(k, c, r, s),
A_torch_nchw,
padding=(pad_h, pad_w),
stride=(stride_h, stride_w),
).to(torch.float32)
+ beta * C_torch_nchw
)
if self.layout_C == cutlass_bindings.TensorNHWC:
if isinstance(A, np.ndarray):
D_torch_out = (
torch.permute(D_torch_nchw, (0, 2, 3, 1)).detach().cpu().numpy()
)
elif isinstance(A, torch.Tensor):
D_torch_out = torch.permute(D_torch_nchw, (0, 2, 3, 1))
return D_torch_out.flatten()

View File

@ -0,0 +1,111 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import re
import sys
from cutlass.backend.memory_manager import PoolMemoryManager
class CheckPackages:
def __init__(self) -> None:
pass
def check_cupy(self):
if "cupy" in sys.modules:
return True
else:
try:
import cupy
cupy_available = True
except ImportError:
print("cupy is not loaded.")
def check_numpy(self):
if "numpy" in sys.modules:
return True
else:
try:
import numpy
numpy_available = True
except ImportError:
print("numpy is not loaded.")
def check_torch(self):
if "torch" in sys.modules:
return True
else:
try:
import torch
torch_available = True
except ImportError:
print("torch is not loaded.")
def SubstituteTemplate(template, values):
text = template
changed = True
while changed:
changed = False
for key, value in values.items():
regex = "\\$\\{%s\\}" % key
newtext = re.sub(regex, value, text)
if newtext != text:
changed = True
text = newtext
return text
# this._device_sm_count = None
def device_sm_count():
# Query the number of SMs, if needed
# if this._device_sm_count is None:
from cuda import cuda
_device = 0
err, _device_sm_count = cuda.cuDeviceGetAttribute(
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, _device
)
if err != cuda.CUresult.CUDA_SUCCESS:
raise Exception(
"Failed to retireve SM count. "
f"cuDeviceGetAttribute() failed with error: {cuda.cuGetErrorString(err)[1]}"
)
return _device_sm_count
def get_memory_pool(init_pool_size=0, max_pool_size=2 ** 34):
memory_pool = PoolMemoryManager(
init_pool_size=init_pool_size, max_pool_size=max_pool_size
)
return memory_pool

View File

@ -0,0 +1,75 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief In-memory compiled artifact cache
*/
#include <pybind11/pybind11.h>
#include <string>
#include <unordered_map>
namespace py = pybind11;
namespace cutlass {
struct CompileCache {
public:
CompileCache() = default;
~CompileCache() = default;
using Cache = std::unordered_map<std::string, py::object>;
/// Check if the kernel has already been compiled
py::object at(const std::string &kernel) {
auto item = cache_.find(kernel);
if (item != cache_.end()) {
return item->second;
}
return py::none();
}
/// Insert a new compiled kernel for new configuration
void insert(const std::string &kernel, const py::object &compiled_kernel){
cache_.emplace(kernel, compiled_kernel);
}
const int64_t size() const { return cache_.size(); }
/// Clear the cache
void clear() { cache_.clear(); }
private:
Cache cache_;
};
} // namespace cutlass

View File

@ -0,0 +1,182 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief binding CUTLASS C++ APIs to Python
*/
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "builtin_types.h"
#include "device_launch_parameters.h"
#include "stddef.h"
#include "cutlass/cutlass.h"
#include "include/conv/convolution.h"
#include "include/gemm/gemm.h"
#include "include/types.h"
#include "include/layout/layout.h"
#include "include/tensor_coord.h"
#include "include/arch.h"
#include "include/tensor_ref_view.h"
#include "include/swizzling.h"
#include "test/conv/convolution.h"
#include "test/gemm/gemm.h"
// Data Types
#include "library.h"
// compiler
#include "compiler.h"
namespace py = pybind11;
PYBIND11_MODULE(cutlass_bindings, m) {
// module doc
m.doc() = "CUTLASS C++ binding";
//
// Bind data type
//
bind_cutlass_types(m);
//
// Bind layout
//
bind_layout(m);
//
// Bind tensor coord
//
bind_tensor_coord(m);
//
// Bind tensor ref
//
bind_tensor_refs_and_views(m);
//
// Bind opcode
//
bind_opcode(m);
//
// Bind convolution
//
py::module_ conv_submodule = m.def_submodule("conv");
bind_convolution(conv_submodule);
//
// Bind gemm
//
py::module_ gemm_submodule = m.def_submodule("gemm");
bind_gemm(gemm_submodule);
//
// Bind swizzling
//
bind_threadblock_swizzle(m);
//
// Bind test units
//
py::module_ test = m.def_submodule("test");
py::module_ test_conv = test.def_submodule("conv");
bind_convolution_test(test_conv);
py::module_ test_gemm = test.def_submodule("gemm");
bind_gemm_test(test_gemm);
// data types
py::enum_<cutlass::DataType>(m, "dtype")
.value("b1", cutlass::DataType::kB1)
.value("u2", cutlass::DataType::kU2)
.value("u4", cutlass::DataType::kU4)
.value("u8", cutlass::DataType::kU8)
.value("u16", cutlass::DataType::kU16)
.value("u32", cutlass::DataType::kU32)
.value("u64", cutlass::DataType::kU64)
.value("s2", cutlass::DataType::kS2)
.value("s4", cutlass::DataType::kS4)
.value("s16", cutlass::DataType::kS16)
.value("s64", cutlass::DataType::kS64)
.value("cf16", cutlass::DataType::kCF16)
.value("cbf16", cutlass::DataType::kCBF16)
.value("cf32", cutlass::DataType::kCF32)
.value("ctf32", cutlass::DataType::kCTF32)
.value("cf64", cutlass::DataType::kCF64)
.value("cs2", cutlass::DataType::kCS2)
.value("cs4", cutlass::DataType::kCS4)
.value("cs8", cutlass::DataType::kCS8)
.value("cs16", cutlass::DataType::kCS16)
.value("cs32", cutlass::DataType::kCS32)
.value("cs64", cutlass::DataType::kCS64)
.value("cu2", cutlass::DataType::kCU2)
.value("cu4", cutlass::DataType::kCU4)
.value("cu8", cutlass::DataType::kCU8)
.value("cu16", cutlass::DataType::kCU16)
.value("cu32", cutlass::DataType::kCU32)
.value("cu64", cutlass::DataType::kCU64)
.value("invalid", cutlass::DataType::kInvalid);
// layout types
py::enum_<cutlass::LayoutType>(m, "layout")
.value("ColumnMajorInterleaved2", cutlass::LayoutType::kColumnMajorInterleaved2)
.value("RowMajorInterleaved2", cutlass::LayoutType::kRowMajorInterleaved2)
.value("ColumnMajorInterleaved64", cutlass::LayoutType::kColumnMajorInterleaved64)
.value("RowMajorInterleaved64", cutlass::LayoutType::kRowMajorInterleaved64)
.value("TensorNDHWC", cutlass::LayoutType::kTensorNDHWC)
.value("TensorNCHW", cutlass::LayoutType::kTensorNCHW)
.value("TensorNGHWC", cutlass::LayoutType::kTensorNGHWC)
.value("TensorNC64HW64", cutlass::LayoutType::kTensorNC64HW64)
.value("TensorC64RSK64", cutlass::LayoutType::kTensorC64RSK64);
// transform types
py::enum_<cutlass::ComplexTransform>(m, "complex_transform")
.value("none", cutlass::ComplexTransform::kNone)
.value("conj", cutlass::ComplexTransform::kConjugate);
//
// Compiler
//
py::class_<cutlass::CompileCache>(m, "CompileCache")
.def(py::init<>())
.def("at", &cutlass::CompileCache::at)
.def("insert", &cutlass::CompileCache::insert)
.def("size", &cutlass::CompileCache::size)
.def("clear", &cutlass::CompileCache::clear);
}

View File

@ -0,0 +1,59 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind opcode classes to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/arch/mma.h"
namespace py = pybind11;
namespace cutlass {
enum class OpcodeClass {
kSimt, kTensorOp, kWmmaTensorOp, kSparseTensorOp
};
}
void bind_opcode(py::module &m) {
py::enum_<cutlass::OpcodeClass>(m, "OpClass",
R"pbdoc(classification of math operators)pbdoc")
.value("Simt", cutlass::OpcodeClass::kSimt,
R"pbdoc(Tag classifying math operators as thread-level operations)pbdoc")
.value("TensorOp", cutlass::OpcodeClass::kTensorOp,
R"pbdoc(Tag classifying operators as Tensor Core operations)pbdoc")
.value("WmmaTensorOp", cutlass::OpcodeClass::kWmmaTensorOp,
R"pbdoc(Tag classifying operators as WMMA Tensor Core operations)pbdoc")
.value("SparseTensorOp", cutlass::OpcodeClass::kSparseTensorOp,
R"pbdoc(Tag classifying operators as sparseTensor Core operations)pbdoc");
}

View File

@ -0,0 +1,102 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind Convolution problem sizes to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/conv/conv2d_problem_size.h"
namespace py = pybind11;
void bind_conv_problem_size(py::module &m) {
//
// Conv2d Problem Size:
// include/cutlass/conv/conv2d_problem_size.h
//
py::class_<cutlass::conv::Conv2dProblemSize>(m, "Conv2dProblemSize")
// constructors
.def(py::init<int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, cutlass::conv::Mode, int, int>())
.def(py::init<cutlass::Tensor4DCoord, cutlass::Tensor4DCoord, cutlass::Tensor4DCoord, cutlass::MatrixCoord, cutlass::MatrixCoord, cutlass::conv::Mode, int, int>())
// attribute accessors
.def_readwrite("N", &cutlass::conv::Conv2dProblemSize::N)
.def_readwrite("H", &cutlass::conv::Conv2dProblemSize::H)
.def_readwrite("W", &cutlass::conv::Conv2dProblemSize::W)
.def_readwrite("C", &cutlass::conv::Conv2dProblemSize::C)
.def_readwrite("P", &cutlass::conv::Conv2dProblemSize::P)
.def_readwrite("Q", &cutlass::conv::Conv2dProblemSize::Q)
.def_readwrite("K", &cutlass::conv::Conv2dProblemSize::K)
.def_readwrite("R", &cutlass::conv::Conv2dProblemSize::R)
.def_readwrite("S", &cutlass::conv::Conv2dProblemSize::S)
.def_readwrite("pad_h", &cutlass::conv::Conv2dProblemSize::pad_h)
.def_readwrite("pad_w", &cutlass::conv::Conv2dProblemSize::pad_w)
.def_readwrite("stride_h", &cutlass::conv::Conv2dProblemSize::stride_h)
.def_readwrite("stride_w", &cutlass::conv::Conv2dProblemSize::stride_w)
.def_readwrite("dilation_h", &cutlass::conv::Conv2dProblemSize::dilation_h)
.def_readwrite("dilation_w", &cutlass::conv::Conv2dProblemSize::dilation_w)
.def_readwrite("mode", &cutlass::conv::Conv2dProblemSize::mode)
.def_readwrite("split_k_slices", &cutlass::conv::Conv2dProblemSize::split_k_slices)
.def_readwrite("groups", &cutlass::conv::Conv2dProblemSize::groups)
// functions
.def("reset_split_k_slices", &cutlass::conv::Conv2dProblemSize::reset_split_k_slices)
.def("activation_extent", &cutlass::conv::Conv2dProblemSize::activation_extent)
.def("filter_extent", &cutlass::conv::Conv2dProblemSize::filter_extent)
.def("output_extent", &cutlass::conv::Conv2dProblemSize::output_extent)
.def("activation_size", &cutlass::conv::Conv2dProblemSize::activation_size)
.def("filter_size", &cutlass::conv::Conv2dProblemSize::filter_size)
.def("output_size", &cutlass::conv::Conv2dProblemSize::output_size);
// Get tensor size
m.def("implicit_gemm_tensor_a_size", py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&>(&cutlass::conv::implicit_gemm_tensor_a_size));
m.def("implicit_gemm_tensor_b_size", py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&>(&cutlass::conv::implicit_gemm_tensor_b_size));
m.def("implicit_gemm_tensor_c_size", py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&>(&cutlass::conv::implicit_gemm_tensor_c_size));
// Get tensor extent
m.def("implicit_gemm_tensor_a_extent",
py::overload_cast<
cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&
>(&cutlass::conv::implicit_gemm_tensor_a_extent));
m.def("implicit_gemm_tensor_b_extent",
py::overload_cast<
cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&
>(&cutlass::conv::implicit_gemm_tensor_b_extent));
m.def("implicit_gemm_tensor_c_extent",
py::overload_cast<
cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&
>(&cutlass::conv::implicit_gemm_tensor_c_extent));
m.def("implicit_gemm_problem_size", py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize &>(&cutlass::conv::implicit_gemm_problem_size));
}

View File

@ -0,0 +1,91 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind convolution related enum types to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "conv_problem_size.h"
#include "host.h"
#include "cutlass/conv/convolution.h"
namespace py = pybind11;
void bind_convolution(py::module &m) {
//
// Enumerate types
// cutlass/include/cutlass/conv/convolution.h
//
/// Convolutional operator
py::enum_<cutlass::conv::Operator>(m, "Operator", R"pbdoc(Convolutional operator)pbdoc")
.value("fprop", cutlass::conv::Operator::kFprop, "Forward propagation")
.value("dgrad", cutlass::conv::Operator::kDgrad, "Activation grad")
.value("wgrad", cutlass::conv::Operator::kWgrad, "Weight grad");
/// Distinguishes convolution from cross correlation
py::enum_<cutlass::conv::Mode>(m, "Mode")
.value("cross_correlation", cutlass::conv::Mode::kCrossCorrelation)
.value("convolution", cutlass::conv::Mode::kConvolution);
/// Selects among several implementation variants trading off performance with simplicity
py::enum_<cutlass::conv::IteratorAlgorithm>(m, "IteratorAlgorithm",
R"pbdoc(Selects among several implementation variants trading off performance with simplicity)pbdoc")
.value("analytic", cutlass::conv::IteratorAlgorithm::kAnalytic, R"pbdoc(functionally correct in all cases but lower performance)pbdoc")
.value("optimized", cutlass::conv::IteratorAlgorithm::kOptimized, R"pbdoc(optimized for R <= 32, S <= 32 and unity-stride dgrad)pbdoc")
.value("fixed_channels", cutlass::conv::IteratorAlgorithm::kFixedChannels, R"pbdoc(Analytic algorithm optimized for fixed channel count (C == AccessSize))pbdoc")
.value("few_channels", cutlass::conv::IteratorAlgorithm::kFewChannels, R"pbdoc(Analytic algorithm optimized for few channels (C divisible by AccessSize))pbdoc");
/// Distinguishes among partial specializations that accelerate certain problems where convolution
/// stride is unit.
py::enum_<cutlass::conv::StrideSupport>(m, "StrideSupport",
R"pbdoc(Distinguishes among partial specializations that accelerate certain problems where convolution
stride is unit.)pbdoc")
.value("strided", cutlass::conv::StrideSupport::kStrided, R"pbdoc(arbitrary convolution stride)pbdoc")
.value("unity", cutlass::conv::StrideSupport::kUnity, R"pbdoc(unit convolution stride)pbdoc");
/// Identifies split-K mode
py::enum_<cutlass::conv::SplitKMode>(m, "SplitKMode")
.value("None", cutlass::conv::SplitKMode::kNone)
.value("Serial", cutlass::conv::SplitKMode::kSerial)
.value("Parallel", cutlass::conv::SplitKMode::kParallel);
// Conv problem sizes
bind_conv_problem_size(m);
//
// host helper functions
//
py::module_ host_submodule = m.def_submodule("host");
bind_conv_host_helper(host_submodule);
}

View File

@ -0,0 +1,54 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind conv host helpers to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/util/host_reorder.h"
#include "cutlass/layout/tensor.h"
namespace py = pybind11;
void bind_conv_host_helper(py::module &m) {
/// reorder operand B for interleaved layout
m.def("reorder_convK", [](
cutlass::TensorRef<int8_t, cutlass::layout::TensorCxRSKx<32>> dest,
cutlass::TensorRef<int8_t, cutlass::layout::TensorCxRSKx<32>> src,
cutlass::conv::Operator conv_op, const cutlass::conv::Conv2dProblemSize & problem_size) {
cutlass::gemm::GemmCoord implicit_problem_size = cutlass::conv::implicit_gemm_problem_size(conv_op, problem_size);
cutlass::reorder_convK<32>(dest, src, implicit_problem_size);
});
}

View File

@ -0,0 +1,222 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A generic wrapper around an epilogue visitor operation
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/memory.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/gemm/kernel/default_gemm.h"
#include "cutlass/gemm/kernel/default_gemm_complex.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h"
#include "epilogue_visitor_op/visitor_op_linear_combination.h"
#include "epilogue_visitor_op/visitor_op_tensor_input.h"
#include "epilogue_visitor_op/visitor_op_accumulator.h"
#include "epilogue_visitor_op/visitor_op_row_broadcast.h"
#include "epilogue_visitor_op/visitor_op_tensor_output.h"
#include "epilogue_visitor_op/visitor_op_column_reduction.h"
#include "epilogue_visitor_op/visitor_op_row_reduction.h"
#include "epilogue_visitor_op/visitor_op_column_broadcast.h"
#include "epilogue_visitor_op/visitor_op_unary.h"
#include "epilogue_visitor_op/visitor_op_binary.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Generic Epilogue Visitor.
template <
typename OutputOp_
>
class EpilogueVisitorGeneric {
public:
using OutputOp = OutputOp_;
using AccumulatorAccessType = typename OutputOp::AccumulatorAccessType;
static int const kElementsPerAccess = OutputOp::kElementsPerAccess;
using ElementOutput = typename OutputOp::ElementOutput;
using OutputTileIterator = typename OutputOp::OutputTileIterator;
static int const kIterations = OutputTileIterator::kIterations;
///
/// End Epilogue Tree
///
/// Additional SMEM bufer is not required in the broadcast epilogue visitor
struct SharedStorage {
typename OutputOp::SharedStorage output_smem;
CUTLASS_HOST_DEVICE
SharedStorage() { }
};
public:
/// Argument structure
struct Arguments {
typename OutputOp::Arguments output_op_args;
//
// Methods
//
Arguments() { }
Arguments(
typename OutputOp::Arguments output_op_args
):
output_op_args(output_op_args)
{
}
};
struct Params {
typename OutputOp::Params output_op_params;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params() { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
output_op_params(args.output_op_args)
{
}
};
private:
OutputOp output_op;
public:
/// Constructor
CUTLASS_DEVICE
EpilogueVisitorGeneric(
Params const &params, ///< Parameters routed to the epilogue
SharedStorage &shared_storage, ///< Shared storage needed by the functors here
MatrixCoord threadblock_offset,
gemm::GemmCoord threadblock_tile_offset,
int thread_idx,
MatrixCoord problem_size
):
output_op(params.output_op_params, shared_storage.output_smem, thread_idx, threadblock_offset, problem_size)
{ }
/// Helper to indicate split-K behavior
CUTLASS_DEVICE
void set_k_partition(
int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
int split_k_slices) { ///< Total number of split-K slices
}
/// Called to set the batch index
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
output_op.set_batch_index(batch_idx);
}
/// Called at the start of the epilogue just before iterating over accumulator slices
CUTLASS_DEVICE
void begin_epilogue() {
output_op.begin_epilogue();
}
/// Called at the start of one step before starting accumulator exchange
CUTLASS_DEVICE
void begin_step(int step_idx) {
output_op.begin_step(step_idx);
}
/// Called at the start of a row
CUTLASS_DEVICE
void begin_row(int row_idx) {
output_op.begin_row(row_idx);
}
/// Called after accumulators have been exchanged for each accumulator vector
CUTLASS_DEVICE
void visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum) {
output_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
}
/// Called at the start of a row
CUTLASS_DEVICE
void end_row(int row_idx) {
output_op.end_row(row_idx);
}
/// Called after all accumulator elements have been visited
CUTLASS_DEVICE
void end_step(int step_idx) {
output_op.end_step(step_idx);
}
/// Called after all steps have been completed
CUTLASS_DEVICE
void end_epilogue() {
output_op.end_epilogue();
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,84 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the binary ops
*/
#pragma once
#include "cutlass/cutlass.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Scalar multiplication
template <typename T, int N>
struct VectorAdd {
struct Arguments {
int tmp;
CUTLASS_HOST_DEVICE
Arguments():tmp(0){ }
CUTLASS_HOST_DEVICE
Arguments(int tmp): tmp(tmp) { }
};
struct Params {
CUTLASS_HOST_DEVICE
Params(Arguments const &args) { }
};
CUTLASS_HOST_DEVICE
VectorAdd(
Params const &params
) { }
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
cutlass::plus<Array<T, N>> add_op;
return add_op(lhs, rhs);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,233 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the unary ops
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/thread/activation.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Scalar multiplication
template <typename T, int N>
struct Mult {
struct Arguments {
T alpha;
CUTLASS_HOST_DEVICE
Arguments():alpha(T(1.0)){ }
CUTLASS_HOST_DEVICE
Arguments(T alpha): alpha(alpha) { }
};
struct Params {
T alpha; ///< scales accumulators
CUTLASS_HOST_DEVICE
Params():alpha(T(1.0)){ }
CUTLASS_HOST_DEVICE
Params(Arguments const &args): alpha(args.alpha) { }
};
T alpha_;
CUTLASS_HOST_DEVICE
Mult(
Params const &params
):
alpha_(params.alpha)
{ }
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &source) const {
cutlass::multiplies<Array<T, N>> multiply_op;
return multiply_op(source, alpha_);
}
CUTLASS_HOST_DEVICE
bool guard() {
return alpha_ != T(0);
}
};
/// ReLU
template <typename T, int N>
struct ReLUVisitor {
struct Arguments {
T threshold;
CUTLASS_HOST_DEVICE
Arguments():threshold(T(0.0)) { }
CUTLASS_HOST_DEVICE
Arguments(T threshold): threshold(threshold) { }
};
struct Params {
T threshold;
CUTLASS_HOST_DEVICE
Params():threshold(T(0.0)) { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args): threshold(args.threshold) { }
};
T threshold_;
CUTLASS_HOST_DEVICE
ReLUVisitor(Params const &params):
threshold_(params.threshold) { }
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &frag) const {
maximum<Array<T, N>> mx;
return mx(frag, threshold_);
}
CUTLASS_HOST_DEVICE
bool guard() {
return true;
}
};
/// leakyReLU
template <typename T, int N>
struct LeakyReLUVisitor {
struct Arguments {
T leaky_alpha;
CUTLASS_HOST_DEVICE
Arguments():leaky_alpha(T(0.0)) { }
CUTLASS_HOST_DEVICE
Arguments(T leaky_alpha): leaky_alpha(leaky_alpha) { }
};
struct Params {
T leaky_alpha;
CUTLASS_HOST_DEVICE
Params():leaky_alpha(T(0.0)) { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args): leaky_alpha(args.leaky_alpha) { }
};
T leaky_alpha_;
CUTLASS_HOST_DEVICE
LeakyReLUVisitor(Params const &params):
leaky_alpha_(params.leaky_alpha) { }
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &frag) const {
cutlass::epilogue::thread::LeakyReLU<Array<T, N>> leaky_op;
return leaky_op(frag, leaky_alpha_);
}
CUTLASS_HOST_DEVICE
bool guard() {
return true;
}
};
/// Tanh
template <typename T, int N>
struct TanhVisitor {
/// Argument
struct Arguments {
// a placeholder argument to ensure correctness of ctypes
int tmp;
CUTLASS_HOST_DEVICE
Arguments(): tmp(0) { };
CUTLASS_HOST_DEVICE
Arguments(int tmp): tmp(tmp) { };
};
/// Param
struct Params {
CUTLASS_HOST_DEVICE
Params(){ };
Params(Arguments const &args) { }
};
/// Constructor
CUTLASS_HOST_DEVICE
TanhVisitor(Params const &params) { }
// scalar operator
CUTLASS_HOST_DEVICE
T tanh_op(T const &scalar) const {
return fast_tanh(scalar);
}
/// vector operator
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &frag) const {
Array<T, N> y;
CUTLASS_PRAGMA_UNROLL
for (int i=0; i < N; ++i) {
y[i] = tanh_op(frag[i]);
}
return y;
}
CUTLASS_HOST_DEVICE
bool guard() {
return true;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,148 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor Op with accumulator
*/
#pragma once
#include "cutlass/cutlass.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue Visitor operator for the following Computation
///
/// ElementAccumulator accum;
/// return accum;
///
/// It can only be the leaf node of the epilogue tree
template <
typename ElementAccumulator_, ///< Data type of the Accumulator
int kElementsPerAccess_ ///< Number of elements computed per operation
>
class VisitorOpAccumulator{
public:
using ElementAccumulator = ElementAccumulator_;
static int const kElementsPerAccess = kElementsPerAccess_;
/// Fragment type for Accumulator
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
/// Fragment type returned by this visitor
using VisitAccessType = AccumulatorAccessType;
/// SMEM buffer class required in the epilogue visitor
struct SharedStorage {
CUTLASS_HOST_DEVICE
SharedStorage() {}
};
/// Host-constructable Arguments structure
struct Arguments {
// Note: it is strange that ctypes will return issue with empty arguments
int tmp;
CUTLASS_HOST_DEVICE
Arguments() { }
CUTLASS_HOST_DEVICE
Arguments(int tmp): tmp(tmp) { }
};
/// Parameter structure
struct Params {
CUTLASS_HOST_DEVICE
Params() { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args) { }
};
public:
/// Constructs the function object
CUTLASS_HOST_DEVICE
VisitorOpAccumulator(
Params const &params,
SharedStorage &shared_storage,
int thread_idx,
MatrixCoord threadblock_offset,
MatrixCoord problem_size
) { }
CUTLASS_DEVICE
void set_batch_index(int batch_idx) { }
CUTLASS_DEVICE
void begin_epilogue() { }
CUTLASS_DEVICE
void begin_step(int step_idx) { }
CUTLASS_DEVICE
void begin_row(int row_idx) { }
CUTLASS_DEVICE
VisitAccessType visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum
) {
return accum;
}
CUTLASS_DEVICE
void end_row(int row_idx) { }
CUTLASS_DEVICE
void end_step(int step_idx) { }
CUTLASS_DEVICE
void end_epilogue() { }
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,245 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor Op with Binary op
*/
#pragma once
#include "cutlass/cutlass.h"
#include "binary_ops.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue Visitor operator for the following computation:
///
/// ElementCompute alpha;
/// ElementCompute beta;
/// ElementCompute C = BinaryOp(alpha * ElementCompute(Visitor_A), beta * ElementCompute(Visitor_B)
/// Return C;
///
template <
typename ElementAccumulator_, ///< Data type of the Accumulator
typename ElementCompute_, ///< Data type used to compute linear combination
int kElementsPerAccess_, ///< Number of elements computed per operation
typename VisitorA_, ///< Child node A
typename VisitorB_, ///< Child node B
template<typename T, int N> typename BinaryOp_
>
class VisitorOpBinary{
public:
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
static int const kElementsPerAccess = kElementsPerAccess_;
using VisitorA = VisitorA_;
using VisitorB = VisitorB_;
/// Fragment type returned from VisitorA.visit
using VisitAccessTypeA = typename VisitorA::VisitAccessType;
using ElementA = typename VisitAccessTypeA::Element;
/// Fragment type returned from VisitorB.visit
using VisitAccessTypeB = typename VisitorB::VisitAccessType;
using ElementB = typename VisitAccessTypeB::Element;
/// Fragment type returned by this visitor
using VisitAccessType = Array<ElementCompute, kElementsPerAccess>;
/// Fragment type of accumulator
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
using BinaryOp = BinaryOp_<ElementCompute, kElementsPerAccess>;
static_assert(kElementsPerAccess==VisitAccessTypeA::kElements, "kElementsPerAccess mismatches with Visitor A");
static_assert(kElementsPerAccess==VisitAccessTypeB::kElements, "kElementsPerAccess mismatches with Visitor B");
/// SMEM buffer class required in the epilogue visitor
struct SharedStorage {
typename VisitorA::SharedStorage storage_a;
typename VisitorB::SharedStorage storage_b;
CUTLASS_HOST_DEVICE
SharedStorage() {}
};
/// Host-constructable Arguments structure
struct Arguments {
typename BinaryOp::Arguments binary_arg;
typename VisitorA::Arguments visitor_a_arg; ///< Argument type for visitor_a
typename VisitorB::Arguments visitor_b_arg; ///< Argument type for visitor_b
//
// Methods
//
CUTLASS_HOST_DEVICE
Arguments():binary_arg() { }
CUTLASS_HOST_DEVICE
Arguments(
typename BinaryOp::Arguments binary_arg,
typename VisitorA::Arguments visitor_a_arg,
typename VisitorB::Arguments visitor_b_arg
):
binary_arg(binary_arg),
visitor_a_arg(visitor_a_arg),
visitor_b_arg(visitor_b_arg)
{ }
};
/// Parameter structure
struct Params {
typename BinaryOp::Params binary_param;
typename VisitorA::Params visitor_a_param; ///< Argument type for visitor_a
typename VisitorB::Params visitor_b_param; ///< Argument type for visitor_b
//
// Methods
//
CUTLASS_HOST_DEVICE
Params() { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
binary_param(args.binary_arg),
visitor_a_param(args.visitor_a_arg),
visitor_b_param(args.visitor_b_arg)
{ }
};
private:
//
// Data members
//
BinaryOp binary_op;
VisitorA visitor_a_op;
VisitorB visitor_b_op;
public:
/// Constructs the function object
CUTLASS_HOST_DEVICE
VisitorOpBinary(
Params const &params,
SharedStorage &shared_storage,
int thread_idx,
MatrixCoord threadblock_offset,
MatrixCoord problem_size
):
binary_op(params.binary_param),
visitor_a_op(params.visitor_a_param, shared_storage.storage_a, thread_idx, threadblock_offset, problem_size),
visitor_b_op(params.visitor_b_param, shared_storage.storage_b, thread_idx, threadblock_offset, problem_size)
{ }
CUTLASS_DEVICE
void begin_epilogue() {
visitor_a_op.begin_epilogue();
visitor_b_op.begin_epilogue();
}
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
visitor_a_op.set_batch_index(batch_idx);
visitor_b_op.set_batch_index(batch_idx);
}
CUTLASS_DEVICE
void begin_step(int step_idx) {
visitor_a_op.begin_step(step_idx);
visitor_b_op.begin_step(step_idx);
}
CUTLASS_DEVICE
void begin_row(int row_idx) {
visitor_a_op.begin_row(row_idx);
visitor_b_op.begin_row(row_idx);
}
CUTLASS_DEVICE
VisitAccessType visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum
) {
/// Get result from visitor A and visitor B
VisitAccessTypeA result_A = visitor_a_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
VisitAccessTypeB result_B = visitor_b_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
/// Type conversion
NumericArrayConverter<ElementCompute, ElementA, kElementsPerAccess> source_converter_A;
NumericArrayConverter<ElementCompute, ElementB, kElementsPerAccess> source_converter_B;
return binary_op(
source_converter_A(result_A),
source_converter_B(result_B)
);
}
CUTLASS_DEVICE
void end_row(int row_idx) {
visitor_a_op.end_row(row_idx);
visitor_b_op.end_row(row_idx);
}
CUTLASS_DEVICE
void end_step(int step_idx) {
visitor_a_op.end_step(step_idx);
visitor_b_op.end_step(step_idx);
}
CUTLASS_DEVICE
void end_epilogue() {
visitor_a_op.end_epilogue();
visitor_b_op.end_epilogue();
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,250 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor Op with broadcasting vector to all columns
*/
#pragma once
#include "cutlass/cutlass.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue Visitor operator for the following computation:
///
/// ElementVector T[i][j] <- device-memory Td[i]
///
/// It can only be a leaf node in the epilogue tree
template <
typename ElementAccumulator_, ///< Data type of the Accumulator
typename ElementFragment_, ///< Data type used to cache vector in register
typename InputTileIterator_ ///< Tile iterator type to read the broadcasted tensor
>
class VisitorOpColumnBroadcast {
public:
using InputTileIterator = InputTileIterator_;
static int const kElementsPerAccess = InputTileIterator::kElementsPerAccess;
using ElementAccumulator = ElementAccumulator_;
using ElementVector = typename InputTileIterator::Element;
using ElementFragment = ElementFragment_;
using VisitAccessType = Array<ElementFragment, kElementsPerAccess>;
/// Thread map used by input tile iterators
using ThreadMap = typename InputTileIterator::ThreadMap;
/// Fragment object used to store the broadcast values
using BroadcastFragment = Array<
ElementFragment, kElementsPerAccess>;
/// Fragment type of accumulator
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
/// Used for the broadcast
struct BroadcastDetail {
/// Number of threads per warp
static int const kWarpSize = 32;
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
/// Number of distinct scalar column indices handled by each thread
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
/// Number of distinct scalar row indices handled by each thread
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
/// Number of threads per threadblock
static int const kThreadCount = ThreadMap::kThreads;
/// Number of distinct threads per row of output tile
static int const kThreadsPerRow = (InputTileIterator::Shape::kN / kColumnsPerThread);
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock.
static int const kThreadRows = kThreadCount / kThreadsPerRow;
// /// Number of iterations (accesses) the threadblock takes to reduce a row
// static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount);
};
// using ComputeFragmentType = Array<ElementVector, BroadcastDetail::kElementsPerAccess>;
struct SharedStorage {
CUTLASS_HOST_DEVICE
SharedStorage() { }
};
/// Host-constructable Argument structure
struct Arguments {
ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand
int64_t batch_stride;
/// Methods
CUTLASS_HOST_DEVICE
Arguments():
broadcast_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Arguments(
ElementVector *broadcast_ptr,
int64_t batch_stride
):
broadcast_ptr(broadcast_ptr),
batch_stride(batch_stride) { }
};
/// Param structure
struct Params {
ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand
int64_t batch_stride;
/// Method
CUTLASS_HOST_DEVICE
Params():
broadcast_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
broadcast_ptr(args.broadcast_ptr),
batch_stride(args.batch_stride) { }
};
private:
ElementVector *broadcast_ptr;
BroadcastFragment broadcast_fragment; ///< Array holds the loaded broadcast fragment
MatrixCoord threadblock_offset_;
int thread_idx_;
MatrixCoord problem_size;
int thread_start_row_;
int state_[3];
int thread_offset_row_;
int64_t batch_stride_;
public:
/// Constructs the function object
CUTLASS_HOST_DEVICE
VisitorOpColumnBroadcast(
Params const &params,
SharedStorage &shared_storage,
int thread_idx,
MatrixCoord threadblock_offset,
MatrixCoord problem_size
):
broadcast_ptr(params.broadcast_ptr),
threadblock_offset_(threadblock_offset),
thread_idx_(thread_idx),
problem_size(problem_size),
thread_start_row_(ThreadMap::initial_offset(thread_idx).row() + threadblock_offset.row()),
batch_stride_(params.batch_stride)
{
state_[0] = state_[1] = state_[2] = 0;
}
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
broadcast_ptr += batch_idx * batch_stride_;
}
CUTLASS_DEVICE
void begin_epilogue() { }
CUTLASS_DEVICE
void begin_step(int step_idx) {}
CUTLASS_DEVICE
void begin_row(int row_idx) {}
CUTLASS_DEVICE
VisitAccessType visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum
) {
// get pointer
thread_offset_row_ = thread_start_row_ + ThreadMap::iteration_offset(frag_idx).row();
ElementFragment broadcast_data = ElementFragment(*(broadcast_ptr + thread_offset_row_));
broadcast_fragment.fill(broadcast_data);
return broadcast_fragment;
}
CUTLASS_DEVICE
void end_row(int row_idx) { }
CUTLASS_DEVICE
void end_step(int step_idx) {
// run operator ++
++state_[0];
thread_start_row_ += ThreadMap::Shape::kRow;
if (state_[0] == ThreadMap::Count::kRow) {
state_[0] = 0;
++state_[1];
thread_start_row_ += (ThreadMap::Shape::kGroup - 1) *
ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
if (state_[1] == ThreadMap::Count::kGroup) {
state_[1] = 0;
++state_[2];
thread_start_row_ += ThreadMap::Count::kGroup *
ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;
if (state_[2] == ThreadMap::Count::kCluster) {
state_[2] = 0;
}
}
}
}
CUTLASS_DEVICE
void end_epilogue() { }
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,341 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor Op with reduction over columns in CTA
*/
#pragma once
#include "cutlass/cutlass.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue Visitor operator for the following computation:
///
/// ElementReductionAccumulator R[j] = \sum_i ElementReductionAccumulator(T[i][j])
/// device memory <- ElementReduction(R[j])
///
template <
typename ThreadblockShape_, /// Threadblock shape
typename ElementAccumulator_, ///< Data type of the Accumulator
typename ElementReduction_, ///< Data type of the output reduction in device memory
typename ElementReductionAccumulator_ , ///< Data type to accumulate reduction in smem and register
typename OutputTileIterator_, ///< Tile Iterator type
typename Visitor_ ///< preceding visitor op
>
class VisitorOpColumnReduction {
public:
using ElementAccumulator = ElementAccumulator_;
using ElementReductionAccumulator = ElementReductionAccumulator_;
using ElementReduction = ElementReduction_;
using OutputTileIterator = OutputTileIterator_;
using ThreadblockShape = ThreadblockShape_;
using Visitor = Visitor_;
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
using ReductionOp = cutlass::plus<Array<ElementReductionAccumulator, kElementsPerAccess>>;
using ReductionOpScalar = cutlass::plus<ElementReductionAccumulator>;
using ElementOutput = typename OutputTileIterator::Element;
/// Fragment type returned from Visitor
using VisitAccessTypeVisitor = typename Visitor::VisitAccessType;
using ElementVisitor = typename VisitAccessTypeVisitor::Element;
using VisitAccessType = VisitAccessTypeVisitor;
/// Fragment type of accumulator
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
/// Fragment type of reduction
using ReductionAccumulatorAccessType = Array<ElementReductionAccumulator, kElementsPerAccess>;
/// Thread map used by output tile iterators
using ThreadMap = typename OutputTileIterator::ThreadMap;
/// Used for the reduction
struct ReductionDetail {
/// Number of threads per warp
static int const kWarpSize = 32;
/// Number of distinct scalar column indices handled by each thread
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
/// Number of distinct scalar row indices handled by each thread
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
/// Number of threads per threadblock
static int const kThreadCount = ThreadMap::kThreads;
/// Number of distinct threads per row of output tile
static int const kThreadsPerRow = ThreadblockShape::kN / kColumnsPerThread;
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock
static int const kThreadRows = kThreadCount / kThreadsPerRow;
/// Number of iterations (accesses) the threadblock takes to reduce a row
static int const kThreadAccessesPerRow = const_max(1, (ThreadblockShape::kN + kThreadCount - 1) / kThreadCount);
using StorageShape = MatrixShape<
kThreadRows,
ThreadblockShape::kN
>;
};
using ReductionFragment = Array<ElementReductionAccumulator, ReductionDetail::kColumnsPerThread>;
/// Shared storage
struct SharedStorage {
typename Visitor::SharedStorage storage_visitor;
AlignedArray<ElementReductionAccumulator, ReductionDetail::StorageShape::kCount, 16> reduction;
CUTLASS_HOST_DEVICE
SharedStorage() {}
};
/// Host-constructable Argument structure
struct Arguments {
ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory
int64_t batch_stride;
typename Visitor::Arguments visitor_arg; ///< Argument type of visitor
/// Method
CUTLASS_HOST_DEVICE
Arguments(): reduction_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Arguments(
ElementReduction *reduction_ptr,
int64_t batch_stride,
typename Visitor::Arguments visitor_arg
):
reduction_ptr(reduction_ptr),
batch_stride(batch_stride),
visitor_arg(visitor_arg)
{ }
};
/// Param structure
struct Params {
ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory
int64_t batch_stride;
typename Visitor::Params visitor_param; ///< Argument type of visitor
/// Method
CUTLASS_HOST_DEVICE
Params(): reduction_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
reduction_ptr(args.reduction_ptr),
batch_stride(args.batch_stride),
visitor_param(args.visitor_arg)
{ }
};
private:
ElementReduction *reduction_output_ptr_; ///< Pointer to the reduction tensor in device memory
ElementReductionAccumulator *reduction_smem_ptr_; ///< Pointer to the partial reductions in shared memory
ReductionFragment reduction_fragment; ///< register fragments that hold the partial reduction
Visitor visitor_; ///< visitor
int thread_idx_;
MatrixCoord threadblock_offset;
MatrixCoord problem_size_;
int64_t batch_stride_;
public:
/// Constructs the function object
CUTLASS_HOST_DEVICE
VisitorOpColumnReduction(
Params const &params,
SharedStorage &shared_storage,
int thread_idx,
MatrixCoord threadblock_offset,
MatrixCoord problem_size
):
visitor_(params.visitor_param, shared_storage.storage_visitor,
thread_idx, threadblock_offset, problem_size),
reduction_smem_ptr_(shared_storage.reduction.data()),
reduction_output_ptr_(params.reduction_ptr),
thread_idx_(thread_idx),
threadblock_offset(threadblock_offset),
problem_size_(problem_size),
batch_stride_(params.batch_stride)
{ }
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
reduction_output_ptr_ += batch_idx * batch_stride_;
visitor_.set_batch_index(batch_idx);
}
CUTLASS_DEVICE
void begin_epilogue() {
visitor_.begin_epilogue();
// clear the reduction fragment
reduction_fragment.clear();
}
CUTLASS_DEVICE
void begin_step(int step_idx) {
visitor_.begin_step(step_idx);
}
CUTLASS_DEVICE
void begin_row(int row_idx) {
visitor_.begin_row(row_idx);
}
CUTLASS_DEVICE
VisitAccessType visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum
) {
/// Get result from visitor
VisitAccessTypeVisitor result = visitor_.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
NumericArrayConverter<ElementReductionAccumulator, ElementVisitor, kElementsPerAccess> reduction_converter;
ReductionOp reduction_op;
ReductionAccumulatorAccessType* reduction_fragment_ = reinterpret_cast<ReductionAccumulatorAccessType*>(&reduction_fragment);
reduction_fragment_[column_idx] = reduction_op(reduction_fragment_[column_idx], reduction_converter(result));
return result;
}
CUTLASS_DEVICE
void end_row(int row_idx) {
visitor_.end_row(row_idx);
}
CUTLASS_DEVICE
void end_step(int step_idx) {
visitor_.end_step(step_idx);
}
CUTLASS_DEVICE
void end_epilogue() {
visitor_.end_epilogue();
//
// Store the partially reduced value to SMEM
//
// Guard against uses of the existing SMEM tile
__syncthreads();
using AccessType = AlignedArray<ElementReductionAccumulator, ThreadMap::kElementsPerAccess>;
//
// Determine a compact thread arrangement to store to SMEM
//
MatrixCoord thread_offset(
thread_idx_ / ReductionDetail::kThreadsPerRow,
(thread_idx_ % ReductionDetail::kThreadsPerRow) * ThreadMap::kElementsPerAccess
);
//
// Each thread store its fragment to a SMEM
//
AccessType *aligned_reduction_ptr = reinterpret_cast<AccessType *>(
&reduction_smem_ptr_[thread_offset.row() * ThreadblockShape::kN + thread_offset.column()]
);
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(
&reduction_fragment
);
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
int col_idx = column * ThreadMap::Delta::kColumn / ThreadMap::kElementsPerAccess;
aligned_reduction_ptr[col_idx] = frag_ptr[column];
}
__syncthreads();
//
// Now, threads are assigned several columns of the output. The fetch over all rows from
// the compacted SMEM tile and perform a reduction.
//
NumericConverter<ElementReduction, ElementReductionAccumulator> output_converter;
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < ReductionDetail::kThreadAccessesPerRow; ++j) {
int column_idx = thread_idx_ + j * ReductionDetail::kThreadCount;
ReductionOpScalar reduction_op;
ElementReductionAccumulator reduction_element = ElementReductionAccumulator();
int output_column_idx = threadblock_offset.column() + column_idx;
if (column_idx < ThreadblockShape::kN && output_column_idx < problem_size_.column()) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ReductionDetail::kThreadRows; ++row) {
if (row) {
auto frag = reduction_smem_ptr_[row * ThreadblockShape::kN + column_idx];
reduction_element = reduction_op(reduction_element, frag);
}
else {
reduction_element = reduction_smem_ptr_[column_idx];
}
}
// Store
reduction_output_ptr_[column_idx + threadblock_offset.column() + threadblock_offset.row() / ThreadblockShape::kM * problem_size_.column()] = output_converter(reduction_element);
}
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,266 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor Op with Linear Combination
*/
#pragma once
#include "cutlass/cutlass.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue Visitor operator for the following computation:
///
/// ElementCompute alpha;
/// ElementCompute beta;
/// ElementCompute C = BinaryOp(alpha * ElementCompute(Visitor_A), beta * ElementCompute(Visitor_B)
/// Return C;
///
template <
typename ElementAccumulator_, ///< Data type of the Accumulator
typename ElementCompute_, ///< Data type used to compute linear combination
int kElementsPerAccess_, ///< Number of elements computed per operation
typename VisitorA_, ///< Child node A
typename VisitorB_ ///< Child node B
>
class VisitorOpLinearCombination{
public:
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
static int const kElementsPerAccess = kElementsPerAccess_;
using VisitorA = VisitorA_;
using VisitorB = VisitorB_;
/// Fragment type returned from VisitorA.visit
using VisitAccessTypeA = typename VisitorA::VisitAccessType;
using ElementA = typename VisitAccessTypeA::Element;
/// Fragment type returned from VisitorB.visit
using VisitAccessTypeB = typename VisitorB::VisitAccessType;
using ElementB = typename VisitAccessTypeB::Element;
/// Fragment type returned by this visitor
using VisitAccessType = Array<ElementCompute, kElementsPerAccess>;
/// Fragment type of accumulator
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
/// Combination Op
using CombinationOp = cutlass::plus<VisitAccessType>;
static_assert(kElementsPerAccess==VisitAccessTypeA::kElements, "kElementsPerAccess mismatches with Visitor A");
static_assert(kElementsPerAccess==VisitAccessTypeB::kElements, "kElementsPerAccess mismatches with Visitor B");
/// SMEM buffer class required in the epilogue visitor
struct SharedStorage {
typename VisitorA::SharedStorage storage_a;
typename VisitorB::SharedStorage storage_b;
CUTLASS_HOST_DEVICE
SharedStorage() {}
};
/// Host-constructable Arguments structure
struct Arguments {
ElementCompute alpha; ///< scales accumulators
ElementCompute beta; ///< scales source tensor
typename VisitorA::Arguments visitor_a_arg; ///< Argument type for visitor_a
typename VisitorB::Arguments visitor_b_arg; ///< Argument type for visitor_b
//
// Methods
//
CUTLASS_HOST_DEVICE
Arguments():
alpha(ElementCompute(1)),
beta(ElementCompute(0))
{ }
CUTLASS_HOST_DEVICE
Arguments(
ElementCompute alpha,
ElementCompute beta,
typename VisitorA::Arguments visitor_a_arg,
typename VisitorB::Arguments visitor_b_arg
):
alpha(alpha),
beta(beta),
visitor_a_arg(visitor_a_arg),
visitor_b_arg(visitor_b_arg)
{ }
};
/// Parameter structure
struct Params {
ElementCompute alpha; ///< scales accumulators
ElementCompute beta; ///< scales source tensor
typename VisitorA::Params visitor_a_param; ///< Argument type for visitor_a
typename VisitorB::Params visitor_b_param; ///< Argument type for visitor_b
//
// Methods
//
CUTLASS_HOST_DEVICE
Params() { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
alpha(args.alpha),
beta(args.beta),
visitor_a_param(args.visitor_a_arg),
visitor_b_param(args.visitor_b_arg)
{ }
};
private:
//
// Data members
//
ElementCompute alpha_;
ElementCompute beta_;
VisitorA visitor_a_op;
VisitorB visitor_b_op;
public:
/// Constructs the function object
CUTLASS_HOST_DEVICE
VisitorOpLinearCombination(
Params const &params,
SharedStorage &shared_storage,
int thread_idx,
MatrixCoord threadblock_offset,
MatrixCoord problem_size
):
alpha_(params.alpha),
beta_(params.beta),
visitor_a_op(params.visitor_a_param, shared_storage.storage_a, thread_idx, threadblock_offset, problem_size),
visitor_b_op(params.visitor_b_param, shared_storage.storage_b, thread_idx, threadblock_offset, problem_size)
{ }
CUTLASS_DEVICE
void begin_epilogue() {
if (alpha_ != ElementCompute(0)) visitor_a_op.begin_epilogue();
if (beta_ != ElementCompute(0)) visitor_b_op.begin_epilogue();
}
CUTLASS_DEVICE
void begin_step(int step_idx) {
if (alpha_ != ElementCompute(0)) visitor_a_op.begin_step(step_idx);
if (beta_ != ElementCompute(0)) visitor_b_op.begin_step(step_idx);
}
CUTLASS_DEVICE
void begin_row(int row_idx) {
if (alpha_ != ElementCompute(0)) visitor_a_op.begin_row(row_idx);
if (beta_ != ElementCompute(0)) visitor_b_op.begin_row(row_idx);
}
CUTLASS_DEVICE
VisitAccessType visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum
) {
/// Get result from visitor A and visitor B
VisitAccessTypeA result_A;
VisitAccessTypeB result_B;
if (alpha_ != ElementCompute(0)) {
result_A = visitor_a_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
} else {
// Fill the result A with zeros
result_A.clear();
}
if (beta_ != ElementCompute(0)) {
result_B = visitor_b_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
} else {
// Fill the result B with zeros
result_B.clear();
}
/// Type conversion
NumericArrayConverter<ElementCompute, ElementA, kElementsPerAccess> source_converter_A;
NumericArrayConverter<ElementCompute, ElementB, kElementsPerAccess> source_converter_B;
CombinationOp combination_op;
cutlass::multiplies<VisitAccessType> multiply_op;
return combination_op(
multiply_op(alpha_, source_converter_A(result_A)),
multiply_op(beta_, source_converter_B(result_B))
);
}
CUTLASS_DEVICE
void end_row(int row_idx) {
if (alpha_ != ElementCompute(0)) visitor_a_op.end_row(row_idx);
if (beta_ != ElementCompute(0)) visitor_b_op.end_row(row_idx);
}
CUTLASS_DEVICE
void end_step(int step_idx) {
if (alpha_ != ElementCompute(0)) visitor_a_op.end_step(step_idx);
if (beta_ != ElementCompute(0)) visitor_b_op.end_step(step_idx);
}
CUTLASS_DEVICE
void end_epilogue() {
if (alpha_ != ElementCompute(0)) visitor_a_op.end_epilogue();
if (beta_ != ElementCompute(0)) visitor_b_op.end_epilogue();
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,258 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor Op with broadcasting vector to all rows
*/
#pragma once
#include "cutlass/cutlass.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue Visitor operator for the following computation:
///
/// ElementVector T[i][j] <- device-memory Td[j]
///
/// It can only be a leaf node in the epilogue tree
template <
typename ElementAccumulator_, ///< Data type of the Accumulator
typename ElementFragment_, ///< Data type used to cache vector in register
typename InputTileIterator_ ///< Tile iterator type to read the broadcasted tensor
>
class VisitorOpRowBroadcast {
public:
using InputTileIterator = InputTileIterator_;
static int const kElementsPerAccess = InputTileIterator::kElementsPerAccess;
using ElementAccumulator = ElementAccumulator_;
using ElementVector = typename InputTileIterator::Element;
using ElementFragment = ElementFragment_;
using VisitAccessType = Array<ElementFragment, kElementsPerAccess>;
/// Thread map used by input tile iterators
using ThreadMap = typename InputTileIterator::ThreadMap;
/// Fragment object used to store the broadcast values
using BroadcastFragment = Array<
ElementFragment,
ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>;
/// Fragment type of accumulator
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
/// Used for the broadcast
struct BroadcastDetail {
/// Number of threads per warp
static int const kWarpSize = 32;
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
/// Number of distinct scalar column indices handled by each thread
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
/// Number of distinct scalar row indices handled by each thread
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
/// Number of threads per threadblock
static int const kThreadCount = ThreadMap::kThreads;
/// Number of distinct threads per row of output tile
static int const kThreadsPerRow = (InputTileIterator::Shape::kN / kColumnsPerThread);
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock.
static int const kThreadRows = kThreadCount / kThreadsPerRow;
// /// Number of iterations (accesses) the threadblock takes to reduce a row
// static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount);
};
// using ComputeFragmentType = Array<ElementVector, BroadcastDetail::kElementsPerAccess>;
struct SharedStorage {
CUTLASS_HOST_DEVICE
SharedStorage() { }
};
/// Host-constructable Argument structure
struct Arguments {
ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand
int64_t batch_stride;
/// Methods
CUTLASS_HOST_DEVICE
Arguments():
broadcast_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Arguments(
ElementVector *broadcast_ptr,
int64_t batch_stride
):
broadcast_ptr(broadcast_ptr),
batch_stride(batch_stride) { }
};
/// Param structure
struct Params {
ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand
int64_t batch_stride;
/// Method
CUTLASS_HOST_DEVICE
Params():
broadcast_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
broadcast_ptr(args.broadcast_ptr),
batch_stride(args.batch_stride) { }
};
private:
ElementVector *broadcast_ptr;
BroadcastFragment broadcast_fragment; ///< Array holds the loaded broadcast fragment
MatrixCoord threadblock_offset_;
int thread_idx_;
MatrixCoord problem_size;
int64_t batch_stride_;
public:
/// Constructs the function object
CUTLASS_HOST_DEVICE
VisitorOpRowBroadcast(
Params const &params,
SharedStorage &shared_storage,
int thread_idx,
MatrixCoord threadblock_offset,
MatrixCoord problem_size
):
broadcast_ptr(params.broadcast_ptr + threadblock_offset.column()),
threadblock_offset_(threadblock_offset),
thread_idx_(thread_idx),
problem_size(problem_size),
batch_stride_(params.batch_stride) { }
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
broadcast_ptr += batch_idx * batch_stride_;
}
CUTLASS_DEVICE
void begin_epilogue() {
// load broadcast fragment
load_broadcast_fragment_();
}
CUTLASS_DEVICE
void begin_step(int step_idx) {}
CUTLASS_DEVICE
void begin_row(int row_idx) {}
CUTLASS_DEVICE
VisitAccessType visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum
) {
VisitAccessType* broadcast_fragment_ = reinterpret_cast<VisitAccessType*>(&broadcast_fragment);
return broadcast_fragment_[column_idx];
}
CUTLASS_DEVICE
void end_row(int row_idx) { }
CUTLASS_DEVICE
void end_step(int step_idx) { }
CUTLASS_DEVICE
void end_epilogue() { }
private:
CUTLASS_DEVICE
void load_broadcast_fragment_() {
broadcast_fragment.clear();
// If no pointer is supplied, set with all zeros and avoid memory accesses
if (!broadcast_ptr) {
return;
}
int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column();
int thread_column_idx = threadblock_offset_.column() + thread_initial_column;
broadcast_ptr += thread_initial_column;
NumericArrayConverter<ElementFragment, ElementVector, BroadcastDetail::kElementsPerAccess> converter;
using AccessType = AlignedArray<ElementVector, BroadcastDetail::kElementsPerAccess>;
using AccessFragmentType = Array<ElementFragment, BroadcastDetail::kElementsPerAccess>;
AccessFragmentType *frag_ptr = reinterpret_cast<AccessFragmentType *>(&broadcast_fragment);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) {
AccessType loaded;
loaded.clear();
if (thread_column_idx < problem_size.column()) {
loaded = *reinterpret_cast<AccessType const *>(broadcast_ptr);
}
AccessFragmentType cvt = converter(loaded);
frag_ptr[j] = cvt;
thread_column_idx += ThreadMap::Delta::kColumn;
broadcast_ptr += ThreadMap::Delta::kColumn;
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,319 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor Op with reduction over rows in CTA
*/
#pragma once
#include "cutlass/cutlass.h"
#include "stdio.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue Visitor operator for the following computation:
///
/// ElementReductionAccumulator R[i] = \sum_i ElementReductionAccumulator(T[i][j])
/// device memory <- ElementReduction(R[i])
///
template <
typename ThreadblockShape_, /// Threadblock shape
typename ElementAccumulator_, ///< Data type of the Accumulator
typename ElementReduction_, ///< Data type of the output reduction in device memory
typename ElementReductionAccumulator_ , ///< Data type to accumulate reduction in smem and register
typename OutputTileIterator_, ///< Tile Iterator type
typename Visitor_ ///< preceding visitor op
>
class VisitorOpRowReduction {
public:
using ElementAccumulator = ElementAccumulator_;
using ElementReductionAccumulator = ElementReductionAccumulator_;
using ElementReduction = ElementReduction_;
using OutputTileIterator = OutputTileIterator_;
using ThreadblockShape = ThreadblockShape_;
using Visitor = Visitor_;
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
using ReductionOp = cutlass::plus<Array<ElementReductionAccumulator, kElementsPerAccess>>;
using ReductionOpScalar = cutlass::plus<ElementReductionAccumulator>;
using ElementOutput = typename OutputTileIterator::Element;
/// Fragment type returned from Visitor
using VisitAccessTypeVisitor = typename Visitor::VisitAccessType;
using ElementVisitor = typename VisitAccessTypeVisitor::Element;
using VisitAccessType = VisitAccessTypeVisitor;
/// Fragment type of accumulator
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
/// Fragment type of reduction
using ReductionAccumulatorAccessType = Array<ElementReductionAccumulator, kElementsPerAccess>;
/// Thread map used by output tile iterators
using ThreadMap = typename OutputTileIterator::ThreadMap;
/// Used for the reduction
struct ReductionDetail {
/// Number of threads per warp
static int const kWarpSize = 32;
/// Number of distinct scalar column indices handled by each thread
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
/// Number of distinct scalar row indices handled by each thread
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
/// Number of threads per threadblock
static int const kThreadCount = ThreadMap::kThreads;
/// Number of distinct threads per row of output tile
static int const kThreadsPerRow = ThreadblockShape::kN / kColumnsPerThread;
/// Half number of threads per row used for cross-thread reduction
static int const kHalfThreadsPerRow = (kThreadsPerRow >> 1);
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock
static int const kThreadRows = kThreadCount / kThreadsPerRow;
};
/// Shared storage
struct SharedStorage {
typename Visitor::SharedStorage storage_visitor;
CUTLASS_HOST_DEVICE
SharedStorage() { }
};
/// Host-constructable Argument structure
struct Arguments {
ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory
int64_t batch_stride;
typename Visitor::Arguments visitor_arg; ///< Argument type of visitor
/// Method
CUTLASS_HOST_DEVICE
Arguments(): reduction_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Arguments(
ElementReduction *reduction_ptr,
int64_t batch_stride,
typename Visitor::Arguments visitor_arg
):
reduction_ptr(reduction_ptr),
batch_stride(batch_stride),
visitor_arg(visitor_arg)
{ }
};
/// Param structure
struct Params {
ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory
int64_t batch_stride;
typename Visitor::Params visitor_param; ///< Argument type of visitor
/// Method
CUTLASS_HOST_DEVICE
Params(): reduction_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
reduction_ptr(args.reduction_ptr),
batch_stride(args.batch_stride),
visitor_param(args.visitor_arg)
{ }
};
private:
ElementReduction *reduction_output_ptr_; ///< Pointer to the reduction tensor in device memory
ElementReductionAccumulator reduction_accum;
Visitor visitor_; ///< visitor
int thread_idx_;
MatrixCoord threadblock_offset;
MatrixCoord problem_size_;
int thread_start_row_; /// used to identify
int state_[3]; /// used to track row iterator
int thread_offset_row_;
int64_t batch_stride_;
public:
/// Constructs the function object
CUTLASS_HOST_DEVICE
VisitorOpRowReduction(
Params const &params,
SharedStorage &shared_storage,
int thread_idx,
MatrixCoord threadblock_offset,
MatrixCoord problem_size
):
visitor_(params.visitor_param, shared_storage.storage_visitor,
thread_idx, threadblock_offset, problem_size),
reduction_output_ptr_(params.reduction_ptr),
thread_idx_(thread_idx),
threadblock_offset(threadblock_offset),
problem_size_(problem_size),
thread_start_row_(ThreadMap::initial_offset(thread_idx).row() + threadblock_offset.row()),
batch_stride_(params.batch_stride)
{
state_[0] = state_[1] = state_[2] = 0;
}
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
reduction_output_ptr_ += batch_idx * batch_stride_;
visitor_.set_batch_index(batch_idx);
}
CUTLASS_DEVICE
void begin_epilogue() {
visitor_.begin_epilogue();
}
CUTLASS_DEVICE
void begin_step(int step_idx) {
visitor_.begin_step(step_idx);
}
CUTLASS_DEVICE
void begin_row(int row_idx) {
visitor_.begin_row(row_idx);
reduction_accum = ElementReductionAccumulator(0);
}
CUTLASS_DEVICE
VisitAccessType visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum
) {
/// Get result from visitor
VisitAccessTypeVisitor result = visitor_.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
thread_offset_row_ = thread_start_row_ + ThreadMap::iteration_offset(frag_idx).row();
ReductionOpScalar reduction_op;
ElementReductionAccumulator reduction_accum_ = reduction(result);
// After performing the in-thread reduction, we then perform cross-thread / in-warp reduction
CUTLASS_PRAGMA_UNROLL
for (int i = ReductionDetail::kHalfThreadsPerRow; i > 0; i >>= 1) {
reduction_accum_ = reduction_op(reduction_accum_, __shfl_xor_sync(0xFFFFFFFF, reduction_accum_, i));
}
reduction_accum = reduction_op(reduction_accum, reduction_accum_);
return result;
}
CUTLASS_DEVICE
void end_row(int row_idx) {
visitor_.end_row(row_idx);
NumericConverter<ElementReduction, ElementReductionAccumulator> output_converter;
bool is_write_thread = (thread_offset_row_ < problem_size_.row() && (thread_idx_ % ReductionDetail::kThreadsPerRow) == 0);
int row_offset = thread_offset_row_ + threadblock_offset.column() / ThreadblockShape::kN * problem_size_.row();
ElementReduction *curr_ptr_reduction = reduction_output_ptr_ + row_offset;
arch::global_store<ElementReduction, sizeof(ElementReduction)>(
output_converter(reduction_accum),
(void *)curr_ptr_reduction,
is_write_thread);
}
CUTLASS_DEVICE
void end_step(int step_idx) {
visitor_.end_step(step_idx);
// run operator ++
++state_[0];
thread_start_row_ += ThreadMap::Shape::kRow;
if (state_[0] == ThreadMap::Count::kRow) {
state_[0] = 0;
++state_[1];
thread_start_row_ += (ThreadMap::Shape::kGroup - 1) *
ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
if (state_[1] == ThreadMap::Count::kGroup) {
state_[1] = 0;
++state_[2];
thread_start_row_ += ThreadMap::Count::kGroup *
ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;
if (state_[2] == ThreadMap::Count::kCluster) {
state_[2] = 0;
}
}
}
}
CUTLASS_DEVICE
void end_epilogue() {
visitor_.end_epilogue();
}
private:
CUTLASS_DEVICE
ElementReductionAccumulator reduction(VisitAccessTypeVisitor const& result) {
ElementReductionAccumulator sum_ = ElementReductionAccumulator(0);
ReductionOpScalar reduction_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < VisitAccessTypeVisitor::kElements; ++i) {
sum_ = reduction_op(sum_, result[i]);
}
return sum_;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,188 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor Op with Tensor Output
*/
#pragma once
#include "cutlass/cutlass.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue Visitor operator for the following computation:
///
/// ElementInput C <- device memory
///
/// It can only be a leaf node in the epilogue tree
template <
typename ElementAccumulator_, ///< Data type of the Accumulator
typename InputTileIterator_ ///< Tile iterator type to read the tensor
>
class VisitorOpTensorInput {
public:
using ElementAccumulator = ElementAccumulator_;
using InputTileIterator = InputTileIterator_;
static int const kElementsPerAccess = InputTileIterator::kElementsPerAccess;
using ElementInput = typename InputTileIterator::Element;
using VisitAccessType = Array<ElementInput, kElementsPerAccess>;
/// Fragment type of accumulator
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
struct SharedStorage {
CUTLASS_HOST_DEVICE
SharedStorage() { }
};
/// Host-constructable Argument structure
struct Arguments {
ElementInput *input_ptr; ///< Pointer to the input tensor in device memory
int ldt; ///< Leading dimension of the input tensor operand
int64_t batch_stride; ///< batch stride for batched GEMM
/// Methods
CUTLASS_HOST_DEVICE
Arguments(): input_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Arguments(
ElementInput *input_ptr,
int ldt, int64_t batch_stride
):
input_ptr(input_ptr),
ldt(ldt),
batch_stride(batch_stride)
{ }
};
/// Param structure
struct Params {
typename InputTileIterator::Params params_input;
ElementInput *input_ptr;
int64_t batch_stride;
/// Method
CUTLASS_HOST_DEVICE
Params():
input_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
params_input(args.ldt),
input_ptr(args.input_ptr),
batch_stride(args.batch_stride)
{ }
};
private:
InputTileIterator iterator_T_;
typename InputTileIterator::Fragment fragment_T_;
MatrixCoord problem_size;
int64_t batch_stride_;
public:
/// Constructs the function object
CUTLASS_HOST_DEVICE
VisitorOpTensorInput(
Params const &params,
SharedStorage &shared_storage,
int thread_idx,
MatrixCoord threadblock_offset,
MatrixCoord problem_size
):
iterator_T_(
InputTileIterator(
params.params_input,
params.input_ptr,
problem_size,
thread_idx,
threadblock_offset
)
),
problem_size(problem_size),
batch_stride_(params.batch_stride) { }
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
iterator_T_.add_pointer_offset(batch_idx * batch_stride_);
}
CUTLASS_DEVICE
void begin_epilogue() { }
CUTLASS_DEVICE
void begin_step(int step_idx) {
fragment_T_.clear();
iterator_T_.load(fragment_T_);
++iterator_T_;
}
CUTLASS_DEVICE
void begin_row(int row_idx) { }
CUTLASS_DEVICE
VisitAccessType visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum
) {
VisitAccessType source = reinterpret_cast<VisitAccessType *>(&fragment_T_)[frag_idx];
return source;
}
CUTLASS_DEVICE
void end_row(int row_idx) { }
CUTLASS_DEVICE
void end_step(int step_idx) { }
CUTLASS_DEVICE
void end_epilogue() { }
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,240 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor Op with Tensor Output
*/
#pragma once
#include "cutlass/cutlass.h"
#include "stdio.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue Visitor operator for the following computation:
///
/// ElementOutput T = ElementOutput(Visitor)
/// T-> device memory
///
template <
typename ElementAccumulator_, ///< Data type of the Accumulator
typename OutputTileIterator_, ///< Tile iterator type to write the tensor
typename Visitor_ ///< Child visitor that produces the output tensor
>
class VisitorOpTensorOutput {
public:
using ElementAccumulator = ElementAccumulator_;
using OutputTileIterator = OutputTileIterator_;
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
using ElementOutput = typename OutputTileIterator::Element;
using Visitor = Visitor_;
/// Fragment type returned from Visitor
using VisitAccessTypeVisitor = typename Visitor::VisitAccessType;
using ElementVisitor = typename VisitAccessTypeVisitor::Element;
using VisitAccessType = VisitAccessTypeVisitor;
/// Fragment type of accumulator
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
/// Fragment type of output
using OutputAccessType = Array<ElementOutput, kElementsPerAccess>;
static_assert(kElementsPerAccess==VisitAccessTypeVisitor::kElements, "kElementsPerAccess mismatches with Visitor");
struct SharedStorage {
typename Visitor::SharedStorage storage_visitor;
CUTLASS_HOST_DEVICE
SharedStorage() { }
};
/// Host-constructable Argument structure
struct Arguments {
ElementOutput *output_ptr; ///< Pointer to the output tensor in device memory
int ldt; ///< Leading dimension of the output tensor operand
int64_t batch_stride; ///< batch stride
typename Visitor::Arguments visitor_arg; ///< Argument type of visitor
/// Methods
CUTLASS_HOST_DEVICE
Arguments(): output_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Arguments(
ElementOutput *output_ptr,
int ldt,
int64_t batch_stride,
typename Visitor::Arguments visitor_arg
):
output_ptr(output_ptr),
ldt(ldt),
batch_stride(batch_stride),
visitor_arg(visitor_arg)
{ }
};
/// Param structure
struct Params {
typename OutputTileIterator::Params params_output;
ElementOutput *output_ptr;
int64_t batch_stride;
typename Visitor::Params visitor_param;
/// Method
CUTLASS_HOST_DEVICE
Params():
output_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
params_output(args.ldt),
output_ptr(args.output_ptr),
batch_stride(args.batch_stride),
visitor_param(args.visitor_arg)
{ }
};
private:
OutputTileIterator iterator_T_;
typename OutputTileIterator::Fragment fragment_T_;
MatrixCoord problem_size;
Visitor visitor_;
int64_t batch_stride_;
public:
/// Constructs the function object
CUTLASS_HOST_DEVICE
VisitorOpTensorOutput(
Params const &params,
SharedStorage &shared_storage,
int thread_idx,
MatrixCoord threadblock_offset,
MatrixCoord problem_size
):
visitor_(params.visitor_param, shared_storage.storage_visitor, thread_idx, threadblock_offset, problem_size),
iterator_T_(
OutputTileIterator(
params.params_output,
params.output_ptr,
problem_size,
thread_idx,
threadblock_offset
)
),
problem_size(problem_size),
batch_stride_(params.batch_stride) { }
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
iterator_T_.add_pointer_offset(batch_idx * batch_stride_);
visitor_.set_batch_index(batch_idx);
}
CUTLASS_DEVICE
void begin_epilogue() {
visitor_.begin_epilogue();
}
CUTLASS_DEVICE
void begin_step(int step_idx) {
fragment_T_.clear();
visitor_.begin_step(step_idx);
}
CUTLASS_DEVICE
void begin_row(int row_idx) {
visitor_.begin_row(row_idx);
}
CUTLASS_DEVICE
VisitAccessType visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum
) {
/// Get result from visitor
VisitAccessTypeVisitor result = visitor_.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
// Column guard
MatrixCoord thread_offset_ = iterator_T_.thread_start() + OutputTileIterator::ThreadMap::iteration_offset(frag_idx);
bool column_guard = (thread_offset_.column() < problem_size.column());
if (column_guard) {
NumericArrayConverter<ElementOutput, ElementVisitor, kElementsPerAccess> output_converter;
OutputAccessType &output = reinterpret_cast<OutputAccessType *>(&fragment_T_)[frag_idx];
output = output_converter(result);
}
return result;
}
CUTLASS_DEVICE
void end_row(int row_idx) {
visitor_.end_row(row_idx);
}
CUTLASS_DEVICE
void end_step(int step_idx) {
visitor_.end_step(step_idx);
iterator_T_.store(fragment_T_);
++iterator_T_;
}
CUTLASS_DEVICE
void end_epilogue() {
visitor_.end_epilogue();
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,226 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor Op with Unary operation
*/
#pragma once
#include "cutlass/cutlass.h"
#include "unary_ops.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue Visitor operator for the following computation:
///
/// ElementCompute alpha;
/// ElementCompute beta;
/// ElementCompute C = UnaryOp(ElementCompute(Visitor))
/// Return C;
///
template <
typename ElementAccumulator_, ///< Data type of the Accumulator
typename ElementCompute_, ///< Data type used to compute linear combination
int kElementsPerAccess_, ///< Number of elements computed per operation
typename Visitor_, ///< Child node
template<typename T, int N> typename UnaryOp_
>
class VisitorOpUnary{
public:
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
static int const kElementsPerAccess = kElementsPerAccess_;
using Visitor = Visitor_;
/// Fragment type returned from Visitor.visit
using VisitAccessTypeVisitor = typename Visitor::VisitAccessType;
using ElementVisit = typename VisitAccessTypeVisitor::Element;
/// Fragment type returned by this visitor
using VisitAccessType = Array<ElementCompute, kElementsPerAccess>;
/// Fragment type of accumulator
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
/// Combination Op
using UnaryOp = UnaryOp_<ElementCompute, kElementsPerAccess>;
static_assert(kElementsPerAccess==VisitAccessTypeVisitor::kElements, "kElementsPerAccess mismatches with Visitor");
/// SMEM buffer class required in the epilogue visitor
struct SharedStorage {
typename Visitor::SharedStorage storage_visitor;
CUTLASS_HOST_DEVICE
SharedStorage() {}
};
/// Host-constructable Arguments structure
struct Arguments {
typename UnaryOp::Arguments unary_arg;
typename Visitor::Arguments visitor_arg; ///< Argument type for visitor
//
// Methods
//
CUTLASS_HOST_DEVICE
Arguments():unary_arg() { }
CUTLASS_HOST_DEVICE
Arguments(
typename UnaryOp::Arguments unary_arg,
typename Visitor::Arguments visitor_arg
):
unary_arg(unary_arg),
visitor_arg(visitor_arg)
{ }
};
/// Parameter structure
struct Params {
typename UnaryOp::Params unary_param;
typename Visitor::Params visitor_param; ///< Argument type for visitor
//
// Methods
//
CUTLASS_HOST_DEVICE
Params():unary_param() { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
unary_param(args.unary_arg),
visitor_param(args.visitor_arg)
{ }
};
private:
//
// Data members
//
UnaryOp unary_op;
Visitor visitor_op;
public:
/// Constructs the function object
CUTLASS_HOST_DEVICE
VisitorOpUnary(
Params const &params,
SharedStorage &shared_storage,
int thread_idx,
MatrixCoord threadblock_offset,
MatrixCoord problem_size
):
unary_op(params.unary_param),
visitor_op(params.visitor_param, shared_storage.storage_visitor, thread_idx, threadblock_offset, problem_size)
{ }
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
visitor_op.set_batch_index(batch_idx);
}
CUTLASS_DEVICE
void begin_epilogue() {
if (unary_op.guard()) visitor_op.begin_epilogue();
}
CUTLASS_DEVICE
void begin_step(int step_idx) {
if (unary_op.guard()) visitor_op.begin_step(step_idx);
}
CUTLASS_DEVICE
void begin_row(int row_idx) {
if (unary_op.guard()) visitor_op.begin_row(row_idx);
}
CUTLASS_DEVICE
VisitAccessType visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum
) {
/// Get result from visitor A and visitor B
VisitAccessTypeVisitor result;
if (unary_op.guard()){
result = visitor_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
} else {
result.clear();
}
/// Type conversion
NumericArrayConverter<ElementCompute, ElementVisit, kElementsPerAccess> source_converter;
cutlass::multiplies<VisitAccessType> multiply_op;
return unary_op(source_converter(result));
}
CUTLASS_DEVICE
void end_row(int row_idx) {
if (unary_op.guard()) visitor_op.end_row(row_idx);
}
CUTLASS_DEVICE
void end_step(int step_idx) {
if (unary_op.guard()) visitor_op.end_step(step_idx);
}
CUTLASS_DEVICE
void end_epilogue() {
if (unary_op.guard()) visitor_op.end_epilogue();
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,480 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue visitor type used for partial computation of a layernorm operation
GemmLayernorm example = GEMM0 with partial reduction fused in epilogue (EpilogueVisitorLayerNorm)
+ lightweight full reduction kernel (ApplyFinalReduction)
+ GEMM1 with elementwise operations fused in mainloop (GemmLayernormMainloopFusion)
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass/cutlass.h"
#include "cutlass/arch/memory.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/gemm/kernel/default_gemm.h"
#include "cutlass/gemm/kernel/default_gemm_complex.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename ThreadblockShape_,
int ThreadCount,
typename OutputTileIterator_,
typename AccumulatorTile_,
typename ElementAccumulator_,
typename ElementVariance_,
typename ElementMean_,
typename ElementLayernormCompute_,
typename ElementwiseFunctor_,
bool IsShiftedVariance_ = false
>
class EpilogueVisitorLayerNorm {
public:
using ElementVariance = ElementVariance_;
using ElementMean = ElementMean_;
using ElementLayernormCompute = ElementLayernormCompute_;
using AccumulatorTile = AccumulatorTile_;
using ThreadblockShape = ThreadblockShape_;
static int const kThreadCount = ThreadCount;
using OutputTileIterator = OutputTileIterator_;
using ElementwiseFunctor = ElementwiseFunctor_;
static int const kIterations = OutputTileIterator::kIterations;
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
static int const kRowIterations = OutputTileIterator::ThreadMap::Iterations::kRow;
static int const kThreads = OutputTileIterator::ThreadMap::kThreads;
static bool const kIsShiftedVariance = IsShiftedVariance_;
using ElementOutput = typename OutputTileIterator::Element;
static int const kDeltaRow = OutputTileIterator::ThreadMap::Delta::kRow;
/// Array type used in Shift-K Layernorm
static int const kRowAccessCount = kIterations * kRowIterations;
using ConvertedShiftFragment = Array<ElementLayernormCompute, kRowAccessCount>;
// Conducts manual transpose externally (already supported) for column major
using LayoutOutput = cutlass::layout::RowMajor;
using ElementAccumulator = ElementAccumulator_;
using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
using LayernormFragment = Array<ElementLayernormCompute, kElementsPerAccess>;
using OutputVector = Array<ElementOutput, kElementsPerAccess>;
using TensorRefD = TensorRef<ElementOutput, LayoutOutput>;
static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth;
static int const kThreadsInColumn = kThreads / kThreadsPerRow;
static int const kHalfThreadsPerRow = (kThreadsPerRow >> 1);
/// Argument structure
struct Arguments {
typename ElementwiseFunctor::Params elementwise;
ElementVariance *ptr_Variance;
ElementMean *ptr_Mean;
ElementOutput *ptr_Shifted_K;
MatrixCoord extent;
//
// Methods
//
Arguments():
ptr_Variance(nullptr),
ptr_Mean(nullptr),
ptr_Shifted_K(nullptr)
{
}
Arguments(
typename ElementwiseFunctor::Params elementwise_,
ElementVariance *ptr_Variance,
ElementMean *ptr_Mean_,
ElementOutput *ptr_Shifted_K_ = nullptr,
MatrixCoord extent = MatrixCoord(0, 0)
):
elementwise(elementwise_),
ptr_Variance(ptr_Variance),
ptr_Mean(ptr_Mean_),
ptr_Shifted_K(ptr_Shifted_K_),
extent(extent)
{
}
};
struct Params {
typename ElementwiseFunctor::Params elementwise;
ElementVariance *ptr_Variance;
ElementMean *ptr_Mean;
ElementOutput *ptr_Shifted_K;
MatrixCoord extent;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params():
ptr_Variance(nullptr),
ptr_Mean(nullptr)
{
}
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
elementwise(args.elementwise),
ptr_Variance(args.ptr_Variance),
ptr_Mean(args.ptr_Mean),
ptr_Shifted_K(args.ptr_Shifted_K),
extent(args.extent)
{
}
};
/// Shared storage
struct SharedStorage {
};
private:
Params const & params_;
SharedStorage & shared_storage_;
MatrixCoord extent_;
ElementwiseFunctor elementwise_;
OutputTileIterator iterator_C_;
OutputTileIterator iterator_D_;
typename OutputTileIterator::Fragment fragment_C_;
typename OutputTileIterator::Fragment fragment_D_;
ElementAccumulator alpha_;
ElementAccumulator beta_;
ConvertedShiftFragment shift_k_frag_;
ElementLayernormCompute accum_sum_square_;
ElementLayernormCompute accum_sum_element_;
int thread_idx_;
MatrixCoord thread_offset_;
gemm::GemmCoord threadblock_tile_offset_;
public:
CUTLASS_DEVICE
EpilogueVisitorLayerNorm(
Params const &params, ///< Parameters routed to the epilogue
SharedStorage &shared_storage, ///< Shared storage needed by the functors here
MatrixCoord threadblock_offset,
gemm::GemmCoord threadblock_tile_offset,
int thread_idx,
OutputTileIterator destination_iterator, ///< Tile iterator for destination
OutputTileIterator source_iterator ///< Threadblock tile coordinate in GEMMM
):
params_(params),
shared_storage_(shared_storage),
elementwise_(params.elementwise),
extent_(params.extent),
iterator_C_(source_iterator),
iterator_D_(destination_iterator),
threadblock_tile_offset_(threadblock_tile_offset),
thread_idx_(thread_idx)
{
alpha_ = (params.elementwise.alpha_ptr ? *params.elementwise.alpha_ptr : params.elementwise.alpha);
beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta);
if (beta_ == ElementAccumulator()) {
iterator_C_.clear_mask();
}
}
/// Helper to indicate split-K behavior
CUTLASS_DEVICE
void set_k_partition(
int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
int split_k_slices) { ///< Total number of split-K slices
}
/// Called to set the batch index
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
}
/// Called at the start of the epilogue just before iterating over accumulator slices
CUTLASS_DEVICE
void begin_epilogue() {
// If shift-K feature is enabled, we load shift-k fragment
// at the very beginning of an epilogue
if (kIsShiftedVariance && params_.ptr_Shifted_K != nullptr) {
shift_k_frag_.clear();
int thread_offset_row_base = iterator_D_.thread_start_row();
CUTLASS_PRAGMA_UNROLL
for (int iter_idx = 0; iter_idx < kIterations; ++iter_idx) {
int step_offset = iter_idx * OutputTileIterator::Shape::kRow;
CUTLASS_PRAGMA_UNROLL
for (int rid = 0; rid < kRowIterations; ++rid) {
int row_step_offset = rid * kDeltaRow;
int row_offset = thread_offset_row_base + step_offset + row_step_offset;
bool is_load = (row_offset < extent_.row());
shift_k_frag_[iter_idx * kRowIterations + rid] = load_shift_k_(row_offset, is_load);
}
}
}
}
/// Called at the start of one step before starting accumulator exchange
CUTLASS_DEVICE
void begin_step(int step_idx) {
fragment_D_.clear();
if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
fragment_C_.clear();
iterator_C_.load(fragment_C_);
++iterator_C_;
}
}
/// Called at the start of a row
CUTLASS_DEVICE
void begin_row(int row_idx) {
/// set the accumulator to 0
accum_sum_element_ = ElementLayernormCompute(0);
accum_sum_square_ = ElementLayernormCompute(0);
}
/// Called after accumulators have been exchanged for each accumulator vector
CUTLASS_DEVICE
void visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorFragment const &accum) {
using Mul = cutlass::multiplies<ElementLayernormCompute>;
using Minus = cutlass::minus<ElementLayernormCompute>;
using Exp = cutlass::fast_exp_op<ElementLayernormCompute>;
Minus minus;
Mul mul;
Exp exponential;
LayernormFragment result;
thread_offset_ =
iterator_D_.thread_start() +
OutputTileIterator::ThreadMap::iteration_offset(frag_idx);
NumericArrayConverter<ElementLayernormCompute, ElementOutput, kElementsPerAccess> source_converter;
OutputVector &source_vector = reinterpret_cast<OutputVector *>(&fragment_C_)[frag_idx];
bool column_guard = (thread_offset_.column() < extent_.column());
if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
result = source_converter(elementwise_(accum));
}else{
result = source_converter(elementwise_(accum, source_vector));
}
ElementLayernormCompute inv_scalar = cutlass::constants::one<ElementLayernormCompute>() / ElementLayernormCompute(extent_.column());
// Fragment is cleared for non-reachable columns so no need to check against column guard
ElementLayernormCompute accum_sum_element_tmp = element_sum_accumulator_(result);
// Square sum is different. Non-reachable columns should've been computed for shift-k
// Otherwise we will incorrectly have some extra k^2 added into square sum.
ElementLayernormCompute accum_sum_square_tmp = ElementLayernormCompute(0);
if (column_guard) {
accum_sum_square_tmp = (kIsShiftedVariance) ? \
square_sum_accumulator_(result, shift_k_frag_[iter_idx * kRowIterations + row_idx]) : \
square_sum_accumulator_(result);
}
accum_sum_element_tmp *= inv_scalar;
accum_sum_square_tmp *= inv_scalar;
// After performing the in-thread reduction, we then perform cross-thread / in-warp reduction
CUTLASS_PRAGMA_UNROLL
for (int i = kHalfThreadsPerRow; i > 0; i >>= 1) {
accum_sum_element_tmp += __shfl_xor_sync(0xFFFFFFFF, accum_sum_element_tmp, i);
accum_sum_square_tmp += __shfl_xor_sync(0xFFFFFFFF, accum_sum_square_tmp, i);
}
accum_sum_element_ += accum_sum_element_tmp;
accum_sum_square_ += accum_sum_square_tmp;
// Convert to the output
NumericArrayConverter<ElementOutput, ElementLayernormCompute, kElementsPerAccess> output_converter;
OutputVector &output = reinterpret_cast<OutputVector *>(&fragment_D_)[frag_idx];
output = output_converter(result);
}
/// Called at the start of a row
CUTLASS_DEVICE
void end_row(int row_idx) {
using ConvertVarianceOutput = cutlass::NumericConverter<ElementVariance, ElementLayernormCompute>;
using ConvertMeanOutput = cutlass::NumericConverter<ElementMean, ElementLayernormCompute>;
ConvertVarianceOutput convert_variance_output;
ConvertMeanOutput convert_mean_output;
bool is_write_thread = (thread_offset_.row() < extent_.row() && (threadIdx.x % kThreadsPerRow) == 0);
int row_offset = thread_offset_.row() + threadblock_tile_offset_.n() * extent_.row();
ElementVariance *curr_ptr_sum_square = params_.ptr_Variance + row_offset;
ElementMean *curr_ptr_element_sum = params_.ptr_Mean + row_offset;
arch::global_store<ElementVariance, sizeof(ElementVariance)>(
convert_variance_output(accum_sum_square_),
(void *)curr_ptr_sum_square,
is_write_thread);
arch::global_store<ElementMean, sizeof(ElementMean)>(
convert_mean_output(accum_sum_element_),
(void *)curr_ptr_element_sum,
is_write_thread);
}
/// Called after all accumulator elements have been visited
CUTLASS_DEVICE
void end_step(int step_idx) {
iterator_D_.store(fragment_D_);
++iterator_D_;
}
/// Called after all steps have been completed
CUTLASS_DEVICE
void end_epilogue() {
}
private:
CUTLASS_DEVICE
ElementLayernormCompute load_shift_k_(int row_offset, bool is_load) {
using ConvertShiftK = cutlass::NumericConverter<ElementLayernormCompute, ElementOutput>;
ConvertShiftK convert_shift_k;
ElementOutput shift_k_val;
// Computes the address to load shift_k element
ElementOutput *curr_ptr_shift_k = params_.ptr_Shifted_K + row_offset;
// Conditionally loads from global memory
arch::global_load<ElementOutput, sizeof(ElementOutput)>(shift_k_val, (void *)curr_ptr_shift_k, is_load);
// Converts data type to return
ElementLayernormCompute converted_shift_k_val = convert_shift_k(shift_k_val);
return converted_shift_k_val;
}
CUTLASS_DEVICE
ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum) {
ElementLayernormCompute sum_ = ElementLayernormCompute(0);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < LayernormFragment::kElements; ++i) {
auto accum_ = accum[i];
sum_ += accum_ * accum_;
}
return sum_;
}
CUTLASS_DEVICE
ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum, ElementLayernormCompute shift_k_val) {
ElementLayernormCompute sum_ = ElementLayernormCompute(0);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < LayernormFragment::kElements; ++i) {
auto accum_ = accum[i] - shift_k_val;
sum_ += accum_ * accum_;
}
return sum_;
}
CUTLASS_DEVICE
ElementLayernormCompute element_sum_accumulator_(LayernormFragment const &accum) {
ElementLayernormCompute sum_ = ElementLayernormCompute(0);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < LayernormFragment::kElements; ++i) {
sum_ += accum[i];
}
return sum_;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,77 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind gemm related enum types to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/gemm/gemm.h"
#include "host.h"
namespace py = pybind11;
void bind_gemm(py::module &m) {
//
// Enumerate types
// cutlass/gemm/gemm.h
py::enum_<cutlass::gemm::GemmUniversalMode>(m, "Mode")
.value("Gemm", cutlass::gemm::GemmUniversalMode::kGemm, "Ordinary GEMM & GEMM Split-K serial")
.value("GemmSplitKParallel", cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, "GEMM Split-K parallel")
.value("Batched", cutlass::gemm::GemmUniversalMode::kBatched, "Batched GEMM")
.value("Array", cutlass::gemm::GemmUniversalMode::kArray)
.value("Invalid", cutlass::gemm::GemmUniversalMode::kInvalid);
/// GemmCoord is a structure that specifies a location within the coordinate space of a GEMM problem
py::class_<cutlass::gemm::GemmCoord>(m, "GemmCoord")
.def(py::init<int, int, int>())
.def("m", py::overload_cast<>(&cutlass::gemm::GemmCoord::m))
.def("n", py::overload_cast<>(&cutlass::gemm::GemmCoord::n))
.def("k", py::overload_cast<>(&cutlass::gemm::GemmCoord::k))
// get tensor coords
.def("mk",
[](const cutlass::gemm::GemmCoord & problem_size) {
return cutlass::MatrixCoord(problem_size.mk());
})
.def("kn",
[](const cutlass::gemm::GemmCoord & problem_size) {
return cutlass::MatrixCoord(problem_size.kn());
})
.def("mn",
[](const cutlass::gemm::GemmCoord & problem_size) {
return cutlass::MatrixCoord(problem_size.mn());
});
py::module_ host_submodule = m.def_submodule("host");
bind_gemm_host_helper(host_submodule);
}

View File

@ -0,0 +1,638 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/params_universal_base.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/complex.h"
#include "cutlass/semaphore.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/trace.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
>
struct GemmUniversalwithEpilogueVisitor {
public:
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueVisitor = typename Epilogue::Visitor;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using ElementC = typename EpilogueVisitor::ElementOutput;
using LayoutC = typename EpilogueVisitor::OutputTileIterator::Layout;
static ComplexTransform const kTransformA = Mma::kTransformA;
static ComplexTransform const kTransformB = Mma::kTransformB;
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
static int const kStages = Mma::kStages;
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
/// Split-K preserves splits that are 128b aligned
static int const kSplitKAlignment = const_max(
128 / sizeof_bits<ElementA>::value,
128 / sizeof_bits<ElementB>::value
);
//
// Structures
//
/// Argument structure
struct Arguments : UniversalArgumentsBase {
//
// Data members
//
typename EpilogueVisitor::Arguments epilogue_visitor;
void const * ptr_A;
void const * ptr_B;
void const * ptr_C;
void * ptr_D;
int64_t batch_stride_A;
int64_t batch_stride_B;
int64_t batch_stride_C;
typename LayoutA::Stride stride_a;
typename LayoutB::Stride stride_b;
typename LayoutC::Stride stride_c;
typename LayoutC::Stride stride_d;
typename LayoutA::Stride::LongIndex lda;
typename LayoutB::Stride::LongIndex ldb;
typename LayoutC::Stride::LongIndex ldc;
typename LayoutC::Stride::LongIndex ldd;
int const * ptr_gather_A_indices;
int const * ptr_gather_B_indices;
int const * ptr_scatter_D_indices;
//
// Methods
//
Arguments():
ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr),
ptr_gather_A_indices(nullptr),
ptr_gather_B_indices(nullptr),
ptr_scatter_D_indices(nullptr) {}
/// constructs an arguments structure
Arguments(
GemmUniversalMode mode,
GemmCoord problem_size,
int batch_count,
typename EpilogueVisitor::Arguments epilogue_visitor,
void const * ptr_A,
void const * ptr_B,
void const * ptr_C,
void * ptr_D,
int64_t batch_stride_A,
int64_t batch_stride_B,
int64_t batch_stride_C,
int64_t batch_stride_D,
typename LayoutA::Stride stride_a,
typename LayoutB::Stride stride_b,
typename LayoutC::Stride stride_c,
typename LayoutC::Stride stride_d,
int const *ptr_gather_A_indices = nullptr,
int const *ptr_gather_B_indices = nullptr,
int const *ptr_scatter_D_indices = nullptr
):
UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D),
epilogue_visitor(epilogue_visitor),
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C),
stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d),
ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),
ptr_scatter_D_indices(ptr_scatter_D_indices) {
lda = 0;
ldb = 0;
ldc = 0;
ldd = 0;
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
}
/// constructs an arguments structure
Arguments(
GemmUniversalMode mode,
GemmCoord problem_size,
int batch_count,
typename EpilogueVisitor::Arguments epilogue_visitor,
void const * ptr_A,
void const * ptr_B,
void const * ptr_C,
void * ptr_D,
int64_t batch_stride_A,
int64_t batch_stride_B,
int64_t batch_stride_C,
int64_t batch_stride_D,
typename LayoutA::Stride::LongIndex lda,
typename LayoutB::Stride::LongIndex ldb,
typename LayoutC::Stride::LongIndex ldc,
typename LayoutC::Stride::LongIndex ldd,
int const *ptr_gather_A_indices = nullptr,
int const *ptr_gather_B_indices = nullptr,
int const *ptr_scatter_D_indices = nullptr
):
UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D),
epilogue_visitor(epilogue_visitor),
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C),
lda(lda), ldb(ldb), ldc(ldc), ldd(ldd),
ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),
ptr_scatter_D_indices(ptr_scatter_D_indices) {
stride_a = make_Coord(lda);
stride_b = make_Coord(ldb);
stride_c = make_Coord(ldc);
stride_d = make_Coord(ldd);
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
}
/// Returns arguments for the transposed problem
Arguments transposed_problem() const {
Arguments args(*this);
std::swap(args.problem_size.m(), args.problem_size.n());
std::swap(args.ptr_A, args.ptr_B);
std::swap(args.lda, args.ldb);
std::swap(args.stride_a, args.stride_b);
std::swap(args.batch_stride_A, args.batch_stride_B);
std::swap(args.ptr_gather_A_indices, args.ptr_gather_B_indices);
return args;
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params : UniversalParamsBase<
ThreadblockSwizzle,
ThreadblockShape,
ElementA,
ElementB,
ElementC> {
using ParamsBase = UniversalParamsBase<
ThreadblockSwizzle,
ThreadblockShape,
ElementA,
ElementB,
ElementC>;
typename Mma::IteratorA::Params params_A;
typename Mma::IteratorB::Params params_B;
typename EpilogueVisitor::OutputTileIterator::Params params_C;
typename EpilogueVisitor::OutputTileIterator::Params params_D;
typename EpilogueVisitor::Params epilogue_visitor;
void * ptr_A;
void * ptr_B;
void * ptr_C;
void * ptr_D;
int64_t batch_stride_A;
int64_t batch_stride_B;
int64_t batch_stride_C;
int * ptr_gather_A_indices;
int * ptr_gather_B_indices;
int * ptr_scatter_D_indices;
int *semaphore;
//
// Methods
//
/// Default constructor
Params() = default;
CUTLASS_HOST_DEVICE
Params(
Arguments const &args,
int device_sms,
int sm_occupancy
):
ParamsBase(args, device_sms, sm_occupancy),
params_A(args.lda ? make_Coord_with_padding<LayoutA::kStrideRank>(args.lda) : args.stride_a),
params_B(args.ldb ? make_Coord_with_padding<LayoutB::kStrideRank>(args.ldb) : args.stride_b),
params_C(args.ldc ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldc) : args.stride_c),
params_D(args.ldd ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldd) : args.stride_d),
epilogue_visitor(args.epilogue_visitor),
ptr_A(const_cast<void *>(args.ptr_A)),
ptr_B(const_cast<void *>(args.ptr_B)),
ptr_C(const_cast<void *>(args.ptr_C)),
ptr_D(args.ptr_D),
batch_stride_A(args.batch_stride_A),
batch_stride_B(args.batch_stride_B),
batch_stride_C(args.batch_stride_C),
ptr_gather_A_indices(const_cast<int *>(args.ptr_gather_A_indices)),
ptr_gather_B_indices(const_cast<int *>(args.ptr_gather_B_indices)),
ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices)) {
}
CUTLASS_HOST_DEVICE
void update(
Arguments const &args,
void *workspace = nullptr) {
ptr_A = const_cast<void *>(args.ptr_A);
ptr_B = const_cast<void *>(args.ptr_B);
ptr_C = const_cast<void *>(args.ptr_C);
ptr_D = args.ptr_D;
ptr_gather_A_indices = const_cast<int *>(args.ptr_gather_A_indices);
ptr_gather_B_indices = const_cast<int *>(args.ptr_gather_B_indices);
ptr_scatter_D_indices = const_cast<int *>(args.ptr_scatter_D_indices);
batch_stride_A = args.batch_stride_A;
batch_stride_B = args.batch_stride_B;
batch_stride_C = args.batch_stride_C;
epilogue_visitor = args.epilogue_visitor;
semaphore = static_cast<int *>(workspace);
CUTLASS_TRACE_HOST("GemmUniversal::Params::update()");
}
};
/// Shared memory storage structure
union SharedStorage {
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
typename EpilogueVisitor::SharedStorage visitor;
};
public:
//
// Methods
//
CUTLASS_DEVICE
GemmUniversalwithEpilogueVisitor() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(
cutlass::gemm::GemmCoord const & problem_size) {
CUTLASS_TRACE_HOST("GemmUniversalwithEpilogueVisitor::can_implement()");
static int const kAlignmentA = (platform::is_same<LayoutA,
layout::ColumnMajorInterleaved<32>>::value)
? 32
: (platform::is_same<LayoutA,
layout::ColumnMajorInterleaved<64>>::value)
? 64
: Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = (platform::is_same<LayoutB,
layout::RowMajorInterleaved<32>>::value)
? 32
: (platform::is_same<LayoutB,
layout::RowMajorInterleaved<64>>::value)
? 64
: Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = (platform::is_same<LayoutC,
layout::ColumnMajorInterleaved<32>>::value)
? 32
: (platform::is_same<LayoutC,
layout::ColumnMajorInterleaved<64>>::value)
? 64
: Epilogue::OutputTileIterator::kElementsPerAccess;
bool isAMisaligned = false;
bool isBMisaligned = false;
bool isCMisaligned = false;
if (platform::is_same<LayoutA, layout::RowMajor>::value) {
isAMisaligned = problem_size.k() % kAlignmentA;
} else if (platform::is_same<LayoutA, layout::ColumnMajor>::value) {
isAMisaligned = problem_size.m() % kAlignmentA;
} else if (platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value
|| platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {
isAMisaligned = problem_size.k() % kAlignmentA;
}
if (platform::is_same<LayoutB, layout::RowMajor>::value) {
isBMisaligned = problem_size.n() % kAlignmentB;
} else if (platform::is_same<LayoutB, layout::ColumnMajor>::value) {
isBMisaligned = problem_size.k() % kAlignmentB;
} else if (platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value
|| platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {
isBMisaligned = problem_size.k() % kAlignmentB;
}
if (platform::is_same<LayoutC, layout::RowMajor>::value) {
isCMisaligned = problem_size.n() % kAlignmentC;
} else if (platform::is_same<LayoutC, layout::ColumnMajor>::value) {
isCMisaligned = problem_size.m() % kAlignmentC;
} else if (platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value
|| platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {
isCMisaligned = problem_size.n() % kAlignmentC;
}
if (isAMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand");
return Status::kErrorMisalignedOperand;
}
if (isBMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand");
return Status::kErrorMisalignedOperand;
}
if (isCMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand");
return Status::kErrorMisalignedOperand;
}
CUTLASS_TRACE_HOST(" returning kSuccess");
return Status::kSuccess;
}
static Status can_implement(Arguments const &args) {
return can_implement(args.problem_size);
}
// Factory invocation
CUTLASS_DEVICE
static void invoke(
Params const &params,
SharedStorage &shared_storage)
{
GemmUniversalwithEpilogueVisitor op;
op(params, shared_storage);
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
return;
}
int offset_k = 0;
int problem_size_k = params.problem_size.k();
ElementA *ptr_A = static_cast<ElementA *>(params.ptr_A);
ElementB *ptr_B = static_cast<ElementB *>(params.ptr_B);
//
// Fetch pointers based on mode.
//
if (params.mode == GemmUniversalMode::kGemm ||
params.mode == GemmUniversalMode::kGemmSplitKParallel) {
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) {
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
}
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
}
else if (params.mode == GemmUniversalMode::kBatched) {
ptr_A += threadblock_tile_offset.k() * params.batch_stride_A;
ptr_B += threadblock_tile_offset.k() * params.batch_stride_B;
}
else if (params.mode == GemmUniversalMode::kArray) {
ptr_A = static_cast<ElementA * const *>(params.ptr_A)[threadblock_tile_offset.k()];
ptr_B = static_cast<ElementB * const *>(params.ptr_B)[threadblock_tile_offset.k()];
}
__syncthreads();
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_tile_offset.m() * Mma::Shape::kM,
offset_k,
};
cutlass::MatrixCoord tb_offset_B{
offset_k,
threadblock_tile_offset.n() * Mma::Shape::kN
};
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
params.params_A,
ptr_A,
{params.problem_size.m(), problem_size_k},
thread_idx,
tb_offset_A,
params.ptr_gather_A_indices);
typename Mma::IteratorB iterator_B(
params.params_B,
ptr_B,
{problem_size_k, params.problem_size.n()},
thread_idx,
tb_offset_B,
params.ptr_gather_B_indices);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
mma(
gemm_k_iterations,
accumulators,
iterator_A,
iterator_B,
accumulators);
//
// Epilogue
//
// EpilogueOutputOp output_op(params.output_op);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
//assume identity swizzle
MatrixCoord threadblock_offset(
threadblock_tile_offset.m() * Mma::Shape::kM,
threadblock_tile_offset.n() * Mma::Shape::kN
);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
ElementC *ptr_C = static_cast<ElementC *>(params.ptr_C);
ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);
//
// Fetch pointers based on mode.
//
// Construct the semaphore.
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
// Tile iterator loading from source tensor.
EpilogueVisitor epilogue_visitor(
params.epilogue_visitor,
shared_storage.visitor,
threadblock_offset,
threadblock_tile_offset,
thread_idx,
params.problem_size.mn()
);
if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) {
epilogue_visitor.set_batch_index(threadblock_tile_offset.k());
}
Epilogue epilogue(
shared_storage.epilogue,
thread_idx,
warp_idx,
lane_idx);
// Wait on the semaphore - this latency may have been covered by iterator construction
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
semaphore.wait(threadblock_tile_offset.k());
}
// Execute the epilogue operator to update the destination tensor.
epilogue(epilogue_visitor, accumulators);
//
// Release the semaphore
//
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
int lock = 0;
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
// The final threadblock resets the semaphore for subsequent grids.
lock = 0;
}
else {
// Otherwise, the semaphore is incremented
lock = threadblock_tile_offset.k() + 1;
}
semaphore.release(lock);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,47 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind gemm host helpers to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/util/host_reorder.h"
#include "cutlass/layout/tensor.h"
namespace py = pybind11;
void bind_gemm_host_helper(py::module &m) {
m.def("reorder_column", &cutlass::reorder_column<32, int8_t, cutlass::layout::RowMajorInterleaved<32>>);
m.def("reorder_column", &cutlass::reorder_column<32, int8_t, cutlass::layout::ColumnMajorInterleaved<32>>);
}

View File

@ -0,0 +1,47 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind CUTLASS layouts to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "tensor.h"
#include "matrix.h"
namespace py = pybind11;
void bind_layout(py::module &m) {
bind_tensor_layout(m);
bind_matrix_layout(m);
}

View File

@ -0,0 +1,87 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind Matrix layouts to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/layout/matrix.h"
namespace py = pybind11;
void bind_matrix_layout(py::module &m) {
//
// Matrix layouts
// cutlass/layout/matrix.h
//
py::class_<cutlass::layout::RowMajor>(m, "RowMajor", R"pbdoc(
Mapping function for row-major matrices.
)pbdoc")
.def_static("packed", &cutlass::layout::RowMajor::packed,
py::arg("extent"),
R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc")
.def("stride", [](const cutlass::layout::RowMajor & layout){
return layout.stride().at(0);
}, R"pbdoc(Returns the stride of the layout)pbdoc");
py::class_<cutlass::layout::ColumnMajor>(m, "ColumnMajor", R"pbdoc(
Mapping function for column-major matrices.
)pbdoc")
.def_static("packed", &cutlass::layout::ColumnMajor::packed,
py::arg("extent"),
R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc" )
.def("stride", [](const cutlass::layout::ColumnMajor & layout){
return layout.stride().at(0);
}, R"pbdoc(Returns the stride of the layout)pbdoc");
py::class_<cutlass::layout::RowMajorInterleaved<32>>(m, "RowMajorInterleaved32",
R"pbdoc(Mapping function for interleaved matrices. Matrix is structured
as row-major arrangement of fixed-size columns 32)pbdoc")
.def_static("packed", &cutlass::layout::RowMajorInterleaved<32>::packed,
py::arg("extent"),
R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc")
.def("stride", [](const cutlass::layout::RowMajorInterleaved<32> & layout){
return layout.stride().at(0);
}, R"pbdoc(Returns the stride of the layout)pbdoc");
py::class_<cutlass::layout::ColumnMajorInterleaved<32>>(m, "ColumnMajorInterleaved32",
R"pbdoc(Mapping function for interleaved matrices. Matrix is structured
as column-major arrangement of fixed-size rows 32)pbdoc")
.def_static("packed", &cutlass::layout::ColumnMajorInterleaved<32>::packed,
py::arg("extent"),
R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc")
.def("stride", [](const cutlass::layout::ColumnMajorInterleaved<32> & layout){
return layout.stride().at(0);
}, R"pbdoc(Returns the stride of the layout)pbdoc");
}

View File

@ -0,0 +1,74 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind Tensor layouts to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/layout/tensor.h"
namespace py = pybind11;
void bind_tensor_layout(py::module &m) {
//
// Tensor layouts
// cutlass/include/cutlass/layout/tensor.h
//
/// Mapping function for 4-D NHWC tensors.
py::class_<cutlass::layout::TensorNHWC>(m, "TensorNHWC",
R"pbdoc(Mapping function for 4-D NHWC tensors)pbdoc")
.def_static("packed", &cutlass::layout::TensorNHWC::packed,
py::arg("extent"),
R"pbdoc(Helper returns a layout to a tightly packed NHWC tensor)pbdoc")
.def("stride", py::overload_cast<>(&cutlass::layout::TensorNHWC::stride),
R"pbdoc(Returns the stride of the layout)pbdoc");
/// Mapping function for 4-D NC/xHWx tensors.
py::class_<cutlass::layout::TensorNCxHWx<32>>(m, "TensorNC32HW32",
R"pbdoc(Mapping function for 4-D NC/32HW32 tensors)pbdoc")
.def_static("packed", &cutlass::layout::TensorNCxHWx<32>::packed,
py::arg("extent"),
R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc")
.def("stride", py::overload_cast<>(&cutlass::layout::TensorNCxHWx<32>::stride),
R"pbdoc(Returns the stride of the layout)pbdoc");
/// Mapping function for 4-D CxRSKx tensors.
py::class_<cutlass::layout::TensorCxRSKx<32>>(m, "TensorC32RSK32",
R"pbdoc(Mapping function for 4-D C32RSK32 tensors)pbdoc")
.def_static("packed", &cutlass::layout::TensorCxRSKx<32>::packed,
py::arg("extent"),
R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc")
.def("stride", py::overload_cast<>(&cutlass::layout::TensorCxRSKx<32>::stride),
R"pbdoc(Returns the stride of the layout)pbdoc");
}

View File

@ -0,0 +1,170 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind threadblock swizzling to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/conv/threadblock/threadblock_swizzle.h"
#include <cxxabi.h>
#include <cuda_runtime.h>
namespace py = pybind11;
std::string demangle(const char* mangled_name) {
std::size_t len = 0;
int status = 0;
std::unique_ptr<char> ptr(
__cxxabiv1::__cxa_demangle(mangled_name, nullptr, &len, &status));
return ptr.get();
}
template<typename T>
void bind_identity_swizzle(py::module & m, std::string name) {
py::class_<T>(m, name.c_str(),
R"pbdoc(Threadblock swizzling function for GEMMs)pbdoc")
.def(py::init<>())
.def("get_tiled_shape",
py::overload_cast<cutlass::gemm::GemmCoord, cutlass::gemm::GemmCoord, int>(
&T::get_tiled_shape, py::const_
), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"),
R"pbdoc(Returns the shape of the problem in units of logical tiles
:param problem_size: gemm(M, N, K)
:type problem_size: :class:`cutlass.gemm.GemmCoord`
)pbdoc")
.def("get_tiled_shape",
py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&, cutlass::gemm::GemmCoord, int>(
&T::get_tiled_shape, py::const_
), py::arg("conv_operator"), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"),
R"pbdoc(Returns the shape of the problem in units of logical tiles
:param problem_size: Implicit gemm problem size conv_operator(NPQK, NHWC, KRSC)
:type problem_size: :class:`cutlass.gemm.GemmCoord`)
)pbdoc")
.def("get_tiled_shape",
py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv3dProblemSize&, cutlass::gemm::GemmCoord, int>(
&T::get_tiled_shape, py::const_
), py::arg("conv_operator"), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"),
R"pbdoc(Returns the shape of the problem in units of logical tiles
:param problem_size: Implicit gemm problem size conv_operator(NZPQK, NDHWC, KTRSC)
:type problem_size: :class:`cutlass.gemm.GemmCoord`)
)pbdoc")
.def("get_grid_shape", &T::get_grid_shape,
py::arg("tiled_shape"),
R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc")
.def("tag", [](const T & swizzle){
return demangle(typeid(T).name());
}, R"pbdoc(Returns the c++ name of the swizzling for code emission)pbdoc");
}
template<typename T>
void bind_swizzle(py::module & m, std::string name, std::string doc) {
py::class_<T>(m, name.c_str(), doc.c_str())
.def(py::init<>())
.def("get_tiled_shape",
py::overload_cast<cutlass::gemm::GemmCoord, cutlass::gemm::GemmCoord, int>(
&T::get_tiled_shape, py::const_
), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"),
R"pbdoc(Returns the shape of the problem in units of logical tiles
:param problem_size: gemm(M, N, K)
:type problem_size: :class:`cutlass.gemm.GemmCoord`
)pbdoc")
.def("get_grid_shape", &T::get_grid_shape,
py::arg("tiled_shape"),
R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc")
.def("tag", [](const T & swizzle){
return demangle(typeid(T).name());
}, R"pbdoc(Returns the c++ name of the swizzling for code emission)pbdoc");
}
template<typename T>
void bind_swizzle_streamk(py::module & m, std::string name, std::string doc) {
py::class_<T>(m, name.c_str(), doc.c_str())
.def(py::init<>())
.def("tag", [](const T & swizzle){
return demangle(typeid(T).name());
}, R"pbdoc(Returns the c++ name of the swizzling for code emission)pbdoc");
}
template<typename T>
void bind_dgrad_swizzle(py::module & m, std::string name) {
py::class_<T>(m, name.c_str(),
R"pbdoc(Threadblock swizzling function for strided dgrad convolution)pbdoc")
.def(py::init<>())
.def("get_tiled_shape",
py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&, cutlass::gemm::GemmCoord, int>(
&T::get_tiled_shape, py::const_
), py::arg("conv_operator"), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"),
R"pbdoc(Returns the shape of the problem in units of logical tiles
:param problem_size: Implicit gemm problem size conv_operator(NPQK, NHWC, KRSC)
:type problem_size: :class:`cutlass.gemm.GemmCoord`)
)pbdoc")
.def("get_grid_shape", [](const T & swizzle, cutlass::gemm::GemmCoord tiled_shape) {
return dim3(tiled_shape.m(), tiled_shape.n(), tiled_shape.k());
}, py::arg("tiled_shape"),
R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc")
.def("tag", [](const T & swizzle){
return demangle(typeid(T).name());
}, R"pbdoc(Returns the c++ name of the swizzling for code emission)pbdoc");
}
void bind_threadblock_swizzle(py::module &m) {
py::class_<dim3>(m, "dim3",
R"pbdoc(A int3 type xyz contains three integers)pbdoc")
.def(py::init<int, int, int>(),
py::arg("x"), py::arg("y"), py::arg("z"))
.def_readwrite("x", &dim3::x, R"pbdoc(get value x)pbdoc")
.def_readwrite("y", &dim3::y, R"pbdoc(get value y)pbdoc")
.def_readwrite("z", &dim3::z, R"pbdoc(get value z)pbdoc");
bind_identity_swizzle<cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>>(m, "IdentitySwizzle1");
bind_identity_swizzle<cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>>(m, "IdentitySwizzle2");
bind_identity_swizzle<cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>>(m, "IdentitySwizzle4");
bind_identity_swizzle<cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>>(m, "IdentitySwizzle8");
bind_swizzle<cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle>(m, "HorizontalSwizzle", R"pbdoc(Threadblock swizzling function for GEMMs)pbdoc");
bind_swizzle<cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle>(m, "BatchedIdentitySwizzle", R"pbdoc(Threadblock swizzling function for batched GEMMs)pbdoc");
bind_swizzle_streamk<cutlass::gemm::threadblock::ThreadblockSwizzleStreamK>(m, "ThreadblockSwizzleStreamK", R"pbdoc(Threadblock swizzling function using Stream K feature)pbdoc");
bind_dgrad_swizzle<cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>>(m, "StridedDgradIdentitySwizzle1");
bind_dgrad_swizzle<cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>>(m, "StridedDgradIdentitySwizzle4");
bind_dgrad_swizzle<cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle>(m, "StridedDgradHorizontalSwizzle");
}

View File

@ -0,0 +1,78 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind Tensor Coord to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/tensor_coord.h"
namespace py = pybind11;
void bind_tensor_coord(py::module &m) {
//
// Tensor Coords
// cutlass/include/cutlass/tensor_coord.h
//
/// Defines a canonical 4D coordinate used by tensor operations.
py::class_<cutlass::Tensor4DCoord>(m, "Tensor4DCoord",
R"pbdoc(Defines a canonical 4D coordinate used by tensor operations)pbdoc")
.def(py::init<int, int, int, int>(),
py::arg("n"), py::arg("h"), py::arg("w"), py::arg("c"),
R"pbdoc(Helper to construct from N, H, W, and C)pbdoc")
.def("at", py::overload_cast<int>(&cutlass::Tensor4DCoord::at),
py::arg("dim"),
R"pbdoc(Gets the index of a given Coord element)pbdoc")
.def("size", [](const cutlass::Tensor4DCoord & coord) {
return coord.at(0) * coord.at(1) * coord.at(2) * coord.at(3);},
R"pbdoc(The size of the tensor coord)pbdoc");
py::class_<cutlass::Coord<3>>(m, "Tensor3DCoord",
R"pbdoc(Defines a canonical 3D coordinate used by tensor operations)pbdoc")
.def("at", py::overload_cast<int>(&cutlass::Coord<3>::at),
py::arg("dim"),
R"pbdoc(Gets the index of a given Coord element)pbdoc");
// Matrix Size
py::class_<cutlass::MatrixCoord>(m, "MatrixCoord",
R"pbdoc(MatrixCoord wraps Coord<2, int> to provide a helper for accessing named dimensions. Classes
expecting a coordinate in the rank=2 index space of a matrix should use MatrixCoord.)pbdoc")
.def(py::init<int, int>(),
py::arg("row"), py::arg("column"), R"pbdoc(Helper to construct from a row and column)pbdoc")
.def("row", py::overload_cast<>(&cutlass::MatrixCoord::row),
R"pbdoc(Returns the row of the coordinate)pbdoc")
.def("column", py::overload_cast<>(&cutlass::MatrixCoord::column),
R"pbdoc(Returns the column of the coordinate)pbdoc");
}

View File

@ -0,0 +1,102 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind TensorRef and View to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/tensor_ref.h"
#include "cutlass/tensor_view.h"
#include "types.h"
template<typename T, typename L, typename TF>
void bind_tensor_ref_view(py::module &m, std::string name) {
py::class_<cutlass::TensorRef<T, L>>(m, ("TensorRef" + name).c_str())
.def("__init__", [](cutlass::TensorRef<T, L>& tensor_ref, int64_t address, const L& layout_ ) {
T* ptr = reinterpret_cast< T*>(address);
new (&tensor_ref) cutlass::TensorRef<T, L>(ptr, layout_);
})
.def("data", [](cutlass::TensorRef<T, L>& tensor_ref) {
T* ptr = tensor_ref.data();
return int64_t(ptr);
})
.def("layout", py::overload_cast<>(&cutlass::TensorRef<T, L>::layout));
m.def("get_tensor_ref", [](int64_t address, TF data, const L& layout_) {
T* ptr = reinterpret_cast<T*>(address);
cutlass::TensorRef<T, L> tensor_ref = cutlass::TensorRef<T, L>(ptr, layout_);
return tensor_ref;
});
py::class_<cutlass::TensorView<T, L>>(m, ("TensorView" + name).c_str())
.def(py::init<const cutlass::TensorRef<T, L>&, const typename L::TensorCoord &>());
}
void bind_tensor_refs_and_views(py::module &m) {
/// float
bind_tensor_ref_view<float, cutlass::layout::RowMajor, cutlass::float32>(m, "F32RowMajor");
bind_tensor_ref_view<float, cutlass::layout::ColumnMajor, cutlass::float32>(m, "F32ColumnMajor");
bind_tensor_ref_view<float, cutlass::layout::TensorNHWC, cutlass::float32>(m, "F32NHWC");
/// double
bind_tensor_ref_view<double, cutlass::layout::RowMajor, cutlass::float64>(m, "F64RowMajor");
bind_tensor_ref_view<double, cutlass::layout::ColumnMajor, cutlass::float64>(m, "F64ColumnMajor");
bind_tensor_ref_view<double, cutlass::layout::TensorNHWC, cutlass::float64>(m, "F64NHWC");
// half_t
bind_tensor_ref_view<cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t>(m, "F16RowMajor");
bind_tensor_ref_view<cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t>(m, "F16ColumnMajor");
bind_tensor_ref_view<cutlass::half_t, cutlass::layout::TensorNHWC, cutlass::half_t>(m, "F16NHWC");
// bfloat16
bind_tensor_ref_view<cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t>(m, "BF16RowMajor");
bind_tensor_ref_view<cutlass::bfloat16_t, cutlass::layout::ColumnMajor, cutlass::bfloat16_t>(m, "BF16ColumnMajor");
bind_tensor_ref_view<cutlass::bfloat16_t, cutlass::layout::TensorNHWC, cutlass::bfloat16_t>(m, "BF16NHWC");
// int8_t
bind_tensor_ref_view<int8_t, cutlass::layout::RowMajorInterleaved<32>, cutlass::int8>(m, "S8RowMajorInterleaved32");
bind_tensor_ref_view<int8_t, cutlass::layout::ColumnMajorInterleaved<32>, cutlass::int8>(m, "S8ColumnMajorInterleaved32");
bind_tensor_ref_view<int8_t, cutlass::layout::RowMajor, cutlass::int8>(m, "S8RowMajor");
bind_tensor_ref_view<int8_t, cutlass::layout::ColumnMajor, cutlass::int8>(m, "S8ColumnMajor");
bind_tensor_ref_view<int8_t, cutlass::layout::TensorNHWC, cutlass::int8>(m, "S8NHWC");
bind_tensor_ref_view<int8_t, cutlass::layout::TensorNCxHWx<32>, cutlass::int8>(m, "S8NC32HW32");
bind_tensor_ref_view<int8_t, cutlass::layout::TensorCxRSKx<32>, cutlass::int8>(m, "S8C32RSK32");
// int32_t
bind_tensor_ref_view<int32_t, cutlass::layout::RowMajor, cutlass::int32>(m, "S32RowMajor");
bind_tensor_ref_view<int32_t, cutlass::layout::ColumnMajor, cutlass::int32>(m, "S32ColumnMajor");
bind_tensor_ref_view<int32_t, cutlass::layout::TensorNHWC, cutlass::int32>(m, "S32NHWC");
}

View File

@ -0,0 +1,146 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind CUTLASS types to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/half.h"
namespace py = pybind11;
namespace cutlass {
/// IEEE 32-bit signed integer
struct alignas(1) int8 {
int8_t storage;
explicit int8(int x) {
storage = int8_t(x);
}
explicit int8(float x) {
storage = int8_t(x);
}
int8_t c_value(){return storage;}
};
/// IEEE 32-bit signed integer
struct alignas(4) int32 {
int storage;
explicit int32(int x) {
storage = x;
}
explicit int32(float x) {
storage = int(x);
}
int c_value(){return storage;}
};
/// IEEE single-precision floating-point type
struct alignas(4) float32 {
float storage;
explicit float32(float x) {
storage = x;
}
explicit float32(int x) {
storage = float(x);
}
float c_value(){return storage;}
};
/// IEEE double-precision floating-point type
struct alignas(4) float64 {
double storage;
explicit float64(float x) {
storage = double(x);
}
explicit float64(int x) {
storage = double(x);
}
double c_value(){return storage;}
};
}
void bind_cutlass_types(py::module &m) {
// s8
py::class_<cutlass::int8>(m, "int8")
.def(py::init<float>())
.def(py::init<int>())
.def_readwrite("storage", &cutlass::int8::storage)
.def("value", &cutlass::int8::c_value);
// s32
py::class_<cutlass::int32>(m, "int32")
.def(py::init<float>())
.def(py::init<int>())
.def_readwrite("storage", &cutlass::int32::storage)
.def("value", &cutlass::int32::c_value);
// f16
py::class_<cutlass::half_t>(m, "float16")
.def(py::init<float>())
.def(py::init<double>())
.def(py::init<int>())
.def(py::init<unsigned>())
.def_readwrite("storage", &cutlass::half_t::storage)
.def("value", [](const cutlass::half_t& value) {return value;});
// bf16
py::class_<cutlass::bfloat16_t>(m, "bfloat16")
.def(py::init<float>())
.def(py::init<int>())
.def_readwrite("storage", &cutlass::bfloat16_t::storage)
.def("value", [](const cutlass::bfloat16_t& value) {return value;});
// f32
py::class_<cutlass::float32>(m, "float32")
.def(py::init<float>())
.def(py::init<int>())
.def_readwrite("storage", &cutlass::float32::storage)
.def("value", &cutlass::float32::c_value);
// tf32
py::class_<cutlass::tfloat32_t>(m, "tfloat32")
.def(py::init<float>())
.def(py::init<int>())
.def_readwrite("storage", &cutlass::tfloat32_t::storage)
.def("value", [](const cutlass::tfloat32_t& value) {return value;});
// f64
py::class_<cutlass::float64>(m, "float64")
.def(py::init<float>())
.def(py::init<int>())
.def_readwrite("storage", &cutlass::float64::storage)
.def("value", &cutlass::float64::c_value);
}

View File

@ -0,0 +1,32 @@
#include <cutlass/complex.h>
namespace cutlass {
/// ENUM class for datatypes
enum class DataType {
kB1, kU2, kU4, kU8,
kU16, kU32, kU64, kS2,
kS4, kS8, kS16, kS32,
kS64, kF16, kBF16, kF32,
kTF32, kF64, kCF16, kCBF16,
kCF32, kCTF32, kCF64, kCS2,
kCS4, kCS8, kCS16, kCS32,
kCS64, kCU2, kCU4, kCU8,
kCU16, kCU32, kCU64, kInvalid
};
/// ENUM class for LayoutTypes
enum class LayoutType {
kColumnMajor, kRowMajor,
kColumnMajorInterleaved2, kRowMajorInterleaved2,
kColumnMajorInterleaved32, kRowMajorInterleaved32,
kColumnMajorInterleaved64, kRowMajorInterleaved64,
kTensorNHWC, kTensorNDHWC, kTensorNCHW, kTensorNGHWC,
kTensorNC32HW32, kTensorNC64HW64, kTensorC32RSK32,
kTensorC64RSK64
};
/// ENUM class for opcode class
} // namespace cutlass

View File

@ -0,0 +1,54 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind convolution problems to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "unit/conv/device/conv2d_problems.h"
#include "cutlass/conv/conv2d_problem_size.h"
namespace py = pybind11;
PYBIND11_MAKE_OPAQUE(std::vector<cutlass::conv::Conv2dProblemSize>);
void bind_conv_problem_size_test(py::module &m) {
py::bind_vector<std::vector<cutlass::conv::Conv2dProblemSize>>(m, "Conv2dProblemVector")
.def("size", &std::vector<cutlass::conv::Conv2dProblemSize>::size);
// Get Conv2d problem sizes
py::class_<test::conv::device::TestbedConv2dProblemSizes>(m, "TestbedConv2dProblemSizes")
.def(py::init<int>())
.def_readonly("conv2d_default_sizes", &test::conv::device::TestbedConv2dProblemSizes::conv2d_default_sizes);
}

View File

@ -0,0 +1,49 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind convolution related types to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "conv_problems.h"
#include "host.h"
namespace py = pybind11;
void bind_convolution_test(py::module &m) {
// Conv problem sizes
bind_conv_problem_size_test(m);
py::module_ host_submodule = m.def_submodule("host");
bind_conv_host_references(host_submodule);
}

View File

@ -0,0 +1,180 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind Convolution host test helpers to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "unit/conv/device/cache_testbed_output.h"
#include "cutlass/util/reference/host/convolution.h"
#include "cutlass/util/reference/host/tensor_compare.h"
namespace py = pybind11;
template<typename Ta, typename La, typename Tb, typename Lb, typename Tc, typename Lc, typename Tacc, typename Te>
void bind_conv2d_host(py::module &m) {
m.def("conv2d", \
&cutlass::reference::host::Conv2d< \
Ta, La, Tb, Lb, Tc, Lc, Te, Tacc>);
m.def("CreateCachedConv2dTestKey", &test::conv::device::CreateCachedConv2dTestKey<Ta, La, Tb, Lb, Tc, Lc, Tacc, Te>);
}
template<typename Ta, typename La, typename Tb, typename Lb, typename Tc, typename Lc, typename Tacc, typename Te>
void bind_conv2d_host_sat(py::module &m) {
m.def("conv2d", \
&cutlass::reference::host::Conv2d< \
Ta, La, Tb, Lb, Tc, Lc, Te, Tacc>);
m.def("CreateCachedConv2dTestKey", &test::conv::device::CreateCachedConv2dTestKey<Ta, La, Tb, Lb, Tc, Lc, Tacc, Te>);
}
template<typename Ta, typename Tb, typename Tc, typename Tacc, typename Te>
void bind_conv2d_host_nhwc(py::module &m) {
bind_conv2d_host<
Ta, cutlass::layout::TensorNHWC,
Tb, cutlass::layout::TensorNHWC,
Tc, cutlass::layout::TensorNHWC,
Tacc, Te>(m);
}
template<typename Ta, typename Tb, typename Tc, typename Tacc, typename Te>
void bind_conv2d_host_nc32hw32(py::module &m) {
bind_conv2d_host_sat<
Ta, cutlass::layout::TensorNCxHWx<32>,
Tb, cutlass::layout::TensorCxRSKx<32>,
Tc, cutlass::layout::TensorNCxHWx<32>,
Tacc, Te>(m);
}
template<typename T, typename Layout>
void bind_tensor_equals(py::module &m) {
m.def("equals", py::overload_cast<
const cutlass::TensorView<T, Layout>&, const cutlass::TensorView<T, Layout>&>(
&cutlass::reference::host::TensorEquals<T, Layout>
));
}
#define BIND_TENSOR_HASH(Element, Layout) { \
m.def("TensorHash", &test::conv::device::TensorHash<Element, Layout>, py::arg("view"), py::arg("hash") = test::conv::device::CRC32(), py::arg("crc")=uint32_t()); \
}
void bind_conv_host_references(py::module &m) {
//
// Conv2d reference on host
// tools/util/include/cutlass/util/reference/host/convolution.h
/// double
bind_conv2d_host_nhwc<double, double, double, double, double>(m);
/// float
bind_conv2d_host_nhwc<float, float, float, float, float>(m);
/// half
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, cutlass::half_t, cutlass::half_t, cutlass::half_t>(m);
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, cutlass::half_t, float, cutlass::half_t>(m);
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, cutlass::half_t, float, float>(m);
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, cutlass::half_t, cutlass::half_t, float>(m);
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, float, cutlass::half_t, cutlass::half_t>(m);
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, float, float, cutlass::half_t>(m);
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, float, float, float>(m);
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, float, cutlass::half_t, float>(m);
/// bfloat16
bind_conv2d_host_nhwc<cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass::bfloat16_t, float, cutlass::bfloat16_t>(m);
bind_conv2d_host_nhwc<cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass::bfloat16_t, float, float>(m);
bind_conv2d_host_nhwc<cutlass::bfloat16_t, cutlass::bfloat16_t, float, float, cutlass::bfloat16_t>(m);
bind_conv2d_host_nhwc<cutlass::bfloat16_t, cutlass::bfloat16_t, float, float, float>(m);
/// s8
bind_conv2d_host_nhwc<int8_t, int8_t, int8_t, int32_t, int32_t>(m);
bind_conv2d_host_nhwc<int8_t, int8_t, int8_t, int32_t, int8_t>(m);
bind_conv2d_host_nhwc<int8_t, int8_t, int32_t, int32_t, int32_t>(m);
bind_conv2d_host_nhwc<int8_t, int8_t, int32_t, int32_t, int8_t>(m);
bind_conv2d_host_nhwc<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_conv2d_host_nhwc<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_conv2d_host_nhwc<int8_t, int8_t, int32_t, int32_t, float>(m);
bind_conv2d_host_nhwc<int8_t, int8_t, int32_t, int32_t, float>(m);
bind_conv2d_host_nc32hw32<int8_t, int8_t, int8_t, int32_t, int32_t>(m);
bind_conv2d_host_nc32hw32<int8_t, int8_t, int8_t, int32_t, int8_t>(m);
bind_conv2d_host_nc32hw32<int8_t, int8_t, int32_t, int32_t, int32_t>(m);
bind_conv2d_host_nc32hw32<int8_t, int8_t, int32_t, int32_t, int8_t>(m);
bind_conv2d_host_nc32hw32<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_conv2d_host_nc32hw32<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_conv2d_host_nc32hw32<int8_t, int8_t, int32_t, int32_t, float>(m);
bind_conv2d_host_nc32hw32<int8_t, int8_t, int32_t, int32_t, float>(m);
//
// Compare whether two tensors are equal
//
/// double
bind_tensor_equals<double, cutlass::layout::TensorNHWC>(m);
/// float
bind_tensor_equals<float, cutlass::layout::TensorNHWC>(m);
/// half
bind_tensor_equals<cutlass::half_t, cutlass::layout::TensorNHWC>(m);
/// bfloat16
bind_tensor_equals<cutlass::bfloat16_t, cutlass::layout::TensorNHWC>(m);
/// s32
bind_tensor_equals<int32_t, cutlass::layout::TensorNHWC>(m);
bind_tensor_equals<int32_t, cutlass::layout::TensorNCxHWx<32>>(m);
/// s8
bind_tensor_equals<int8_t, cutlass::layout::TensorNHWC>(m);
bind_tensor_equals<int8_t, cutlass::layout::TensorNCxHWx<32>>(m);
/// Cache
py::class_<test::conv::device::CachedTestKey>(m, "CachedTestKey")
.def(py::init<>())
.def(py::init<std::string, std::string, std::string, uint32_t, uint32_t, uint32_t>());
py::class_<test::conv::device::CachedTestResult>(m, "CachedTestResult")
.def(py::init<>())
.def(py::init<uint32_t>())
.def_readwrite("D", &test::conv::device::CachedTestResult::D);
py::class_<test::conv::device::CachedTestResultListing>(m, "CachedTestResultListing")
.def(py::init<const std::string &>())
.def("find", &test::conv::device::CachedTestResultListing::find)
.def("append", &test::conv::device::CachedTestResultListing::append)
.def("write", &test::conv::device::CachedTestResultListing::write);
py::class_<test::conv::device::CRC32>(m, "CRC32")
.def(py::init<>());
BIND_TENSOR_HASH(double, cutlass::layout::TensorNHWC)
BIND_TENSOR_HASH(float, cutlass::layout::TensorNHWC);
BIND_TENSOR_HASH(cutlass::half_t, cutlass::layout::TensorNHWC);
BIND_TENSOR_HASH(cutlass::bfloat16_t, cutlass::layout::TensorNHWC);
BIND_TENSOR_HASH(int32_t, cutlass::layout::TensorNHWC);
BIND_TENSOR_HASH(int8_t, cutlass::layout::TensorNCxHWx<32>);
}

View File

@ -0,0 +1,45 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind gemm test to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "host.h"
namespace py = pybind11;
void bind_gemm_test(py::module &m) {
py::module_ host_submodule = m.def_submodule("host");
bind_gemm_host_reference(host_submodule);
}

View File

@ -0,0 +1,431 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind gemm test host functions to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/cutlass.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/host_reorder.h"
#include "cutlass/functional.h"
namespace py = pybind11;
template<
typename ElementA, typename LayoutA,
typename ElementB, typename LayoutB,
typename ElementC, typename LayoutC,
typename AccumulatorType, typename ComputeType,
typename InnerProductOp>
void bind_host_gemm_saturate(py::module &m) {
m.def("gemm_saturate", py::overload_cast<
cutlass::gemm::GemmCoord, ComputeType,
cutlass::TensorRef<ElementA, LayoutA>,
cutlass::TensorRef<ElementB, LayoutB>,
ComputeType,
cutlass::TensorRef<ElementC, LayoutC>,
cutlass::TensorRef<ElementC, LayoutC>,
AccumulatorType>(
&cutlass::reference::host::compute_gemm<
ElementA, LayoutA,
ElementB, LayoutB,
ElementC, LayoutC,
ComputeType,
AccumulatorType,
InnerProductOp,
cutlass::NumericConverterClamp<ElementC, AccumulatorType>>
));
}
template<
typename ElementA, typename LayoutA,
typename ElementB, typename LayoutB,
typename ElementC, typename LayoutC,
typename AccumulatorType, typename ComputeType,
typename InnerProductOp>
void bind_host_gemm(py::module &m) {
m.def("gemm", py::overload_cast<
cutlass::gemm::GemmCoord, ComputeType,
cutlass::TensorRef<ElementA, LayoutA>,
cutlass::TensorRef<ElementB, LayoutB>,
ComputeType,
cutlass::TensorRef<ElementC, LayoutC>,
cutlass::TensorRef<ElementC, LayoutC>,
AccumulatorType>(
&cutlass::reference::host::compute_gemm<
ElementA, LayoutA,
ElementB, LayoutB,
ElementC, LayoutC,
ComputeType,
AccumulatorType,
InnerProductOp,
cutlass::NumericConverter<ElementC, AccumulatorType>>
));
}
template<
typename ElementA, typename ElementB, typename ElementC,
typename AccumulatorType, typename ComputeType>
void bind_host_gemm_multiply_add(py::module &m) {
bind_host_gemm<
ElementA, cutlass::layout::RowMajor,
ElementB, cutlass::layout::RowMajor,
ElementC, cutlass::layout::RowMajor,
ComputeType, AccumulatorType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::ColumnMajor,
ElementB, cutlass::layout::RowMajor,
ElementC, cutlass::layout::RowMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::RowMajor,
ElementB, cutlass::layout::ColumnMajor,
ElementC, cutlass::layout::RowMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::RowMajor,
ElementB, cutlass::layout::RowMajor,
ElementC, cutlass::layout::ColumnMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::RowMajor,
ElementB, cutlass::layout::ColumnMajor,
ElementC, cutlass::layout::ColumnMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::ColumnMajor,
ElementB, cutlass::layout::RowMajor,
ElementC, cutlass::layout::ColumnMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::ColumnMajor,
ElementB, cutlass::layout::ColumnMajor,
ElementC, cutlass::layout::RowMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::ColumnMajor,
ElementB, cutlass::layout::ColumnMajor,
ElementC, cutlass::layout::ColumnMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
}
template<
typename ElementA, typename ElementB, typename ElementC,
typename AccumulatorType, typename ComputeType>
void bind_host_gemm_multiply_add_saturate(py::module &m) {
bind_host_gemm_saturate<
ElementA, cutlass::layout::RowMajor,
ElementB, cutlass::layout::RowMajor,
ElementC, cutlass::layout::RowMajor,
ComputeType, AccumulatorType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::ColumnMajor,
ElementB, cutlass::layout::RowMajor,
ElementC, cutlass::layout::RowMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::RowMajor,
ElementB, cutlass::layout::ColumnMajor,
ElementC, cutlass::layout::RowMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::RowMajor,
ElementB, cutlass::layout::RowMajor,
ElementC, cutlass::layout::ColumnMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::RowMajor,
ElementB, cutlass::layout::ColumnMajor,
ElementC, cutlass::layout::ColumnMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::ColumnMajor,
ElementB, cutlass::layout::RowMajor,
ElementC, cutlass::layout::ColumnMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::ColumnMajor,
ElementB, cutlass::layout::ColumnMajor,
ElementC, cutlass::layout::RowMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::ColumnMajor,
ElementB, cutlass::layout::ColumnMajor,
ElementC, cutlass::layout::ColumnMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
}
template<
typename ElementA, typename ElementB, typename ElementC,
typename AccumulatorType, typename ComputeType>
void bind_host_gemm_multiply_add_interleaved(py::module &m) {
bind_host_gemm<
ElementA, cutlass::layout::RowMajorInterleaved<32>,
ElementB, cutlass::layout::RowMajorInterleaved<32>,
ElementC, cutlass::layout::RowMajorInterleaved<32>,
ComputeType, AccumulatorType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
ElementB, cutlass::layout::RowMajorInterleaved<32>,
ElementC, cutlass::layout::RowMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::RowMajorInterleaved<32>,
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
ElementC, cutlass::layout::RowMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::RowMajorInterleaved<32>,
ElementB, cutlass::layout::RowMajorInterleaved<32>,
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::RowMajorInterleaved<32>,
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
ElementB, cutlass::layout::RowMajorInterleaved<32>,
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
ElementC, cutlass::layout::RowMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
}
template<
typename ElementA, typename ElementB, typename ElementC,
typename AccumulatorType, typename ComputeType>
void bind_host_gemm_multiply_add_saturate_interleaved(py::module &m) {
bind_host_gemm_saturate<
ElementA, cutlass::layout::RowMajorInterleaved<32>,
ElementB, cutlass::layout::RowMajorInterleaved<32>,
ElementC, cutlass::layout::RowMajorInterleaved<32>,
ComputeType, AccumulatorType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
ElementB, cutlass::layout::RowMajorInterleaved<32>,
ElementC, cutlass::layout::RowMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::RowMajorInterleaved<32>,
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
ElementC, cutlass::layout::RowMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::RowMajorInterleaved<32>,
ElementB, cutlass::layout::RowMajorInterleaved<32>,
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::RowMajorInterleaved<32>,
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
ElementB, cutlass::layout::RowMajorInterleaved<32>,
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
ElementC, cutlass::layout::RowMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
}
#define BIND_TENSOR_EQUAL(Element, Layout) { \
m.def("equals", py::overload_cast< \
const cutlass::TensorView<Element, Layout>&, const cutlass::TensorView<Element, Layout>&>( \
&cutlass::reference::host::TensorEquals<Element, Layout>)); \
}
void bind_gemm_host_reference(py::module &m) {
/// double
bind_host_gemm_multiply_add<double, double, double, double, double>(m);
/// float
bind_host_gemm_multiply_add<float, float, float, float, float>(m);
/// half_t
bind_host_gemm_multiply_add<cutlass::half_t, cutlass::half_t, cutlass::half_t, cutlass::half_t, cutlass::half_t>(m);
bind_host_gemm_multiply_add<cutlass::half_t, cutlass::half_t, cutlass::half_t, float, float>(m);
bind_host_gemm_multiply_add<cutlass::half_t, cutlass::half_t, float, cutlass::half_t, cutlass::half_t>(m);
bind_host_gemm_multiply_add<cutlass::half_t, cutlass::half_t, float, float, float>(m);
/// bfloat16
bind_host_gemm_multiply_add<cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass::bfloat16_t, float, float>(m);
bind_host_gemm_multiply_add<cutlass::bfloat16_t, cutlass::bfloat16_t, float, float, float>(m);
/// s8
bind_host_gemm_multiply_add<int8_t, int8_t, int8_t, int32_t, int32_t>(m);
bind_host_gemm_multiply_add<int8_t, int8_t, int8_t, int32_t, int8_t>(m);
bind_host_gemm_multiply_add<int8_t, int8_t, int32_t, int32_t, int32_t>(m);
bind_host_gemm_multiply_add<int8_t, int8_t, int32_t, int32_t, int8_t>(m);
bind_host_gemm_multiply_add<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_host_gemm_multiply_add<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_host_gemm_multiply_add<int8_t, int8_t, int32_t, int32_t, float>(m);
bind_host_gemm_multiply_add<int8_t, int8_t, int32_t, int32_t, float>(m);
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int8_t, int32_t, int32_t>(m);
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int8_t, int32_t, int8_t>(m);
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int32_t, int32_t, int32_t>(m);
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int32_t, int32_t, int8_t>(m);
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int32_t, int32_t, float>(m);
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int32_t, int32_t, float>(m);
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int8_t, int32_t, int32_t>(m);
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int8_t, int32_t, int8_t>(m);
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int32_t, int32_t, int32_t>(m);
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int32_t, int32_t, int8_t>(m);
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int32_t, int32_t, float>(m);
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int32_t, int32_t, float>(m);
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int8_t, int32_t, int32_t>(m);
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int8_t, int32_t, int8_t>(m);
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int32_t, int32_t, int32_t>(m);
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int32_t, int32_t, int8_t>(m);
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int32_t, int32_t, float>(m);
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int32_t, int32_t, float>(m);
// float
BIND_TENSOR_EQUAL(float, cutlass::layout::RowMajor);
BIND_TENSOR_EQUAL(float, cutlass::layout::ColumnMajor);
// double
BIND_TENSOR_EQUAL(double, cutlass::layout::RowMajor);
BIND_TENSOR_EQUAL(double, cutlass::layout::ColumnMajor);
// half_t
BIND_TENSOR_EQUAL(cutlass::half_t, cutlass::layout::RowMajor);
BIND_TENSOR_EQUAL(cutlass::half_t, cutlass::layout::ColumnMajor);
// bfloat16
BIND_TENSOR_EQUAL(cutlass::bfloat16_t, cutlass::layout::RowMajor);
BIND_TENSOR_EQUAL(cutlass::bfloat16_t, cutlass::layout::ColumnMajor);
// int32_t
BIND_TENSOR_EQUAL(int32_t, cutlass::layout::RowMajor);
BIND_TENSOR_EQUAL(int32_t, cutlass::layout::ColumnMajor);
// int8_t
BIND_TENSOR_EQUAL(int8_t, cutlass::layout::RowMajor);
BIND_TENSOR_EQUAL(int8_t, cutlass::layout::ColumnMajor);
BIND_TENSOR_EQUAL(int8_t, cutlass::layout::RowMajorInterleaved<32>);
BIND_TENSOR_EQUAL(int8_t, cutlass::layout::ColumnMajorInterleaved<32>);
}

View File

@ -0,0 +1,33 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass.emit.pytorch import pytorch

View File

@ -0,0 +1,182 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Common utilities for emitting CUTLASS kernels
"""
import cutlass
# Strings used for printing information about the generation of emitted scripts
_AUTOGEN_STR = f"This file was automatically generated by the CUTLASS {cutlass.__version__} Python interface (https://github.com/nvidia/cutlass/python)"
_CSTYLE_AUTOGEN_COMMENT = f"""// {_AUTOGEN_STR}
"""
_PYSTYLE_AUTOGEN_COMMENT = f"""# {_AUTOGEN_STR}
"""
_CUTLASS_KERNEL_ARGS_2x = """
typename DeviceKernel::Arguments arguments {
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K}, // problem size
1,
{alpha, beta},
A, B, C, D,
0, 0, 0, 0, // batch strides
DeviceKernel::LayoutA::packed({M, K}).stride(0), // lda
DeviceKernel::LayoutB::packed({K, N}).stride(0), // ldb
DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldc
DeviceKernel::LayoutC::packed({M, N}).stride(0) // ldd
};
"""
_CUTLASS_KERNEL_ARGS_2x_STREAM_K = """
typename DeviceKernel::Arguments arguments {
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K}, // problem size
1,
{alpha, beta},
A, B, C, D,
0, 0, 0, 0, // batch strides
DeviceKernel::LayoutA::packed({M, K}).stride(0), // lda
DeviceKernel::LayoutB::packed({K, N}).stride(0), // ldb
DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldc
DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldd
-1 // avail_sms
};
"""
_CUTLASS_KERNEL_RUN_GEMM_2x = """
using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;
cutlass::Status ${name}_kernel_run(int M, int N, int K,
const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C, DeviceKernel::ElementC* D,
ElementCompute alpha, ElementCompute beta) {
${args}
size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
DeviceKernel gemm_op;
cutlass::Status status = gemm_op.initialize(arguments,
workspace.get(),
nullptr); // CUDA stream
if (status != cutlass::Status::kSuccess) {
return status;
}
status = gemm_op();
return status;
}
"""
_CUTLASS_KERNEL_RUN_GEMM_3x = """
using StrideA = typename DeviceKernel::GemmKernel::StrideA;
using StrideB = typename DeviceKernel::GemmKernel::StrideB;
using StrideC = typename DeviceKernel::GemmKernel::StrideC;
using StrideD = typename DeviceKernel::GemmKernel::StrideD;
using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;
cutlass::Status ${name}_kernel_run(
int M, int N, int K, int L,
const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C, DeviceKernel::ElementC* D,
ElementCompute alpha, ElementCompute beta, const cutlass::KernelHardwareInfo& hw_info) {
typename DeviceKernel::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K, L}, // problem size
A, // ptrA
make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)), // stride A
B, // ptrB
make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)), // stride B
{
C, // ptrC
make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)), // stride C
D, // ptrD
make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)), // stride D
{alpha, beta},
},
hw_info
};
size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
DeviceKernel gemm_op;
cutlass::Status status = gemm_op.run(arguments,
workspace.get(),
nullptr); // CUDA stream
return status;
}
"""
_CUTLASS_KERNEL_RUN_GROUPED_GEMM_2x = """
using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;
int threadblock_count = DeviceKernel::sufficient();
cutlass::Status ${name}_kernel_run(int problem_count, cutlass::gemm::GemmCoord* problem_sizes,
DeviceKernel::ElementA** A, DeviceKernel::ElementB** B, DeviceKernel::ElementC** C, DeviceKernel::ElementC** D,
int64_t* lda, int64_t* ldb, int64_t* ldc, int64_t* ldd,
ElementCompute alpha, ElementCompute beta) {
typename DeviceKernel::Arguments arguments {
problem_sizes,
problem_count,
threadblock_count,
{alpha, beta},
A, B, C, D,
lda, ldb, ldc, ldd
};
size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
DeviceKernel gemm_op;
cutlass::Status status = gemm_op.initialize(arguments,
workspace.get(),
nullptr); // CUDA stream
if (status != cutlass::Status::kSuccess) {
return status;
}
status = gemm_op();
return status;
}
"""

View File

@ -0,0 +1,639 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Utilities for generating source for building a PyTorch CUDA extension that using a CUTLASS kernel.
If specified, the extension can be JIT compiled via PyTorch's ``cpp_extension.load`` method.
Example usage with JIT compilation:
.. highlight:: python
.. code-block:: python
plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass.LayoutType.RowMajor)
op = plan.construct()
mod = cutlass.emit.pytorch(op, 'cutlass_gemm', 80, jit=True)
# Generate inputs for the GEMM
A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)]
# Run the module
D = mod.run(A, B, C)
Example usage without JIT compilation:
.. highlight:: python
.. code-block:: python
plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass.LayoutType.RowMajor)
op = plan.construct()
cutlass.emit.pytorch(op, 'cutlass_gemm', 80, jit=False, sourcedir='output')
After this call, the directory ``output`` contains ``setup.py``,
``cutlass_gemm.cpp``, and ``cutlass_gemm_kernel.cu``. The module can be built from
within ``output`` by running: ``TORCH_CUDA_ARCH_LIST="8.0" python setup.py develop --user``.
The module can later be used in Python via:
.. highlight:: python
.. code-block:: python
import torch
import cutlass_gemm
# Generate inputs for the GEMM
A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)]
# Run the module
D = cutlass_gemm.run(A, B, C)
"""
import logging
import os
import cutlass_bindings
from cutlass import CUTLASS_PATH, logger, swizzle
from cutlass.backend.gemm_operation import GemmOperationGrouped, GemmOperationUniversal
from cutlass.backend.library import ApiVersion
from cutlass.backend.utils.software import CheckPackages, SubstituteTemplate
from cutlass.emit import common
torch_available = CheckPackages().check_torch()
if torch_available:
import torch
_PYTORCH_CUDA_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
#include <torch/extension.h>
#include <ATen/ATen.h>
#include "cutlass/cutlass.h"
#include "cutlass/util/device_memory.h"
${includes}
${declaration}
${impl}
"""
_PYTORCH_GEMM_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <pybind11/stl.h>
// CUDA forward declarations
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt, float alpha=1.f, float beta=0.f);
// C++ interface
at::Tensor ${name}(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt, float alpha=1.f, float beta=0.f) {
return ${name}_kernel(A, B, C, alpha, beta);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("run", py::overload_cast<const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>, float, float>(&${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f);
}
"""
_PYTORCH_GROUPED_GEMM_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <pybind11/stl.h>
// CUDA forward declarations
std::vector<at::Tensor> ${name}_kernel(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C=at::nullopt, float alpha=1.f, float beta=0.f);
// C++ interface
std::vector<at::Tensor> ${name}(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C=at::nullopt, float alpha=1.f, float beta=0.f) {
return ${name}_kernel(A, B, C, alpha, beta);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("run", py::overload_cast<const std::vector<at::Tensor>&, const std::vector<at::Tensor>&, at::optional<const std::vector<at::Tensor>>, float, float>(&${name}),
py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f);
}
"""
_PYTORCH_GEMM_INCLUDES = {
ApiVersion.v2x: """
#include "cutlass/gemm/device/gemm_universal.h"
""",
ApiVersion.v3x: """
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/util/packed_stride.hpp"
""",
}
_PYTORCH_GROUPED_GEMM_INCLUDES = """
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/gemm/device/gemm_grouped.h"
"""
_CUTLASS_TYPE_TO_TORCH_TYPE = {
cutlass_bindings.float16: "torch::kF16",
cutlass_bindings.float32: "torch::kF32",
cutlass_bindings.float64: "torch::kF64",
cutlass_bindings.int8: "torch::I8",
cutlass_bindings.int32: "torch::I32",
}
_PYTORCH_GEMM_IMPL_TEMPLATE_2x = (
common._CUTLASS_KERNEL_RUN_GEMM_2x
+ """
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C, float alpha, float beta) {
int M = A.size(0);
int N = B.size(1);
int K = A.size(1);
typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ?
nullptr :
reinterpret_cast<typename DeviceKernel::ElementC*>(C->contiguous().data_ptr());
at::Tensor D = B.new_empty({M, N}, ${torch_type_C});
cutlass::Status status = ${name}_kernel_run(M, N, K,
reinterpret_cast<typename DeviceKernel::ElementA*>(A.contiguous().data_ptr()),
reinterpret_cast<typename DeviceKernel::ElementB*>(B.contiguous().data_ptr()),
ptrC,
reinterpret_cast<typename DeviceKernel::ElementC*>(D.contiguous().data_ptr()),
ElementCompute(alpha), ElementCompute(beta));
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
return D;
}
"""
)
_PYTORCH_GEMM_IMPL_TEMPLATE_3x = (
common._CUTLASS_KERNEL_RUN_GEMM_3x
+ """
bool hw_info_queried = false;
cutlass::KernelHardwareInfo hw_info;
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C, float alpha, float beta) {
int M = A.size(0);
int N = B.size(1);
int K = A.size(1);
int L = 1;
// Query hardware info if we haven't already
if (!hw_info_queried) {
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
}
typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ?
nullptr :
reinterpret_cast<typename DeviceKernel::ElementC*>(C->contiguous().data_ptr());
at::Tensor D = B.new_empty({M, N}, ${torch_type_C});
cutlass::Status status = ${name}_kernel_run(M, N, K, L,
reinterpret_cast<typename DeviceKernel::ElementA*>(A.contiguous().data_ptr()),
reinterpret_cast<typename DeviceKernel::ElementB*>(B.contiguous().data_ptr()),
ptrC,
reinterpret_cast<typename DeviceKernel::ElementC*>(D.contiguous().data_ptr()),
ElementCompute(alpha), ElementCompute(beta),
hw_info);
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
return D;
}
"""
)
_PYTORCH_GROUPED_GEMM_IMPL_TEMPLATE = (
common._CUTLASS_KERNEL_RUN_GROUPED_GEMM_2x
+ """
std::vector<at::Tensor> ${name}_kernel(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C, float alpha, float beta) {
size_t num = A.size();
// To avoid performing many small cudaMallocs and host-to-device copies,
// we serialize the grouped GEMM arguments on the host, allocate one
// large chunk of device memory, and perform a single cudaMemcpy to
// copy the host data to the device. Allocation overheads could be
// avoided by using a memory pool.
// Calculate the total size of the data to be copied from host to device
size_t total_size = sizeof(cutlass::gemm::GemmCoord) +
sizeof(DeviceKernel::ElementA*) +
sizeof(DeviceKernel::ElementB*) +
sizeof(DeviceKernel::ElementC*) +
sizeof(DeviceKernel::ElementC*) +
sizeof(int64_t) +
sizeof(int64_t) +
sizeof(int64_t);
total_size *= num;
// num * sizeof(cutlass::gemm::GemmCoord) may leave one at a non-multiple
// of sizeof(DeviceKernel::ElementA*) (which will be 64 on a 64-bit system).
// To ensure that we don't end up having misaligned loads in the kernel,
// we pad to the nearest multiple of 8.
//
// Note that, even on a 32-bit system (for which sizeof(X*) will not equal
// sizeof(int64_t)), only padding between the list of GemmCoords and the
// list of ptr_As is sufficient because the set of four equal-length lists of pointers
// (A*, B*, C*, D*) will ensure that the first list of int64_ts will always
// start on a multiple of 8.
int64_t padding = 8 - (total_size % 8);
total_size += padding;
uint8_t* host_data = new uint8_t[total_size];
cutlass::DeviceAllocation<uint8_t> device_data(total_size);
uint8_t* start = host_data;
cutlass::gemm::GemmCoord* problem_sizes_host = reinterpret_cast<cutlass::gemm::GemmCoord*>(start);
// Apply the padding after the list of GemmCoords
start += num * sizeof(cutlass::gemm::GemmCoord) + padding;
int64_t ptr_A_offset = start - host_data;
DeviceKernel::ElementA** ptr_A_host = reinterpret_cast<DeviceKernel::ElementA**>(start);
start += num * sizeof(DeviceKernel::ElementA*);
int64_t ptr_B_offset = start - host_data;
DeviceKernel::ElementB** ptr_B_host = reinterpret_cast<DeviceKernel::ElementB**>(start);
start += num * sizeof(DeviceKernel::ElementB*);
int64_t ptr_C_offset = start - host_data;
DeviceKernel::ElementC** ptr_C_host = reinterpret_cast<DeviceKernel::ElementC**>(start);
start += num * sizeof(DeviceKernel::ElementC*);
int64_t ptr_D_offset = start - host_data;
DeviceKernel::ElementC** ptr_D_host = reinterpret_cast<DeviceKernel::ElementC**>(start);
start += num * sizeof(DeviceKernel::ElementC*);
int64_t lda_offset = start - host_data;
int64_t* lda_host = reinterpret_cast<int64_t*>(start);
start += num * sizeof(int64_t);
int64_t ldb_offset = start - host_data;
int64_t* ldb_host = reinterpret_cast<int64_t*>(start);
start += num * sizeof(int64_t);
int64_t ldc_offset = start - host_data;
int64_t* ldc_host = reinterpret_cast<int64_t*>(start);
start += num * sizeof(int64_t);
std::vector<at::Tensor> D(num);
bool need_C = (C != at::nullopt) && (beta != 0.f);
for (size_t i = 0; i < num; ++i) {
int M = A[i].size(0);
int N = B[i].size(1);
int K = A[i].size(1);
*(problem_sizes_host + i) = {M, N, K};
*(ptr_A_host + i) = reinterpret_cast<typename DeviceKernel::ElementA*>(A[i].contiguous().data_ptr());
*(ptr_B_host + i) = reinterpret_cast<typename DeviceKernel::ElementB*>(B[i].contiguous().data_ptr());
if (need_C) {
*(ptr_C_host + i) = reinterpret_cast<typename DeviceKernel::ElementC*>(C->at(i).contiguous().data_ptr());
}
else {
*(ptr_C_host + i) = nullptr;
}
D[i] = B[i].new_empty({M, N}, ${torch_type_C});
*(ptr_D_host + i) = reinterpret_cast<typename DeviceKernel::ElementC*>(D[i].contiguous().data_ptr());
*(lda_host + i) = DeviceKernel::LayoutA::packed({M, K}).stride(0);
*(ldb_host + i) = DeviceKernel::LayoutB::packed({K, N}).stride(0);
*(ldc_host + i) = DeviceKernel::LayoutC::packed({M, N}).stride(0);
}
device_data.copy_from_host(host_data);
cutlass::Status status = ${name}_kernel_run(
num,
reinterpret_cast<cutlass::gemm::GemmCoord*>(device_data.get()),
reinterpret_cast<DeviceKernel::ElementA**>(device_data.get() + ptr_A_offset),
reinterpret_cast<DeviceKernel::ElementB**>(device_data.get() + ptr_B_offset),
reinterpret_cast<DeviceKernel::ElementC**>(device_data.get() + ptr_C_offset),
reinterpret_cast<DeviceKernel::ElementC**>(device_data.get() + ptr_D_offset),
reinterpret_cast<int64_t*>(device_data.get() + lda_offset),
reinterpret_cast<int64_t*>(device_data.get() + ldb_offset),
reinterpret_cast<int64_t*>(device_data.get() + ldc_offset),
reinterpret_cast<int64_t*>(device_data.get() + ldc_offset),
ElementCompute(alpha), ElementCompute(beta));
delete[] host_data;
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
return D;
}
"""
)
_PYTORCH_SETUP_PY = common._PYSTYLE_AUTOGEN_COMMENT + """
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='${name}',
ext_modules=[
CUDAExtension('${name}', [
'${name}.cpp',
'${name}_kernel.cu',
],
include_dirs=['${cutlass_path}/include', '${cutlass_path}/tools/util/include'],
extra_compile_args=['-std=c++17']
),
],
cmdclass={
'build_ext': BuildExtension
})
"""
def _generate_setup(name: str, sourcedir: str):
"""
Generates a setup.py file for the extension
:param name: name of the module to generate
:type name: str
:param sourcedir: directory to which generated source files should be written
:type sourcedir: str
"""
setup_py_file = os.path.join(sourcedir, "setup.py")
setup_source = SubstituteTemplate(
_PYTORCH_SETUP_PY, {"name": name, "cutlass_path": CUTLASS_PATH}
)
with open(setup_py_file, "w") as outfile:
outfile.write(setup_source)
class _ArchListSetter:
"""
Utility context manager for temporarily setting the value of the ``TORCH_CUDA_ARCH_LIST``
environment variable when building a PyTorch CUDA module.
``TORCH_CUDA_ARCH_LIST`` is a space-delmited list of compute capabilites for which a PyTorch
CUDA module should be compiled.
For example, ``TORCH_CUDA_ARCH_LIST="7.0 8.0"`` would result in the inclusion of
``-gencode=arch=compute_70,code=sm_70`` and ``-gencode=arch=compute_80,code=sm_80`` in the
compilation of the module.
This utility wraps the building of a PyTorch CUDA module with a setting of this environment
variable according to the current compute capability being targetted.
Example usage:
.. highlight:: python
.. code-block:: python
# Temporarily set TORCH_CUDA_ARCH_LIST="8.0"
with _ArchListSetter(80):
# Perform JIT compilation and loading of the module
mod = torch.utils.cpp_extension.load(...)
:param cc: compute capability
:type cc: int
"""
_TORCH_CUDA_ARCH_LIST = "TORCH_CUDA_ARCH_LIST"
def __init__(self, cc: int):
self.cc_str = ".".join(list(str(cc)))
def __enter__(self):
"""
Saves the old value of TORCH_CUDA_ARCH_LIST and reset it to the new value based on ``cc``
"""
self.old_arch_list = os.getenv(_ArchListSetter._TORCH_CUDA_ARCH_LIST)
os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] = self.cc_str
return self
def __exit__(self, exc_type, exc_val, traceback):
"""
Restores the old value of TORCH_CUDA_ARCH_LIST
"""
os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] = self.old_arch_list
def _jit(name: str, cc: int, cpp_file: str, cuda_file: str):
"""
JIT compiles and loads a PyTorch CUDA extension.
:param name: name of the module to generate
:type name: str
:param cc: compute capability of the device the module should target
:type cc: int
:param cpp_file: path to file containing extension's C++ interface
:type cpp_file: str
:param cuda_file: path to file containing extension's CUDA interface
:type cuda_file: str
:return: loaded PyTorch module
"""
from torch.utils.cpp_extension import load
extra_cuda_cflags = ["-std=c++17"]
if cc == 90:
# PyTorch does not currently add the sm_90a target when compute capability
# 9.0 is set within TORCH_CUDA_ARCH_LIST. Thus, we manually add the sm_90a target.
extra_cuda_cflags.append("-gencode=arch=compute_90a,code=sm_90a")
with _ArchListSetter(cc):
jitmodule = load(
name,
[cpp_file, cuda_file],
extra_cuda_cflags=extra_cuda_cflags,
extra_include_paths=[
os.path.join(CUTLASS_PATH, "include"),
os.path.join(CUTLASS_PATH, "tools/util/include"),
],
verbose=(logger.level == logging.DEBUG)
)
return jitmodule
def _pytorch_gemm(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
"""
Generates source for building a PyTorch CUDA module that leverages the CUTLASS GEMM
specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
compiled, loaded, and returned.
:param op: operation to emit in the module
:param name: name of the module to generate
:type name: str
:param cc: compute capability of the device the module should target
:type cc: int
:param jit: whether the module should be just-in-time compiled
:type jit: bool
:param sourcedir: directory to which generated source files should be written
:type sourcedir: str
:return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise
"""
if sourcedir != "" and not os.path.isdir(sourcedir):
os.makedirs(sourcedir)
cuda_file = os.path.join(sourcedir, name + "_kernel.cu")
extra_kw = {}
if op.api == ApiVersion.v3x:
impl_template = _PYTORCH_GEMM_IMPL_TEMPLATE_3x
else:
impl_template = _PYTORCH_GEMM_IMPL_TEMPLATE_2x
if isinstance(op.swizzling_functor, swizzle.ThreadblockSwizzleStreamK):
extra_kw["args"] = common._CUTLASS_KERNEL_ARGS_2x_STREAM_K
else:
extra_kw["args"] = common._CUTLASS_KERNEL_ARGS_2x
impl_template = (
_PYTORCH_GEMM_IMPL_TEMPLATE_3x
if op.api == ApiVersion.v3x
else _PYTORCH_GEMM_IMPL_TEMPLATE_2x
)
cuda_impl = SubstituteTemplate(impl_template, {"name": name, **extra_kw})
cuda_source = SubstituteTemplate(
_PYTORCH_CUDA_TEMPLATE,
{
"includes": _PYTORCH_GEMM_INCLUDES[op.api],
"declaration": op.rt_module.emit(),
"procedural_name": op.procedural_name(),
"impl": cuda_impl,
"torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element],
},
)
with open(cuda_file, "w") as outfile:
outfile.write(cuda_source)
cpp_file = os.path.join(sourcedir, name + ".cpp")
cpp_source = SubstituteTemplate(
_PYTORCH_GEMM_CPP_TEMPLATE,
{"name": name, "description": f"CUTLASS {op.procedural_name()} GEMM"},
)
with open(cpp_file, "w") as outfile:
outfile.write(cpp_source)
_generate_setup(name, sourcedir)
if jit:
return _jit(name, cc, cpp_file, cuda_file)
return None
def _pytorch_grouped_gemm(
op, name: str, cc: int, jit: bool = False, sourcedir: str = ""
):
"""
Generates source for building a PyTorch CUDA module that leverages the CUTLASS grouped GEMM
specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
compiled, loaded, and returned.
:param op: operation to emit in the module
:param name: name of the module to generate
:type name: str
:param cc: compute capability of the device the module should target
:type cc: int
:param jit: whether the module should be just-in-time compiled
:type jit: bool
:param sourcedir: directory to which generated source files should be written
:type sourcedir: str
:return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise
"""
if op.api != ApiVersion.v2x:
raise Exception("Grouped GEMM is currently only supported for CUTLASS 2.x")
if sourcedir != "" and not os.path.isdir(sourcedir):
os.makedirs(sourcedir)
cuda_file = os.path.join(sourcedir, name + "_kernel.cu")
cuda_impl = SubstituteTemplate(_PYTORCH_GROUPED_GEMM_IMPL_TEMPLATE, {"name": name})
cuda_source = SubstituteTemplate(
_PYTORCH_CUDA_TEMPLATE,
{
"includes": _PYTORCH_GROUPED_GEMM_INCLUDES,
"declaration": op.rt_module.emit(),
"procedural_name": op.procedural_name(),
"impl": cuda_impl,
"torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element],
},
)
with open(cuda_file, "w") as outfile:
outfile.write(cuda_source)
cpp_file = os.path.join(sourcedir, name + ".cpp")
cpp_source = SubstituteTemplate(
_PYTORCH_GROUPED_GEMM_CPP_TEMPLATE,
{"name": name, "description": f"CUTLASS {op.procedural_name()} grouped GEMM"},
)
with open(cpp_file, "w") as outfile:
outfile.write(cpp_source)
_generate_setup(name, sourcedir)
if jit:
return _jit(name, cc, cpp_file, cuda_file)
return None
def pytorch(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
"""
Generates source for building a PyTorch CUDA module that leverages the CUTLASS kernel
specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
compiled, loaded, and returned.
The result of this method is files within ``sourcedir`` that can be used for building
a PyTorch module.
:param op: operation to emit in the module
:param name: name of the module to generate
:type name: str
:param cc: compute capability of the device the module should target
:type cc: int
:param jit: whether the module should be just-in-time compiled
:type jit: bool
:param sourcedir: directory to which generated source files should be written
:type sourcedir: str
:return: loaded PyTorch module (if ``jit=True``) or None
"""
device_op = op.device_op()
if isinstance(op, GemmOperationUniversal):
return _pytorch_gemm(device_op, name, cc, jit, sourcedir)
elif isinstance(op, GemmOperationGrouped):
return _pytorch_grouped_gemm(device_op, name, cc, jit, sourcedir)
else:
raise Exception(
f"Operation type {type(op)} is not currently supported for PyTorch emission."
)

107
python/cutlass/epilogue.py Normal file
View File

@ -0,0 +1,107 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Registry of elementwise epilogues
Elementwise epilogues can be added to many CUTLASS kernels in the CUTLAS Python interface via
code like the following for GEMM:
.. highlight:: python
.. code-block:: python
plan = cutlass.op.Gemm(element=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor)
plan.activation = cutlass.epilogue.relu
"""
from cutlass.backend import epilogue
gelu = epilogue.gelu
hardswish = epilogue.hardswish
identity = epilogue.identity
leaky_relu = epilogue.leaky_relu
relu = epilogue.relu
sigmoid = epilogue.sigmoid
silu = epilogue.silu
tanh = epilogue.tanh
_activations = [gelu, hardswish, identity, leaky_relu, relu, sigmoid, silu, tanh]
def get_activations() -> list:
"""
Returns a list of available activation functions
:return: list of available activation functions
:rtype: list
"""
return _activations
def get_activation_epilogue(
activation,
element_output,
elements_per_access,
element_accumulator,
element_compute,
):
"""
Return an epilogue corresponding to the activation function, data types, and alignment
used in the kernel
:param activation: elementwise activation function to use
:param element_output: data type of the output
:param elements_per_access: alignment of operand C of the kernel
:type elements_per_access: int
:param element_accumulator: data type of the accumulated output C
:param element_compute: data type in which compute operations should be performed
:return: epilogue functor
"""
if activation not in _activations:
raise Exception(
f"Unsupported activation type {activation}. Available activations are: {_activations}"
)
if activation == identity:
return epilogue.LinearCombination(
element_output, elements_per_access, element_accumulator, element_compute
)
else:
return epilogue.LinearCombinationGeneric(
activation(element_compute),
element_output,
elements_per_access,
element_accumulator,
element_compute,
)

View File

@ -0,0 +1,445 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Classes containing valid operations for a given compute capability and data types.
"""
import logging
from cuda import __version__
# Strip any additional information from the CUDA version
_cuda_version = __version__.split("rc")[0]
# Imports from CUTLASS profiler generator and manifest scripts
import generator as prof_generator
import manifest as prof_manifest
import cutlass
from cutlass.utils.check import valid_stage_count
from cutlass.utils.datatypes import td_from_profiler_td, td_from_profiler_op, has_binding_type
_generator_ccs = [50, 60, 61, 70, 75, 80, 90]
class KernelsForDataType:
"""
Container class for keeping track of kernels that correspond to a particular combination
of data types for operands A, B, and accumulator
"""
def __init__(self, datatype_comb: tuple, layout_comb: tuple):
self.datatype_comb = datatype_comb
self.layout_comb = layout_comb
# Dictionary mapping from alignment (int) to a list of kernels that fit the alignment
# constraint for the data type combination
self.kernels_by_alignment = {}
def add(self, operation):
"""
Add an operation to the list of supported kernels
"""
alignment = operation.A.alignment
if alignment not in self.kernels_by_alignment:
self.kernels_by_alignment[alignment] = []
self.kernels_by_alignment[alignment].append(operation)
@property
def alignments(self):
"""
Returns an unsorted list of alignments supported by this data type combination
:return: unsorted list of alignments supported by this data type combination
:rtype: list
"""
return list(self.kernels_by_alignment.keys())
@property
def all_operations(self):
"""
Returns a list of all operations supported by this data type combination
:return: list of all operations supported by this data type combination
:rtype: list
"""
ops = []
for _, alignment_ops in self.kernels_by_alignment.items():
ops.extend(alignment_ops)
return ops
def operations(self, alignment: int):
"""
Returns operations satisfying the alignment constraint indicated by `alignment`
:param alignment: alignment constraint of operations to return
:type alignment: int
:return: list of operations
:rtype: list
"""
if alignment not in self.kernels_by_alignment:
raise Exception(
f"No operations of alignment {alignment} found for data type and layout "
f"combination {self.datatype_comb} {self.layout_comb}"
)
return self.kernels_by_alignment[alignment]
def find_alignment(self, shape: tuple, layout: cutlass.LayoutType) -> int:
"""
Returns the most preferable alignment for a given shape and layout
:param shape: extent of each dimension of the tensor
:type shape: tuple
:param layout: layout of the tensor
:type layout: cutlass.LayoutType
:return: maximum alignment supported by the data type combination and tensor size
:rtype: int
"""
# Determine the leading dimension of the shape
if layout == cutlass.LayoutType.RowMajor:
ld = shape[0]
elif layout == cutlass.LayoutType.RowMajor:
ld = shape[1]
else:
raise Exception(f"Unexpected or unsupported layout {layout}")
for alignment in sorted(list(self.kernels_by_alignment.keys()), reverse=True):
if ld % alignment == 0:
return alignment
# Default to alignment of 1 if no others match
return 1
def sort(self):
"""
Sorts each list of kernels in `kernels_by_alignment` in descending order of threadblock shape
"""
key = lambda op: (
op.tile_description.threadblock_shape[0]
* op.tile_description.threadblock_shape[1]
* op.tile_description.threadblock_shape[2]
)
for alignment in self.kernels_by_alignment.keys():
self.kernels_by_alignment[alignment].sort(key=key, reverse=True)
class ArchOptions:
"""
Structure for keeping track of kernels available on a given compute capability
:param target_cc: compute capability of the device on which kernels will be run
:type target_cc: int
:param kernel_cc: compute capability of the kernels to generate
:type kernel_cc: int
:param operation_kind: type of operation to register
:type operation_kind: cutlass.OperationKind
:param gemm_kinds: types of GEMM operations that can be included
:type gemm_kinds: list
:param allowed_math_operations: types of primitive math operations allowed
:type allowed_math_operations: list
"""
def __init__(
self,
target_cc: int,
kernel_cc: int,
operation_kind: cutlass.OperationKind,
gemm_kinds: list,
allowed_math_operations: list = [
cutlass.MathOperation.multiply_add,
cutlass.MathOperation.multiply_add_saturate,
]
):
self.cc = kernel_cc
# Dictionary with following structure:
# Key: OpcodeClass
# Value: Dictionary with the following structure:
# Key: tuple of ((DataType, DataType, DataType), (LayoutType, LayoutType, LayoutType),
# representing ((element_a, element_b, element_accumulator), (layout_a, layout_b))
# Value: KernelsForDataType
self.operations_by_opclass = {}
self.op_class = None
self.allowed_math_operations = allowed_math_operations
# Identify the method within CUTLASS generator script that generates kernel
# descriptions for the target CC
generate_function_name = "GenerateSM" + str(kernel_cc)
if not hasattr(prof_generator, generate_function_name):
cutlass.logger.warning(f"No generator found for architecture {kernel_cc}")
return
generate_function = getattr(prof_generator, generate_function_name)
# Initialize a default manifest and populate it with valid kernel descriptions
# for the target CC
args = [
"--kernels=all",
f"--log-level={logging.getLevelName(cutlass.logger.level)}"
]
manifest_args = prof_generator.define_parser().parse_args(args)
manifest = prof_manifest.Manifest(manifest_args)
generate_function(manifest, _cuda_version)
if operation_kind not in manifest.operations:
# No kernels generated for this architecture, this could be because the CUDA
# toolkit is insufficient to support operations in this CC
cutlass.logger.warning(f"No operations of type {operation_kind} found for CC {kernel_cc}")
return
# Iterate through the available operations for this operation kind and
# find available opclasses and data types
for name, op_list in manifest.operations[operation_kind].items():
for op in op_list:
if op.gemm_kind not in gemm_kinds:
continue
mi = op.tile_description.math_instruction
if mi.math_operation not in self.allowed_math_operations:
continue
datatype_comb = (mi.element_a, mi.element_b, mi.element_accumulator)
# Skip any data types that do not currently have conversions via cutlass_bindings
if False in [has_binding_type(elt) for elt in datatype_comb]:
continue
# Prune operations that don't fit in shared memory
td = td_from_profiler_op(op)
if not valid_stage_count(target_cc, td)[0]:
continue
if mi.opcode_class not in self.operations_by_opclass:
self.operations_by_opclass[mi.opcode_class] = {}
datatype_comb = (mi.element_a, mi.element_b, mi.element_accumulator)
layout_comb = (op.A.layout, op.B.layout)
# Register TF32 kernels as F32 to enable F32 -> TF32 conversion + TF32 Tensor Core operations
if datatype_comb == (cutlass.DataType.tf32, cutlass.DataType.tf32, cutlass.DataType.f32):
# TF32 kernels only supported on SM80 and beyond
if self.cc < 80:
continue
elif self.cc == 90:
if (op.A.element != cutlass.DataType.f32
or op.B.element != cutlass.DataType.f32
or op.C.element != cutlass.DataType.f32):
continue
datatype_comb = (cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f32)
opclass_dict = self.operations_by_opclass[mi.opcode_class]
key = (datatype_comb, layout_comb)
if key not in opclass_dict:
opclass_dict[key] = KernelsForDataType(datatype_comb, layout_comb)
opclass_dict[key].add(op)
# Set the default opclass to TensorOp, if available. Otherwise default to SIMT
if cutlass.OpcodeClass.TensorOp in self.operations_by_opclass:
self.op_class = cutlass.OpcodeClass.TensorOp
else:
self.op_class = cutlass.OpcodeClass.Simt
# The profiler's generator may generate only a limited set of combinations of operands for SIMT kernels.
# Here, we generate additional versions via a generic TileDescription.
if cutlass.OpcodeClass.Simt not in self.operations_by_opclass:
self.operations_by_opclass[cutlass.OpcodeClass.Simt] = {}
types = [
(cutlass.DataType.s8, cutlass.DataType.s8, cutlass.DataType.s8),
(cutlass.DataType.s8, cutlass.DataType.s8, cutlass.DataType.s32),
(cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f16),
(cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f32),
(cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f32),
(cutlass.DataType.f64, cutlass.DataType.f64, cutlass.DataType.f64),
]
layouts = [
(cutlass.LayoutType.RowMajor, cutlass.LayoutType.RowMajor),
(cutlass.LayoutType.RowMajor, cutlass.LayoutType.ColumnMajor),
(cutlass.LayoutType.ColumnMajor, cutlass.LayoutType.RowMajor),
(cutlass.LayoutType.ColumnMajor, cutlass.LayoutType.ColumnMajor),
]
alignment = 1
epilogue_functor = cutlass.EpilogueFunctor.LinearCombination
swizzling_functor = cutlass.SwizzlingFunctor.Identity8
for type_comb in types:
for layout_comb in layouts:
comb = (type_comb, layout_comb)
if comb in self.operations_by_opclass[cutlass.OpcodeClass.Simt]:
continue
A = cutlass.TensorDescription(type_comb[0], layout_comb[0], alignment)
B = cutlass.TensorDescription(type_comb[1], layout_comb[1], alignment)
C = cutlass.TensorDescription(type_comb[2], cutlass.LayoutType.ColumnMajor, alignment)
math_inst = cutlass.MathInstruction(
[1, 1, 1],
type_comb[0],
type_comb[1],
type_comb[2],
cutlass.OpcodeClass.Simt,
cutlass.MathOperation.multiply_add
)
td = cutlass.TileDescription(
[128, 128, 8], 2, [4, 2, 1], math_inst, 50, 1024)
# Prune operations that don't fit in shared memory
if not valid_stage_count(target_cc, td_from_profiler_td(td))[0]:
continue
new_operation = prof_manifest.GemmOperation(
cutlass.GemmKind.Universal, td.minimum_compute_capability,
td, A, B, C, type_comb[2], epilogue_functor, swizzling_functor)
new_kernels = KernelsForDataType(type_comb, layout_comb)
new_kernels.add(new_operation)
self.operations_by_opclass[cutlass.OpcodeClass.Simt][comb] = new_kernels
# Sort all operations
for oc in self.operations_by_opclass.keys():
for comb in self.operations_by_opclass[oc].keys():
self.operations_by_opclass[oc][comb].sort()
def opclass_supports_combination(
self, op_class: cutlass.OpcodeClass, datatype_comb: tuple, layout_comb: tuple
) -> bool:
"""
Returns whether the provided operation class supports the provided data type and layout combination
:param op_class: operation class to consider
:type op_class: cutlass.OpcodeClass
:param datatype_comb: tuple of data types for (element_A, element_B, element_accumulator)
:type datatype_comb: tuple[cutlass.DataType]
:param layout_comb: tuple of data types for (layout_A, layout_B)
:type layout_comb: tuple[cutlass.LayoutType]
:return: set of operation classes that support the provided data type and layout combination
:rtype: set
"""
if op_class not in self.operations_by_opclass:
raise Exception(f"Unexpected or unsupported operation class {op_class}")
return (datatype_comb, layout_comb) in self.operations_by_opclass[op_class]
def supporting_opclasses(
self,
element_a: cutlass.DataType,
element_b: cutlass.DataType,
element_accumulator: cutlass.DataType,
layout_a: cutlass.LayoutType,
layout_b: cutlass.LayoutType,
) -> set:
"""
Returns a set of operation classes that support the provided data type combination
:param element_a: data type of operand A
:type element_a: cutlass.DataType
:param element_b: data type of operand B
:type element_b: cutlass.DataType
:param element_accumulator: data type of accumulator
:type element_accumulator: cutlass.DataType
:param layout_a: layout of operand A
:type layout_a: cutlass.LayoutType
:param layout_b: layout of operand B
:type layout_b: cutlass.LayoutType
:return: set of operation classes that support the provided data type combination
:rtype: set
"""
supporting_op_classes = set()
datatype_comb = (element_a, element_b, element_accumulator)
layout_comb = (layout_a, layout_b)
for op_class in self.operations_by_opclass.keys():
if self.opclass_supports_combination(op_class, datatype_comb, layout_comb):
supporting_op_classes.add(op_class)
return supporting_op_classes
def operations(
self,
op_class: cutlass.OpcodeClass,
element_a: cutlass.DataType,
element_b: cutlass.DataType,
element_accumulator: cutlass.DataType,
layout_a: cutlass.LayoutType,
layout_b: cutlass.LayoutType,
) -> KernelsForDataType:
"""
Returns whether the provided operation class supports the provided data type combination
:param op_class: operation class to consider
:type op_class: cutlass.OpcodeClass
:param element_a: data type of operand A
:type element_a: cutlass.DataType
:param element_b: data type of operand B
:type element_b: cutlass.DataType
:param element_accumulator: data type of accumulator
:type element_accumulator: cutlass.DataType
:param layout_a: layout of operand A
:type layout_a: cutlass.LayoutType
:param layout_b: layout of operand B
:type layout_b: cutlass.LayoutType
:return: container of kernels by alignment supported by the provided combination of parameters
:rtype: KernelsForDataType
"""
datatype_comb = (element_a, element_b, element_accumulator)
layout_comb = (layout_a, layout_b)
if not self.opclass_supports_combination(op_class, datatype_comb, layout_comb):
raise Exception(
f"Data type layout combination {datatype_comb}, {layout_comb} "
f"is not supported by opcode class {op_class} on CC {self.cc}."
)
return self.operations_by_opclass[op_class][(datatype_comb, layout_comb)]
class OptionRegistry:
"""
Container of all architecture-specific options
:param target_cc: compute capability of the device on which operations will be run
:type target_cc: int
"""
def __init__(self, target_cc: int):
self.registry = {}
gemm_kinds = [cutlass.GemmKind.Universal, cutlass.GemmKind.Universal3x]
# Construct options for each CC
for kernel_cc in _generator_ccs:
self.registry[kernel_cc] = ArchOptions(target_cc, kernel_cc, cutlass.OperationKind.Gemm, gemm_kinds)
def options_for_cc(self, cc: int) -> ArchOptions:
return self.registry.get(cc, None)

View File

@ -0,0 +1,35 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass.op.gemm import Gemm
from cutlass.op.gemm_grouped import GroupedGemm
from cutlass.op.op import OperationBase

696
python/cutlass/op/gemm.py Normal file
View File

@ -0,0 +1,696 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Ease-of-use interface for constructing, compiling, and running GEMMs.
The ``Gemm`` interface is meant to allow one to easily instantiate, compile, and run
GEMM operations in CUTLASS via Python, without specifying many configuration parameters.
Under the hood, the interface will select sensible default parameters for the many template
parameters for CUTLASS GEMMs.
Note: optimal performance is not to be expected from this interface. To achieve optimal
performance, one should specify and tune each configuration parameter.
The simplest example of using this interface is the following:
.. highlight:: python
.. code-block:: python
# A, B, C, and D are torch/numpy/cupy tensor objects
plan = cutlass.op.Gemm(A, B, C, D)
plan.run()
One can also use the interface by specifying data types of operands at construction
and using different tensor objects with these data types at runtime:
.. highlight:: python
.. code-block:: python
# The following is shorthand for:
# cutlass.op.Gemm(element_A=torch.float32, element_B=torch.float32,
# element_C=torch.float32, element_D=torch.float32,
# element_accumulator=torch.float32,
# layout=cutlass.LayoutType.RowMajor)
plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass.LayoutType.RowMajor)
A0 = torch.rand((128, 256), device='cuda')
B0 = torch.rand((256, 64), device='cuda')
C0 = torch.zeros((128, 64), device='cuda')
D0 = torch.zeros((128, 64), device.'cuda')
plan.run(A0, B0, C0, D0)
A = torch.rand((32, 128), device='cuda')
B = torch.rand((128, 256), device='cuda')
C = torch.zeros((32, 256), device='cuda')
D = torch.zeros((32, 256), device.'cuda')
plan.run(A1, B1, C1, D1)
The interface additionally enables one to decouple the compilation of the underlying CUTLASS
kernel from its execution:
.. highlight:: python
.. code-block:: python
plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor)
plan.compile()
# Do other work...
plan.run(A0, B0, C0, D0)
# Do other work...
plan.run(A1, B1, C1, D1)
Elementwise activation functions are easily fused to the GEMM via the interface:
.. highlight:: python
.. code-block:: python
plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor)
plan.activation = cutlass.epilogue.relu
Operations can also be run asynchronously:
.. highlight:: python
.. code-block:: python
plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor)
args = plan.run()
# Do other work...
args.sync()
"""
import cutlass_bindings
import cutlass
from cutlass import epilogue, swizzle
from cutlass.backend import compiler
from cutlass.backend.gemm_operation import GemmArguments, GemmOperationUniversal
from cutlass.backend.library import TensorDescription, TileDescription
from cutlass.op.op import OperationBase
from cutlass.utils import check, datatypes
class Gemm(OperationBase):
"""
Constructs a ``Gemm`` object.
The data types and layouts of operands A, B, and C, along with the data type of output D
and that used for accumulation, are bound to the ``Gemm`` object throughout its lifetime --
these are not to be changed after a ``Gemm`` has been constructed.
The constructor has optional parameters for flexibly setting these parameters. The following
constructors are equivalent:
.. highlight:: python
.. code-block:: python
# Use F32 for A, B, C, D, and accumulation. All operands are row major.
# Use the generic ``element`` and ``layout`` parameters to concisely set all data types and layouts
# for operands to the same values.
Gemm(element=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor)
# Explicitly specify the data types to use for A, B, C, and D. Use the generic ``layout``.
Gemm(element_A=cutlass.DataType.f32, element_B=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_D=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor)
# Set the data types and elements from existing tensors. Note that one can use different tensors when
# executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must
# have the same data type and layout as those passed in here).
# A, B, C, and D are row-major torch.Tensor objects of type torch.float32
Gemm(A=A, B=B, C=C, D=D)
# Use the generic ``element`` and explicitly specify the layouts to use for A, B, and C (layout of D is
# the same as that for D, at present)
Gemm(element=cutlass.DataType.f32, layout_A=cutlass.LayoutType.RowMajor,
layout_B=cutlass.LayoutType.RowMajor, layout_C=cutlass.LayoutType.RowMajor)
# Explicitly specify the data type and layout for only some of A, B, C, and D. Unspecified data types
# and layouts will inherit those passed in via the generic ``element`` and ``layout``
Gemm(element_A=cutlass.DataType.f32, layout_B=cutlass.LayoutType.RowMajor,
element=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor)
The order of precedence for the setting of the data type and layout for a given operand/output is as follows:
1) If the tensor type is specified (e.g., ``A``), use the data type and layout inferred from this tensor
2) Otherwise, if the data type/layout (e.g., ``element_A``, ``layout_A``) is specified, use those
3) Otherwise, use the generic values (e.g., ``element``, ``layout``)
:param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
:type cc: int
:param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
:type kernel_cc: int
:param A: tensor representing data type and layout of operand A
:param B: tensor representing data type and layout of operand B
:param C: tensor representing data type and layout of operand C
:param D: tensor representing data type and layout of operand D
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
:param beta: scalar parameter beta from GEMM operation that scales operand C
:param element_accumulator: data type to be used in accumulation of the product of operands A and B
:type element_accumulator: cutlass.DataType
:param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
:type element: cutlass.DataType
:param layout: generic layout type to be used for operands A, B, C, and D
:type layout: cutlass.LayoutType
:param element_A: data type to be used for operand A
:type element_A: cutlass.DataType
:param element_B: data type to be used for operand B
:type element_B: cutlass.DataType
:param element_C: data type to be used for operand C
:type element_C: cutlass.DataType
:param element_D: data type to be used for operand D
:type element_D: cutlass.DataType
:type layout_A: layout of operand A
:param layout_A: cutlass.LayoutType
:type layout_B: layout of operand B
:param layout_B: cutlass.LayoutType
:type layout_C: layout of operand C
:param layout_C: cutlass.LayoutType
:type layout_D: layout of operand D
:param layout_D: cutlass.LayoutType
"""
def __init__(
self, A=None, B=None, C=None, D=None,
alpha=1.0, beta=0.0, element_accumulator=None,
element=None, layout=None,
element_A=None, element_B=None, element_C=None, element_D=None,
layout_A=None, layout_B=None, layout_C=None,
cc: int = None, kernel_cc: int = None
):
super().__init__(cc=cc, kernel_cc=kernel_cc)
self.name = "gemm"
self.compiled = False
elements = []
layouts = []
# Check that at least one of the following is set for each tensor (illustrated assuming tensor A):
# ``A``, ``element_A``, ``element`` and ``A``, ``layout_A``, ``layout``
for elt, lay, tens, name in zip([element_A, element_B, element_C, element_D],
[layout_A, layout_B, layout_C, layout_C],
[A, B, C, D],
["A", "B", "C", "D"]):
if elt is not None and tens is not None:
raise Exception(f'Must not specify both element_{name} and tensor {name}')
if lay is not None and tens is not None:
raise Exception(f'Must not specify both layout_{name} and tensor {name}')
if elt is None and tens is None and element is None:
raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.')
if lay is None and tens is None and layout is None:
raise Exception(f'Must specify one of layout_{name}, tensor {name}, or generic layout.')
elt_to_set = None
lay_to_set = None
if tens is not None:
elt_to_set, lay_to_set = datatypes.get_datatype_and_layout(tens)
else:
elt_to_set = elt if elt is not None else element
lay_to_set = lay if lay is not None else layout
elements.append(datatypes.library_type(elt_to_set))
layouts.append(datatypes.library_layout(lay_to_set))
self._element_a, self._element_b, self._element_c, self._element_d = elements
self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts
if element_accumulator is None:
self._element_accumulator = self._element_c
else:
self._element_accumulator = datatypes.library_type(element_accumulator)
self.A = A
self.B = B
self.C = C
self.D = D
self.alpha = alpha
self.beta = beta
self.epilogue_functor = None
self.op_class = None
self._reset_operations()
self._swizzling_functor = cutlass.swizzle.IdentitySwizzle1
def _reset_operations(self, reset_epilogue: bool = True):
# Set the default op class
datatype_comb = (self._element_a, self._element_b, self._element_accumulator)
layout_comb = (self._layout_a, self._layout_b)
self.possible_op_classes = self.options.supporting_opclasses(
self._element_a, self._element_b, self._element_accumulator,
self._layout_a, self._layout_b)
if cutlass.OpcodeClass.TensorOp in self.possible_op_classes:
self.opclass = cutlass.OpcodeClass.TensorOp
elif cutlass.OpcodeClass.Simt in self.possible_op_classes:
self.opclass = cutlass.OpcodeClass.Simt
else:
raise Exception(f'No kernel configuration found for supported data type and layout '
f'combination {datatype_comb}x{layout_comb}')
if reset_epilogue:
self._reset_epilogue_functor_activation(epilogue.identity)
def _reset_epilogue_functor_activation(self, activation):
if self.epilogue_functor is None:
if self.op_class == cutlass.OpcodeClass.Simt:
elements_per_access = 1
else:
elements_per_access = 128 // cutlass.DataTypeSize[self._element_c]
else:
elements_per_access = self.epilogue_functor.epilogue_vector_length
if not self.specified_kernel_cc:
if self.current_cc == 90 and activation != epilogue.identity:
# CUTLASS 3.0 kernels currently only support identity activation. If one requests a non-identity activation,
# revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels.
cutlass.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
self._reset_options(80)
self._reset_operations(reset_epilogue=False)
elif (self.cc == 90 and self.current_cc != 90 and activation == epilogue.identity):
# SM80 fallback kernels are currently used. Since an identity activation is requested,
# we can switch back to using SM90 kernels.
self._reset_options(90)
self._reset_operations(reset_epilogue=False)
else:
if self.current_cc == 90 and activation != epilogue.identity:
raise Exception("Epilogues with elementwise fusion are not currently supported "
"in the Python interface for 3.x kernels. To use 2.x kernels "
"with fused elementwise epilogues, do not set the `kernel_cc` "
"parameter when constructing the Gemm object.")
self.epilogue_functor = epilogue.get_activation_epilogue(
activation,
datatypes.binding_type(self._element_c),
elements_per_access,
datatypes.binding_type(self._element_accumulator),
datatypes.binding_type(self._element_accumulator),
)
def _reset_epilogue_functor_alignment(self, alignment):
if self.epilogue_functor is None or not hasattr(self.epilogue_functor, 'activation_functor'):
activation = epilogue.identity
else:
activation = type(self.epilogue_functor.activation_functor)
self.epilogue_functor = epilogue.get_activation_epilogue(
activation,
datatypes.binding_type(self._element_c),
alignment,
datatypes.binding_type(self._element_accumulator),
datatypes.binding_type(self._element_accumulator),
)
@property
def activation(self):
"""
Returns the type of the current activation function used
"""
return type(self.epilogue_functor.activation_functor)
@activation.setter
def activation(self, act):
"""
Sets the type of the activation function to use
"""
self._reset_epilogue_functor_activation(act)
@property
def opclass(self) -> cutlass.OpcodeClass:
"""
Returns the opcode class currently in use by the GEMM
:return: opcode class currently in use
:rtype: cutlass.OpcodeClass
"""
return self.op_class
@opclass.setter
def opclass(self, oc: cutlass.OpcodeClass):
"""
Sets the opcode class to use in the GEMM. If the opcode class is not supported under
the given compute capability and element/layout combinations of the GEMM, an exception is raised.
"""
if oc in self.possible_op_classes:
self.op_class = oc
else:
raise Exception(
f'Unsupported operation class {oc} for CC {self.cc} and data type combination '
f'({self._element_a}, {self._element_b}, {self._element_accumulator}) and '
f'layout combination ({self._layout_a}, {self._layout_b}).')
# Changing the op class changes the elements per access in the epilogue. Reset this.
if self.op_class == cutlass.OpcodeClass.Simt:
elements_per_access = 1
else:
elements_per_access = 128 // cutlass.DataTypeSize[self._element_c]
if self.epilogue_functor is not None:
self._reset_epilogue_functor_alignment(elements_per_access)
# Changing the op class also changes the possible operations available. Reset these.
self.possible_operations = self.options.operations(
self.op_class, self._element_a, self._element_b,
self._element_accumulator, self._layout_a, self._layout_b)
@property
def swizzling_functor(self):
"""
Returns the type of the swizzling functor currently being used by the GEMM
:return: swizzing functor type
"""
return self._swizzling_functor
@swizzling_functor.setter
def swizzling_functor(self, swizzling_functor):
"""
Sets the swizzling functor to the type specified by `swizzling_functor`
"""
if swizzling_functor == swizzle.ThreadblockSwizzleStreamK:
if self.op_class == cutlass.OpcodeClass.Simt:
raise Exception('ThreadblockSwizzleStreamK is currently only supported with opcode class TensorOp')
if self.current_cc == 90:
raise Exception('ThreadblockSwizzleStreamK is currently unsupported on SM90')
self._swizzling_functor = swizzling_functor
def _valid_tile_description(self, td: TileDescription) -> tuple:
"""
Checks whether the provided tile description is valid for the given compute capability. At present,
this checks the following:
- Does the tile description use a number of stages supported by the compute capability in question?
- Does the tile size requested fit within shared memory?
- Are cluster dimensions outside the valid range requested for a given architecture (e.g.,
more non-unit cluster dimensions for pre-SM90 architectures)?
- Is the kernel schedule being used supported on the architecture in question?
:param td: tile description to validate
:type td: cutlass.backend.TileDescription
:return: tuple in which the first element is a bool indicating that the tile description is valid
and the second element is a string providing an optional error message.
:rtype: tuple
"""
# Check stage count based on the CC to which we are compiling (self.cc), rather
# than the CC from which we find kernels (self.current_cc)
valid, msg = check.valid_stage_count(self.cc, td)
if not valid:
return (valid, msg)
valid, msg = check.valid_cluster_shape(self.current_cc, td.cluster_shape)
if not valid:
return (valid, msg)
valid, msg = check.valid_kernel_schedule(self.current_cc, td.kernel_schedule)
return valid, msg
def tile_descriptions(self) -> list:
"""
Returns a list of valid tile descriptions for the operations
:returns: list of valid tile descriptions for the operations
:rtype: list
"""
return [datatypes.td_from_profiler_op(op) for op in self.possible_operations.all_operations]
def construct(
self, tile_description: TileDescription = None,
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None) -> GemmOperationUniversal:
"""
Constructs a ``cutlass.backend.GemmUniversalOperation`` based on the input parameters and current
kernel specification of the ``Gemm`` object.
:param tile_description: tile description specifying shapes and operand types to use in the kernel
:type tile_description: cutlass.backend.TileDescription
:param alignment_A: alignment of operand A
:type alignment_A: int
:param alignment_B: alignment of operand B
:type alignment_B: int
:param alignment_C: alignment of operand C
:type alignment_C: int
:return: operation that was constructed
:rtype: cutlass.backend.GemmOperationUniversal
"""
alignment_pref_A = min(128 // cutlass.DataTypeSize[self._element_a], max(self.possible_operations.alignments))
alignment_pref_B = min(128 // cutlass.DataTypeSize[self._element_b], max(self.possible_operations.alignments))
alignment_pref_C = min(128 // cutlass.DataTypeSize[self._element_c], max(self.possible_operations.alignments))
alignment_A = check.alignment_or_default(alignment_A, alignment_pref_A)
alignment_B = check.alignment_or_default(alignment_B, alignment_pref_B)
alignment_C = check.alignment_or_default(alignment_C, alignment_pref_C)
self._reset_epilogue_functor_alignment(alignment_C)
tensor_A = TensorDescription(
datatypes.binding_type(self._element_a),
datatypes.binding_layout(self._layout_a),
alignment_A
)
tensor_B = TensorDescription(
datatypes.binding_type(self._element_b),
datatypes.binding_layout(self._layout_b),
alignment_B
)
tensor_C = TensorDescription(
datatypes.binding_type(self._element_c),
datatypes.binding_layout(self._layout_c),
alignment_C
)
if tile_description is None:
op = self.possible_operations.operations(alignment_A)[0]
tile_description = datatypes.td_from_profiler_op(op)
else:
valid, err_str = self._valid_tile_description(tile_description)
if not valid:
raise Exception(f"Invalid tile description. {err_str}")
self.tile_description = tile_description
operation = GemmOperationUniversal(
arch=self.current_cc,
tile_description=tile_description,
A=tensor_A, B=tensor_B, C=tensor_C,
epilogue_functor=self.epilogue_functor,
swizzling_functor=self._swizzling_functor,
)
return operation
def compile(self, tile_description: TileDescription = None,
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
print_module: bool = False) -> cutlass.backend.GemmOperationUniversal:
"""
Emits and compiles the kernel currently specified. If ``tile_description`` and any
of the ``alignment`` parameters are set, the kernel will be chosen using this
tile description and alignments. Otherwise, a default tile description and alignment
will be used.
:param tile_description: tile description specifying shapes and operand types to use in the kernel
:type tile_description: cutlass.backend.TileDescription
:param alignment_A: alignment of operand A
:type alignment_A: int
:param alignment_B: alignment of operand B
:type alignment_B: int
:param alignment_C: alignment of operand C
:type alignment_C: int
:param print_module: whether to print the emitted C++ code
:type print_module: bool
:return: operation that was compiled
:rtype: cutlass.backend.GemmOperationUniversal
"""
self.operation = self.construct(tile_description, alignment_A, alignment_B, alignment_C)
if print_module:
print(self.operation.rt_module.emit())
compiler.add_module([self.operation,])
return self.operation
def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name):
"""
Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception
is raised if it does not.
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
:type tensor: numpy/cupy/torch array/tensor object
:param ref_dtype: data type for the tensor that this object was initialized to
:param ref_layout: layout for the tensor that this object was initialized to
:param name: identifier of the tensor to verify. Used in raising exceptions
:type name: str
"""
dtype, layout = datatypes.get_datatype_and_layout(tensor)
if dtype != ref_type or layout != ref_layout:
raise Exception(f'Tensor {name} with type and layout ({dtype}, {layout}) '
f'does not match the expected type and '
f'layout of ({ref_type}, {ref_layout}).')
def _verify_tensor(self, tensor, ref_tensor, ref_dtype, ref_layout, name):
"""
Verifies the following properties:
1) Either ``tensor`` or ``ref_tensor`` must be set (i.e., not ``None``)
2) If ``tensor`` is not ``None``, its datatype and layout must match matches the current versions
set by the plan (i.e., those in ``ref_dtype`` and ``ref_layout``)
If either of these properties does not hold, an exception is raised. If these properties hold and
``tensor`` is not ``None``, ``tensor`` is returned. Otherwise, ``ref_tensor`` is returned.
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
:type tensor: numpy/cupy/torch array/tensor object
:param ref_tensor: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in
:type ref_tensor: numpy/cupy/torch array/tensor object
:param ref_dtype: data type for the tensor that this object was initialized to
:param ref_layout: layout for the tensor that this object was initialized to
:param name: identifier of the tensor to verify. Used in raising exceptions
:type name: str
:return: valid tensor object to use
:rtype: numpy/cupy/torch array/tensor object
"""
if tensor is None:
if ref_tensor is None:
raise Exception(f"Tensor {name} must be set.")
return ref_tensor
self._verify_type_and_layout(tensor, ref_dtype, ref_layout, name)
return tensor
def _verify_scalar(self, scalar, ref_scalar, ref_dtype, name):
"""
Verifies the following properties:
1) Either ``scalar`` or ``ref_scakar`` must be set (i.e., not ``None``)
2) If ``scalar`` is not ``None``, its datatype must match matches the current version
set by the plan (i.e., those in ``ref_dtype``)
If either of these properties does not hold, an exception is raised. If these properties hold and
``scalar`` is not ``None``, ``scalar`` is returned. Otherwise, ``ref_scalar`` is returned.
:param scalar: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
:type scalar: numpy/cupy/torch scalar
:param ref_scalar: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in
:type ref_scalar: numpy/cupy/torch scalar
:param ref_dtype: data type for the scalar that this object was initialized to
:param name: identifier of the scalar to verify. Used in raising exceptions
:type name: str
:return: valid scalar to use
:rtype: numpy/cupy/torch scalar
"""
if scalar is None:
if ref_scalar is None:
raise Exception(f"Scalar {name} must be set.")
return ref_scalar
dtype = datatypes.library_type(scalar.dtype)
if dtype != ref_dtype:
raise Exception(
f"Tensor {name} with type {dtype} does not match expected type {ref_dtype}."
)
return scalar
def run(self, A=None, B=None, C=None, D=None,
alpha=None, beta=None, batch_count: int = 1,
sync: bool = True, print_module: bool = False) -> GemmArguments:
"""
Runs the kernel currently specified. If it has not already been, the kernel is emitted and
compiled. Tensors holding operands and outputs of the kernel are sourced either from the
``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta``
parameters provided in this call, or from those
passed in on the construction of this object -- one of the two must be specified.
By default, this call returns only once the kernel has completed. To launch the kernel
and immediately return, set ``sync=False``. In this case, it is the responsibility of the
caller to syncrhonize the results of the kernel before attempting to access outputs
by calling ``sync()`` on the arguments returned from this call.
:param A: tensor representing data type and layout of operand A
:param B: tensor representing data type and layout of operand B
:param C: tensor representing data type and layout of operand C
:param D: tensor representing data type and layout of operand D
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
:param beta: scalar parameter beta from GEMM operation that scales operand C
:param batch_count: number of GEMMs in the batch
:type batch_count: int
:param sync: whether the call should wait for the kernel to complete before returning
:type sync: bool
:param print_module: whether to print the emitted C++ code
:type print_module: bool
:return: arguments passed in to the kernel
:rtype: cutlass.backend.GemmArguments
"""
if batch_count < 1:
raise Exception(f"Invalid batch count {batch_count}. Value must be an integer >= 1.")
A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A")
B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B")
C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C")
D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D")
alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
alignment_a = self.possible_operations.find_alignment(A.shape, self._layout_a)
alignment_b = self.possible_operations.find_alignment(B.shape, self._layout_b)
alignment_c = self.possible_operations.find_alignment(C.shape, self._layout_c)
self.compile(self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
alignment_C=alignment_c, print_module=print_module)
problem_size = cutlass_bindings.gemm.GemmCoord(A.shape[0], B.shape[1], A.shape[1])
if batch_count == 1:
mode = cutlass_bindings.gemm.Mode.Gemm
kwargs = {'split_k_slices': 1}
else:
mode = cutlass_bindings.gemm.Mode.Batched
kwargs = {'batch': batch_count}
arguments = GemmArguments(
operation=self.operation, problem_size=problem_size,
A=A, B=B, C=C, D=D,
output_op=self.operation.epilogue_type(alpha, beta),
gemm_mode=mode,
**kwargs
)
self.operation.run(arguments)
if sync:
arguments.sync()
return arguments

View File

@ -0,0 +1,270 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Ease-of-use interface for constructing, compiling, and running GEMMs.
The ``GroupedGemm`` interface is meant to allow one to easily instantiate, compile, and run
grouped GEMM operations in CUTLASS via Python, without specifying many configuration parameters.
Under the hood, the interface will select sensible default parameters for the many template
parameters for CUTLASS grouped GEMMs.
Note: optimal performance is not to be expected from this interface. To achieve optimal
performance, one should specify and tune each configuration parameter.
The simplest example of using this interface is the following:
.. highlight:: python
.. code-block:: python
# As, Bs, Cs, and Ds are torch/numpy/cupy tensor objects
plan = cutlass.op.GroupedGemm(element=cutlass.DataType.f16, layout=cutlass.LayoutType.RowMajor)
plan.run([A0, A1], [B0, B1], [C0, C1], [D0, D1])
"""
import cutlass_bindings
from cutlass.backend.gemm_operation import (
GemmGroupedArguments,
GemmOperationGrouped,
)
from cutlass.backend.library import (
DataTypeSize,
SchedulerMode,
TensorDescription,
TileDescription,
)
from cutlass.op.gemm import Gemm
from cutlass.utils import check, datatypes
class GroupedGemm(Gemm):
"""
Constructs a ``GroupedGemm`` object.
The data types and layouts of operands A, B, and C, along with the data type of output D
and that used for accumulation, are bound to the ``GroupedGemm`` object throughout its lifetime --
these are not to be changed after a ``GroupedGemm`` has been constructed.
The constructor has optional parameters for flexibly setting these parameters. Please see the constructor
for ``Gemm`` for examples of these.
:param cc: compute capability of device to generate kernels for
:type cc: int
:param A: tensor representing data type and layout of operands A
:param B: tensor representing data type and layout of operands B
:param C: tensor representing data type and layout of operands C
:param D: tensor representing data type and layout of operands D
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
:param beta: scalar parameter beta from GEMM operation that scales operand C
:param element_accumulator: data type to be used in accumulation of the product of operands A and B
:type element_accumulator: cutlass.DataType
:param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
:type element: cutlass.DataType
:param layout: generic layout type to be used for operands A, B, C, and D
:type layout: cutlass.LayoutType
:param element_A: data type to be used for operand A
:type element_A: cutlass.DataType
:param element_B: data type to be used for operand B
:type element_B: cutlass.DataType
:param element_C: data type to be used for operand C
:type element_C: cutlass.DataType
:param element_D: data type to be used for operand D
:type element_D: cutlass.DataType
:type layout_A: layout of operand A
:param layout_A: cutlass.LayoutType
:type layout_B: layout of operand B
:param layout_B: cutlass.LayoutType
:type layout_C: layout of operand C
:param layout_C: cutlass.LayoutType
:type layout_D: layout of operand D
:param layout_D: cutlass.LayoutType
"""
def __init__(
self, A=None, B=None, C=None, D=None,
alpha=1.0, beta=0.0, element_accumulator=None,
element=None, layout=None,
element_A=None, element_B=None, element_C=None, element_D=None,
layout_A=None, layout_B=None, layout_C=None,
cc: int = None,
):
super().__init__(
A=A, B=B, C=C, D=D,
alpha=alpha, beta=beta,
element_accumulator=element_accumulator,
element=element, layout=layout,
element_A=element_A, element_B=element_B,
element_C=element_C, element_D=element_D,
layout_A=layout_A, layout_B=layout_B, layout_C=layout_C,
cc=cc
)
# Grouped GEMM specializations for SM90 are currently unavailable. Revert to using SM80
if self.current_cc == 90:
self._reset_options(80)
self._reset_operations(reset_epilogue=False)
self.name = "grouped_gemm"
@Gemm.swizzling_functor.setter
def swizzling_functor(self, swizzling_functor):
"""
Sets the swizzling functor to the type specified by `swizzling_functor`
"""
raise Exception('Grouped GEMM does not currently support different swizzling functors')
def construct(self, tile_description: TileDescription = None,
alignment_A: int = None,
alignment_B: int = None,
alignment_C: int = None) -> GemmOperationGrouped:
"""
Constructs a ``cutlass.backend.GemmOperationGrouped`` based on the input parameters and current
kernel specification of the ``Gemm`` object.
:param tile_description: tile description specifying shapes and operand types to use in the kernel
:type tile_description: cutlass.backend.TileDescription
:param alignment_A: alignment of operand A
:type alignment_A: int
:param alignment_B: alignment of operand B
:type alignment_B: int
:param alignment_C: alignment of operand C
:type alignment_C: int
:return: operation that was constructed
:rtype: cutlass.backend.GemmOperationGrouped
"""
alignment_preference = max(self.possible_operations.alignments)
alignment_A = check.alignment_or_default(alignment_A, alignment_preference)
alignment_B = check.alignment_or_default(alignment_B, alignment_preference)
alignment_C = check.alignment_or_default(alignment_C, alignment_preference)
self._reset_epilogue_functor_alignment(alignment_C)
tensor_A = TensorDescription(
datatypes.binding_type(self._element_a),
datatypes.binding_layout(self._layout_a),
alignment_A
)
tensor_B = TensorDescription(
datatypes.binding_type(self._element_b),
datatypes.binding_layout(self._layout_b),
alignment_B
)
tensor_C = TensorDescription(
datatypes.binding_type(self._element_c),
datatypes.binding_layout(self._layout_c),
alignment_C
)
if tile_description is None:
op = self.possible_operations.operations(alignment_A)[0]
tile_description = datatypes.td_from_profiler_op(op)
else:
valid, err_str = self._valid_tile_description(tile_description)
if not valid:
raise Exception(f"Invalid tile description. {err_str}")
self.tile_description = tile_description
operation = GemmOperationGrouped(
arch=self.current_cc,
tile_description=tile_description,
A=tensor_A, B=tensor_B, C=tensor_C,
epilogue_functor=self.epilogue_functor,
swizzling_functor=self._swizzling_functor,
precompute_mode=SchedulerMode.Device)
return operation
def run(self, A, B, C, D,
alpha=None, beta=None, sync: bool = True,
print_module: bool = False) -> GemmGroupedArguments:
"""
Runs the kernel currently specified.
By default, this call returns only once the kernel has completed. To launch the kernel
and immediately return, set ``sync=False``. In this case, it is the responsibility of the
caller to syncrhonize the results of the kernel before attempting to access outputs
by calling ``sync()`` on the arguments returned from this call.
:param A: list of tensors representing data type and layout of operand A
:type A: list
:param B: list of tensors representing data type and layout of operand B
:type B: list
:param C: list of tensors representing data type and layout of operand C
:type C: list
:param D: list of tensors representing data type and layout of operand D
:type D: list
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
:param beta: scalar parameter beta from GEMM operation that scales operand C
:param sync: whether the call should wait for the kernel to complete before returning
:type sync: bool
:param print_module: whether to print the emitted C++ code
:type print_module: bool
:return: arguments passed in to the kernel
:rtype: cutlass.backend.GemmGroupedArguments
"""
if len(A) != len(B) or len(A) != len(C) or len(A) != len(D):
raise Exception("Lengths of A, B, C, and D lists must be equal")
problem_sizes = []
As, Bs, Cs, Ds = ([None] * len(A) for _ in range(4))
for i in range(len(A)):
As[i] = self._verify_tensor(A[i], self.A, self._element_a, self._layout_a, "A")
Bs[i] = self._verify_tensor(B[i], self.B, self._element_b, self._layout_b, "B")
Cs[i] = self._verify_tensor(C[i], self.C, self._element_c, self._layout_c, "C")
Ds[i] = self._verify_tensor(D[i], self.D, self._element_d, self._layout_d, "D")
problem_sizes.append(cutlass_bindings.gemm.GemmCoord(A[i].shape[0], B[i].shape[1], A[i].shape[1]))
alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
alignment_a = min((self.possible_operations.find_alignment(A.shape, self._layout_a) for A in As))
alignment_b = min((self.possible_operations.find_alignment(B.shape, self._layout_b) for B in Bs))
alignment_c = min((self.possible_operations.find_alignment(C.shape, self._layout_c) for C in Cs))
self.compile(self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
alignment_C=alignment_c, print_module=print_module)
arguments = GemmGroupedArguments(
operation=self.operation,
problem_sizes=problem_sizes,
A=As, B=Bs, C=Cs, D=Ds,
output_op=self.operation.epilogue_type(alpha, beta)
)
self.operation.run(arguments)
if sync:
arguments.sync()
return arguments

116
python/cutlass/op/op.py Normal file
View File

@ -0,0 +1,116 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d)
"""
from bisect import bisect_left
from cutlass import option_registry
from cutlass.backend.utils.device import device_cc
from cutlass.epilogue import get_activations
from cutlass.library_defaults import _generator_ccs
from cutlass.swizzle import get_swizzling_functors
class OperationBase:
"""
Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d)
"""
def __init__(self, cc: int = None, kernel_cc: int = None):
"""
:param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
:type cc: int
:param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
:type kernel_cc: int
"""
self.cc = cc if cc is not None else device_cc()
self.specified_kernel_cc = kernel_cc is not None
self.current_cc = kernel_cc if kernel_cc is not None else self._find_closest_cc(self.cc)
self.tile_description = None
self.options = option_registry.options_for_cc(self.current_cc)
if self.options is None:
raise Exception(f"Invalid or unsupported compute capability: {self.current_cc}")
def _find_closest_cc(self, cc: int) -> int:
"""
Returns the closest CC in _generator_ccs less than or equal to `cc`
:param cc: compute capability to query
:type cc: int
:returns: closest CC in _generator_ccs less than or equal to `cc`
:rtype: int
"""
if cc in _generator_ccs:
return cc
# Find closest CC lower than this CC
idx = bisect_left(_generator_ccs, cc)
if idx == 0:
raise Exception(f'No valid CC to fall back to for {cc}')
return _generator_ccs[idx-1]
def activations(self) -> list:
"""
Returns possible activation functions that can be used
:return: list of activation functions that can be used
:rtype: list
"""
return get_activations()
def swizzling_functors(self) -> list:
"""
Returns possible swizzling functions that can be used
:return: list of swizzling functions that can be used
:rtype: list
"""
return get_swizzling_functors()
def _reset_options(self, cc: int):
"""
Resets the kernel options based on cc
:param cc: compute capability to reset to
:type cc: int
"""
if cc != self.current_cc:
if cc not in _generator_ccs:
raise Exception(f'Invalid CC for CUTLASS kernels: {cc}.')
self.current_cc = cc
self.options = option_registry.options_for_cc(self.current_cc)

66
python/cutlass/swizzle.py Normal file
View File

@ -0,0 +1,66 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Registry of swizzling functions
"""
import cutlass_bindings
IdentitySwizzle1 = cutlass_bindings.IdentitySwizzle1
IdentitySwizzle2 = cutlass_bindings.IdentitySwizzle2
IdentitySwizzle4 = cutlass_bindings.IdentitySwizzle4
IdentitySwizzle8 = cutlass_bindings.IdentitySwizzle8
HorizontalSwizzle = cutlass_bindings.HorizontalSwizzle
BatchedIdentitySwizzle = cutlass_bindings.BatchedIdentitySwizzle
ThreadblockSwizzleStreamK = cutlass_bindings.ThreadblockSwizzleStreamK
StridedDgradIdentitySwizzle1 = cutlass_bindings.StridedDgradIdentitySwizzle1
StridedDgradIdentitySwizzle4 = cutlass_bindings.StridedDgradIdentitySwizzle4
StridedDgradHorizontalSwizzle = cutlass_bindings.StridedDgradHorizontalSwizzle
_swizzling_functors = [
IdentitySwizzle1,
IdentitySwizzle2,
IdentitySwizzle4,
IdentitySwizzle8,
HorizontalSwizzle,
BatchedIdentitySwizzle,
ThreadblockSwizzleStreamK,
StridedDgradIdentitySwizzle1,
StridedDgradIdentitySwizzle4,
StridedDgradHorizontalSwizzle,
]
def get_swizzling_functors():
return _swizzling_functors

View File

@ -0,0 +1,40 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass.utils.check import (
alignment_or_default,
calculate_smem_usage,
calculate_smem_usage_per_stage,
valid_cluster_shape,
valid_kernel_schedule,
valid_stage_count,
)

View File

@ -0,0 +1,192 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Utility functions for checking constraints on kernels and calculating kernel attributes
"""
import ctypes
import cutlass_bindings
import cutlass
from cutlass.backend.library import DataTypeSize, TileDescription
def calculate_smem_usage_per_stage(tile_description, operation_kind):
"""
Returns the amount of shared memory in bytes consumed in a single stage of a kernel.
:return: number of bytes of shared memory consumed by a single stage
:rtype: int
"""
m, n, k = tile_description.threadblock_shape
if operation_kind == cutlass.OperationKind.Gemm:
stage_barrier_bytes = 32
return (
(DataTypeSize[tile_description.math_instruction.element_a] * m * k // 8)
+ (DataTypeSize[tile_description.math_instruction.element_b] * k * n // 8)
+ stage_barrier_bytes
)
else:
raise Exception(f"No available shared memory calculation for operation kind {operation.operation_kind}")
def calculate_smem_usage(operation):
"""
Returns the amount of shared memory in bytes consumed by a kernel.
:return: number of bytes of shared memory consumed by the operation
:return: int
"""
_per_stage = calculate_smem_usage_per_stage(operation.tile_description, operation.operation_kind)
return _per_stage * operation.tile_description.stages
def valid_stage_count(cc: int, td: TileDescription) -> tuple:
"""
Checks whether a device with `cc` supports the number of stages within `tile_description`, both
based on raw limits on the number of stages and based on shared memory capacity
:param cc: compute capability of device in question
:type cc: int
:param td: tile description to check
:type td: TileDescription
:return: tuple with the first element indicating whether the provided tile description is
valid for the provided device and the second element being an error message
:rtype: tuple
"""
if cc == 90 and (td.stages is None or td.stages == 0):
# Stage count of None or 0 for SM90 indicates that the CollectiveBuilder automatically
# determines the stage count to use. Thus, all settings are valid in these scenarios.
return (True, "")
if td.stages <= 0:
return (False, f"Stage counts must be positive integers. Tile description has stage count of {td.stages}.")
if cc < 80 and td.stages != 2:
return (False, f"Tile description has stage count of {td.stages}, "
f"but only 2 stages are supported on SM{cc}.")
smem_per_stage = calculate_smem_usage_per_stage(td, cutlass.OperationKind.Gemm)
smem_arch = cutlass.SharedMemPerCC[cc] << 10
if (smem_per_stage * td.stages) > smem_arch:
return ( False,
"Configuration uses too much shared memory. Consider reducing stage count or tile shape.\n"
f"Details: configuration uses {smem_per_stage} bytes of shared memory per stage, and "
f"{td.stages} stages for a total of {smem_per_stage * td.stages} bytes.\n"
f"The maxmium amoung of shared memory that can be used per block on CC {cc} is {smem_arch}.")
return (True, "")
def valid_cluster_shape(cc: int, cluster_shape: list) -> tuple:
"""
Checks whether a device with `cc` supports a thread block cluster of shape `cluster_shape`.
:param cc: compute capability of device in question
:type cc: int
:param cluster_shape: dimensions of thread block cluster shape to check
:type cluster_shape: list
:return: tuple with the first element indicating whether the provided cluster shape is
valid for the provided device and the second element being an error message
:rtype: tuple
"""
if cc < 90:
if cluster_shape != [1, 1, 1]:
return (False,
f"Cluster shape for pre-SM90 architectures must be [1, 1, 1]. Received cluster shape of "
f"{cluster_shape} for SM{cc}.")
else:
return (True, "")
if len(cluster_shape) != 3:
return (False,
f"Cluster shapes must be rank-3. Received {cluster_shape} (rank {len(cluster_shape)}")
if cluster_shape[2] != 1:
return (False,
"CUTLASS kernels currently require the third dimension of cluster shape to be 1. "
f"Received cluster shape of {cluster_shape}.")
# The CUDA programming guide currently defines a maximum of 8 thread blocks per cluster
# as being portably supported (https://docs.nvidia.com/cuda/cuda-c-programming-guide/#thread-block-clusters).
# Current CUTLASS kernels only have non-unit cluster dimensions within the first two dimensions,
# so we check that the first two dimensions of the cluster shape do not exceed 8 thread blocks in total.
blocks_in_2d = cluster_shape[0] * cluster_shape[1]
if blocks_in_2d > 8:
return (False,
f"Thread block clusters with more than 8 thread blocks are currently unsupported on SM{cc}. "
f"Received cluster shape {cluster_shape}, which has {blocks_in_2d} thread blocks.")
return (True, "")
def valid_kernel_schedule(cc: int, kernel_schedule: cutlass.KernelScheduleType) -> tuple:
"""
Checks whether a device with ``cc`` supports ``kernel_schedule``.
:param cc: compute capability of device in question
:type cc: int
:param kernel_schedule: kernel schedule type
:type KernelScheduleType: cutlass.KernelScheduleType
:return: tuple with the first element indicating whether the provided kernel schedule is
valid for the provided device and the second element being an error message
:rtype: tuple
"""
if kernel_schedule != cutlass.KernelScheduleType.ScheduleAuto and cc < 90:
return (False, "Non-default kernel schedules are only supported on SM90 and beyond")
return (True, "")
def alignment_or_default(alignment_provided: int, default_alignment: int) -> int:
"""
Returns `alignment_provided` if it is set, otherwise `default_alignment` and checks
that `alignment_provided` does not exceed `default_alignment`.
:param alignment_provided: alignment preference specified. Can be None.
:type alignment_provided: int
:param default_alignment: alignment to use if `alignment_provided` is None
:type default_alignment: int
:return: alignment to use
:rtype: int
"""
if alignment_provided is not None:
if alignment_provided > default_alignment:
raise Exception(f"Alignment {alignment_provided} exceeds the maximum supported of {default_alignment}.")
return alignment_provided
return default_alignment

View File

@ -0,0 +1,339 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Utility functions for converting between frontend datatypes and CUTLASS datatypes
"""
import cutlass_bindings
import cutlass
from cutlass.backend.library import (
DataTypeSize,
MathInstruction,
MathOperation,
ShortLayoutTypeNames,
TileDescription,
)
try:
import numpy as np
numpy_available = True
_library_to_numpy_dict = {
cutlass.DataType.f16: np.float16,
cutlass.DataType.f32: np.float32,
cutlass.DataType.f64: np.float64,
cutlass.DataType.s8: np.int8,
cutlass.DataType.s32: np.int32,
}
except ImportError:
numpy_available = False
_library_to_numpy_dict = {}
def numpy_library_type(inp) -> cutlass.DataType:
if numpy_available:
if inp == np.float16:
return cutlass.DataType.f16
elif inp == np.float32:
return cutlass.DataType.f32
elif inp == np.float64:
return cutlass.DataType.f64
elif inp == np.int8:
return cutlass.DataType.s8
elif inp == np.int32:
return cutlass.DataType.s32
return None
def numpy_type(inp):
return _library_to_numpy_dict.get(inp, None)
try:
import cupy as cp
cupy_available = True
_library_to_cupy_dict = {
cutlass.DataType.f16: cp.float16,
cutlass.DataType.f32: cp.float32,
cutlass.DataType.f64: cp.float64,
cutlass.DataType.s8: cp.int8,
cutlass.DataType.s32: cp.int32,
}
except ImportError:
cupy_available = False
_library_to_cupy_dict = {}
def cupy_library_type(inp) -> cutlass.DataType:
if cupy_available:
if inp == cp.float16:
return cutlass.DataType.f16
elif inp == cp.float32:
return cutlass.DataType.f32
elif inp == cp.float64:
return cutlass.DataType.f64
return None
def cupy_type(inp):
return _library_to_cupy_dict.get(inp, None)
try:
import torch
torch_available = True
_torch_to_library_dict = {
torch.half: cutlass.DataType.f16,
torch.float16: cutlass.DataType.f16,
torch.float: cutlass.DataType.f32,
torch.float32: cutlass.DataType.f32,
torch.double: cutlass.DataType.f64,
torch.float64: cutlass.DataType.f64,
}
_library_to_torch_dict = {
cutlass.DataType.f16: torch.half,
cutlass.DataType.f16: torch.float16,
cutlass.DataType.f32: torch.float,
cutlass.DataType.f32: torch.float32,
cutlass.DataType.f64: torch.double,
cutlass.DataType.f64: torch.float64,
}
except ImportError:
torch_available = False
_torch_to_library_dict = {}
_library_to_torch_dict = {}
def torch_library_type(inp) -> cutlass.DataType:
return _torch_to_library_dict.get(inp, None)
def torch_type(inp):
return _library_to_torch_dict.get(inp, None)
try:
import bfloat16
bfloat16_available = True
except ImportError:
bfloat16_available = False
def bfloat16_library_type(inp) -> cutlass.DataType:
if bfloat16_available:
if inp == bfloat16.bfloat16:
return cutlass.DataType.bf16
def bfloat16_type(inp) -> bfloat16.bfloat16:
if bfloat16_available:
if inp == cutlass.DataType.bf16:
return bfloat16.bfloat16
# Mapping from library data type to Python-bound CUTLASS data type
library_to_binding_dict = {
cutlass.DataType.s8: cutlass_bindings.int8,
cutlass.DataType.s32: cutlass_bindings.int32,
cutlass.DataType.f16: cutlass_bindings.float16,
cutlass.DataType.bf16: cutlass_bindings.bfloat16,
cutlass.DataType.f32: cutlass_bindings.float32,
cutlass.DataType.f64: cutlass_bindings.float64,
cutlass.DataType.tf32: cutlass_bindings.tfloat32,
}
# Mapping from Python-bound CUTLASS data type to library data type
binding_to_library = {
cutlass_bindings.int8: cutlass.DataType.s8,
cutlass_bindings.int32: cutlass.DataType.s32,
cutlass_bindings.float16: cutlass.DataType.f16,
cutlass_bindings.bfloat16: cutlass.DataType.bf16,
cutlass_bindings.float32: cutlass.DataType.f32,
cutlass_bindings.float64: cutlass.DataType.f64,
cutlass_bindings.tfloat32: cutlass.DataType.tf32,
}
def binding_library_type(inp):
if inp in binding_to_library:
return binding_to_library[inp]
return None
def has_binding_type(inp: cutlass.DataType):
return inp in library_to_binding_dict
def library_to_binding(inp: cutlass.DataType):
if not has_binding_type(inp):
raise Exception(f"No available conversion from library type {inp} to Python-bound CUTLASS type")
return library_to_binding_dict[inp]
def library_type(inp):
if inp in cutlass.DataTypeSize.keys():
return inp
for cvt_fn in [
bfloat16_library_type,
cupy_library_type,
numpy_library_type,
torch_library_type,
binding_library_type,
]:
out = cvt_fn(inp)
if out is not None:
return out
raise Exception(f"No available conversion from type {inp} to a library type.")
def library_layout(layout):
if layout in cutlass.LayoutTag.keys():
return layout
# Convert Python-bound CUTLASS layout to profiler library layout
if layout == cutlass_bindings.RowMajor:
return cutlass.LayoutType.RowMajor
elif layout == cutlass_bindings.ColumnMajor:
return cutlass.LayoutType.ColumnMajor
else:
raise Exception(f"No conversion available for layout {layout} to library layout.")
def binding_type(inp):
if inp in DataTypeSize.keys():
return inp
libtype = library_type(inp)
return library_to_binding(libtype)
def binding_layout(layout):
if layout in ShortLayoutTypeNames.keys():
return layout
elif layout == cutlass.LayoutType.RowMajor:
return cutlass_bindings.RowMajor
elif layout == cutlass.LayoutType.ColumnMajor:
return cutlass_bindings.ColumnMajor
else:
raise Exception(f"No conversion available for layout {layout} to Python-bound CUTLASS layout.")
def _tensor_from_numpy(np_tensor):
dtype = library_type(np_tensor.dtype)
if np_tensor.flags.c_contiguous:
layout = cutlass.LayoutType.RowMajor
elif np_tensor.flags.f_contiguous:
layout = cutlass.LayoutType.ColumnMajor
return (dtype, layout)
def _tensor_from_torch(pt_tensor):
dtype = library_type(pt_tensor.dtype)
return (dtype, cutlass.LayoutType.RowMajor)
def get_datatype_and_layout(tensor):
if (numpy_available and isinstance(tensor, np.ndarray)) or (
cupy_available and isinstance(tensor, cp.ndarray)
):
return _tensor_from_numpy(tensor)
elif torch_available and isinstance(tensor, torch.Tensor):
return _tensor_from_torch(tensor)
else:
raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.")
def binding_opclass(opclass: cutlass.OpcodeClass):
if opclass == cutlass.OpcodeClass.TensorOp:
return cutlass_bindings.OpClass.TensorOp
elif opclass == cutlass.OpcodeClass.Simt:
return cutlass_bindings.OpClass.Simt
else:
raise Exception(f"Unable to convert opcode class of type {opclass} to Python-bound CUTLASS opcode class.")
_math_operation_value_map = {x.value: x for x in MathOperation}
def backend_math_operation(math_op: cutlass.MathOperation):
if math_op.value not in _math_operation_value_map.keys():
raise Exception(f"Unable to convert math operation of type {math_op} to backend math operation.")
return _math_operation_value_map[math_op.value]
def construct_backend_td(td: cutlass.TileDescription,
kernel_schedule: cutlass.KernelScheduleType) -> TileDescription:
mi = td.math_instruction
backend_mi = MathInstruction(
mi.instruction_shape,
binding_type(mi.element_a),
binding_type(mi.element_b),
binding_type(mi.element_accumulator),
binding_opclass(mi.opcode_class),
backend_math_operation(mi.math_operation)
)
return TileDescription(td.threadblock_shape, td.stages, td.warp_count,
backend_mi, td.cluster_shape, kernel_schedule)
def td_from_profiler_op(op) -> TileDescription:
"""
Converts the profiler's TileDescription in ``op`` into the backend TileDescription
:param op: profiler Operation
:returns: backend TileDescription
:rtype: cutlass.backend.TileDescription
"""
schedule = op.kernel_schedule if hasattr(op, 'kernel_schedule') else None
return construct_backend_td(op.tile_description, schedule)
def td_from_profiler_td(td: cutlass.backend.TileDescription) -> TileDescription:
"""
Converts the profiler's TileDescription into the backend TileDescription
:param td: profiler TileDescription
:type td: cutlass.TileDescription
:returns: backend TileDescription
:rtype: cutlass.backend.TileDescription
"""
return construct_backend_td(td, kernel_schedule=None)

View File

@ -0,0 +1,40 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
FROM nvcr.io/nvidia/pytorch:22.11-py3
RUN chmod ugo+rwx /home
RUN pip uninstall -y rmm
RUN pip install rmm-cu11 --extra-index-url=https://pypi.ngc.nvidia.com
ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
ENV LIBRARY_PATH=/usr/local/cuda/lib64:$LIBRARY_PATH
ENV CUDA_INSTALL_PATH=/usr/local/cuda

View File

@ -0,0 +1,38 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
FROM nvcr.io/nvidia/pytorch:23.01-py3
RUN chmod ugo+rwx /home
ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
ENV LIBRARY_PATH=/usr/local/cuda/lib64:$LIBRARY_PATH
ENV CUDA_INSTALL_PATH=/usr/local/cuda

20
python/docs_src/Makefile Normal file
View File

@ -0,0 +1,20 @@
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = source
BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

35
python/docs_src/make.bat Normal file
View File

@ -0,0 +1,35 @@
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=.
set BUILDDIR=_build
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.https://www.sphinx-doc.org/
exit /b 1
)
if "%1" == "" goto help
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 49 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 48 KiB

View File

@ -0,0 +1,94 @@
{% extends "!layout.html" %}
{% block sidebartitle %} {{ super() }}
<style>
/* Sidebar header (and topbar for mobile) */
.wy-side-nav-search, .wy-nav-top {
background: #76b900;
}
.wy-menu > p > span.caption-text {
color: #76b900;
}
.wy-menu-vertical p {
height: 32px;
line-height: 32px;
padding: 0 1.618em;
margin: 12px 0 0;
display: block;
font-weight: 700;
text-transform: uppercase;
font-size: 85%;
white-space: nowrap;
}
.wy-side-nav-search a:link, .wy-nav-top a:link {
color: #fff;
}
.wy-side-nav-search a:visited, .wy-nav-top a:visited {
color: #fff;
}
.wy-side-nav-search a:hover, .wy-nav-top a:hover {
color: #fff;
}
.wy-menu-vertical a:link, .wy-menu-vertical a:visited {
color: #d9d9d9
}
.wy-menu-vertical a:active {
background-color: #76b900
}
.wy-side-nav-search>div.version {
color: rgba(0, 0, 0, 0.3)
}
.wy-nav-content {
max-width: 1000px;
}
/* override table width restrictions */
.wy-table-responsive table td, .wy-table-responsive table th {
/* !important prevents the common CSS stylesheets from
overriding this as on RTD they are loaded after this stylesheet */
white-space: normal !important;
}
.wy-table-responsive {
overflow: visible !important;
}
</style>
{% endblock %}
{% block footer %} {{ super() }}
<style>
a:link, a:visited {
color: #76b900;
}
a:hover {
color: #8c0;
}
html.writer-html4 .rst-content dl:not(.docutils)>dt, html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.citation):not(.glossary):not(.simple)>dt {
background: rgba(118, 185, 0, 0.1);
color: rgba(59,93,0,1);
border-top: solid 3px rgba(59,93,0,1);
}
html.writer-html4 .rst-content dl:not(.docutils) .property, html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) .property {
text-transform: capitalize;
display: inline-block;
padding-right: 8px;
}
</style>
{%- if nvidia_analytics_id %}
<script type="text/javascript">_satellite.pageBottom();</script>
{%- endif %}
{% endblock %}

View File

@ -0,0 +1,100 @@
# Configuration file for the Sphinx documentation builder.
#
# For the full list of built-in configuration values, see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import sys
sys.path.insert(0, os.path.abspath('..'))
sys.path.insert(0, os.path.abspath('../..'))
sys.path.insert(0, os.path.abspath('../../media/docs'))
# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
project = 'CUTLASS Python interface'
copyright = '2023, NVIDIA'
author = 'NVIDIA'
release = '3.1.0'
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'myst_parser',
'nbsphinx',
'nbsphinx_link',
'sphinx_copybutton',
'sphinx.ext.autodoc',
'sphinx.ext.autosectionlabel',
'sphinx.ext.autosummary',
'sphinx.ext.coverage',
'sphinx.ext.extlinks',
'sphinx.ext.ifconfig',
'sphinx.ext.intersphinx',
'sphinx.ext.mathjax',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinx_inline_tabs',
]
source_suffix = {
'.rst': 'restructuredtext',
'.md': 'markdown',
}
autodoc_typehints = 'description'
pygments_style = "sphinx"
pygments_dark_style = "monokai"
templates_path = ['_templates']
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
# Ignore errors when converting notebooks
nbsphinx_allow_errors = True
language = 'en'
# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_static_path = ['_static']
html_title = "CUTLASS Python"
html_baseurl = 'docs'
html_theme = 'furo'
html_theme_options = {
"light_logo": "cutlass-logo-small.png",
"dark_logo": "cutlass-logo-small.png",
"light_css_variables": {
"color-brand-primary": "#76B900",
"color-brand-content": "#76B900",
},
"dark_css_variables": {
"color-brand-primary": "#76B900",
"color-brand-content": "#76B900",
},
"footer_icons": [
{
"name": "GitHub",
"url": "https://github.com/NVIDIA/cutlass",
"html": """
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 16 16">
<path fill-rule="evenodd" d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0 0 16 8c0-4.42-3.58-8-8-8z"></path>
</svg>
""",
"class": "",
},
],
}

View File

@ -0,0 +1,9 @@
# Contributing
Thank you for your interest in contributing to the CUTLASS Python interface. Based on the type of contribution, it will fall into two categories:
1. You want to report a bug, feature request, or documentation issue
- File an [issue](https://github.com/NVIDIA/cutlass/issues/new/choose) describing what you encountered or what you want to see changed.
- The CUTLASS team will evaluate the issues and triage them, scheduling them for a release. If you believe the issue needs priority attention, comment on the issue to notify the team.
2. You want to implement a feature or bug-fix
- We welcome contributions from the community. We recommend that you contribute via a [pull request](https://github.com/NVIDIA/cutlass/pulls). If you have questions about CUTLASS, consider asking a question via the [Discussions](https://github.com/NVIDIA/cutlass/discussions) tab. Please be sure to search through both existing issues and discussions to see whether your question has already been answered.

View File

@ -0,0 +1,18 @@
Emitters
========
Common
------
.. automodule:: cutlass.emit.common
:members:
:undoc-members:
:show-inheritance:
PyTorch
-------
.. automodule:: cutlass.emit.pytorch
:members:
:undoc-members:
:show-inheritance:

View File

@ -0,0 +1,26 @@
Operations
==========
GEMM
----
.. automodule:: cutlass.op.gemm
:members:
:undoc-members:
:show-inheritance:
Grouped GEMM
------------
.. automodule:: cutlass.op.gemm_grouped
:members:
:undoc-members:
:show-inheritance:
Operation
---------
.. automodule:: cutlass.op.op
:members:
:undoc-members:
:show-inheritance:

View File

@ -0,0 +1,36 @@
CUTLASS
=======
Subpackages
-----------
.. toctree::
:maxdepth: 1
cutlass.emit
cutlass.op
cutlass.utils
Epilogue
--------
.. automodule:: cutlass.epilogue
:members:
:undoc-members:
:show-inheritance:
Library Defaults
----------------
.. automodule:: cutlass.library_defaults
:members:
:undoc-members:
:show-inheritance:
Swizzle
----------
.. automodule:: cutlass.swizzle
:members:
:undoc-members:
:show-inheritance:

View File

@ -0,0 +1,18 @@
Utilities
=========
Checks
------
.. automodule:: cutlass.utils.check
:members:
:undoc-members:
:show-inheritance:
Data Types
----------
.. automodule:: cutlass.utils.datatypes
:members:
:undoc-members:
:show-inheritance:

View File

@ -0,0 +1,9 @@
Examples
==================
.. toctree::
:maxdepth: 5
Basic GEMM <externals/00_basic_gemm.nblink>
Epilogue <externals/01_epilogue.nblink>
PyTorch Extension <externals/02_pytorch_extension_grouped_gemm.nblink>

View File

@ -0,0 +1,3 @@
{
"path": "./../../../../examples/python/00_basic_gemm.ipynb"
}

View File

@ -0,0 +1,3 @@
{
"path": "./../../../../examples/python/01_epilogue.ipynb"
}

View File

@ -0,0 +1,3 @@
{
"path": "./../../../../examples/python/02_pytorch_extension_grouped_gemm.ipynb"
}

View File

@ -0,0 +1,55 @@
.. CUTLASS Python interface documentation master file, created by
sphinx-quickstart on Mon Feb 13 17:57:39 2023.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
.. include:: ../../README.md
:start-line: 1
:parser: markdown
.. toctree::
:hidden:
Home <self>
.. toctree::
:hidden:
:caption: Getting Started:
install.md
Getting Started <externals/00_basic_gemm.nblink>
contribute.md
.. toctree::
:hidden:
:caption: Python Documentation:
modules.rst
.. toctree::
:hidden:
:caption: Examples and Tutorials:
examples.rst
.. toctree::
:hidden:
:caption: Advanced:
.. toctree::
:hidden:
:caption: FAQ:
.. toctree::
:hidden:
:caption: Reference:
Github <https://github.com/NVIDIA/cutlass>
Indices and tables
==================
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`

View File

@ -0,0 +1,37 @@
# Installation
## Installing from source
Installing from source requires the latest CUDA Toolkit that matches the major.minor of CUDA Python installed.
Prior to installing the CUTLASS Python interface, one may optionally set the following environment variables:
* `CUTLASS_PATH`: the path to the cloned CUTLASS repository
* `CUDA_INSTALL_PATH`: the path to the installation of CUDA
If these environment variables are not set, the installation process will infer them to be the following:
* `CUTLASS_PATH`: one directory level above the current directory (i.e., `$(pwd)/..`)
* `CUDA_INSTALL_PATH`: the directory holding `/bin/nvcc` for the first version of `nvcc` on `$PATH` (i.e., `which nvcc | awk -F'/bin/nvcc' '{print $1}'`)
**NOTE:** The version of `cuda-python` installed must match the CUDA version in `CUDA_INSTALL_PATH`.
### Installing a developer-mode package
The CUTLASS Python interface can currently be installed via:
```bash
python setup.py develop --user
```
This will allow changes to the Python interface source to be reflected when using the Python interface.
We plan to add support for installing via `python setup.py install` in a future release.
## Docker
To ensure that you have all of the necessary Python modules for running the examples using the
CUTLASS Python interface, we recommend using one of the Docker images for CUDA [11.8](../../../python/docker/Dockerfile-cuda11.8-pytorch)
and [12.0](../../../python/docker/Dockerfile-cuda12.0-pytorch) are located in the docker directory.
For example, to build and launch a container that uses CUDA 12.0 via an NGC PyTorch container, run:
```bash
docker build -t cutlass-cuda12.0:latest -f docker/Dockerfile-cuda12.0-pytorch .
docker run --gpus all -it --rm cutlass-cuda12.0:latest
```
The CUTLASS Python interface has been tested with CUDA 11.8 and CUDA 12.0 on Python 3.8.10 and 3.9.7.

View File

@ -0,0 +1,7 @@
CUTLASS Python API
==================
.. toctree::
:maxdepth: 5
cutlass

106
python/setup.py Normal file
View File

@ -0,0 +1,106 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 'AS IS'
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import os
from setuptools import setup
def _cutlass_path_from_dir() -> str:
cutlass_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../')
if not os.path.isdir(cutlass_path):
raise Exception(f'Environment variable "CUTLASS_PATH" is not defined, and default path of {cutlass_path} does not exist.')
return cutlass_path
def _cuda_install_path_from_nvcc() -> str:
import subprocess
# Attempt to detect CUDA_INSTALL_PATH based on location of NVCC
result = subprocess.run(['which', 'nvcc'], capture_output=True)
if result.returncode != 0:
raise Exception(f'Unable to find nvcc via `which` utility.')
cuda_install_path = result.stdout.decode('utf-8').split('/bin/nvcc')[0]
if not os.path.isdir(cuda_install_path):
raise Exception(f'Environment variable "CUDA_INSTALL_PATH" is not defined, and default path of {cuda_install_path} does not exist.')
return cuda_install_path
cutlass_path = (
os.getenv('CUTLASS_PATH')
if os.getenv('CUTLASS_PATH') is not None
else _cutlass_path_from_dir()
)
cuda_install_path = (
os.getenv('CUDA_INSTALL_PATH')
if os.getenv('CUDA_INSTALL_PATH') is not None
else _cuda_install_path_from_nvcc()
)
ext_modules = []
try:
from pybind11.setup_helpers import Pybind11Extension, build_ext
include_dirs = [
cutlass_path + '/include',
cuda_install_path + '/include',
cutlass_path + '/tools/util/include',
cutlass_path + '/test',
]
ext_modules = [
Pybind11Extension('cutlass_bindings',
['cutlass/cpp/cutlass_bindings.cpp'],
include_dirs=include_dirs,
extra_compile_args=['-fpermissive', '-w', '-std=c++17', '-DCUTLASS_PYTHON_HOST_CC=1'])
]
except ImportError:
pass
setup(
name='cutlass',
version='3.1.0',
description='CUTLASS Pythonic Interface',
package_dir={'': '.'},
packages=['cutlass', 'cutlass.emit', 'cutlass.op', 'cutlass.utils', 'cutlass.backend', 'cutlass.backend.utils'],
setup_requires=['pybind11'],
install_requires=[
'bfloat16',
'cuda-python>=11.8.0',
'pybind11',
'scikit-build',
'treelib'
],
ext_modules=ext_modules,
)