CUTLASS 3.2.1 (#1113)

* Updates for 3.2.1 release.

* Minor fix in gemm op profiler for raster order.

* Add scheduler mapping for raster order in the kernels.
This commit is contained in:
ANIKET SHIVAM
2023-09-26 14:24:26 -07:00
committed by GitHub
parent e0aaa3c3b3
commit 90d3b0fb18
428 changed files with 22253 additions and 21762 deletions

View File

@ -1,6 +1,12 @@
![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")
# CUTLASS Python Interface
# Python packages associated with CUTLASS
This directory contains Python packages that are associated with CUTLASS:
* `cutlass`: the CUTLASS Python interface, which enables one to compile and run CUTLASS kernels from within Python
* `cutlass_library`: utilities used for enumerating and emitting C++ code for CUTLASS kernels
## CUTLASS Python Interface
The CUTLASS Python interface enables one to compile and run CUTLASS operations from within Python.
```python
@ -15,7 +21,7 @@ plan.run(A, B, C, D)
**NOTE:** The CUTLASS Python interface is currently an experimental release. The API may change in the future.
We welcome feedback from the community.
## Overview
### Overview
The CUTLASS Python interface aims to provide an ease-of-use interface for using CUTLASS via Python. Toward this goal,
the CUTLASS Python interface attempts to:
@ -25,7 +31,7 @@ the CUTLASS Python interface attempts to:
* Reduce the occurrence of C++ compile-time errors in favor of descriptive Python exceptions
* Make it easy to export CUTLASS kernels to framework extensions (e.g., PyTorch CUDA extensions)
### Non-goals
#### Non-goals
The CUTLASS Python interface does not intended to:
**Select optimal kernel configurations.**
@ -43,7 +49,7 @@ one of the CUTLASS emitters for automatically creating a framework extension for
The CUTLASS Python interface intends to enable one to use CUTLASS via Python. It can be used by frameworks for JIT compiling
Python to CUDA kernels, but does not set out to be such a framework.
### Comparison to PyCUTLASS
#### Comparison to PyCUTLASS
The CUTLASS Python interface builds atop CUTLASS's [PyCUTLASS](https://github.com/NVIDIA/cutlass/tree/v3.0.0/tools/library/scripts/pycutlass) library. PyCUTLASS enables
one to declare, compile, and run GEMMs, convolutions, and grouped GEMM operators with nearly the same configuration
space as CUTLASS's C++ interface. While this flexibility enables one to achieve the similar levels of functionality
@ -53,43 +59,14 @@ to operators -- similar to what one must do in specifying template parameters to
In contrast, the CUTLASS Python interface aims to provide a higher-level API for declaring, emitting, and compiling
kernels that does not require exhaustively defining template parameters.
#### Transitioning from PyCUTLASS
At present, existing PyCUTLASS functionality remains available via the CUTLASS Python interface. One can
continue to use PyCUTLASS by replacing references to the PyCUTLASS `cutlass` module with `cutlass_bindings`
and the PyCUTLASS `pycutlass` module with `cutlass.backend`.
For example, the following code using PyCUTLASS:
```python
import pycutlass
import cutlass
math_inst = pycutlass.MathInstruction(
[1, 1, 1], cutlass.float32, cutlass.float32, cutlass.float32,
cutlass.OpClass.Simt, pycutlass.MathOperation.multiply_add
)
```
can work with the Python interface via:
```python
import cutlass.backend as pycutlass
import cutlass_bindings
math_inst = pycutlass.MathInstruction(
[1, 1, 1], cutlass_bindings.float32, cutlass_bindings.float32, cutlass_bindings.float32,
cutlass_bindings.OpClass.Simt, pycutlass.MathOperation.multiply_add
)
```
**NOTE:** backwards compatibility of `cutlass.backend` with `pycutlass` will not be maintained moving forward.
## Current functionality
### Current functionality
The CUTLASS Python interface currently supports the following operations:
* GEMMs
* GEMMs with fused elementwise epilogues (e.g., ReLU) (for pre-SM90 kernels)
* Stream K swizzling (for pre-SM90 kernels)
* Grouped GEMM (for pre-SM90 kernels)
## Getting started
### Getting started
We recommend using the CUTLASS Python interface via one of the Docker images located in the [docker](/python/docker) directory.
```bash
@ -99,7 +76,7 @@ docker run --gpus all -it --rm cutlass-cuda12.1:latest
The CUTLASS Python interface has been tested with CUDA 11.8, 12.0, and 12.1 on Python 3.8.10 and 3.9.7.
### Optional environment variables
#### Optional environment variables
Prior to installing the CUTLASS Python interface, one may optionally set the following environment variables:
* `CUTLASS_PATH`: the path to the cloned CUTLASS repository
* `CUDA_INSTALL_PATH`: the path to the installation of CUDA
@ -110,7 +87,7 @@ If these environment variables are not set, the installation process will infer
**NOTE:** The version of `cuda-python` installed must match the CUDA version in `CUDA_INSTALL_PATH`.
### Installation
#### Installation
The CUTLASS Python interface can currently be installed via:
```bash
python setup.py develop --user
@ -119,7 +96,7 @@ This will allow changes to the Python interface source to be reflected when usin
We plan to add support for installing via `python setup.py install` in a future release.
## Examples
### Examples
Jupyter notebook examples of using the CUTLASS Python interface are located in [examples/python](/examples/python).
To launch these notebooks from this directory, run:
@ -127,7 +104,7 @@ To launch these notebooks from this directory, run:
jupyter-lab ../examples/python
```
## Building documentation
### Building documentation
The CUTLASS Python interface uses [Sphinx](https://www.sphinx-doc.org/en/master/) for documentation.
Building the documentation requires additional packages. These can be installed via:
@ -147,6 +124,22 @@ make html
mv _build/* ../docs
```
## CUTLASS library package
[cutlass_library](/python/cutlass_library) contains utilities for enumerating and emitting CUTLASS C++ kernels.
It is used by the CUTLASS CMake system to construct a library of kernels that can be profiled using the CUTLASS profiler.
To install the `cutlass_library` package, run
```bash
python setup_library.py develop --user
```
Alternatively, `cutlass_library` will automatically be installed if you install the CUTLASS Python interface package.
You can also use the [generator.py](/python/cutlass_library/generator.py) script directly without installing the module via:
```bash
python -m cutlass_library.generator
```
# Copyright
Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

View File

@ -34,6 +34,8 @@ import logging
import os
import sys
import cutlass_library
def _cutlass_path_from_dir() -> str:
cutlass_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../')
@ -62,19 +64,29 @@ CUTLASS_PATH = os.getenv("CUTLASS_PATH", _cutlass_path_from_dir())
CUDA_INSTALL_PATH = os.getenv("CUDA_INSTALL_PATH", _cuda_install_path_from_nvcc())
CACHE_FILE = "compiled_cache.db"
# Add the path to the CUTLASS profiler generation/manifest scripts to PYTHONPATH
sys.path.insert(0, os.path.join(CUTLASS_PATH, "tools/library/scripts/"))
# Import types/methods from the CUTLASS utility libraries for profiler generation/emission under
from library import (
from cutlass_library.library import (
ArchitectureNames,
ComplexTransform,
ComplexTransformTag,
ConvKind,
ConvKindNames,
ConvKindTag,
ConvMode,
DataType,
DataTypeNames,
DataTypeSize,
DataTypeTag,
EpilogueFunctor,
EpilogueScheduleSuffixes,
EpilogueScheduleTag,
EpilogueScheduleType,
GemmKind,
GemmKindNames,
GemmUniversalMode,
IteratorAlgorithm,
IteratorAlgorithmNames,
IteratorAlgorithmTag,
LayoutTag,
LayoutType,
KernelScheduleSuffixes,
@ -82,15 +94,27 @@ from library import (
KernelScheduleType,
MathInstruction,
MathOperation,
MathOperationTag,
OpcodeClass,
OpcodeClassNames,
OpcodeClassTag,
OperationKind,
SharedMemPerCC,
ShortComplexLayoutNames,
ShortDataTypeNames,
ShortLayoutTypeNames,
SplitKMode,
StrideSupport,
StrideSupportNames,
StrideSupportTag,
SwizzlingFunctor,
SwizzlingFunctorTag,
TensorDescription,
TileDescription,
TileSchedulerSuffixes,
TileSchedulerTag,
TileSchedulerType
TileSchedulerType,
get_complex_from_real,
)
this = sys.modules[__name__]
@ -112,7 +136,7 @@ from cutlass.backend.utils.device import device_cc
this.option_registry = OptionRegistry(device_cc())
this.__version__ = '3.2.0'
this.__version__ = '3.2.1'
from cutlass.backend import get_memory_pool
from cutlass.emit.pytorch import pytorch
@ -120,5 +144,6 @@ from cutlass.op.gemm import Gemm
from cutlass.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad
from cutlass.op.gemm_grouped import GroupedGemm
from cutlass.op.op import OperationBase
from cutlass.backend.evt.ir.tensor import Tensor
get_memory_pool(init_pool_size=2 ** 30, max_pool_size=2 ** 32)

View File

@ -1,6 +1,3 @@
# module-wide variables
import os
from cutlass.backend.arguments import *
from cutlass.backend.c_types import *
from cutlass.backend.compiler import ArtifactManager
@ -11,9 +8,7 @@ from cutlass.backend.gemm_operation import *
from cutlass.backend.library import *
from cutlass.backend.memory_manager import PoolMemoryManager
from cutlass.backend.operation import *
from cutlass.backend.parser import *
from cutlass.backend.reduction_operation import *
from cutlass.backend.tensor_ref import *
from cutlass.backend.type_hint import *
from cutlass.backend.utils import *
from cutlass.backend.utils.device import device_cc

View File

@ -30,6 +30,7 @@
#
#################################################################################################
from math import prod
from typing import Union
from cuda import cuda, cudart
@ -67,39 +68,39 @@ class ArgumentBase:
# by default, tensor_C is not bias
self.bias = False
# preprocessing input tensors
if isinstance(A, np.ndarray):
self.host_D = D
self.buffer_A = NumpyFrontend.argument(A, False)
self.buffer_B = NumpyFrontend.argument(B, False)
self.buffer_C = NumpyFrontend.argument(C, False)
self.buffer_D = NumpyFrontend.argument(D, True)
self.ptr_A = self.buffer_A.ptr
self.ptr_B = self.buffer_B.ptr
self.ptr_C = self.buffer_C.ptr
self.ptr_D = self.buffer_D.ptr
# number of elements in C
self.tensor_c_numel = C.size
elif torch_available and isinstance(A, torch.Tensor):
self.ptr_A = TorchFrontend.argument(A)
self.ptr_B = TorchFrontend.argument(B)
self.ptr_C = TorchFrontend.argument(C)
self.ptr_D = TorchFrontend.argument(D)
# number of elements in C
self.tensor_c_numel = C.numel()
elif isinstance(A, cuda.CUdeviceptr):
self.ptr_A = A
self.ptr_B = B
self.ptr_C = C
self.ptr_D = D
# RMM buffers used to track tensor lifetime
self.buffers = {}
# Host tensor to copy the computed result back
self.host_tensors = {}
elif cupy_available and isinstance(A, cp.ndarray):
self.ptr_A = CupyFrontend.argument(A)
self.ptr_B = CupyFrontend.argument(B)
self.ptr_C = CupyFrontend.argument(C)
self.ptr_D = CupyFrontend.argument(D)
# number of elements in C
self.tensor_c_numel = C.size
self.ptr_A = self.tensor_to_ptr(A, "A")
self.ptr_B = self.tensor_to_ptr(B, "B")
self.ptr_C = self.tensor_to_ptr(C, "C")
self.ptr_D = self.tensor_to_ptr(D, "D", True)
if C is not None:
if not isinstance(C, cuda.CUdeviceptr):
self.tensor_c_numel = prod(C.shape)
def tensor_to_ptr(self, tensor, name, is_output=False):
"""
Convert and remember the input tensor to cuda.CUdeviceptr used by cuda python
For numpy.ndarray, it also remembers the host buffer for synchronization
"""
if tensor is None:
return cuda.CUdeviceptr(0)
if isinstance(tensor, np.ndarray):
if is_output:
assert name
self.buffers[name] = NumpyFrontend.argument(tensor, is_output)
if is_output:
self.host_tensors[name] = tensor
return self.buffers[name].ptr
elif torch_available and isinstance(tensor, torch.Tensor):
return TorchFrontend.argument(tensor)
elif isinstance(tensor, cuda.CUdeviceptr):
return tensor
elif cupy_available and isinstance(tensor, cp.ndarray):
return CupyFrontend.argument(tensor)
else:
raise TypeError("Unsupported Frontend. Only support numpy and torch")
@ -109,11 +110,12 @@ class ArgumentBase:
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
if hasattr(self, "host_D"):
for key in self.host_tensors.keys():
host_tensor = self.host_tensors[key]
(err,) = cuda.cuMemcpyDtoH(
self.host_D,
self.ptr_D,
self.host_D.size * self.host_D.itemsize,
host_tensor,
self.buffers[key].ptr,
host_tensor.size * host_tensor.itemsize,
)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))

View File

@ -32,7 +32,6 @@
import ctypes
import cutlass_bindings
from cutlass import (
DataType,
KernelScheduleType
@ -47,9 +46,10 @@ class GemmCoord_(ctypes.Structure):
("k", ctypes.c_int)
]
def __init__(self, gemm_coord) -> None:
for field_name, _ in self._fields_:
setattr(self, field_name, getattr(gemm_coord, field_name)())
def __init__(self, m, n, k) -> None:
self.m = m
self.n = n
self.k = k
class GemmCoordBatched_(ctypes.Structure):
@ -66,9 +66,10 @@ class GemmCoordBatched_(ctypes.Structure):
]
def __init__(self, gemm_coord, batch_count) -> None:
for field_name, _ in self._fields_[:-1]:
setattr(self, field_name, getattr(gemm_coord, field_name)())
setattr(self, "batch_count", batch_count)
self.m = gemm_coord.m
self.n = gemm_coord.n
self.k = gemm_coord.k
self.batch_count = batch_count
class MatrixCoord_(ctypes.Structure):
@ -98,14 +99,6 @@ class StrideBatched_(ctypes.Structure):
]
dtype2ctype = {
cutlass_bindings.float16: ctypes.c_uint16,
cutlass_bindings.float32: ctypes.c_float,
cutlass_bindings.float64: ctypes.c_double,
cutlass_bindings.int32: ctypes.c_int32,
}
class GenericMainloopArguments3x_(ctypes.Structure):
"""
Structure representing the superset of possible mainloop arguments.
@ -196,15 +189,28 @@ def get_mainloop_arguments_3x(
def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor):
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
if hasattr(epilogue_functor, "visitor"):
class _EpilogueArguments(ctypes.Structure):
_fields_ = [
("epilogue", _EpilogueOutputOpParams),
("arg_C", epilogue_functor.arg_c_type),
("arg_D", epilogue_functor.arg_d_type)
]
class _EpilogueArguments(ctypes.Structure):
_fields_ = [
("epilogue", _EpilogueOutputOpParams),
("ptr_C", ctypes.c_void_p),
("stride_C", StrideBatched_),
("ptr_D", ctypes.c_void_p),
("stride_D", StrideBatched_),
]
def __init__(self, output_op, ptr_c, stride_c, ptr_d, stride_d) -> None:
self.epilogue = output_op
self.arg_C = epilogue_functor.arg_c_type(ptr_c)
self.arg_D = epilogue_functor.arg_d_type(ptr_d)
else:
class _EpilogueArguments(ctypes.Structure):
_fields_ = [
("epilogue", _EpilogueOutputOpParams),
("ptr_C", ctypes.c_void_p),
("stride_C", StrideBatched_),
("ptr_D", ctypes.c_void_p),
("stride_D", StrideBatched_),
]
class _HardwareInfo(ctypes.Structure):
_fields_ = [
@ -324,7 +330,7 @@ def get_gemm_grouped_arguments(epilogue_functor):
############################################################################################
class Conv2DProblemSize(ctypes.Structure):
class Conv2DProblemSize_(ctypes.Structure):
_fields_ = [
("N", ctypes.c_int),
("H", ctypes.c_int),
@ -382,11 +388,13 @@ def get_conv2d_arguments(epilogue_functor):
class _Conv2dArguments(ctypes.Structure):
_fields_ = [
("problem_size", Conv2DProblemSize),
("ref_A", TensorRef_),
("ref_B", TensorRef_),
("ref_C", TensorRef_),
("ref_D", TensorRef_),
("conv_kind", ctypes.c_int),
("problem_size", Conv2DProblemSize_),
("ptr_A", ctypes.c_void_p),
("ptr_B", ctypes.c_void_p),
("ptr_C", ctypes.c_void_p),
("ptr_D", ctypes.c_void_p),
("tensor_C_numel", ctypes.c_int),
("output_op", _EpilogueOutputOpParams),
("split_k_mode", ctypes.c_int)
]
@ -414,3 +422,189 @@ def get_reduction_params(epilogue_functor):
]
return _ReductionParams, _EpilogueOutputParams
###########################################################################################
# Epilogue Visitor Type Factory
###########################################################################################
class Empty(ctypes.Structure):
_fields_ = []
def __init__(self, *arg) -> None:
pass
class EmptyByte(ctypes.Structure):
_fields_ = [
("byte", ctypes.c_byte)
]
def __init__(self, *arg) -> None:
pass
class EBO:
def __init__(self, index: int, type) -> None:
self.index = index
self.type = type
def __eq__(self, other) -> bool:
if isinstance(other, EBO):
return self.index == other.index and self.type == other.type
return False
def __hash__(self) -> int:
return hash((self.index, self.type))
def __ne__(self, other):
return not self.__eq__(other)
def __str__(self) -> str:
return f"<{self.index}, {self.type}>"
def tuple_factory_(input_tuple, dtype, constants=[0,1]):
"""
The factory function generating cute::Tuple with input tuple
:param input_tuple: the input tuple
:type input_tuple: tuple
:param dtype: the data type for non-constant values
:type dtype: str, "int32_t", "int", "int64_t"
:param constant: the values that will be treated as constants
:type constant: list[int]
:return: ctype structure representing the cute::Tuple
:return: the empty base classes of the tuple
"""
# The empty base classes of the current tuple
empty_bases = []
# The first non empty base class
first_non_empty_base = None
# The ctype fields of the current tuple
ctype_fields = []
for idx, entry in enumerate(input_tuple):
# For nested tuples
if isinstance(entry, tuple):
sub_tuple_ctype, sub_empty_bases = tuple_factory_(entry, dtype, constants)
if ctypes.sizeof(sub_tuple_ctype) == 0:
# The empty tuple base class is also an empty EBO
empty_bases.append(EBO(idx, entry))
else:
if first_non_empty_base is None:
first_non_empty_base = sub_empty_bases
ctype_fields.append((f"entry_{idx}", sub_tuple_ctype))
else:
if entry in constants:
empty_bases.append(EBO(idx, entry))
ctype_fields.append((f"entry_{idx}", Empty))
else:
ctype_fields.append((f"entry_{idx}", dtype))
if first_non_empty_base is None:
first_non_empty_base = []
# Determine whether or not add an additional byte for empty base classes
additional_byte = False
# Special case for constant tuple
if first_non_empty_base is None:
additional_byte = False
else:
for base in first_non_empty_base:
if base in empty_bases:
additional_byte = True
break
if additional_byte:
ctype_fields = [("empty_byte", EmptyByte), ] + ctype_fields
# Create the ctype tuple
class TupleType(ctypes.Structure):
_fields_ = ctype_fields
def __init__(self, args) -> None:
if additional_byte:
fields = self._fields_[1:]
else:
fields = self._fields_
assert len(fields) == len(args)
for field, arg in zip(fields, args):
name = field[0]
field_type = field[1]
setattr(self, name, field_type(arg))
return TupleType, empty_bases
def tuple_factory(input_tuple, dtype: str, constants=[0,1]):
"""
The factory function generating cute::Tuple with input tuple
:param input_tuple: the input tuple
:type input_tuple: tuple
:param dtype: the data type for non-constant values
:type dtype: str, "int32_t", "int", "int64_t"
:param constant: the values that will be treated as constants
:type constant: list[int]
:return: ctype structure representing the cute::Tuple
:return: the empty base classes of the tuple
"""
# Step 1: convert the dtype
if dtype == "int64_t":
dtype = ctypes.c_longlong
elif dtype in ["int", "int32_t"]:
dtype = ctypes.c_int32
else:
raise NotImplementedError(f"Type {dtype} is not supported")
tuple_type, _ = tuple_factory_(input_tuple, dtype, constants)
if ctypes.sizeof(tuple_type) == 0:
return EmptyByte
return tuple_type
def visitor_factory(node_types, node_names):
"""
Creates the argument type of epilogue visitor type
:param node_types: list of argument types under ctypes
:param node_names: list of argument names under str
:return: tuple type in ctypes.Structure
"""
ctypes_field = []
# Struct is used when number of nodes < 4
# Because the Sm90VisitorImplBase has specification up to 4 nodes
# in `include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp`
if len(node_types) <= 4:
for idx, node_type in enumerate(node_types):
if ctypes.sizeof(node_type) == 0:
# Special case for empty struct
# 1 byte placeholder is used for correct alignment
ctypes_field.append((node_names[idx], ctypes.c_byte))
else:
ctypes_field.append((node_names[idx], node_type))
class VisitorType(ctypes.Structure):
_fields_ = ctypes_field
def __init__(self, kwargs) -> None:
for field in self._fields_:
fname, ftype = field
if ftype != ctypes.c_byte:
setattr(self, fname, ftype(kwargs))
# For cases with more than 4 nodes, tuple is used
else:
for idx, node_type in enumerate(node_types):
ctypes_field.append((node_names[idx], node_type))
class VisitorType(ctypes.Structure):
_fields_ = ctypes_field
def __init__(self, kwargs) -> None:
for field in self._fields_:
fname, ftype = field
setattr(self, fname, ftype(kwargs))
return VisitorType

View File

@ -34,17 +34,16 @@ import ctypes
import json
import os
import sqlite3
import subprocess
import tempfile
from cuda import cuda, nvrtc
import cutlass_bindings
from cutlass import CACHE_FILE, CUDA_INSTALL_PATH, CUTLASS_PATH, logger
from cutlass.backend.gemm_operation import GemmOperationUniversal
from cutlass.backend.library import ApiVersion
from cutlass.backend.utils.device import device_cc
from cutlass.backend.utils.software import SubstituteTemplate
import subprocess
IncludeTemplate = r"""#include "${include}"
"""
@ -157,8 +156,8 @@ class ArtifactManager:
"-Xcudafe --diag_suppress=esa_on_defaulted_function_ignored",
]
self.nvcc()
self.compiled_cache_device = cutlass_bindings.CompileCache()
self.compiled_cache_host = cutlass_bindings.CompileCache()
self.compiled_cache_device = {}
self.compiled_cache_host = {}
def nvrtc(self):
self.backend = "nvrtc"
@ -197,7 +196,7 @@ class ArtifactManager:
raise RuntimeError("Cuda Error: {}".format(err))
err, kernel = cuda.cuModuleGetFunction(module, bytes(str.encode(operation_name)))
self.compiled_cache_device.insert(key, kernel)
self.compiled_cache_device[key] = kernel
compiled_host_fns = {}
host_lib = CDLLBin(host_binary)
@ -222,7 +221,7 @@ class ArtifactManager:
compiled_host_fns[attr] = func
self.compiled_cache_host.insert(key, compiled_host_fns)
self.compiled_cache_host[key] = compiled_host_fns
return True
def emit_compile_(self, operation_list, compilation_options, host_compilation_options):
@ -246,11 +245,10 @@ class ArtifactManager:
)
for incl in includes_host:
if "/device/" not in incl:
source_buffer_host += SubstituteTemplate(
IncludeTemplate,
{"include": incl},
)
source_buffer_host += SubstituteTemplate(
IncludeTemplate,
{"include": incl},
)
# 2. Operations
for operation in operation_list:
@ -382,16 +380,16 @@ class ArtifactManager:
# step 1: get kernel string as key
key = operation.rt_module.emit() + operation.procedural_name() + self.backend
# step 1: check if the operation is in cache
compiled_kernel = self.compiled_cache_device.at(key)
compiled_kernel = self.compiled_cache_device.get(key)
if compiled_kernel is None and not bypass_cache:
hit = self.load_operation(key, getattr( operation.rt_module, "extra_funcs", {}))
if hit:
compiled_kernel = self.compiled_cache_device.at(key)
compiled_kernel = self.compiled_cache_device.get(key)
assert compiled_kernel is not None
if compiled_kernel is not None:
operation.rt_module.kernel = compiled_kernel
compiled_host_fns = self.compiled_cache_host.at(key)
compiled_host_fns = self.compiled_cache_host.get(key)
assert compiled_host_fns is not None
for key in compiled_host_fns.keys():
setattr(operation.rt_module, key, compiled_host_fns[key])
@ -417,7 +415,7 @@ class ArtifactManager:
bytes(str.encode(operation.name()))
)
operation_name.append(operation.name())
self.compiled_cache_device.insert(key, operation.kernel)
self.compiled_cache_device[key] = operation.kernel
# get host functions
compiled_host_fns = {}
op_attr = []
@ -456,7 +454,7 @@ class ArtifactManager:
op_attr.append(suffix)
operation_attr.append(op_attr)
self.compiled_cache_host.insert(key, compiled_host_fns)
self.compiled_cache_host[key] = compiled_host_fns
for (key, operation_name, operation_attr,) in zip(operation_key, operation_name, operation_attr):
self.insert_operation(

View File

@ -29,19 +29,14 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
# from typeguard import typechecked
import ctypes
from typing import Union
from cuda import cuda
import cutlass_bindings
import numpy as np
from cutlass.backend.arguments import ArgumentBase
from cutlass.backend.c_types import Conv2DProblemSize, TensorRef_, get_conv2d_arguments
from cutlass.backend.library import (
EmissionType,
from cutlass import (
ConvKindNames,
ConvKindTag,
DataTypeNames,
@ -50,29 +45,40 @@ from cutlass.backend.library import (
IteratorAlgorithmNames,
IteratorAlgorithmTag,
LayoutTag,
LayoutType,
MathOperation,
MathOperationTag,
OpcodeClass,
OpcodeClassNames,
OpcodeClassTag,
OperationKind,
ShortDataTypeNames,
ShortLayoutTypeNames,
SplitKMode,
StrideSupport,
StrideSupportTag,
SwizzlingFunctor,
SwizzlingFunctorTag,
get_complex_from_real,
)
from cutlass.backend.arguments import ArgumentBase
from cutlass.backend.c_types import dim3_, get_conv2d_arguments
from cutlass.backend.library import (
EmissionType,
TensorDescription,
TileDescription,
get_complex_from_real,
)
from cutlass.backend.memory_manager import device_mem_alloc
from cutlass.backend.operation import ExecutableOperation, LaunchConfiguration
from cutlass.backend.tensor_ref import TensorRef
from cutlass.backend.utils.datatypes import to_device_ptr
from cutlass.backend.utils.software import CheckPackages, SubstituteTemplate
from cutlass.shape import GemmCoord
if CheckPackages().check_torch():
import torch
# @typechecked
class Conv2dArguments(ArgumentBase):
"""
Argument wrapper for Conv2d. It encodes problem information and
@ -81,7 +87,7 @@ class Conv2dArguments(ArgumentBase):
:param operation: the Conv2d operation to take the argument
:type operation: :class:`cutlass.backend.Conv2dOperation`
:param problem_size: the Conv2d problem size
:type problem_size: :class:`cutlass_bindings.conv.Conv2dProblemSize`
:type problem_size: :class:`cutlass.shape.Conv2dProblemSize`
:param A: tensor A
:type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
:param B: tensor B
@ -90,135 +96,70 @@ class Conv2dArguments(ArgumentBase):
:type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
:param D: tensor D
:type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
:param split_k_mode: conv2d split K mode, defaults to cutlass_bindings.conv.SplitKMode.Serial
:type split_k_mode: cutlass_bindings.conv.SplitKMode, optional
:param split_k_mode: conv2d split K mode, defaults to cutlass_library.library.SplitKMode.Serial
:type split_k_mode: cutlass_library.library.SplitKMode, optional
:param output_op: output operator, optional
:type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments`
"""
def __init__(
self,
operation: "Conv2dOperation",
problem_size: "cutlass_bindings.conv.Conv2dProblemSize",
A: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]",
B: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]",
C: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]",
D: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]",
split_k_mode: "cutlass_bindings.conv.SplitKMode" = cutlass_bindings.conv.SplitKMode.Serial,
**kwargs,
) -> None:
def __init__(self, operation, problem_size, A, B, C, D,
split_k_mode=SplitKMode.Serial, **kwargs, ) -> None:
self.operation = operation
#: convolution kind
self.conv_kind: cutlass_bindings.conv.Operator = operation.conv_kind
self.layout_A: cutlass_bindings.layout = operation.A.layout
self.layout_B: cutlass_bindings.layout = operation.B.layout
self.layout_C: cutlass_bindings.layout = operation.C.layout
self.conv_kind = operation.conv_kind
self.layout_A = operation.A.layout
self.layout_B = operation.B.layout
self.layout_C = operation.C.layout
self.element_A = operation.A.element
self.element_B = operation.B.element
self.element_C = operation.C.element
if self.layout_C == cutlass_bindings.TensorNC32HW32:
B = self.reorder_tensor_B(B, problem_size)
if self.layout_C == LayoutType.TensorNC32HW32:
raise Exception("Layout type TensorNC32HW32 is not currently supported")
super().__init__(A, B, C, D, **kwargs)
# preprocessing output ops
if "split_k_slices" in kwargs.keys() and kwargs["split_k_slices"] > 1:
self.split_k_mode = split_k_mode
self.split_k_slices = kwargs["split_k_slices"]
else:
self.split_k_mode = cutlass_bindings.conv.SplitKMode.Serial
self.split_k_mode = SplitKMode.Serial
self.split_k_slices = 1
if "output_op" in kwargs.keys() and self.split_k_mode != cutlass_bindings.conv.SplitKMode.Parallel:
if "output_op" in kwargs.keys() and self.split_k_mode != SplitKMode.Parallel:
self.output_op = kwargs["output_op"]
else:
self.output_op = self.operation.epilogue_type(1.0, 0.0)
#: problem_size
self.problem_size: cutlass_bindings.conv.Conv2dProblemSize = problem_size
self.problem_size = problem_size
self.problem_size.split_k_slices = self.split_k_slices
if hasattr(self, "tensor_c_numel"):
c_coord = cutlass_bindings.conv.implicit_gemm_tensor_c_extent(
self.conv_kind, problem_size)
if self.tensor_c_numel == c_coord.at(3) and self.tensor_c_numel < c_coord.size():
self.bias = True
#
# initialize the argument
#
self.initialize()
# @typechecked
def reorder_tensor_B(self, tensor_B: "np.ndarray",
problem_size: "cutlass_bindings.conv.Conv2dProblemSize"):
"""
Reorder tensor_B for interleaved layout
:param tensor_B: input tensor B
:type tensor_B: numpy.ndarray
:param problem_size: Conv2d problem size
:type problem_size: :class:`cutlass_bindings.conv.Conv2dProblemSize`
:return: reordered tensor B
:rtype: numpy.ndarray
"""
reordered_tensor_B = np.empty_like(tensor_B)
tensor_ref_B = self.get_tensor_ref(
tensor_B, self.element_B, self.layout_B, problem_size, "b")
reordered_tensor_ref_B = self.get_tensor_ref(
reordered_tensor_B, self.element_B, self.layout_B, problem_size, "b")
cutlass_bindings.conv.host.reorder_convK(
reordered_tensor_ref_B, tensor_ref_B, self.conv_kind, problem_size)
return reordered_tensor_B
def get_tensor_ref(
self, tensor, dtype, tensor_layout, problem_size, operand):
if operand == "a":
tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_a_extent(
self.conv_kind, problem_size)
elif operand == "b":
tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_b_extent(
self.conv_kind, problem_size)
elif operand in ["c", "d"]:
tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_c_extent(
self.conv_kind, problem_size)
else:
raise ValueError("unknown operand: " + operand)
# Zero stride trick
if operand == "c" and self.bias:
tensor_coord = cutlass_bindings.Tensor4DCoord(0, 0, 0, 0)
layout = tensor_layout.packed(tensor_coord)
return TensorRef(tensor, dtype, layout).tensor_ref
def get_arguments(self, semaphore):
ref_A = TensorRef_(self.get_tensor_ref(
self.ptr_A, self.element_A, self.layout_A, self.problem_size, "a"))
ref_B = TensorRef_(self.get_tensor_ref(
self.ptr_B, self.element_B, self.layout_B, self.problem_size, "b"))
ref_C = TensorRef_(self.get_tensor_ref(
self.ptr_C, self.element_C, self.layout_C, self.problem_size, "c"))
ref_D = TensorRef_(self.get_tensor_ref(
self.ptr_D, self.element_C, self.layout_C, self.problem_size, "d"))
def get_arguments(self):
tc_numel = -1
if hasattr(self, "tensor_c_numel"):
tc_numel = self.tensor_c_numel
self.c_arguments = self.operation.argument_type(
Conv2DProblemSize(self.problem_size),
ref_A, ref_B, ref_C, ref_D, self.output_op, self.split_k_mode)
self.semaphore = semaphore
int(self.conv_kind),
self.problem_size.ctype,
int(to_device_ptr(self.ptr_A)),
int(to_device_ptr(self.ptr_B)),
int(to_device_ptr(self.ptr_C)),
int(to_device_ptr(self.ptr_D)),
tc_numel,
self.output_op,
int(self.split_k_mode)
)
def initialize(self):
# Get launch configuration
self.launch_config = self.operation.rt_module.plan(self)
# Allocate and initialize device workspace
device_workspace_size = self.operation.rt_module.get_device_workspace_size(self)
self.get_arguments()
# Allocate and initialize device workspace
device_workspace_size = self.operation.rt_module.get_workspace_size(self.c_arguments)
if device_workspace_size > 0:
self.workspace_buffer = device_mem_alloc(device_workspace_size)
workspace_ptr = self.workspace_buffer.ptr
@ -227,19 +168,16 @@ class Conv2dArguments(ArgumentBase):
else:
workspace_ptr = None
# Get kernel params as a bytearray
semaphore = 0
if (workspace_ptr is not None
and self.split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel):
self.semaphore = 0
if workspace_ptr is not None and self.split_k_mode == SplitKMode.Parallel:
self.ptr_D = workspace_ptr
elif (workspace_ptr is not None
and self.split_k_mode == cutlass_bindings.conv.SplitKMode.Serial):
semaphore = workspace_ptr
self.get_arguments(semaphore)
# Reset arguments now that ptr_D has been updated
self.get_arguments()
elif workspace_ptr is not None and self.split_k_mode == SplitKMode.Serial:
self.semaphore = workspace_ptr
params_ = self.operation.rt_module.get_args(
ctypes.byref(self.c_arguments), ctypes.c_void_p(int(self.semaphore)))
self.c_arguments, ctypes.c_void_p(int(self.semaphore)))
self.host_workspace = bytearray(params_.contents)
self.device_workspace = None
@ -251,7 +189,6 @@ class Conv2dArguments(ArgumentBase):
return super().sync()
# @typechecked
class Conv2dRT(ExecutableOperation):
"""
Conv2dRT manages the CUTLASS runtime components
@ -287,24 +224,104 @@ extern "C" {
return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
}
using ElementA = typename ${operation_name}_base::ElementA;
using ElementB = typename ${operation_name}_base::ElementB;
using ElementC = typename ${operation_name}_base::ElementC;
using LayoutA = typename ${operation_name}_base::LayoutA;
using LayoutB = typename ${operation_name}_base::LayoutB;
using LayoutC = typename ${operation_name}_base::LayoutC;
using EpilogueOutputOp = typename ${operation_name}_base::EpilogueOutputOp;
struct ${operation_name}_TemporaryArgs {
int conv_kind;
cutlass::conv::Conv2dProblemSize problem_size;
ElementA* ptr_A;
ElementB* ptr_B;
ElementC* ptr_C;
ElementC* ptr_D;
int tensor_c_numel;
typename EpilogueOutputOp::Params epilogue_params;
int split_k_mode;
};
typename ${operation_name}${operation_suffix}::Arguments
construct_arguments(${operation_name}_TemporaryArgs args) {
cutlass::conv::Operator conv_operator = static_cast<cutlass::conv::Operator>(args.conv_kind);
auto tc_A = cutlass::conv::implicit_gemm_tensor_a_extent(conv_operator, args.problem_size);
auto tc_B = cutlass::conv::implicit_gemm_tensor_b_extent(conv_operator, args.problem_size);
auto tc_C = cutlass::conv::implicit_gemm_tensor_c_extent(conv_operator, args.problem_size);
auto tc_D = cutlass::conv::implicit_gemm_tensor_c_extent(conv_operator, args.problem_size);
auto size_C = tc_C.at(0) * tc_C.at(1) * tc_C.at(2) * tc_C.at(3);
if (args.tensor_c_numel >= 0 && args.tensor_c_numel == tc_C.at(3) && args.tensor_c_numel < size_C) {
// C is interpreted as bias
tc_C = {0, 0, 0, 0};
}
cutlass::TensorRef<ElementA, LayoutA> tref_A(args.ptr_A, LayoutA::packed(tc_A));
cutlass::TensorRef<ElementB, LayoutA> tref_B(args.ptr_B, LayoutB::packed(tc_B));
cutlass::TensorRef<ElementC, LayoutA> tref_C(args.ptr_C, LayoutC::packed(tc_C));
cutlass::TensorRef<ElementC, LayoutA> tref_D(args.ptr_D, LayoutC::packed(tc_D));
return {
args.problem_size,
tref_A,
tref_B,
tref_C,
tref_D,
args.epilogue_params,
static_cast<cutlass::conv::SplitKMode>(args.split_k_mode)
};
}
// Get the params as byte array
char* ${operation_name}_get_params(${operation_name}${operation_suffix}::Arguments* arguments, int *semaphore=nullptr){
char* ${operation_name}_get_params(${operation_name}_TemporaryArgs args, int *semaphore=nullptr) {
auto arguments = construct_arguments(args);
typename ${operation_name}${operation_suffix}::Params* params;
params = new ${operation_name}${operation_suffix}::Params(*arguments, semaphore);
params = new ${operation_name}${operation_suffix}::Params(arguments, semaphore);
char *bytes = ((char*)(params));
char *output = new char[sizeof(${operation_name}${operation_suffix}::Params)];
for (unsigned int i = 0; i < sizeof(${operation_name}${operation_suffix}::Params); i ++)
output[i] = bytes[i];
output[i] = bytes[i];
return output;
}
dim3 ${operation_name}_get_grid_shape(
int conv_kind,
cutlass::conv::Conv2dProblemSize problem_size,
cutlass::gemm::GemmCoord tile_size,
int split_k_slices
) {
using Swizzle = typename ${operation_name}_base::ThreadblockSwizzle;
auto tiled_shape = Swizzle::get_tiled_shape(
static_cast<cutlass::conv::Operator>(conv_kind),
problem_size,
tile_size,
split_k_slices);
return Swizzle::get_grid_shape(tiled_shape);
}
size_t ${operation_name}_get_workspace_size(${operation_name}_TemporaryArgs args) {
auto arguments = construct_arguments(args);
// Temporarily define device::-level Conv2d so that we can call get_workspace_size
using DeviceConv = cutlass::conv::device::ImplicitGemmConvolution<${operation_name}_base>;
return DeviceConv::get_workspace_size(arguments);
}
}
"""
def __init__(self, operation: "Conv2dOperation"):
super().__init__(operation)
self.extra_funcs = {
"get_grid_shape": dim3_,
"get_workspace_size": ctypes.c_uint64
}
self.argument_type, self.epilogue_type = get_conv2d_arguments(operation.epilogue_functor)
self.argtype = [ctypes.POINTER(self.argument_type), ctypes.c_void_p]
self.conv_kind = operation.conv_kind
@ -313,47 +330,27 @@ extern "C" {
self.emitter = EmitConv2dInstance("_type")
self.threads: int = operation.tile_description.num_threads
self.threads = operation.tile_description.num_threads
self.swizzle_functor = operation.swizzling_functor
def emit(self):
return self.emitter.emit(self.operation)
def get_device_workspace_size(self, arguments: Conv2dArguments):
workspace_bytes = 0
launch_config = arguments.launch_config
self.conv_kind = self.operation.conv_kind
if arguments.split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel:
problem_size = arguments.problem_size
workspace_bytes = DataTypeSize[self.operation.C.element] \
* launch_config.grid[2] * cutlass_bindings.conv.implicit_gemm_tensor_c_size(
self.conv_kind, problem_size
) // 8
elif arguments.split_k_mode == cutlass_bindings.conv.SplitKMode.Serial and \
arguments.split_k_slices > 1:
workspace_bytes = launch_config.grid[0] * launch_config.grid[1] * 4
return workspace_bytes
# @typechecked
def plan(self, arguments: Conv2dArguments):
tile_size = cutlass_bindings.gemm.GemmCoord(
tile_size = GemmCoord(
self.operation.tile_description.threadblock_shape[0],
self.operation.tile_description.threadblock_shape[1],
self.operation.tile_description.threadblock_shape[2],
)
grid = self.swizzle_functor.get_grid_shape(
self.swizzle_functor.get_tiled_shape(
self.conv_kind, arguments.problem_size,
tile_size, arguments.split_k_slices
)
grid = self.get_grid_shape(
int(self.conv_kind),
arguments.problem_size.ctype,
tile_size.ctype,
arguments.split_k_slices
)
return LaunchConfiguration(
[grid.x, grid.y, grid.z], [self.threads, 1, 1],
self.shared_memory_capacity)
@ -364,7 +361,7 @@ extern "C" {
attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
value=self.shared_memory_capacity)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("Cuda Error: {}".format(err))
raise RuntimeError(f"CUDA Error: {err}")
class Conv2dOperation:
@ -372,11 +369,11 @@ class Conv2dOperation:
CUTLASS Conv2d operation description.
:param conv_kind: convolution operator
:type conv_kind: :class:`cutlass_bindings.conv.Operator`
:type conv_kind: :class:`cutlass_library.library.ConvKind`
:param iterator_algorithm: Selects among several implementation
:param iterator_algorithm: Selects among several implementation
variants trading off performance with simplicity
:type iterator_algorithm: :class:`cutlass_bindings.conv.IteratorAlgorithm`
:type iterator_algorithm: :class:`cutlass_library.library.IteratorAlgorithm`
:param arch: GPU compute capability (sm_xx)
:type arch: int
@ -397,12 +394,11 @@ class Conv2dOperation:
:type D: :class:`cutlass.backend.TensorDescription`
:param element_epilogue: element type for computation in epilogue \
:type element_epilogue: cutlass_bindings.int8 | cutlass_bindings.int32 | cutlass_bindings.float16 | \
cutlass_bindings.bfloat16 | cutlass_bindings.float32 | cutlass_bindings.float64
:type element_epilogue: cutlass_library.library.DataType
:param stride_support: distinguish among partial specializations that \
accelerate certain problems where convolution stride is unit \
:type stride_support: :class:`cutlass_bindings.conv.StrideSupport`
:type stride_support: :class:`cutlass_library.library.StrideSupport`
:param epilogue_functor: convolution epilogue functor
:type epilogue_functor: :class:`EpilogueFunctor`
@ -411,8 +407,8 @@ class Conv2dOperation:
"""
def __init__(
self,
conv_kind: cutlass_bindings.conv.Operator,
iterator_algorithm: cutlass_bindings.conv.IteratorAlgorithm,
conv_kind,
iterator_algorithm,
arch: int,
tile_description: TileDescription,
A: TensorDescription,
@ -420,7 +416,7 @@ class Conv2dOperation:
C: TensorDescription,
stride_support,
epilogue_functor,
swizzling_functor=cutlass_bindings.IdentitySwizzle1,
swizzling_functor=SwizzlingFunctor.Identity1,
emission_type=EmissionType.Kernel,
**kwargs
):
@ -434,8 +430,8 @@ class Conv2dOperation:
self.epilogue_functor = epilogue_functor
self.iterator_algorithm = iterator_algorithm
self.stride_support = stride_support
self.swizzling_functor = swizzling_functor()
self.swizzling_functor = swizzling_functor
self.emission_type = emission_type
self.rt_module: Conv2dRT = Conv2dRT(self)
@ -458,7 +454,7 @@ class Conv2dOperation:
)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
raise RuntimeError(f"CUDA Error {err}")
return err
@ -470,8 +466,6 @@ class Conv2dOperation:
"""The full procedural name indicates architecture, extended name, tile size, and layout."""
return self.configuration_name()
#
def configuration_name(self):
"""The full procedural name indicates architecture, extended name, tile size, and layout."""
@ -503,7 +497,6 @@ class Conv2dOperation:
},
)
#
def extended_name(self):
"""Append data types if they differ from compute type."""
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
@ -523,17 +516,15 @@ class Conv2dOperation:
return extended_name
#
def layout_name(self):
return "%s" % (ShortLayoutTypeNames[self.A.layout])
#
def core_name(self):
"""The basic operation kind is prefixed with a letter indicating the accumulation type."""
intermediate_type = ""
if self.tile_description.math_instruction.opcode_class == cutlass_bindings.OpClass.TensorOp:
if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp:
inst_shape = "%dx%dx%d" % tuple(
self.tile_description.math_instruction.instruction_shape)
if self.tile_description.math_instruction.element_a != self.A.element and \
@ -550,7 +541,6 @@ class Conv2dOperation:
IteratorAlgorithmNames[self.iterator_algorithm]
)
#
def is_complex(self):
complex_operators = [
MathOperation.multiply_add_complex,
@ -558,7 +548,6 @@ class Conv2dOperation:
]
return self.tile_description.math_instruction.math_operation in complex_operators
#
def accumulator_type(self):
accum = self.tile_description.math_instruction.element_accumulator
@ -570,16 +559,17 @@ class Conv2dOperation:
def device_op(self):
"""
Returns a new Conv2dOperation object that is constructed with emission type
``EmissionType.Device``.
``EmissionType.Device``.
:return: operation ready for device-level code emission
:rtype: Conv2dOperation
"""
return Conv2dOperation(
self.conv_kind, self.iterator_algorithm, self.arch, self.tile_description,
self.A, self.B, self.C, self.stride_support, self.epilogue_functor, type(self.swizzling_functor),
self.A, self.B, self.C, self.stride_support, self.epilogue_functor, self.swizzling_functor,
emission_type=EmissionType.Device)
###################################################################################################
#
# Emits single instances of a CUTLASS device-wide operator
@ -594,17 +584,18 @@ class EmitConv2dInstance:
"cutlass/cutlass.h",
"cutlass/conv/kernel/default_conv2d_fprop.h",
"cutlass/conv/kernel/default_conv2d_dgrad.h",
"cutlass/conv/kernel/default_conv2d_wgrad.h"
"cutlass/conv/kernel/default_conv2d_wgrad.h",
"cutlass/conv/device/implicit_gemm_convolution.h"
]
self.template = """
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
using ${operation_name}_base =
using ${operation_name}_base =
typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
${element_a},
${element_a},
${layout_a},
${element_b},
${element_b},
${layout_b},
${element_c},
${element_c},
${layout_c},
${element_accumulator},
${opcode_class},
@ -631,11 +622,11 @@ struct ${operation_name}${operation_suffix}:
// Conv2d operation ${operation_name}
using Conv2d${conv_kind_name}Kernel = typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
${element_a},
${element_a},
${layout_a},
${element_b},
${element_b},
${layout_b},
${element_c},
${element_c},
${layout_c},
${element_accumulator},
${opcode_class},
@ -689,7 +680,7 @@ using DeviceKernel =
"instruction_shape_k": str(operation.tile_description.math_instruction.instruction_shape[2]),
"epilogue_vector_length": str(epilogue_vector_length),
"epilogue_functor": operation.epilogue_functor.emit(),
"swizzling_functor": operation.swizzling_functor.tag(),
"swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
"stages": str(operation.tile_description.stages),
"iterator_algorithm": IteratorAlgorithmTag[operation.iterator_algorithm],
"iterator_algorithm_name": IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(),
@ -698,7 +689,7 @@ using DeviceKernel =
"align_a": str(operation.A.alignment),
"align_b": str(operation.B.alignment),
}
if operation.emission_type == EmissionType.Kernel:
conv2d_template = self.template
else:

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
@ -30,7 +30,5 @@
#
################################################################################
from cutlass.backend.test.conv2d_testbed import *
from cutlass.backend.test.gemm_grouped_testbed import *
from cutlass.backend.test.gemm_testbed import *
from cutlass.backend.test.profiler import *
from cutlass.backend.evt.epilogue import EpilogueFunctorVisitor
from cutlass.backend.evt.frontend import PythonASTFrontend

View File

@ -0,0 +1,36 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass.backend.evt.backend.sm80_emitter import Sm80Emitter
import cutlass.backend.evt.backend.sm80_nodes as sm80_nodes
from cutlass.backend.evt.backend.sm90_emitter import Sm90Emitter
import cutlass.backend.evt.backend.sm90_nodes as sm90_nodes

View File

@ -0,0 +1,158 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Base class for Epilogue Visitor Emitter
"""
from cutlass import DataTypeTag
from cutlass.backend.evt.ir import TopoVisitorNode, DAGIR
class FusionCallbacks:
def __init__(self, dag_ir: DAGIR, cc: int, emit_CD=True) -> None:
"""
Emit the EVT fusion callbacks
:param dag_ir: the DAG IR holding the epilogue visitor
:param cc: compute capability
:param emit_CD: whether to emit nodes C & D as a part of the fusion callbacks
For Sm90, set emit_CD=False, as Tensor C & D are hardcoded in the collective API
so that their shared memory can be explicitly reused
For Sm89, set emit_CD=True as they are treated as normal AuxLoad & AuxStore nodes.
"""
self.dag_ir = dag_ir
self.emit_CD = emit_CD
self.cc = cc
if self.cc < 90:
self.namespace = "threadblock"
else:
self.namespace = "fusion"
#
# Helper functions
#
def get_visitor_name(self, node: str):
"""
Get the visitor name
"""
meta = self.dag_ir.get_node_meta(node)
if not isinstance(meta, TopoVisitorNode) and self.dag_ir.in_degree(node) > 0:
return f"EVT{meta.name_camel}"
else:
return meta.name_camel
def emit(self):
node_metas = self.dag_ir.node_metas_topological_order()
epilogue_str = ""
# Step 1: emit individual node type decl
# emit the EVT & DAG connector
for meta in node_metas:
if not meta.disabled:
epilogue_str += self.emit_node(meta)
if not self.emit_CD and meta.name == "D":
continue
if isinstance(meta, TopoVisitorNode):
epilogue_str += self.emit_dag(meta)
else:
epilogue_str += self.emit_evt(meta)
# Step 2: post-processing & get callback name
if not self.emit_CD:
if not self.dag_ir.has_node("C"):
epilogue_str += "using ElementC = void;\nusing StrideC = StrideD;\n"
output_node = self.dag_ir.get_all_inputs("D")[0]
# The callback is the src of node D
callback_name = self.get_visitor_name(output_node)
else:
# The callback is the last node in the topological order
callback_name = self.get_visitor_name(node_metas[-1].name)
return epilogue_str, callback_name
def emit_evt(self, node):
if self.dag_ir.in_degree(node.name) == 0:
return ""
evt_tmp = f"""
using EVT{node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.cc}EVT<
{node.name_camel},
"""
sorted_children = self.dag_ir.get_all_inputs(node.name)
evt_node_strs = [f" {self.get_visitor_name(child_name)}" for child_name in sorted_children]
evt_tmp += ",\n".join(evt_node_strs) + ">;\n"
return evt_tmp
def emit_dag(self, node):
subgraph = node.subgraph
subgraph_nodes = subgraph.nodes_topological_order()
# Emit the Edge Tuple
edge_tuples = "cute::tuple<\n"
for n in subgraph_nodes[:-1]:
in_edges = subgraph.in_edges(n)
edge_weights = [subgraph.get_edge_weight(edge[0], edge[1]) for edge in in_edges]
sorted_children = [edge[0] for _, edge in sorted(zip(edge_weights, in_edges))]
edge_tuple = " cute::seq<"
edge_str = [str(subgraph_nodes.index(child)) for child in sorted_children]
edge_tuple += ", ".join(edge_str) + ">,\n"
edge_tuples += edge_tuple
edge_tuples += " >"
# Emit the node list
dag_nodes = ""
dag_node_strs = []
for n in subgraph_nodes[:-1]:
n_meta = subgraph.get_node_meta(n)
if n_meta.disabled:
dag_node_strs.append(f" {self.get_visitor_name(n)}")
else:
dag_node_strs.append(f" {n_meta.name_camel}")
dag_nodes = ",\n".join(dag_node_strs)
return f"""
using {node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.cc}TopologicalVisitor<
{DataTypeTag[node.subgraph.element_compute]},
{edge_tuples},
{dag_nodes}
>;
"""
def emit_node(self, node):
if isinstance(node, TopoVisitorNode):
emission = ""
for node in node.subgraph.node_metas_topological_order():
if not node.disabled:
emission += self.emit_node(node)
return emission
else:
return node.underlying_impl.type_decl

View File

@ -0,0 +1,47 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Emitter for Sm80 Epilogue Visitor
"""
from cutlass.backend.evt.backend.emitter_base import FusionCallbacks
from cutlass.backend import GemmOperationUniversal
class Sm80Emitter:
def __init__(self, operation: GemmOperationUniversal, graph) -> None:
self.fusion_callbacks = FusionCallbacks(graph, cc=80)
def emit(self):
callback_decl, callback_name = self.fusion_callbacks.emit()
return callback_name, callback_decl

View File

@ -0,0 +1,258 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass import DataTypeTag
from cutlass.backend.evt.ir import (
# Load Node
AccumulatorImpl,
AuxLoadImpl,
ColumnBroadcastImpl,
LoadNode,
LoadSrcImpl,
RowBroadcastImpl,
ScalarBroadcastImpl,
# Compute Node
ComputeImpl,
# Store Node
AuxStoreImpl,
ColumnReductionImpl,
RowReductionImpl,
ScalarReductionImpl
)
from cutlass.backend.library import (
FloatRoundStyleTag,
FunctionalOp,
op_tag,
)
class Sm80AccumulatorImpl(AccumulatorImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""\nusing {self.name_camel} = cutlass::epilogue::threadblock::VisitorAccFetch;\n"""
return self._type_decl
class Sm80AuxLoadImpl(AuxLoadImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorAuxLoad<
OutputTileThreadMap, {DataTypeTag[self.element]}, {self.stride_mnl}
>;
"""
return self._type_decl
class Sm80LoadSrcImpl(Sm80AuxLoadImpl):
pass
class Sm80ScalarBroadcastImpl(ScalarBroadcastImpl):
def __init__(self, node: LoadNode) -> None:
super().__init__(node)
self.broadcast_count = 1
self.reduction_fn = FunctionalOp.Multiplies
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorScalarBroadcast<
{DataTypeTag[self.element]}, {self.stride_mnl}, {self.broadcast_count}, {op_tag(self.reduction_fn)}
>;
"""
return self._type_decl
class Sm80RowBroadcastImpl(RowBroadcastImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, {DataTypeTag[self.element]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm80ColumnBroadcastImpl(ColumnBroadcastImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorColBroadcast<
OutputTileThreadMap, {DataTypeTag[self.element]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm80ComputeImpl(ComputeImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorCompute<
{op_tag(self.fn)}, {DataTypeTag[self.element_output]}, {DataTypeTag[self.element_compute]},
{FloatRoundStyleTag[self.round_style]}
>;
"""
return self._type_decl
class Sm80AuxStoreImpl(AuxStoreImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, {DataTypeTag[self.element]}, {FloatRoundStyleTag[self.round_style]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm80StoreDImpl(Sm80AuxStoreImpl):
pass
class Sm80ColumnReductionImpl(ColumnReductionImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorColReduction<
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
OutputTileThreadMap, {DataTypeTag[self.element]},
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm80RowReductionImpl(RowReductionImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorRowReduction<
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
OutputTileThreadMap, {DataTypeTag[self.element]},
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm80ScalarReductionImpl(ScalarReductionImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorScalarReduction<
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
OutputTileThreadMap, {DataTypeTag[self.element]},
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
{self.stride_mnl}
>;
"""
return self._type_decl

View File

@ -0,0 +1,98 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Emitter for Sm90 Epilogue Visitor
"""
from cutlass import DataTypeTag, EpilogueScheduleTag
from cutlass.backend import GemmOperationUniversal
from cutlass.backend.evt.backend.emitter_base import FusionCallbacks
class CollectiveEpilogue:
def __init__(self, tile_description,
schedule,
element_c,
element_d,
fusion_callbacks) -> None:
self.cta_tile_mnk = tile_description.threadblock_shape
self.element_c = element_c
self.element_d = element_d
self.schedule = schedule
self.fusion_callbacks = fusion_callbacks
@property
def CtaTileMNK(self) -> str:
"""
The threadblock shape
"""
return f"cute::Shape<_{self.cta_tile_mnk[0]}, _{self.cta_tile_mnk[1]}, _{self.cta_tile_mnk[2]}>"
@property
def EpilogueTileType(self) -> str:
"""
The epilogue tile type
"""
return "cutlass::epilogue::collective::EpilogueTileAuto"
@property
def Schedule(self) -> str:
return EpilogueScheduleTag[self.schedule]
def emit(self):
callback_decl, callback_name = self.fusion_callbacks.emit()
return callback_name, f"""
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor<
{self.CtaTileMNK}, {self.EpilogueTileType},
{DataTypeTag[self.element_c]}, {DataTypeTag[self.element_d]},
{self.Schedule}
>;
{callback_decl}
"""
class Sm90Emitter:
def __init__(self, operation: GemmOperationUniversal, graph) -> None:
fusion_callbacks = FusionCallbacks(graph, cc=90, emit_CD=False)
self.collective_epilogue = CollectiveEpilogue(
tile_description=operation.tile_description,
schedule=operation.tile_description.epilogue_schedule,
element_c=operation.C.element,
element_d=operation.C.element,
fusion_callbacks=fusion_callbacks
)
def emit(self):
return self.collective_epilogue.emit()

View File

@ -0,0 +1,351 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from pycute import product
from cutlass import DataTypeSize, DataTypeTag
from cutlass.backend.evt.ir import (
# Load Node
AccumulatorImpl,
AuxLoadImpl,
ColumnBroadcastImpl,
LoadNode,
LoadSrcImpl,
RowBroadcastImpl,
ScalarBroadcastImpl,
# Compute Node
ComputeImpl,
ComputeNode,
# Store Node
AuxStoreImpl,
ColumnReductionImpl,
RowReductionImpl,
ScalarReductionImpl,
StoreNode,
StoreDImpl,
)
from cutlass.backend.library import (
FloatRoundStyleTag,
FunctionalOp,
op_tag,
)
class Sm90AccumulatorImpl(AccumulatorImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""\nusing {self.name_camel} = cutlass::epilogue::fusion::Sm90AccFetch;\n"""
return self._type_decl
class Sm90LoadSrcImpl(LoadSrcImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using ElementC = {DataTypeTag[self.element]};
using StrideC = {self.stride_mnl};
using {self.name_camel} = cutlass::epilogue::fusion::Sm90SrcFetch;
"""
return self._type_decl
class Sm90AuxLoadImpl(AuxLoadImpl):
@property
def descriptor(self) -> str:
"""
Descriptor for Aux Load
"""
return f"{self.name_camel}Descriptor"
def decl_descriptor(self) -> str:
"""
Declare the descriptor type
"""
return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::AuxLoadDescriptor<EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}>;\n"
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = self.decl_descriptor()
self._type_decl += f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxLoad<
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
{self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, typename {self.descriptor}::CopyOpS2R
>;
"""
return self._type_decl
def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
"""
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
"""
return (DataTypeSize[self.element] * stages_c * product(epilogue_tile_mn) // 8, 128)
class Sm90ScalarBroadcastImpl(ScalarBroadcastImpl):
def __init__(self, node: LoadNode) -> None:
super().__init__(node)
self.broadcast_count = 1
self.reduction_fn = FunctionalOp.Multiplies
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ScalarBroadcast<
{DataTypeTag[self.element]}, {self.stride_mnl}, {self.broadcast_count}, {op_tag(self.reduction_fn)}
>;
"""
return self._type_decl
class Sm90RowBroadcastImpl(RowBroadcastImpl):
@property
def descriptor(self) -> str:
"""
Descriptor for Aux Load
"""
return f"{self.name_camel}Descriptor"
def decl_descriptor(self) -> str:
"""
Declare the descriptor type
"""
return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::RowBroadcastDescriptor<EpilogueDescriptor, {DataTypeTag[self.element]}>;\n"
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = self.decl_descriptor()
self._type_decl += f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowBroadcast<
{self.descriptor}::Stages, typename EpilogueDescriptor::TileShape,
typename {self.descriptor}::Element, {self.stride_mnl}
>;
"""
return self._type_decl
def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
"""
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
"""
stages = (stages_c + epi_tiles - 1) // epi_tiles + 1
return (DataTypeSize[self.element] * cta_tile_mnk[1] * stages // 8, 16)
class Sm90ColumnBroadcastImpl(ColumnBroadcastImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ColBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm90ComputeImpl(ComputeImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90Compute<
{op_tag(self.fn)}, {DataTypeTag[self.element_output]}, {DataTypeTag[self.element_compute]},
{FloatRoundStyleTag[self.round_style]}
>;
"""
return self._type_decl
class Sm90AuxStoreImpl(AuxStoreImpl):
@property
def descriptor(self) -> str:
"""
Descriptor for Aux Load
"""
return f"{self.name_camel}Descriptor"
def decl_descriptor(self) -> str:
"""
Declare the descriptor type
"""
return f"""
using {self.descriptor} = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}
>;
"""
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = self.decl_descriptor()
self._type_decl += f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxStore<
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
{FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom,
typename {self.descriptor}::CopyOpR2S
>;
"""
return self._type_decl
def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
"""
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
"""
return (DataTypeSize[self.element] * stages_d * product(epilogue_tile_mn) // 8, 128)
class Sm90StoreDImpl(StoreDImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
return f"""
using ElementD = {DataTypeTag[self.element]};
using StrideD = {self.stride_mnl};
"""
class Sm90ColumnReductionImpl(ColumnReductionImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ColReduction<
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, 0,
typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]},
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm90RowReductionImpl(RowReductionImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowReduction<
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, 0 /* Stages */,
typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]},
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm90ScalarReductionImpl(ScalarReductionImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ScalarReduction<
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
{DataTypeTag[self.element]}, {DataTypeTag[self.element_compute]},
{FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}
>;
"""
return self._type_decl

View File

@ -0,0 +1,165 @@
################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
"""
Epilogue Visitor interface for compiling, and running visitor-based epilogue.
"""
import ctypes
from cuda import cuda
import numpy as np
from cutlass import DataType
from cutlass.backend.epilogue import EpilogueFunctorBase
import cutlass.backend.evt.backend
from cutlass.backend.frontend import TensorFrontend
class EpilogueFunctorVisitor(EpilogueFunctorBase):
"""
Apply an epilogue functor described by the epilogue EVT
:param cc: compute capability
:param visitor_frontend: user-provide visitor frontend
"""
def __init__(self, cc: int, visitor, element_compute=DataType.f32) -> None:
# Type of Emitter based on CC
self.emit_cls = getattr(cutlass.backend.evt.backend, f"Sm{cc}Emitter")
# Visitor Types
self.visitor = visitor
self.graph = visitor.dag_ir
# Data types
self.element_epilogue = element_compute # element compute
self.element_output = self.graph.get_node_meta('D').underlying_impl.element
# Epilogue Thread Type
epilogue_thread_type = self.visitor.epilogue_thread_type
if cc == 90:
self.arg_c_type = self.visitor.arg_c_type
self.arg_d_type = self.visitor.arg_d_type
output_names = self.visitor.return_names
reduction_names = self.visitor.reduction_names
# Epilogue stages specialized for sm80 kernel
if cc == 80:
if hasattr(self.visitor, "epilogue_stages"):
self.epilogue_stages = self.visitor.epilogue_stages
assert self.epilogue_stages <= 2, "Only supports Stages <=2 in SM80 Epilogue"
# Epilogue Argument Type
class _Arguments(ctypes.Structure):
"""
Concepts:
class _EpilogueArguments(ctypes.Structure):
_fields_ = [
("epilogue", _Arguments), <- this class
("ptr_C", ctypes.c_void_p),
("stride_C", StrideBatched_),
("ptr_D", ctypes.c_void_p),
("stride_D", StrideBatched_)
]
"""
_fields_ = [
("output_op", epilogue_thread_type)
]
def __init__(self, kwargs: dict) -> None:
# The user-input kwargs is a dict of (name: tensors)
# We first convert all of them to device pointers
ptr_kwargs = {}
for key in kwargs.keys():
is_output = key in output_names and key not in reduction_names
ptr_kwargs[key] = self.get_tensor_ptr(key, kwargs, is_output)
# Initialize the thread arguments
self.output_op = epilogue_thread_type(ptr_kwargs)
def get_tensor_ptr(self, tensor_name, kwargs, is_output=False):
"""
Helper function for extracting device pointer
"""
# Skip the special tensors
if cc == 90:
if tensor_name in ["C", "D"]:
return 0
if tensor_name not in kwargs.keys():
raise ValueError(f"Tensor {tensor_name} is not provided.")
tensor = kwargs[tensor_name]
# For float scalar constant, directly return the value
if isinstance(tensor, float):
return tensor
# The tensor frontend returns a device buffer for np.ndarray
# and device ptr for other frontends
buffer_or_ptr = TensorFrontend.argument(tensor, is_output)
if isinstance(tensor, np.ndarray):
# Remember the host tensor for later synchronization
setattr(self, f"{tensor_name}_buffer", buffer_or_ptr)
setattr(self, f"{tensor_name}_host", tensor)
return int(buffer_or_ptr.ptr)
else:
return int(buffer_or_ptr)
def sync(self):
"""
Synchronize the results from device to host
"""
for name in output_names:
if hasattr(self, f"{name}_host"):
host_tensor = getattr(self, f"{name}_host")
tensor_ptr = getattr(self, f"{name}_buffer").ptr
(err,) = cuda.cuMemcpyDtoH(
host_tensor,
tensor_ptr,
host_tensor.size * host_tensor.itemsize,
)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
self.epilogue_type = _Arguments
def emit(self, operation):
"""
Emit the C++ code
"""
emitter = self.emit_cls(operation, self.graph)
return emitter.emit()
def get_smem_size(self, tile_description):
"""
Get the shared memory size in bytes
"""
return self.visitor.get_smem_size(tile_description)

View File

@ -0,0 +1,33 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass.backend.evt.frontend.python_ast import PythonASTFrontend

View File

@ -0,0 +1,262 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Base class for Python EVT Frontend
"""
from typing import Union
from cutlass import DataType
from cutlass.backend.evt.ir import (
ComputeNode,
DAGIR,
LayoutNode,
LoadNode,
StoreNode,
)
from cutlass.backend.evt.passes import (
EVTGraphDrawer,
EVTPassManager,
GetSmemSize,
PassDAG2Tree,
PassGetArgumentType,
PassGetImpl,
PassFixElementD,
PassLayoutManipulateElimination,
PassPreprocessRed,
PassShapeTypePropagation,
)
from cutlass.backend.utils import device_cc
from cutlass.epilogue.evt_ops import permute, reshape
from cutlass.utils.datatypes import library_type
class EVTFrontendBase:
layout_fns = {
"permute": permute,
"reshape": reshape
}
def __init__(self, element_compute=DataType.f32, cc=None, additional_passes=[], **kwargs) -> None:
self.cc = cc if cc else device_cc()
self.element_compute = library_type(element_compute)
self.dag_ir = DAGIR(self.element_compute, self.cc)
self.compute_cnt = 0
self.layout_cnt = 0
self.pass_manager = EVTPassManager(
self.dag_ir,
[
PassPreprocessRed,
PassGetArgumentType,
PassShapeTypePropagation,
PassLayoutManipulateElimination,
PassGetImpl,
PassDAG2Tree,
PassFixElementD
] + additional_passes)
if self.cc == 80:
self._epilogue_stages = 1
else:
self._epilogue_stages = None
@property
def epilogue_stages(self):
return self._epilogue_stages
@epilogue_stages.setter
def epilogue_stages(self, stages):
self._epilogue_stages = stages
def parse(self, *args, **kwargs):
raise NotImplementedError(f"The 'parse' function must be overloaded in frontend class")
def trace(self, *args, **kwargs):
# Parse the input
self.parse(*args, **kwargs)
# Run the passes
self.pass_manager()
# Set the epilogue type
self.epilogue_thread_type = self.dag_ir.epilogue_thread_type
if self.cc == 90:
self.arg_c_type = self.dag_ir.arg_c_type
self.arg_d_type = self.dag_ir.arg_d_type
self.reduction_names = self.dag_ir.reduction_names
#
# Helper functions for DAG IR manipulation
#
def add_node(self, node):
self.dag_ir.add_node(node)
def add_edge(self, src, tgt, weight=0):
self.dag_ir.add_edge(src, tgt, weight=weight)
def set_tensor(self, node_name, example):
"""
Add an example tensor to node {node_name} in the DAG IR
"""
meta = self.dag_ir.get_node_meta(node_name)
meta.tensor = {"tensor": example}
def set_store_tensor(self, node_name, example):
"""
Add an example tensor to node {node_name} in the DAG IR
"""
meta = self.dag_ir.get_node_meta(node_name)
meta.store_tensor = {"tensor": example}
def mark_output(self, node_name):
"""
Mark a store node as output
"""
meta = self.dag_ir.get_node_meta(node_name)
if not isinstance(meta, StoreNode):
raise ValueError(
f"Only StoreNodes can be marked as output. "
f"Got {type(meta).__name__}: {node_name}")
meta.is_output = True
# Add node with specific type
def add_load_node(self, name, example):
"""
Add a Load node to DAG IR
:param name: name of the loaded variable
:type name: str
:param example: example input
:type example: np.ndarray|torch.Tensor|cupy.ndarray|float
"""
if name is None:
raise ValueError(f"Name is not provided.")
if example is None:
raise ValueError(f"Example input for {name} is not provided.")
load_node = LoadNode(name)
load_node.tensor = {"tensor": example}
# Special logics for accumulator
if name == "accum":
if load_node.tensor.rank == 2:
new_shape = tuple([1, ] + list(load_node.tensor.shape))
load_node.tensor.broadcast(new_shape)
elif load_node.tensor.rank < 2 or load_node.tensor.rank > 3:
raise ValueError(f"Expect example inputs for 'accum' be a rank-2 or rank-3 tensor. Got {load_node.tensor.shape}.")
self.add_node(load_node)
def add_imm(self, value: Union[float,int]):
"""
Add an immediate scalar value to DAG IR
:param value: the value of the immediate scalar
:type value: float
"""
try:
value = float(value)
except:
raise ValueError(f"{type(value).__name__} cannot be converted to float.")
name = f"imm_{value}".replace('.', '_')
load_node = LoadNode(name)
load_node.tensor = {"tensor": value, "is_constant": True}
self.add_node(load_node)
return name
def add_compute_node(self, op, name=None):
"""
Add a compute node.
:param op: the computation op
:param name: the node name (optional)
:type name: str
:return: the name of the compute node
"""
if name is None:
name = f"compute_{self.compute_cnt}"
self.compute_cnt += 1
compute_node = ComputeNode(
name=name, fn=op,
element_output=self.element_compute,
element_compute=self.element_compute)
self.add_node(compute_node)
return compute_node.name
def add_layout_node(self, op, kwargs, name=None):
"""
Add a layout node.
:param op: the layout op
:type op: evt_ops
:param name: the node name (optional)
:type name: str
:return: the name of the layout node
"""
if name is None:
name = f"layout_{self.layout_cnt}"
self.layout_cnt += 1
layout_node = LayoutNode(name=name, fn=op, kwargs=kwargs)
self.add_node(layout_node)
return layout_node.name
def add_store_node(self, name):
store_node = StoreNode(name)
self.add_node(store_node)
#
# Visualization The DAG IR
#
def visualize(self, name="dag_ir"):
"""
Visualize the dag ir with svg file
:param name: the name of the graph
"""
drawer = EVTGraphDrawer(self.dag_ir, name)
if drawer.dot_available:
for name, graph in drawer.get_dot_graph():
graph.write_svg(f"./{name}.svg")
else:
raise RuntimeError(
"'dot' is not found in path. GraphDrawer is disabled. "
"Please install it with 'sudo apt-get install graphviz'."
)
#
# Get shared memory size
#
def get_smem_size(self, tile_description):
"""
Get the shared memory size of the epilogue
"""
smem_size = GetSmemSize(self.dag_ir)(tile_description)
return smem_size

View File

@ -0,0 +1,184 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Python AST frontend that parses input into DAG IR
"""
import ast
import inspect
import textwrap
import cutlass
from cutlass import DataType
from cutlass.backend.evt.frontend.frontend_base import EVTFrontendBase
from cutlass.backend.epilogue import relu
from cutlass.backend.library import FunctionalOp
class PythonASTFrontend(EVTFrontendBase, ast.NodeVisitor):
def __init__(self, element_compute=DataType.f32, **kwargs):
super().__init__(element_compute, **kwargs)
# Flags
# If this state is True, visit_Constant returns values without creating imm node
self.no_imm = False
self.visiting_return = False
def parse(self, example_inputs):
self.example_inputs = example_inputs
self.source = textwrap.dedent(inspect.getsource(self.__call__))
self.ast = ast.parse(self.source)
self.visit(self.ast)
#
# Helper functions
#
@staticmethod
def ast_op_to_bindings(op):
mapping = {
ast.Add: FunctionalOp.Plus,
ast.Sub: FunctionalOp.Minus,
ast.Mult: FunctionalOp.Multiplies,
ast.Div: FunctionalOp.Divides,
"relu": relu.binding_type,
"multiply_add": FunctionalOp.MultiplyAdd,
"sum": (FunctionalOp.Plus, FunctionalOp.AtomicAdd),
"max": (FunctionalOp.Maximum, FunctionalOp.AtomicMaximum)
}
return mapping[op]
#
# Visiting different node types
#
def visit_FunctionDef(self, node: ast.FunctionDef):
# Visit args and register load nodes
for arg in node.args.args:
self.visit(arg)
for expr in node.body:
self.visit(expr)
def visit_arg(self, node: ast.arg):
# Name of the argument
name = node.arg
try:
example_tensor = self.example_inputs[name]
except:
raise RuntimeError(f"Example input for {name} is not provided.")
self.add_load_node(name, example_tensor)
def visit_Name(self, node: ast.Name):
return node.id
def visit_Constant(self, node: ast.Constant):
if self.no_imm:
return node.value
else:
name = self.add_imm(node.value)
return name
def visit_Tuple(self, node: ast.Tuple):
results = []
for elt in node.elts:
results.append(self.visit(elt))
return tuple(results)
def visit_keyword(self, node: ast.keyword):
return {node.arg: self.visit(node.value)}
def visit_BinOp(self, node: ast.BinOp):
if self.visiting_return:
raise SyntaxError("Return value cannot be an expression")
lhs = self.visit(node.left)
rhs = self.visit(node.right)
op = self.ast_op_to_bindings(type(node.op))
name = self.add_compute_node(op)
# Add edges
# The edge weights are used to sort the input args
self.add_edge(lhs, name, weight=0)
self.add_edge(rhs, name, weight=1)
return name
def visit_Assign(self, node: ast.BinOp):
target = self.visit(node.targets[0])
value = self.visit(node.value)
# Create the assign node
self.add_store_node(target)
# Add edges
self.add_edge(value, target)
return target
def visit_Call(self, node: ast.Call):
if self.visiting_return:
raise SyntaxError("Return value cannot be an expression")
func = self.visit(node.func)
args = [self.visit(arg) for arg in node.args]
if func in self.layout_fns.keys():
# Parse kwargs
# By default, visiting imm automatically creates a load node
# However, in function call, keyword args are used to set
# specific function attributes such as indices for permute
# So no_imm is set to True temporarily
self.no_imm = True
kwargs = {}
for kw in node.keywords:
kwargs.update(self.visit(kw))
self.no_imm = False
op = self.layout_fns[func]
name = self.add_layout_node(op, kwargs)
else:
op = self.ast_op_to_bindings(func)
name = self.add_compute_node(op)
# Add edges
for idx, arg in enumerate(args):
self.add_edge(arg, name, weight=idx)
return name
def visit_Return(self, node: ast.Return):
self.visiting_return = True
results = self.visit(node.value)
self.visiting_return = False
self.return_names = results
if not isinstance(results, tuple):
results = (results,)
for rst in results:
try:
example_tensor = self.example_inputs[rst]
except:
raise RuntimeError(f"Example input for {rst} is not provided.")
self.set_store_tensor(rst, example_tensor)
self.mark_output(rst)

View File

@ -0,0 +1,53 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass.backend.evt.ir.compute_nodes import ComputeNode, ComputeImpl
from cutlass.backend.evt.ir.dag_ir import DAGIR
from cutlass.backend.evt.ir.layout_nodes import LayoutNode
from cutlass.backend.evt.ir.load_nodes import (
LoadNode,
AccumulatorImpl,
LoadSrcImpl,
AuxLoadImpl,
RowBroadcastImpl,
ColumnBroadcastImpl,
ScalarBroadcastImpl
)
from cutlass.backend.evt.ir.node import TopoVisitorNode, NoOpImpl
from cutlass.backend.evt.ir.store_nodes import (
StoreNode,
StoreDImpl,
AuxStoreImpl,
ColumnReductionImpl,
RowReductionImpl,
ScalarReductionImpl
)

View File

@ -0,0 +1,91 @@
################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
"""
Python registration for compute nodes in EVT
"""
from cutlass.backend.evt.ir.node import NodeBase, ImplBase
from cutlass.backend.library import FloatRoundStyle
class ComputeImplBase(ImplBase):
"""
Base class for compute implementation
"""
def __init__(self, node) -> None:
super().__init__(node)
class ComputeImpl(ComputeImplBase):
"""
Implementation for Compute Node
"""
def __init__(self, node) -> None:
super().__init__(node)
self.fn = node.fn
self.element_output = node.element_output
self.element_compute = node.element_compute
self.round_style = node.round_style
@staticmethod
def match(node, problem_size: tuple):
return True
class ComputeNode(NodeBase):
"""
Compute Node in DAG IR
"""
possible_impls = [
ComputeImpl
]
def __init__(
self, name: str, fn, element_output,
element_compute,
round_style=FloatRoundStyle.ToNearest) -> None:
super().__init__(name)
self.op = "compute"
self.fn = fn
self.element_compute = element_compute
self.round_style = round_style
def type_propagation(self, *args, **kwargs):
"""
Load node loads tensor under type `tensor.element` and returns an array of type `tensor.element`.
"""
self.element = self.element_compute
# In general, the compute nodes have element_output = element_compute
# In certain cases like producer of D it is overwritten by other passes
if not hasattr(self, "element_output"):
self.element_output = self.element

View File

@ -0,0 +1,235 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
DAG IR used by Python EVT
"""
import networkx as nx
from cutlass import DataType
from cutlass.backend.evt.ir.node import NodeBase
from cutlass.backend.utils import device_cc
class DAGIR:
"""
``DAGIR`` is the main data structure used in the EVT Intermediate Representation.
It consists of a series of ``Node`` s, each representing epilogue visitor nodes.
In the DAGIR, ``node`` is an string of its name. ``node_meta`` is the underlying class of the node
"""
def __init__(self, element_compute=DataType.f32, cc: int=None) -> None:
# The EVT DAGIR is managed through the nextworkX Digraph class
self._graph = nx.DiGraph()
self.element_compute = element_compute
self.reduction_names = []
self.cc = cc if cc else device_cc()
#
# IR manipulator
#
def add_node(self, meta: NodeBase):
"""
Add a node to dag ir
"""
if self.has_node(meta.name):
raise SyntaxError(f"Variable '{meta.name}' cannot be defined twice.")
self._graph.add_node(meta.name, meta=meta)
def add_edge(self, src: str, dst: str, weight: int=0):
"""
Add an edge src -> dst to dag ir with weight
"""
if not self.has_node(src):
raise SyntaxError(f"Variable '{src}' is undefined.")
if not self.has_node(dst):
raise SyntaxError(f"Variable '{dst}' is undefined.")
self._graph.add_edge(src, dst, weight=weight)
def remove_node(self, node: str):
"""
Remove node from dag ir
"""
self._graph.remove_node(node)
def remove_edge(self, src: str, dst: str):
"""
Remove edge src -> dst
"""
self._graph.remove_edge(src, dst)
#
# Helper functions for getting attrs
#
def has_node(self, node: str) -> bool:
"""
Check if the node is in the graph
"""
return self._graph.has_node(node)
def in_degree(self, node: str):
"""
Get the input degree of node
"""
return self._graph.in_degree(node)
def in_edges(self, node: str):
"""
Get the input edges of node
"""
return [edge for edge in self._graph.in_edges(node)]
def out_degree(self, node: str):
"""
Get the output degree of node
"""
return self._graph.out_degree(node)
def out_edges(self, node: str):
"""
Get the output edges of node
"""
return [edge for edge in self._graph.out_edges(node)]
def get_node_meta(self, node: str):
"""
Get the meta data of the node
"""
return self._graph.nodes[node]["meta"]
def get_edge_weight(self, src, dst):
"""
Get the edge weight of edge src->dst
"""
return self._graph.get_edge_data(src, dst)["weight"]
#
# High-level helper functions
#
def all_reachable_nodes(self, node: str):
"""
Get all the nodes reachable from the current node (exclude)
"""
return list(nx.dfs_preorder_nodes(self._graph, source=node))
def get_users(self, node: str):
"""
Get all users of the current node
"""
return [edge[1] for edge in self.out_edges(node)]
def get_all_inputs(self, node: str):
"""
Get all the input nodes sorted by edge weight
"""
in_edges = self.in_edges(node)
edge_weights = [self.get_edge_weight(*edge) for edge in in_edges]
return [edge[0] for _, edge in sorted(zip(edge_weights, in_edges))]
def get_all_inputs_meta(self, node: str):
"""
Get all the input node metas sorted by edge weight
"""
return [self.get_node_meta(input_node) for input_node in self.get_all_inputs(node)]
def replace_all_uses_with(self, node1, node2):
"""
Replace all uses of node1 with node2
"""
for edge in self.out_edges(node1):
weight = self.get_edge_weight(*edge)
user = edge[1]
self.add_edge(node2, user, weight)
self.remove_edge(node1, user)
self.remove_node(node1)
#
# Node accessor
#
def nodes_topological_order(self):
"""
Get the nodes in the unique lexicographical topological order
It generates a unique ordering of nodes by first sorting topologically
and then additionally by sorting lexicographically.
Although topological_sort alone also works, this generates a unique key
for each epilogue visitor pattern and ensures the compilation cache can be reused.
:return: list[str]
"""
return list(nx.lexicographical_topological_sort(self._graph))
def node_metas_topological_order(self):
"""
Get the node metas in topological order
:return: list[NodeBase]
"""
return [self.get_node_meta(node) for node in self.nodes_topological_order()]
@property
def nodes(self):
"""
Get all nodes
:return: list[str]
"""
return list(self._graph.nodes)
@property
def nodes_meta(self):
"""
Get all node metas
:return: list[NodeBase]
"""
return [data[1]['meta'] for data in self._graph.nodes.data()]
@property
def edges(self):
"""
Get all edges
:return: list[(str, str)]
"""
return list(self._graph.edges)
#
# Path
#
def has_path(self, src: str, target: str) -> bool:
"""
Return True is a path exists from src to target
"""
return nx.has_path(self._graph, src, target)

View File

@ -0,0 +1,324 @@
################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
"""
Layout algebras
"""
from pycute import Layout, composition, make_layout, flatten, product
def _infer_split(old_shape, new_shape):
old_shape = _tuple_to_list(old_shape)
new_shape = _tuple_to_list(new_shape)
if len(old_shape) == 0 and len(new_shape) == 0:
return []
if len(old_shape) == 0:
if product(tuple(new_shape)) != 1:
raise ValueError("Invalid reshape size")
else:
return new_shape
if len(new_shape) == 0:
if product(tuple(old_shape)) != 1:
raise ValueError("Invalid reshape size")
else:
return old_shape
# This is done recursively by only process the last dimension at each time
old_dim = old_shape[-1]
new_dim = new_shape[-1]
# Exact match
if old_dim == new_dim:
return _infer_split(old_shape[:-1], new_shape[:-1]) + [new_dim,]
# Needs split
if old_dim > new_dim and old_dim % new_dim == 0:
residual = old_dim // new_dim
return _infer_split(old_shape[:-1] + [residual,], new_shape[:-1]) + [new_dim,]
# Needs merge
if old_dim < new_dim and new_dim % old_dim == 0:
residual = new_dim // old_dim
return _infer_split(old_shape[:-1], new_shape[:-1] + [residual,]) + [old_dim,]
raise NotImplementedError(f"Unsupported split: {old_shape} -> {new_shape}")
def _infer_merge(flatten_shape, shape):
flatten_shape = _tuple_to_list(flatten_shape)
shape = _tuple_to_list(shape)
idx_flat = 0
merged_shape = []
for dim in shape:
# Exact match
if dim == flatten_shape[idx_flat]:
merged_shape.append(dim)
idx_flat += 1
# Need group
elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0:
residual = dim
group = []
while(residual > 1):
group.append(flatten_shape[idx_flat])
residual = residual // flatten_shape[idx_flat]
idx_flat += 1
merged_shape.append(group)
else:
raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}")
return merged_shape
def _list_to_tuple(nested_list):
if isinstance(nested_list, list) or isinstance(nested_list, tuple):
return tuple(_list_to_tuple(item) for item in nested_list)
return nested_list
def _tuple_to_list(nested_tuple):
if isinstance(nested_tuple, list) or isinstance(nested_tuple, tuple):
return list(_tuple_to_list(item) for item in nested_tuple)
return nested_tuple
def _reverse_tuple(nested_tuple: tuple):
if isinstance(nested_tuple, tuple):
return tuple([_reverse_tuple(item) for item in nested_tuple][::-1])
return nested_tuple
def _get_first_lhs_nonzero_stride(stride_list, idx):
for i in reversed(range(idx)):
if stride_list[i] != 0:
return i
else:
return None
def _get_first_rhs_nonzero_stride(stride_list, idx):
for i in range(idx+1, len(stride_list)):
if stride_list[i] != 0:
return i
else:
return None
def reshape(layout, new_shape):
"""
General reshape of input layout.
It takes two steps:
1. split the dimensions of the old layout
2. merge the splitted dimensions according to the new shape
"""
#
# Step 1: Split the dimensions of the old layout
#
# 1.1 Flat old and new shape
old_flatten_shape = list(flatten(layout.shape))
new_flatten_shape = list(flatten(new_shape))
# 1.2 Infer the flatten splitted shape
splitted_flatten_shape = _infer_split(old_flatten_shape, new_flatten_shape)
# 1.3 Unflat the splitted shape based on the old shape
splited_shape = _infer_merge(splitted_flatten_shape, old_flatten_shape)
# 1.4 Infer the type of each split
# If the split type is in row-major (R), the dimension list is reversed because
# the cute::composition only support column-major split
split_type = [] # the type of each split (ColumnMajor or RowMajor)
permuted_splitted_shape = []
old_flatten_stride = list(flatten(layout.stride))
for idx, dim in enumerate(splited_shape):
if not isinstance(dim, list):
permuted_splitted_shape.append(dim)
split_type.append("C")
else:
lhs_stride = _get_first_lhs_nonzero_stride(old_flatten_stride, idx)
rhs_stride = _get_first_rhs_nonzero_stride(old_flatten_stride, idx)
# Special case for single tuple
# Use column-major by default
if lhs_stride is None and rhs_stride is None:
permuted_splitted_shape.append(dim)
split_type.append("C")
else:
if lhs_stride is not None and rhs_stride is not None:
# We consider shape[idx]:stride[idx]
# Case 1: stride[idx - 1] <= stride[idx] <= stride[idx + 1]: column major
if lhs_stride <= old_flatten_stride[idx] and old_flatten_stride[idx] <= rhs_stride:
permuted_splitted_shape.append(dim)
split_type.append("C")
# Case 2: stride[idx - 1] > stride[idx] > stride[idx + 1]: row major
elif lhs_stride > old_flatten_stride[idx] and old_flatten_stride[idx] > rhs_stride:
permuted_splitted_shape.append([d for d in reversed(dim)])
split_type.append("R")
# Case 3: stride[idx - 1] <= stride[idx] > stride[idx + 1]: concave
elif lhs_stride <= old_flatten_stride[idx] and old_flatten_stride[idx] > rhs_stride:
if lhs_stride >= rhs_stride:
permuted_splitted_shape.append(dim)
split_type.append("C")
else:
permuted_splitted_shape.append([d for d in reversed(dim)])
split_type.append("R")
# Case 4: stride[idx - 1] > stride[idx] <= stride[idx + 1]: concave
elif lhs_stride > old_flatten_stride[idx] and old_flatten_stride[idx] <= rhs_stride:
if lhs_stride >= rhs_stride:
permuted_splitted_shape.append(dim)
split_type.append("C")
else:
permuted_splitted_shape.append([d for d in reversed(dim)])
split_type.append("R")
else:
raise NotImplementedError()
elif lhs_stride is None:
# Case 1: dim's stride < dim+1's stride, expand in column major
if old_flatten_stride[idx] > rhs_stride:
permuted_splitted_shape.append([d for d in reversed(dim)])
split_type.append("R")
else:
permuted_splitted_shape.append(dim)
split_type.append("C")
else:
# Case 1: dim's stride > dim-1's stride
if old_flatten_stride[idx] < lhs_stride:
permuted_splitted_shape.append([d for d in reversed(dim)])
split_type.append("R")
else:
permuted_splitted_shape.append(dim)
split_type.append("C")
# 1.4 Generate the splitted layout
permuted_splitted_layout = composition(layout, Layout(_list_to_tuple(permuted_splitted_shape)))
# 1.5 Reverse the permutation in 1.4 before merge
splitted_shape = []
splitted_stride = []
for shape_dim, stride_dim, type in zip(
permuted_splitted_layout.shape,
permuted_splitted_layout.stride,
split_type):
if type == "C":
splitted_shape.append(shape_dim)
splitted_stride.append(stride_dim)
else:
splitted_shape.append(tuple([d for d in reversed(shape_dim)]))
splitted_stride.append(tuple([d for d in reversed(stride_dim)]))
splitted_layout = Layout(tuple(splitted_shape), tuple(splitted_stride))
#
# Step 2: Merge the splitted dimensions according to the new shape
#
# 2.1 Merge layout
merged_layout = composition(splitted_layout, Layout(new_shape))
# 2.2 Cleaning up
output_layout = composition(merged_layout, Layout(new_shape))
return output_layout
def permutation(layout, permutation):
"""
Permute the layout
"""
new_shape = tuple([layout.shape[idx] for idx in permutation])
new_stride = tuple([layout.stride[idx] for idx in permutation])
return Layout(new_shape, new_stride)
def _broadcast(layout, new_shape):
if len(layout) == 1 and isinstance(new_shape, int):
old_dim = layout.shape
old_stride = layout.stride
new_dim = new_shape
if old_dim == new_dim:
return Layout(old_dim, old_stride)
elif old_dim == 1:
return Layout(new_dim, 0)
else:
raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {new_dim}")
# Align the dimensions
old_shape = layout.shape
if isinstance(old_shape, int):
old_shape = (old_shape,)
sub_layouts = [layout,]
else:
sub_layouts = [sub_layout for sub_layout in layout]
rhs_broadcast_layouts = [Layout(1, 0)] * (len(new_shape) - len(old_shape))
# Get the broadcasted layout
broadcast_layouts = []
try:
layout = make_layout(*sub_layouts, *rhs_broadcast_layouts)
broadcast_layouts = []
for idx, sub_layout in enumerate(layout):
broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx]))
except NotImplementedError:
layout = make_layout(*rhs_broadcast_layouts, *sub_layouts)
for idx, sub_layout in enumerate(layout):
broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx]))
return make_layout(*broadcast_layouts)
def broadcast(layout, new_shape):
"""
Broadcast the new layout based on the input shape
The broadcasted shape equals to the new shape
The stride of broadcasted dimensions are 0
"""
return _broadcast(layout, new_shape)
def debroadcast(layout, dims):
"""
Squeeze the 0-stride
"""
for dim in dims:
if layout.stride[dim] != 0:
raise ValueError(f"Dim{dim} cannot be debroadcasted as it has stride {layout.stride[dim]}")
new_shape = tuple([s for idx, s in enumerate(layout.shape) if idx not in dims])
new_stride = tuple([s for idx, s in enumerate(layout.stride) if idx not in dims])
return Layout(new_shape, new_stride)
def canonicalization_(shapes, strides):
if isinstance(shapes, tuple):
c_shapes = []
c_strides = []
for shape, stride in zip(shapes, strides):
c_shape, c_stride = canonicalization_(shape, stride)
c_shapes.append(c_shape)
c_strides.append(c_stride)
return tuple(c_shapes), tuple(c_strides)
else:
if shapes == 1:
return 1, 0
else:
return shapes, strides
def canonicalization(layout):
"""
Canonicalize the input layout
1. set the stride of shape "1" to 0
"""
new_shape, new_stride = canonicalization_(layout.shape, layout.stride)
return Layout(new_shape, new_stride)

View File

@ -0,0 +1,336 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Layout manipulation nodes and implementations
The layout Nodes change the layout of intermediate nodes in epilogue visitor graph
"""
from copy import deepcopy
from pycute import product, flatten
import cutlass
from cutlass import LayoutType
from cutlass.backend.evt.ir.layout_algorithm import _list_to_tuple, _tuple_to_list
from cutlass.backend.evt.ir.node import NodeBase
from cutlass.backend.evt.ir.tensor import Tensor
class PermutationImpl:
"""
Detailed implementation and helper functions for permutation
"""
def __init__(self, node) -> None:
assert "indices" in node.kwargs.keys()
self.indices = list(node.kwargs["indices"])
self.inverse_indices = self.get_inverse_indices(self.indices)
def get_inverse_impl(self):
inverse_impl = deepcopy(self)
inverse_impl.indices = self.inverse_indices
inverse_impl.inverse_indices = self.indices
return inverse_impl
def update(self, shape):
num_dim = len(shape)
indices = self.indices
num_old_dim = len(indices)
# Add offset
for i, idx in enumerate(indices):
indices[i] = idx + num_dim - num_old_dim
# Add broadcast dims
for i in range(num_dim - num_old_dim):
indices = [i,] + indices
self.indices = indices
self.inverse_indices = self.get_inverse_indices(self.indices)
def get_inverse_indices(self, indices):
"""
Get the indices for inverse permutation
"""
num_dim = len(indices)
inverse_indices = [0] * num_dim
for i in range(num_dim):
inverse_indices[indices[i]] = i
return inverse_indices
def shape_propagation(self, input_node_meta):
input_shape = input_node_meta.tensor.shape
output_shape = tuple([input_shape[idx] for idx in self.indices])
return output_shape
def broadcast(self, shape, node_meta: NodeBase):
"""
Broadcast the inputs based on current shape
"""
self.update(shape)
inverse_shape = tuple([shape[idx] for idx in self.inverse_indices])
node_meta.tensor.broadcast(inverse_shape)
def apply_to_user(self, usr_meta: NodeBase):
"""
Propagate the permutation to the users of the current nodes
"""
usr_meta.tensor.permute(self.inverse_indices)
if hasattr(usr_meta, "store_tensor"):
if usr_meta.store_tensor is not None:
usr_meta.store_tensor.permute(self.inverse_indices)
def apply_to_input(self, input_meta: NodeBase):
"""
Propagate the permutation to inputs of the current nodes
"""
input_meta.tensor.permute(self.indices)
if hasattr(input_meta, "store_tensor"):
if input_meta.store_tensor is not None:
input_meta.store_tensor.permute(self.indices)
class ReshapeImpl:
"""
Detailed implementation and helper functions for reshape
"""
def __init__(self, node) -> None:
self.node = node
assert "new_shape" in node.kwargs.keys()
self.output_shape = _list_to_tuple(node.kwargs["new_shape"])
def get_inverse_impl(self):
inverse_impl = deepcopy(self)
inverse_impl.output_shape = self.input_shape
inverse_impl.input_shape = self.output_shape
return inverse_impl
def shape_propagation(self, input_node_meta):
self.input_shape = input_node_meta.tensor.shape
return _list_to_tuple(self.output_shape)
def broadcast(self, shape, node_meta: NodeBase):
"""
Broadcast the inputs based on current shape.
"""
# Step 1: infer split
flatten_split_shape = self.infer_split(flatten(self.input_shape), flatten(self.output_shape))
split_input_shape = self.infer_merge(flatten_split_shape, self.input_shape)
split_output_shape = self.infer_merge(flatten_split_shape, self.output_shape)
# broadcast shape -> split_output_shape -> flatten_split_shape
if len(shape) - len(split_output_shape) > 0:
for _ in range(len(shape) - len(split_output_shape)):
split_output_shape = [1,] + split_output_shape
flatten_split_shape = [1,] + flatten_split_shape
split_input_shape = [1,] + split_input_shape
broadcast_factor = []
for dim, old_dim in zip(shape, split_output_shape):
if not isinstance(dim, list):
dim = [dim,]
if not isinstance(old_dim, list):
old_dim = [old_dim,]
if product(tuple(dim)) == product(tuple(old_dim)):
broadcast_factor += [1] * len(old_dim)
elif product(tuple(old_dim)) == 1:
assert len(dim) == 1
broadcast_factor.append(dim[0])
else:
raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {dim}")
# flatten_split_shape -> split_input_shape
factor_idx = 0
broadcast_split_input_shape = []
for dim in split_input_shape:
if isinstance(dim, list):
new_dim = []
for d in dim:
new_dim.append(d * broadcast_factor[factor_idx])
factor_idx += 1
broadcast_split_input_shape.append(new_dim)
else:
broadcast_split_input_shape.append(dim * broadcast_factor[factor_idx])
factor_idx += 1
broadcast_split_input_shape = _list_to_tuple(broadcast_split_input_shape)
node_meta.tensor.reshape(_list_to_tuple(split_input_shape))
node_meta.tensor.broadcast(broadcast_split_input_shape)
# Last reshape op to clean up
broadcast_input_shape = tuple([product(dim) for dim in broadcast_split_input_shape])
node_meta.tensor.reshape(broadcast_input_shape)
# Update the input shape and output shape
self.input_shape = _list_to_tuple(node_meta.tensor.shape)
self.output_shape = _list_to_tuple(shape)
def apply_to_user(self, user_meta: NodeBase):
"""
Propagate the reshape to user nodes
"""
user_meta.tensor.reshape(tuple(self.input_shape))
if hasattr(user_meta, "store_tensor"):
if user_meta.store_tensor is not None:
user_meta.store_tensor.reshape(tuple(self.input_shape))
def apply_to_input(self, input_meta: NodeBase):
"""
Propagate the reshape to input nodes
"""
input_meta.tensor.reshape(tuple(self.output_shape))
if hasattr(input_meta, "store_tensor"):
if input_meta.store_tensor is not None:
input_meta.store_tensor.reshape(tuple(self.output_shape))
#
# Helper functions
#
def infer_split(self, input_shape, output_shape):
"""
Infer the flatten splitted shape that can be merged to both input_shape and output_shape
"""
input_shape = _tuple_to_list(input_shape)
output_shape = _tuple_to_list(output_shape)
if len(input_shape) == 0 and len(output_shape) == 0:
return []
if len(input_shape) == 0:
if product(tuple(output_shape)) != 1:
raise ValueError("Invalid reshape size")
else:
return output_shape
if len(output_shape) == 0:
if product(tuple(input_shape)) != 1:
raise ValueError("Invalid reshape size")
else:
return input_shape
# This is done recursively by only process the last dimension at each time
old_dim = input_shape[-1]
new_dim = output_shape[-1]
# Exact match
if old_dim == new_dim:
return self.infer_split(input_shape[:-1], output_shape[:-1]) + [new_dim,]
# Needs split
if old_dim > new_dim and old_dim % new_dim == 0:
residual = old_dim // new_dim
return self.infer_split(input_shape[:-1] + [residual,], output_shape[:-1]) + [new_dim,]
# Needs merge
if old_dim < new_dim and new_dim % old_dim == 0:
residual = new_dim // old_dim
return self.infer_split(input_shape[:-1], output_shape[:-1] + [residual,]) + [old_dim,]
raise NotImplementedError(f"Unsupported split: {input_shape} -> {output_shape}")
def infer_merge(self, flatten_shape, shape):
flatten_shape = _tuple_to_list(flatten_shape)
shape = _tuple_to_list(shape)
idx_flat = len(flatten_shape) - 1
merged_shape = []
for dim in reversed(shape):
# Exact match
if dim == flatten_shape[idx_flat]:
merged_shape.append(dim)
idx_flat -= 1
# need group
elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0:
residual = dim
group = []
while(residual > 1):
group.append(flatten_shape[idx_flat])
residual = residual // flatten_shape[idx_flat]
idx_flat -= 1
merged_shape.append(group[::-1])
else:
raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}")
return merged_shape[::-1]
class LayoutNode(NodeBase):
"""
Layout manipulation nodes
"""
fn_to_impl = {
"permute": PermutationImpl,
"reshape": ReshapeImpl
}
def __init__(self, name: str, fn, kwargs: dict) -> None:
super().__init__(name)
self.op = "layout"
self.fn = fn
self.kwargs = kwargs
self.underlying_impl = self.fn_to_impl[self.fn.__name__](self)
def get_inverse_node(self):
inverse_node = deepcopy(self)
inverse_node.underlying_impl = self.underlying_impl.get_inverse_impl()
return inverse_node
def shape_propagation(self, input_node_metas):
if self._tensor is not None:
return
assert len(input_node_metas) == 1, "Layout node can only have one input node"
output_shape = self.underlying_impl.shape_propagation(input_node_metas[0])
self._tensor = Tensor(
element=self.element_output,
shape=output_shape, layout_tag=LayoutType.RowMajor
)
return super().shape_propagation(input_node_metas)
def type_propagation(self, input_node_metas: 'list[NodeBase]'):
"""
The store nodes has element_output = element_input
"""
assert len(input_node_metas) == 1, "Layout node can only have one input node"
self.element_output = input_node_metas[0].element_output
def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
"""
Propagate the broadcast in the reversed topological order
"""
if self.tensor is None:
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
shape = self.tensor.shape
for child in input_node_metas:
self.underlying_impl.broadcast(shape, child)
def apply_to_user(self, usr_meta: NodeBase):
"""
Propagate the permutation to user nodes
"""
self.underlying_impl.apply_to_user(usr_meta)
def apply_to_input(self, input_meta: NodeBase):
"""
Propagate the permutation to input nodes
"""
self.underlying_impl.apply_to_input(input_meta)

View File

@ -0,0 +1,294 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Load nodes and implementations
"""
import ctypes
from cutlass.backend.c_types import tuple_factory
from cutlass.backend.epilogue import dtype2ctype, to_ctype_value
from cutlass.backend.evt.ir.node import NodeBase, ImplBase
class LoadImplBase(ImplBase):
"""
Base class for load node implementations
"""
reserved_names = ["accum", "C"]
def __init__(self, node) -> None:
super().__init__(node)
self.element = node.element
self.element_output = node.element_output
self.stride = node.tensor.stride
class AccumulatorImpl(LoadImplBase):
"""
Accumulator node implementation
"""
@staticmethod
def match(node, problem_size: tuple):
return node.name == "accum" and node.tensor.shape == problem_size
class LoadSrcImpl(LoadImplBase):
"""
Load C implementation
"""
@property
def name_camel(self) -> str:
return "TensorC"
@property
def argument_type_c(self):
stride_mnl = self.get_stride_mnl()
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
class _Argument(ctypes.Structure):
_fields_ = [
("ptr_C", ctypes.c_void_p),
("stride_C", tuple_type)
]
def __init__(self, ptr) -> None:
self.ptr_C = ptr
self.stride_C = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
return node.name == "C" and node.tensor.shape == problem_size
class AuxLoadImpl(LoadImplBase):
"""
Load arbitrary tensor
"""
@property
def argument_type(self):
stride_mnl = self.get_stride_mnl()
name = self.name
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
element_type = self.element
class _Argument(ctypes.Structure):
_fields_ = [
("ptr_aux", ctypes.c_void_p),
("null_default", dtype2ctype[element_type]),
("dAux", tuple_type)
]
def __init__(self, kwargs) -> None:
ptr = kwargs[name]
self.ptr_aux = ptr
self.null_default = to_ctype_value(0, element_type)
self.dAux = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
if node.name in LoadImplBase.reserved_names:
return False
strideMN = node.tensor.stride[-2:]
if (strideMN[0] == 1 and strideMN[1] != 0 or
strideMN[0] != 0 and strideMN[1] == 1 ):
return True
else:
return False
class RowBroadcastImpl(LoadImplBase):
"""
Broadcast a row vector
"""
def __init__(self, node) -> None:
super().__init__(node)
self.stride_dtype = "int"
@property
def argument_type(self):
stride_mnl = self.get_stride_mnl()
name = self.name
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
element_type = self.element
class _Argument(ctypes.Structure):
_fields_ = [
("ptr_row", ctypes.c_void_p),
("null_default", dtype2ctype[element_type]),
("dRow", tuple_type)
]
def __init__(self, kwargs) -> None:
ptr = kwargs[name]
self.ptr_row = ptr
self.null_default = to_ctype_value(0, element_type)
self.dRow = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
if node.name in LoadImplBase.reserved_names:
return False
strideMN = node.tensor.stride[-2:]
if strideMN == (0, 1):
return True
else:
return False
class ColumnBroadcastImpl(LoadImplBase):
"""
Broadcast a column vector
"""
def __init__(self, node) -> None:
super().__init__(node)
self.stride_dtype = "int"
@property
def argument_type(self):
stride_mnl = self.get_stride_mnl()
name = self.name
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
element_type = self.element
class _Argument(ctypes.Structure):
_fields_ = [
("ptr_col", ctypes.c_void_p),
("null_default", dtype2ctype[element_type]),
("dCol", tuple_type)
]
def __init__(self, kwargs) -> None:
ptr = kwargs[name]
self.ptr_col = int(ptr)
self.null_default = to_ctype_value(0, element_type)
self.dCol = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
if node.name in LoadImplBase.reserved_names:
return False
strideMN = node.tensor.stride[-2:]
if strideMN == (1, 0):
return True
else:
return False
class ScalarBroadcastImpl(LoadImplBase):
"""
Broadcast a scalar
"""
def __init__(self, node) -> None:
super().__init__(node)
self.stride_dtype = "int"
@property
def argument_type(self):
stride_mnl = self.get_stride_mnl()
name = self.name
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
element_type = self.element
if self.tensor.is_constant:
value = self.tensor.value
class _Argument(ctypes.Structure):
_fields_ = [
("scalars", dtype2ctype[element_type]),
("scalar_ptrs", ctypes.c_void_p),
("dScalar", tuple_type)
]
def __init__(self, kwargs) -> None:
self.scalars = to_ctype_value(value, element_type)
self.scalar_ptrs = 0
self.dScalar = tuple_type(stride_mnl)
else:
class _Argument(ctypes.Structure):
_fields_ = [
("scalars", dtype2ctype[element_type]),
("scalar_ptrs", ctypes.c_void_p),
("dScalar", tuple_type)
]
def __init__(self, kwargs) -> None:
scalar_or_ptr = kwargs[name]
if isinstance(scalar_or_ptr, float):
self.scalars = to_ctype_value(scalar_or_ptr, element_type)
self.scalar_ptrs = 0
else:
self.scalar_ptrs = int(scalar_or_ptr)
self.dScalar = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
if node.name in LoadImplBase.reserved_names:
return False
strideMN = node.tensor.stride[-2:]
if strideMN == (0, 0):
return True
else:
return False
class LoadNode(NodeBase):
"""
Load Node
"""
cnt = 0
possible_impls = [
AccumulatorImpl, LoadSrcImpl, AuxLoadImpl,
RowBroadcastImpl, ColumnBroadcastImpl,
ScalarBroadcastImpl
]
def __init__(self, name: str) -> None:
if name is None:
name = f"load{LoadNode.cnt}"
LoadNode.cnt += 1
super().__init__(name)
self.op = "load"
def type_propagation(self, *args, **kwargs):
"""
Load node loads tensor under type `tensor.element` and returns an array of type `tensor.element`.
"""
if self.tensor is None:
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
self.element = self.tensor.element
self.element_output = self.tensor.element

View File

@ -0,0 +1,292 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Base & visitor classes of DAGIR Nodes
"""
import ctypes
from re import sub
from cutlass import LayoutType
from cutlass.backend.evt.ir.layout_algorithm import _list_to_tuple, _reverse_tuple
from cutlass.backend.evt.ir.tensor import Tensor
class ImplBase:
"""
Base class for Node Implementation
"""
def __init__(self, node) -> None:
self.node = node
self.name = node.name
self.tensor = node.tensor
self._type_decl = None
self.stride_dtype = "int64_t"
@staticmethod
def match(node, problem_size: tuple):
"""
Match function used in get_underlying_impl
"""
raise NotImplementedError(f"The `match` function is not defined.")
@property
def argument_type(self):
"""
Default class for Argument Type
"""
class _Argument(ctypes.Structure):
_fields_ = []
def __init__(self, *args, **kwargs) -> None:
pass
return _Argument
@property
def name_camel(self) -> str:
"""
Return the CamelCase name.
"""
return sub(r"(_|-)+", " ", self.name).title().replace(" ", "")
def _emit_cute_tuple(self, py_tuple):
"""
Emit the cute tuple to C++ code
"""
if isinstance(py_tuple, int):
if py_tuple in [0, 1]:
return f"cute::Int<{py_tuple}>"
else:
return f"{self.stride_dtype}"
elif isinstance(py_tuple, tuple):
decl = "cute::Stride<"
for item in py_tuple:
decl += self._emit_cute_tuple(item) + ", "
return decl[:-2] + ">"
else:
raise ValueError(f"_emit_cute_tuple only accepts tuple or int, got {type(py_tuple).__name__}")
@property
def stride_mnl(self):
"""
Typename StrideMNL
"""
stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2]))))
return self._emit_cute_tuple(stride)
def get_non_constant_stride(self, py_tuple):
if isinstance(py_tuple, int):
if py_tuple not in [0, 1]:
return py_tuple
else:
return None
non_constant_stride = []
for item in py_tuple:
item_out = self.get_non_constant_stride(item)
if item_out:
non_constant_stride.append(item_out)
return tuple(non_constant_stride)
def get_stride_mnl(self):
"""
Get the non-zero stride mnl. This is used in argument construction
"""
stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2]))))
return stride
def get_smem_size(self, *args, **kwargs):
"""
Get the shared memory size and alignment of current node
"""
return (0, 1)
class NoOpImpl(ImplBase):
"""
The NoOpImpl does nothing but forward its input to users
"""
def __init__(self, node) -> None:
super().__init__(node)
@staticmethod
def match(node, problem_size: tuple):
if node.op == "store":
# Store that is not output is a No OP
return not node.is_output
class NodeBase:
"""
Base class of DAG Node
"""
def __init__(self, name: str) -> None:
self.name = name
self.underlying_impl = None
self._tensor = None
# Whether the node is disabled for emit
self.disabled = False
@property
def name_camel(self) -> str:
"""
Return the CamelCase name.
"""
return self.underlying_impl.name_camel
@property
def tensor(self) -> Tensor:
"""
Return the output tensor (concept: cutlass.backend.evt.ir.tensor)
"""
return self._tensor
@tensor.setter
def tensor(self, kwargs):
"""
Setting the tensor
"""
self._tensor = Tensor(**kwargs)
#
# Helper functions for type/shape propagation
#
def shape_propagation(self, input_node_metas):
"""
Infer shape from input nodes
General Broadcasting Rules from NumPy
When operating on two arrays, we compare their shapes element-wise.
It starts with the trailing (i.e. rightmost) dimension and works its
way left. Two dimensions are compatible when
1. they are equal
2. one of them is 1
"""
if self._tensor is not None:
return
shape = None
for src in input_node_metas:
src_shape = src.tensor.shape
if shape is None:
shape = src_shape
else:
len_difference = len(shape) - len(src_shape)
if len_difference > 0:
for _ in range(len_difference):
src_shape = [1, ] + list(src_shape)
elif len_difference < 0:
for _ in range(-len_difference):
shape = [1, ] + list(shape)
broadcasted_shape = []
# Infer broadcast shape
for shape_dim, src_dim in zip(reversed(shape), reversed(src_shape)):
if shape_dim == 1:
broadcasted_shape = [src_dim, ] + list(broadcasted_shape)
elif src_dim == 1:
broadcasted_shape = [shape_dim, ] + list(broadcasted_shape)
elif shape_dim == src_dim:
broadcasted_shape = [shape_dim, ] + list(broadcasted_shape)
else:
error_msg = "Dimension mismatch between "
for src_ in input_node_metas:
error_msg += f"{src_.name}{src_.tensor.shape}, "
error_msg = error_msg[:-2] + "."
raise RuntimeError(error_msg)
shape = tuple(broadcasted_shape)
self._tensor = Tensor(element=self.element_output, shape=shape, layout_tag=LayoutType.RowMajor)
def type_propagation(self, *args, **kwargs):
"""
Each node is associated with two data types: `element` and `element_output`.
The `element_output` is the type of return array of the node. The `element`
has specific meaning for different node types.
* Load Node: data type of tensor in gmem
* Compute Node: element compute
* Store Node: data type of tensor in gmem
This function must be overloaded in the derived classes
"""
raise NotImplementedError(f"Function `type_propagation` is not overloaded in {self.__class__.__name__}")
def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
"""
Propagate the broadcast in the reversed topological order.
For example:
C[l, m, n] = A[m, 1] + B[l, m, n]
After the broadcast propagation, it will be come
C[l, m, n] = A[l, m, n] + B[l, m, n]
and each tensor will have a proper stride accessing the underlying tensor
"""
if self.tensor is None:
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
for child in input_node_metas:
child.tensor.broadcast(self.tensor.shape)
def get_underlying_impl(self, problem_size: tuple):
"""
Get the underlying implementation of the current node.
"""
if self.tensor is None:
raise RuntimeError(f"The Layout of node {self.name} is unknown. Please call PassShapeTypePropagation first.")
for impl in self.possible_impls:
if impl.match(self, problem_size):
self.underlying_impl = impl(self)
break
if self.underlying_impl is None:
raise NotImplementedError(f"No matching op for node {self.name} with stride {self.tensor.stride}.")
#
# Visitor Nodes & Impls
#
class TopoVisitorImpl(ImplBase):
"""
Impl for topological visitor
"""
def __init__(self, node) -> None:
super().__init__(node.output_node)
self.name = node.name
self.element_output = node.output_node.element_output
class TopoVisitorNode(NodeBase):
def __init__(self, name: str, subgraph, output_node) -> None:
super().__init__(name)
self.subgraph = subgraph
self.output_node = output_node
self.op = "dag"
self.underlying_impl = TopoVisitorImpl(self)

View File

@ -0,0 +1,276 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Store node and implementations
"""
import ctypes
from cutlass import DataType
from cutlass.backend.c_types import tuple_factory
from cutlass.backend.epilogue import dtype2ctype, to_ctype_value
from cutlass.backend.evt.ir.node import NodeBase, ImplBase, NoOpImpl
from cutlass.backend.evt.ir.tensor import Tensor
from cutlass.backend.library import FloatRoundStyle, FunctionalOp
class StoreImplBase(ImplBase):
"""
Base class for store node implementation
"""
reserved_names = ["D"]
def __init__(self, node) -> None:
super().__init__(node)
self.element = node.element
self.element_output = node.element_output
self.stride = node.store_tensor.stride
class StoreDImpl(StoreImplBase):
"""
Store D implementation
"""
@property
def argument_type_d(self):
stride_mnl = self.get_stride_mnl()
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
class _Argument(ctypes.Structure):
_fields_ = [
("ptr_D", ctypes.c_void_p),
("stride_D", tuple_type)
]
def __init__(self, ptr: int) -> None:
self.ptr_D = ptr
self.stride_D = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
if node.name == "D" and node.store_tensor.shape == problem_size:
return True
return False
class AuxStoreImpl(StoreImplBase):
def __init__(self, node) -> None:
super().__init__(node)
self.round_style = FloatRoundStyle.ToNearest
@property
def argument_type(self):
stride_mnl = self.get_stride_mnl()
name = self.name
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
class _Argument(ctypes.Structure):
_fields_ = [
("ptr_aux", ctypes.c_void_p),
("dAux", tuple_type)
]
def __init__(self, kwargs) -> None:
ptr = kwargs[name]
self.ptr_aux = ptr
self.dAux = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
if not node.is_output:
return False
if node.name in StoreImplBase.reserved_names:
return False
strideMN = node.store_tensor.stride[-2:]
if (strideMN[0] == 1 and strideMN[1] != 0 or
strideMN[0] != 0 and strideMN[1] == 1 ):
return True
else:
return False
class ReductionImplBase(StoreImplBase):
def __init__(self, node) -> None:
super().__init__(node)
self.element = node.store_tensor.element
self.element_compute = node.element_compute
self.reg_reduce_fn = self.node.reg_reduce_fn
self.gmem_reduce_fn = self.node.gmem_reduce_fn
self.round_style = node.round_style
self.stride_dtype = "int"
def get_reduce_identity(self):
"""
Return the reduction identity of the current reduce_fn
"""
maxes = {
DataType.f32: (2 ** 31) - 1,
DataType.f16: (2 ** 15),
DataType.s32: (2 ** 31) - 1,
DataType.s8: (2 ** 7) - 1
}
mins = {
DataType.f32: -maxes[DataType.f32],
DataType.f16: -maxes[DataType.f16],
DataType.s32: -maxes[DataType.s32],
DataType.s8: -maxes[DataType.s8]
}
if self.reg_reduce_fn == FunctionalOp.Maximum:
if self.element_compute not in mins:
raise Exception(f"No min entry for data type {self.element_compute}")
return to_ctype_value(mins[self.element_compute], self.element_compute)
elif self.reg_reduce_fn == FunctionalOp.Multiplies:
return to_ctype_value(1., self.element_compute)
elif self.reg_reduce_fn == FunctionalOp.Minimum:
if self.element_compute not in maxes:
raise Exception(f"No max entry for data type {self.element_compute}")
return to_ctype_value(maxes[self.element_compute], self.element_compute)
else:
return to_ctype_value(0., self.element_compute)
@property
def argument_type(self):
self.get_reduce_identity()
stride_mnl = self.get_stride_mnl()
name = self.name
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
element_compute = self.element_compute
reduce_identity = self.get_reduce_identity()
class _Argument(ctypes.Structure):
_fields_ = [
("ptr", ctypes.c_void_p),
("reduce_identity", dtype2ctype[element_compute]),
("dMNL", tuple_type)
]
def __init__(self, kwargs) -> None:
ptr = kwargs[name]
self.ptr = ptr
self.reduce_identity = reduce_identity
self.dMNL = tuple_type(stride_mnl)
return _Argument
class ColumnReductionImpl(ReductionImplBase):
@staticmethod
def match(node, problem_size: tuple):
if not node.is_output:
return False
if node.name in StoreImplBase.reserved_names:
return False
strideMN = node.store_tensor.stride[-2:]
if strideMN == (1, 0):
return True
else:
return False
class RowReductionImpl(ReductionImplBase):
@staticmethod
def match(node, problem_size: tuple):
if not node.is_output:
return False
if node.name in StoreImplBase.reserved_names:
return False
strideMN = node.store_tensor.stride[-2:]
if strideMN == (0, 1):
return True
else:
return False
class ScalarReductionImpl(ReductionImplBase):
@staticmethod
def match(node, problem_size: tuple):
if not node.is_output:
return False
if node.name in StoreImplBase.reserved_names:
return False
strideMN = node.store_tensor.stride[-2:]
if strideMN == (0, 0):
return True
else:
return False
class StoreNode(NodeBase):
"""
Store node
"""
possible_impls = [
AuxStoreImpl, RowReductionImpl,
ColumnReductionImpl, ScalarReductionImpl,
NoOpImpl, StoreDImpl
]
def __init__(self, name: str) -> None:
super().__init__(name)
self.op = "store"
self.is_output = False
self._store_tensor = None
@property
def store_tensor(self) -> Tensor:
"""
Return the output tensor (concept: cutlass.backend.evt.ir.tensor)
"""
return self._store_tensor
@store_tensor.setter
def store_tensor(self, kwargs):
"""
Setting the tensor
"""
self._store_tensor = Tensor(**kwargs)
def type_propagation(self, input_node_metas: 'list[NodeBase]'):
"""
The store nodes has element_output = element_input
"""
if self.is_output:
if self.store_tensor is None:
raise RuntimeError(f"The store tensor of node {self.name} is unknown.")
self.element = self.store_tensor.element
assert len(input_node_metas) == 1, "Store node can only have one input node"
self.element_output = input_node_metas[0].element_output
def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
super().broadcast_propagation(input_node_metas)
if self.is_output:
self._store_tensor.broadcast(self.tensor.shape)

View File

@ -0,0 +1,130 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
High-level class for tensor
"""
from cutlass import LayoutType
from cutlass.backend.evt.ir.layout_algorithm import (
Layout,
broadcast,
canonicalization,
permutation,
reshape,
_reverse_tuple
)
from cutlass.utils.datatypes import get_datatype_and_layout, get_tensor_shape, library_type
class Tensor:
"""
The tensor abstracts the data type
"""
def __init__(self, tensor=None, element=None, shape=None, layout_tag=None, is_constant=False) -> None:
if element is not None and tensor is not None:
raise Exception(f"Must not specify both element and tensor")
elif shape is not None and tensor is not None:
raise Exception(f"Must not specify both shape and tensor")
elif layout_tag is not None and tensor is not None:
raise Exception(f"Must not specify both layout_tag and tensor")
elif (element is None or layout_tag is None or shape is None) and (tensor is None) :
raise Exception(f"Must specify one of (element, shape, layout) or (tensor)")
if isinstance(tensor, Tensor):
# Directly copy all the attributes
self.__dict__.update(vars(tensor))
else:
if tensor is None:
self.element = library_type(element)
else:
self.element, layout_tag = get_datatype_and_layout(tensor)
shape = get_tensor_shape(tensor)
if layout_tag == LayoutType.RowMajor:
self.layout = Layout(shape[::-1])
elif layout_tag == LayoutType.ColumnMajor:
self.layout = permutation(Layout(shape), [idx for idx in reversed(range(len(shape)))])
self.layout = canonicalization(self.layout)
self.is_constant = is_constant
# Save the tensor value if it is constant
if is_constant and tensor is not None:
self.value = tensor
@property
def shape(self):
"""
Returns the RowMajor layout shape
"""
return _reverse_tuple(self.layout.shape)
@property
def stride(self):
"""
Returns the RowMajor layout stride
"""
return _reverse_tuple(self.layout.stride)
@property
def rank(self):
"""
Returns the rank of the tensor
"""
return len(self.shape)
#
# Layout Algorithms
#
def broadcast(self, shape):
"""
Broadcast self.layout to shape
"""
assert isinstance(shape, tuple)
self.layout = broadcast(self.layout, _reverse_tuple(shape))
def reshape(self, shape):
"""
Reshape self.layout to shape
"""
assert isinstance(shape, tuple)
reverse_shape = _reverse_tuple(shape)
self.layout = reshape(self.layout, reverse_shape)
def permute(self, indices):
"""
Permute self.layout according to indices
"""
length = len(indices)
indices = [length - idx - 1 for idx in indices]
self.layout = permutation(self.layout, indices[::-1])

View File

@ -0,0 +1,42 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass.backend.evt.passes.graph_drawer import EVTGraphDrawer
from cutlass.backend.evt.passes.pass_argument_type import PassGetArgumentType
from cutlass.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree
from cutlass.backend.evt.passes.pass_get_impl import PassGetImpl
from cutlass.backend.evt.passes.pass_fix_element_d import PassFixElementD
from cutlass.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination
from cutlass.backend.evt.passes.pass_manager import EVTPassManager
from cutlass.backend.evt.passes.pass_preprocess_red import PassPreprocessRed
from cutlass.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
from cutlass.backend.evt.passes.smem_size_calculator import GetSmemSize

View File

@ -0,0 +1,158 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import subprocess
import pydot
from cutlass import DataTypeTag
from cutlass.backend.evt.ir.dag_ir import DAGIR
_COLOR_MAP = {
"load": '"AliceBlue"',
"compute": "LemonChiffon1",
"accumulator": "LightGrey",
"store": "PowderBlue",
"layout": "lightseagreen",
"dag": "darkorange"
}
class EVTGraphDrawer:
"""
Visualize a EVT DAGIR with graphviz
"""
def __init__(
self,
graph: DAGIR,
name: str
):
self._name = name
self._dot_graphs = {}
self._dot_graphs[name] = self._to_dot(graph, name)
self.dot_available = self._check_dot_availability()
def _check_dot_availability(self):
"""
Check if graphviz is installed
"""
try:
# Run the 'dot' command and capture its output
result = subprocess.run(
["dot", "-V"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
# Check if the command was successful and the output contains version information
if result.returncode == 0 and "dot - graphviz" in result.stderr:
return True
except FileNotFoundError:
pass
return False
def _get_node_style(self, node):
template = {
"shape": "record",
"fillcolor": "#CAFFE3",
"style": '"filled,rounded"',
"fontcolor": "#000000",
}
if node.op in _COLOR_MAP:
template["fillcolor"] = _COLOR_MAP[node.op]
else:
raise NotImplementedError("unknown node op")
if node.disabled:
template["fontcolor"] = "grey"
template["fillcolor"] = "white"
return template
def _get_node_label(self, node):
label = "{" + f"name={node.name}|op={node.op}"
if node.op == "layout":
label += f"|fn={node.fn.__name__}"
for key in node.kwargs:
label += f"|{key}={node.kwargs[key]}"
if node.underlying_impl is not None:
label += f"|impl={type(node.underlying_impl).__name__}"
if node.op == "load":
label += f"|element_output={DataTypeTag[node.underlying_impl.element]}"
elif node.op == "compute":
label += f"|element_compute={DataTypeTag[node.underlying_impl.element_compute]}|element_output={DataTypeTag[node.underlying_impl.element_output]}"
elif node.op == "store":
label += f"|element_store={DataTypeTag[node.underlying_impl.element]}|element_output={DataTypeTag[node.underlying_impl.element_output]}"
elif node.op == "dag":
label += f"|element_output={DataTypeTag[node.underlying_impl.element_output]}"
if node.tensor is not None:
shape = node.tensor.shape
stride = node.tensor.stride
label += f"|shape={shape}|stride={stride}"
if hasattr(node, "store_tensor"):
if node.store_tensor is not None:
store_shape = node.store_tensor.shape
store_stride = node.store_tensor.stride
label += f"|store_shape={store_shape}|stride_stride={store_stride}"
label += "}"
return label
def _to_dot(
self,
graph: DAGIR,
name: str
):
dot_graph = pydot.Dot(name, randir="TB")
for node in graph.nodes_meta:
style = self._get_node_style(node)
label = self._get_node_label(node)
dot_node = pydot.Node(
node.name, label=label, **style
)
dot_graph.add_node(dot_node)
if node.op == "dag":
dot_subgraph = self._to_dot(node.subgraph, name=node.name)
self._dot_graphs[node.name] = dot_subgraph
# Add edges
for src, dst in graph.edges:
weight = graph.get_edge_weight(src, dst)
dot_graph.add_edge(pydot.Edge(src, dst, label=weight))
return dot_graph
def get_dot_graph(self) -> pydot.Dot:
return [(key, self.get_dot_graph_by_name(key)) for key in self._dot_graphs.keys()]
def get_dot_graph_by_name(self, name) -> pydot.Dot:
return self._dot_graphs[name]
def get_main_dot_graph(self) -> pydot.Dot:
return self._dot_graphs[self._name]

View File

@ -0,0 +1,116 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Construct the epilogue visitor argument type
"""
from cutlass.backend.c_types import visitor_factory
from cutlass.backend.evt.ir import TopoVisitorNode
from cutlass.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree
from cutlass.backend.evt.passes.pass_get_impl import PassGetImpl
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
from cutlass.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
class PassGetArgumentType(EVTPassBase):
"""
Construct the epilogue visitor argument type
"""
dependencies = [
PassShapeTypePropagation, # The Layout of all nodes must be set
PassDAG2Tree, # The type of each node must be set
PassGetImpl # The DAG subgraphs must be set
]
def requires(self) -> None:
# Check "D" is in the node list
if self.cc == 90 and (not self.dag_ir.has_node("D")):
raise SyntaxError(
"Sm90 EVT requires the epilogue to have a returned tensor D, "
"but the variable 'D' is not found in the return values.")
def call(self):
nodes = self.dag_ir.nodes_topological_order()
self.argument_types = {}
for node in nodes:
meta = self.dag_ir.get_node_meta(node)
if not meta.disabled:
self.argument_types[node] = meta.underlying_impl.argument_type
if node == "D" and self.cc == 90:
continue
if isinstance(meta, TopoVisitorNode):
self.get_dag_argument_type(node)
else:
self.get_evt_argument_type(node)
self.cc_specific_method(self.set_argument_type)()
def get_evt_argument_type(self, node):
# Sort the input nodes by edge weight
input_types = [self.argument_types[child] for child in self.dag_ir.get_all_inputs(node)]
if len(input_types) > 0:
self.argument_types[node] = visitor_factory(
input_types + [self.argument_types[node],], self.dag_ir.get_all_inputs(node) + [node,])
def get_dag_argument_type(self, node):
meta = self.dag_ir.get_node_meta(node)
subgraph = meta.subgraph
subgraph_nodes = subgraph.nodes_topological_order()
# Visit the unvisited nodes in subgraph
for n in subgraph_nodes:
m = subgraph.get_node_meta(n)
if m.disabled:
continue
else:
self.argument_types[n] = m.underlying_impl.argument_type
input_types = [self.argument_types[child] for child in subgraph_nodes[:-1]]
if len(input_types) > 0:
self.argument_types[node] = visitor_factory(input_types, subgraph_nodes[:-1])
def set_argument_type(self):
pass
def sm90_set_argument_type(self):
self.dag_ir.epilogue_thread_type = self.argument_types[self.dag_ir.get_all_inputs("D")[0]]
# Get the tensorD argument type
self.dag_ir.arg_d_type = self.dag_ir.get_node_meta("D").underlying_impl.argument_type_d
# Get the tensorC argument type
if self.dag_ir.has_node("C"):
self.dag_ir.arg_c_type = self.dag_ir.get_node_meta("C").underlying_impl.argument_type_c
else:
self.dag_ir.arg_c_type = self.dag_ir.arg_d_type
def sm80_set_argument_type(self):
nodes = self.dag_ir.nodes_topological_order()
self.dag_ir.epilogue_thread_type = self.argument_types[nodes[-1]]

View File

@ -0,0 +1,147 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Merge non-tree sub-graphs of the DAG IR into a single DAG. The fused DAG will be implemented
by the topological visitor, while the rest of the graph will be implemented with the tree visitor.
"""
from copy import deepcopy
from cutlass.backend.evt.ir import DAGIR, TopoVisitorNode
from cutlass.backend.evt.passes.pass_get_impl import PassGetImpl
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
from cutlass.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
class PassDAG2Tree(EVTPassBase):
"""
Convert the DAG IR to Tree by fusing subgraphs
"""
dependencies = [
PassShapeTypePropagation,
PassGetImpl
]
def call(self):
# Step 1: find the nodes that have multiple parents
multi_parent_nodes = []
for node in self.dag_ir.nodes_topological_order():
if self.dag_ir.out_degree(node) > 1:
multi_parent_nodes.append(node)
# Step 2: find the lowest common ancestor (LCA) of all its parents
for node in multi_parent_nodes:
# A multi-parent node could be already fused by the previous node
if not self.dag_ir.has_node(node):
continue
# A node uncovered by the previous fusions can have out degree change
# Case 1: it has <= 1 edges to the previously fused subgraph, no degree change
# Case 2: it has more than one edges to the previously fused subgraph, degree drops
if self.dag_ir.out_degree(node) <= 1:
continue
# Otherwise, the node still
reachable_nodes = []
# Complexity: O(Dout*N)
for parent in self.dag_ir.get_users(node):
reachable_nodes.append(set(self.dag_ir.all_reachable_nodes(parent)))
# get the common reachable objects
common_items = set.intersection(*reachable_nodes)
# If common ancestor exists, find the lowest one
if len(common_items) > 0:
topo_order = self.dag_ir.nodes_topological_order()
lca = None
topo_idx = -1
for item in common_items:
if lca is None:
lca = item
topo_idx = topo_order.index(item)
else:
if topo_idx > topo_order.index(item):
lca = item
topo_idx = topo_order.index(item)
# The lca is the output node of the DAG node
# Get the nodes to be fused
node_to_fuse = set.union(*reachable_nodes).difference(common_items)
node_to_fuse.add(lca)
# Get all the input nodes
all_input_nodes = []
all_output_nodes = []
for node in node_to_fuse:
all_input_nodes.append(set(self.dag_ir.get_all_inputs(node)))
all_output_nodes.append(set(self.dag_ir.get_users(node)))
all_input_nodes = set.union(*all_input_nodes)
all_output_nodes = set.union(*all_output_nodes)
new_subgraph_nodes = set.union(node_to_fuse, all_input_nodes, all_output_nodes)
# Create the subgraph
subgraph_ = self.dag_ir._graph.subgraph(new_subgraph_nodes)
subgraph = DAGIR()
for node in subgraph_.nodes:
meta = deepcopy(self.dag_ir.get_node_meta(node))
if node not in node_to_fuse:
meta.disabled = True
subgraph.add_node(meta)
for edge in subgraph_.edges:
subgraph.add_edge(edge[0], edge[1], self.dag_ir.get_edge_weight(edge[0], edge[1]))
# Create the fused node
dag_node = TopoVisitorNode(
name=f"dag_{lca}", subgraph=subgraph,
output_node=self.dag_ir.get_node_meta(lca))
self.dag_ir.add_node(dag_node)
# Add input edges
for idx, node in enumerate(all_input_nodes):
self.dag_ir.add_edge(node, dag_node.name, weight=idx)
# Replace all uses with DAG node (only 1 output node)
self.dag_ir.replace_all_uses_with(lca, dag_node.name)
# Remove all fused nodes
node_to_fuse.remove(lca)
for node in node_to_fuse:
self.dag_ir.remove_node(node)
else:
raise NotImplementedError("No LCA found. Consider SplitTreeVisitor.")
def ensures(self) -> None:
# Ensure that after the pass, the resulting DAG becomes a tree
for node in self.dag_ir.nodes:
out_degree = self.dag_ir.out_degree(node)
if out_degree > 1:
raise RuntimeError(f"PassDAG2Tree failed. Node {node} still have outdegree = {out_degree}")

View File

@ -1,6 +1,6 @@
#################################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
@ -30,40 +30,35 @@
#
#################################################################################################
from cuda import cuda, cudart
"""
Fix the element_output of producer of D.
In Sm90 epilogue visitor, the node writing D to gmem does not have internal
element converter, so the compute node producing D must have element_output = type(D).
"""
from cutlass.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
class GpuTimer:
def __init__(self) -> None:
self.events = [
cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1],
cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1],
]
class PassFixElementD(EVTPassBase):
"""
In Sm90 epilogue visitor, the node writing D to gmem does not have internal
element converter, so the compute node producing D must have
element_output = type(D)
"""
dependencies = [
PassLayoutManipulateElimination
]
def get_producer(self, node, element_D):
node_meta = self.dag_ir.get_node_meta(node)
if node_meta.op == "compute":
node_meta.element_output = element_D
elif node_meta.op == "store":
self.get_producer(self.dag_ir.get_all_inputs(node)[0], element_D)
def start(self, stream=cuda.CUstream(0)):
(err,) = cuda.cuEventRecord(self.events[0], stream)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
def stop(self, stream=cuda.CUstream(0)):
(err,) = cuda.cuEventRecord(self.events[1], stream)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
pass
def stop_and_wait(self, stream=cuda.CUstream(0)):
self.stop(stream)
if stream:
(err,) = cuda.cuStreamSynchronize(stream)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
else:
(err,) = cudart.cudaDeviceSynchronize()
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
def duration(self, iterations=1):
err, duration = cuda.cuEventElapsedTime(self.events[0], self.events[1])
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
return duration / float(iterations)
def call(self):
if self.dag_ir.has_node("D"):
node_d_meta = self.dag_ir.get_node_meta("D")
element_D = node_d_meta.store_tensor.element
self.get_producer("D", element_D)

View File

@ -0,0 +1,89 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Infer the underlying implement of each node.
While the frontend only distinguish between Load/Store/Compute Node,
each of these nodes can have different underlying implementation based
on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadcast, etc.
This pass infers the underlying impl of each node
"""
import cutlass.backend.evt.backend as evt_backend
from cutlass.backend.evt.ir import DAGIR, LoadNode
from cutlass.backend.evt.passes.pass_fix_element_d import PassFixElementD
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
from cutlass.backend.evt.passes.pass_no_op_elimination import PassNoOpElimination
from cutlass.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
class PassGetImpl(EVTPassBase):
"""
While the frontend only distinguish between Load/Store/Compute Node,
each of these nodes can have different underlying implementation based
on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadcast, etc.
This pass infers the underlying impl of each node
"""
dependencies = [
PassShapeTypePropagation, # The shape and type info are required for inference
PassFixElementD
]
def __init__(self, dag_ir: DAGIR) -> None:
super().__init__(dag_ir)
self.no_op_elimination = PassNoOpElimination(dag_ir)
def requires(self) -> None:
# Verify "accum" is in the arg list
if not self.dag_ir.has_node("accum"):
raise SyntaxError("Cannot find 'accum' in the argument list.")
def call(self):
# The loop structure of the epilogue is determined by the
# accumulator shape
accumulator: LoadNode = self.dag_ir.get_node_meta("accum")
problem_size = accumulator.tensor.shape
for node_meta in self.dag_ir.node_metas_topological_order():
node_meta.get_underlying_impl(problem_size)
def ensures(self) -> None:
# Some nodes will be lowered to NoOp, eliminate them
self.no_op_elimination()
# Lower to cc-specific impl
for node_meta in self.dag_ir.nodes_meta:
node_impl_ccs = getattr(evt_backend, f"sm{self.cc}_nodes")
node_meta.underlying_impl = getattr(
node_impl_ccs,
f"Sm{self.cc}" + node_meta.underlying_impl.__class__.__name__
)(node_meta)

View File

@ -0,0 +1,217 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Eliminate layout manipulation nodes
"""
from copy import deepcopy
from cutlass.backend.evt.ir import DAGIR, LayoutNode
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
from cutlass.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
class PassLayoutManipulateElimination(EVTPassBase):
"""
Eliminate layout manipulation nodes
"""
dependencies = [PassShapeTypePropagation]
def __init__(self, dag_ir: DAGIR) -> None:
super().__init__(dag_ir)
self.copy_cnt = 0
def call(self):
self.layout_nodes_worklist = self.get_all_layout_nodes()
# Run while loop utill all layout nodes are eliminated
while(len(self.layout_nodes_worklist) > 0):
node = self.layout_nodes_worklist.pop(0)
# for node in layout_nodes:
# Step 1: get the propagation direction
direction = self.get_propagation_direction(node)
self.visited = []
getattr(self, f"propagate_to_{direction}")(self.dag_ir.get_node_meta(node), node)
# Eliminate the current node
input_node = self.dag_ir.get_all_inputs(node)[0]
self.dag_ir.replace_all_uses_with(node, input_node)
# layout_nodes = self.get_all_layout_nodes()
def get_all_layout_nodes(self):
layout_nodes = []
for node_meta in reversed(self.dag_ir.node_metas_topological_order()):
if isinstance(node_meta, LayoutNode):
layout_nodes.append(node_meta.name)
return layout_nodes
def get_propagation_direction(self, node: str):
"""
The logic is propagating all layout nodes away from the accumulator node.
"""
self.visited = []
self.get_influenced_users(node)
nodes_influenced_dir_users = self.visited
self.visited = []
self.get_influenced_inputs(node)
nodes_influenced_dir_inputs = self.visited
if "accum" in nodes_influenced_dir_users and "accum" not in nodes_influenced_dir_inputs:
return "inputs"
elif "accum" not in nodes_influenced_dir_users and "accum" in nodes_influenced_dir_inputs:
return "users"
else:
raise RuntimeError("Unsolved propagation direction")
# Get all influenced nodes if we propagate along the user direction
def get_influenced_users(self, node: str):
if node in self.visited:
return
self.visited.append(node)
users = self.dag_ir.get_users(node)
for user in users:
self.get_influenced_users(user)
user_inputs = []
for user in users:
user_inputs.append(set(self.dag_ir.get_all_inputs(user)))
if len(user_inputs) > 0:
user_inputs = set.union(*user_inputs)
user_inputs.remove(node)
for input in user_inputs:
self.get_influenced_inputs(input)
# Get all influenced nodes if we propagate along the input direction
def get_influenced_inputs(self, node: str):
if node in self.visited:
return
self.visited.append(node)
inputs = self.dag_ir.get_all_inputs(node)
for input in inputs:
self.get_influenced_inputs(input)
input_users = []
for input in inputs:
input_users.append(set(self.dag_ir.get_users(input)))
if len(input_users) > 0:
input_users = set.union(*input_users)
input_users.remove(node)
for user in input_users:
self.get_influenced_users(user)
def add_copy_before(self, layout_node_meta: LayoutNode, target: str):
copied_node_meta = deepcopy(layout_node_meta)
copied_node = f"{copied_node_meta.name}_copy{self.copy_cnt}"
self.copy_cnt += 1
copied_node_meta.name = copied_node
self.dag_ir.add_node(copied_node_meta)
# Add edges
target_inputs = self.dag_ir.get_all_inputs(target)
for src in target_inputs:
self.dag_ir.remove_edge(src, target)
self.dag_ir.add_edge(src, copied_node)
self.dag_ir.add_edge(copied_node, target)
self.layout_nodes_worklist.append(copied_node)
def add_copy_after(self, layout_node_meta: LayoutNode, target: str):
copied_node_meta = deepcopy(layout_node_meta)
copied_node = f"{copied_node_meta.name}_copy{self.copy_cnt}"
self.copy_cnt += 1
copied_node_meta.name = copied_node
self.dag_ir.add_node(copied_node_meta)
# Add edges
users = self.dag_ir.get_users(target)
for user in users:
self.dag_ir.remove_edge(target, user)
self.dag_ir.add_edge(copied_node, user)
self.dag_ir.add_edge(target, copied_node)
self.layout_nodes_worklist.append(copied_node)
# Propagate the layout `node` along the user direction
def propagate_to_users(self, layout_node_meta: LayoutNode, node: str):
"""
Propagate layout node to users
"""
if node in self.visited:
# Avoid applying twice
return
self.visited.append(node)
node_meta = self.dag_ir.get_node_meta(node)
if layout_node_meta.name != node:
if isinstance(node_meta, LayoutNode):
# Layout node is not transparent with layout node
self.add_copy_before(layout_node_meta, node)
return
else:
layout_node_meta.apply_to_user(node_meta)
users = self.dag_ir.get_users(node)
user_inputs = []
for user in users:
user_inputs.append(set(self.dag_ir.get_all_inputs(user)))
for user in users:
self.propagate_to_users(layout_node_meta, user)
if len(user_inputs) > 0:
user_inputs = set.union(*user_inputs)
user_inputs.remove(node)
for input in user_inputs:
self.propagate_to_inputs(layout_node_meta.get_inverse_node(), input)
# Propagate the layout `node` along the input direction
def propagate_to_inputs(self, layout_node_meta: LayoutNode, node: str):
"""
Propagate layout node to inputs
"""
if node in self.visited:
# Avoid applying twice
return
self.visited.append(node)
node_meta = self.dag_ir.get_node_meta(node)
if layout_node_meta.name != node:
if isinstance(node_meta, LayoutNode):
# Layout node is not transparent with layout node
self.add_copy_after(layout_node_meta, node)
return
else:
layout_node_meta.apply_to_input(node_meta)
inputs = self.dag_ir.get_all_inputs(node)
input_users = []
for input in inputs:
input_users.append(set(self.dag_ir.get_users(input)))
for input in inputs:
self.propagate_to_inputs(layout_node_meta, input)
if len(input_users) > 0:
input_users = set.union(*input_users)
input_users.remove(node)
for user in input_users:
self.propagate_to_users(layout_node_meta.get_inverse_node(), user)

View File

@ -0,0 +1,163 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Pass manager for DAG IR.
"""
from typing import Any
import networkx as nx
from cutlass.backend.evt.ir import DAGIR
class EVTPassBase:
"""
Base class for EVT Passes
"""
dependencies = []
def __init__(self, dag_ir: DAGIR) -> None:
self.dag_ir = dag_ir
self.cc = self.dag_ir.cc
def requires(self) -> None:
"""
This function will be called before the pass is run.
"""
pass
def call(self) -> None:
"""
The pass that is run through the self.dag_ir
"""
raise NotImplementedError(
f"__call__ is not overwritten in Pass {self.__class__.__name__}")
def ensures(self) -> None:
"""
This function will be called after the pass is run.
"""
pass
def __call__(self) -> Any:
self.requires()
self.call()
self.ensures()
def cc_specific_method(self, func):
"""
This enables defining function that behaves differently under different cc
The simplest example of using this function is the following
.. highlight:: python
.. code-block:: python
class ExamplePass(EVTPassBase):
def call(sekf):
# This automatically select the smXX_func based on current cc
self.cc_specific_method(self.func)()
# Interface func, can be empty
def func(self):
pass
# Sm90 specific func
def sm90_func(self):
// sm90 specific method
return
# Sm80 specific func
def sm80_func(self):
// sm80 specific method
return
"""
func_name = f"sm{self.cc}_{func.__name__}"
if hasattr(self, func_name):
return getattr(self, func_name)
else:
raise NotImplementedError(f"func {func.__name__} is not overwritten for Sm{self.cc}")
class EVTPassManager(nx.DiGraph):
"""
Topological-based Pass Manager.
Each registered pass has a list of dependencies. The pass manager organizes
the passes as a DAG and launch the compiler passes under topological order.
"""
def __init__(self, dag_ir: DAGIR, pass_list):
super().__init__()
self.dag_ir = dag_ir
for pass_cls in pass_list:
self.add_pass(pass_cls)
self.sorted_passes = self.schedule()
def get_callable(self, pass_name):
"""
Return the callable of the pass
"""
return self.nodes[pass_name]["callable"]
def add_pass(self, pass_cls):
"""
Add a pass to the pass manager
:param pass_cls: the class of pass
:type pass_cls: derived class of EVTPassBase
"""
name = pass_cls.__name__
pass_callable = pass_cls(self.dag_ir)
self.add_node(name, callable=pass_callable)
def schedule(self):
"""
Schedule the added passes under topological order
"""
# Add edges
for pass_name in self.nodes:
callable = self.get_callable(pass_name)
for dependency_cls in callable.dependencies:
self.add_edge(
dependency_cls.__name__,
type(callable).__name__)
# Topological sort
return list(nx.topological_sort(self))
def __call__(self) -> Any:
"""
Launch the registered passes
"""
for pass_name in self.sorted_passes:
callable = self.get_callable(pass_name)
callable()

View File

@ -0,0 +1,53 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
No op elimination node
"""
from typing import Any
from cutlass.backend.evt.ir import NoOpImpl
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
class PassNoOpElimination(EVTPassBase):
"""
The dead node elimination pass removes nodes with NoOpImpl in DAG IR
"""
dependencies = []
def call(self) -> Any:
for node in self.dag_ir.nodes_topological_order():
node_meta = self.dag_ir.get_node_meta(node)
if isinstance(node_meta.underlying_impl, NoOpImpl):
self.dag_ir.replace_all_uses_with(node, self.dag_ir.get_all_inputs(node)[0])

View File

@ -0,0 +1,98 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Preprocess the reduction nodes.
The parser treats reduction as Compute(op=(reg_reduce_fn, gmem_reduce_fn)) - Store()
This pass fuses these into a single store node, and then replaces all uses of the
current node with the new store node.
"""
from cutlass.backend.evt.ir import ComputeNode, StoreNode
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
class PassPreprocessRed(EVTPassBase):
"""
Preprocess red nodes
"""
def call(self):
# Step 1: find the compute nodes with op=red
red_compute_nodes = []
for node_meta in self.dag_ir.nodes_meta:
if isinstance(node_meta, ComputeNode):
if type(node_meta.fn) == tuple:
# To keep the frontend simple, the reduction nodes
# are parsed into compute nodes by default
# The simple heuristic to distinguish between compute
# and reduction node is that compute node is a single function,
# while the reduction node is a tuple of functions for
# in-register reduction and atomic global memory reduction
red_compute_nodes.append(node_meta.name)
# Step 2: for each compute, merge it with the succeeding store
for node in red_compute_nodes:
# Verify
users = self.dag_ir.get_users(node)
inputs = self.dag_ir.get_all_inputs(node)
# Has a single user
assert len(users) == 1
assert len(inputs) == 1
user = users[0]
input = inputs[0]
user_meta = self.dag_ir.get_node_meta(user)
# Must be a store node
assert isinstance(user_meta, StoreNode)
# With output degree == 0
assert self.dag_ir.out_degree(user) == 0
# Register the reduce op
node_meta = self.dag_ir.get_node_meta(node)
user_meta.reg_reduce_fn, user_meta.gmem_reduce_fn = node_meta.fn
user_meta.element_compute = node_meta.element_compute
user_meta.round_style = node_meta.round_style
# Replace all uses
self.dag_ir.remove_edge(input, node)
input_users = self.dag_ir.get_users(input)
for iu in input_users:
weight = self.dag_ir.get_edge_weight(input, iu)
self.dag_ir.add_edge(user, iu, weight)
self.dag_ir.remove_edge(input, iu)
self.dag_ir.add_edge(input, user)
self.dag_ir.remove_node(node)
# Register the reduction name
self.dag_ir.reduction_names.append(user)

View File

@ -0,0 +1,59 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Shape and type propagation pass
"""
from cutlass.backend.evt.ir.node import NodeBase
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
from cutlass.backend.evt.passes.pass_preprocess_red import PassPreprocessRed
class PassShapeTypePropagation(EVTPassBase):
"""
Propagate the shape and type of all nodes
"""
dependencies = [PassPreprocessRed]
def call(self):
# Propagate the node shape and type
for node in self.dag_ir.nodes_topological_order():
node_meta: NodeBase = self.dag_ir.get_node_meta(node)
input_node_metas = self.dag_ir.get_all_inputs_meta(node)
node_meta.type_propagation(input_node_metas)
node_meta.shape_propagation(input_node_metas)
for node in reversed(self.dag_ir.nodes_topological_order()):
node_meta: NodeBase = self.dag_ir.get_node_meta(node)
input_node_metas = self.dag_ir.get_all_inputs_meta(node)
node_meta.broadcast_propagation(input_node_metas)

View File

@ -0,0 +1,200 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Compute the shared memory size in bytes
"""
from pycute import shape_div, product
import cutlass
from cutlass.backend.evt.ir import TopoVisitorNode, DAGIR
from cutlass.backend.library import DataTypeSize
class GetSmemSize:
"""
Get the size in byte of shared memory used by the kernel
"""
def __init__(self, dag_ir: DAGIR) -> None:
self.dag_ir = dag_ir
self.cc = self.dag_ir.cc
#
# Sm90 epilogue specific
#
def sm90_epilogue_tile(self, tile_description):
# Get the epilogue tile size
schedule = tile_description.epilogue_schedule
if schedule == cutlass.EpilogueScheduleType.TmaWarpSpecialized:
epilogue_tile_mn = (64, 32)
elif schedule == cutlass.EpilogueScheduleType.TmaWarpSpecializedCooperative:
epilogue_tile_mn = (128, 32)
else:
raise NotImplementedError(f"Unsupported schedule: {schedule}")
# Get the pipeline stages
stages_d = 2
epi_tiles = product(shape_div(tuple(tile_description.threadblock_shape)[:2], epilogue_tile_mn))
if self.dag_ir.has_node("C"):
element_c = self.dag_ir.get_node_meta("C").element
else:
element_c = None
element_d = self.dag_ir.get_node_meta("D").element
if element_c == element_d:
reuse_smem_c = True
else:
reuse_smem_c = False
stages_c = max(epi_tiles, stages_d + 1) if reuse_smem_c else epi_tiles
# Record the epilogue tile
self.cta_tile_mnk = tuple(tile_description.threadblock_shape)
self.epilogue_tile_mn = epilogue_tile_mn
self.epi_tiles = epi_tiles
self.stages_c = stages_c
self.stages_d = stages_d
self.reuse_smem_c = reuse_smem_c
self.element_c = element_c
self.element_d = element_d
self.is_source_supported = element_c is not None
def sm90_epilogue_smem_size(self, tile_description):
"""
Compute the shared memory size of sm90 collective epilogue
"""
self.sm90_epilogue_tile(tile_description)
# Get the Fusion Storage
nodes = self.dag_ir.nodes_topological_order()
self.smem_types = {}
for node in nodes:
meta = self.dag_ir.get_node_meta(node)
if not meta.disabled:
self.smem_types[node] = meta.underlying_impl.get_smem_size(
self.cta_tile_mnk, self.epilogue_tile_mn,
self.stages_c, self.stages_d, self.epi_tiles)
if node == "D":
continue
if isinstance(meta, TopoVisitorNode):
self.get_dag_smem_type(node)
else:
self.get_evt_smem_type(node)
thread_smem_size = self.smem_types[self.dag_ir.get_all_inputs("D")[0]][0]
# Get the Tensor Storage
tensors = []
if self.is_source_supported:
smem_C = DataTypeSize[self.element_c] * product(self.epilogue_tile_mn) * self.stages_c // 8
tensors.append((smem_C, 128))
else:
tensors.append((0, 1))
if self.reuse_smem_c:
tensors.append((0, 128))
else:
smem_D = DataTypeSize[self.element_d] * product(self.epilogue_tile_mn) * self.stages_d // 8
tensors.append((smem_D, 128))
tensors.append((thread_smem_size, 128))
tensor_smem_size = self.get_struct_size(tensors)
# Get pipeline storage size
# sizeof(uint64_t * stages_c * 2), alignment of uint64_t
# 2 is for FullBarrier and EmptyBarrier
pipeline_smem_size = (8 * self.stages_c * 2, 8)
# get SharedStorage size
smem_size = self.get_struct_size([tensor_smem_size, pipeline_smem_size])
return smem_size[0]
def __call__(self, tile_description):
return getattr(self, f"sm{self.cc}_epilogue_smem_size")(tile_description)
#
# Helper functions
#
@staticmethod
def get_visitor_size(members: list, ebo: bool):
"""
Get the size of struct in bytes
"""
offset = 0
max_alignment = 1
if len(members) > 0:
# Get alignment
for _, alignment in members:
max_alignment = max(max_alignment, alignment)
for type_size, _ in members:
if type_size != 0:
offset = ((offset + max_alignment - 1) // max_alignment) * max_alignment
if type_size == 0 and not ebo:
offset += 1
else:
offset += type_size
offset = ((offset + max_alignment - 1) // max_alignment) * max_alignment
return (offset, max_alignment)
else:
# Struct size is at least 1
return (1, 1)
def get_struct_size(self, members: list):
"""
Get the size of struct in bytes
"""
return self.get_visitor_size(members, False)
def get_evt_smem_type(self, node):
# Sort the input nodes by edge weight
input_types = [self.smem_types[child] for child in self.dag_ir.get_all_inputs(node)]
input_types.append(self.smem_types[node])
if len(input_types) > 1:
ebo = len(input_types) > 4
self.smem_types[node] = self.get_visitor_size(input_types, ebo)
def get_dag_smem_type(self, node):
meta = self.dag_ir.get_node_meta(node)
subgraph = meta.subgraph
subgraph_nodes = subgraph.nodes_topological_order()
# Visit the unvisited nodes in subgraph
for n in subgraph_nodes:
m = subgraph.get_node_meta(n)
if m.disabled:
continue
else:
self.smem_types[n] = m.underlying_impl.get_smem_size(
self.cta_tile_mnk, self.epilogue_tile_mn,
self.stages_c, self.stages_d, self.epi_tiles)
input_types = [self.smem_types[child] for child in subgraph_nodes[:-1]]
if len(input_types) > 0:
ebo = len(input_types) > 4
self.smem_types[node] = self.get_visitor_size(input_types, ebo)

View File

@ -36,10 +36,12 @@ import numpy as np
from cutlass.backend.memory_manager import device_mem_alloc, todevice
from cutlass.backend.utils.software import CheckPackages
if CheckPackages().check_torch():
torch_available = CheckPackages().check_torch()
if torch_available:
import torch
if CheckPackages().check_cupy():
cupy_available = CheckPackages().check_cupy()
if cupy_available:
import cupy as cp
@ -94,3 +96,19 @@ class CupyFrontend:
@staticmethod
def argument(cupy_ndarray: "cp.ndarray"):
return cuda.CUdeviceptr(int(cupy_ndarray.data.ptr))
class TensorFrontend:
"""
Universal Frontend for client-provide tensors
"""
@staticmethod
def argument(tensor, is_output=False):
if isinstance(tensor, np.ndarray):
return NumpyFrontend.argument(tensor, is_output)
elif torch_available and isinstance(tensor, torch.Tensor):
return TorchFrontend.argument(tensor)
elif cupy_available and isinstance(tensor, cp.ndarray):
return CupyFrontend.argument(tensor)
else:
raise NotImplementedError("Unknown Tensor Type")

File diff suppressed because it is too large Load Diff

View File

@ -31,14 +31,21 @@
#################################################################################################
"""
Common data types and string names for them. This file is similar to /tools/library/scripts/library.py,
but uses the Pybind-bound CUTLASS data types as many keys to the dictionary.
Common data types and string names/tags for them
"""
import enum
import cutlass_bindings
from cutlass import EpilogueScheduleType, KernelScheduleType, TileSchedulerType
from cutlass import (
ComplexTransform,
DataType,
DataTypeSize,
EpilogueScheduleType,
KernelScheduleType,
MathOperation,
OpcodeClass,
TileSchedulerType
)
# The following block implements enum.auto() for Python 3.5 variants that don't include it such
@ -58,121 +65,6 @@ except ImportError:
return i
ShortDataTypeNames = {
cutlass_bindings.int32: "i",
cutlass_bindings.float16: "h",
cutlass_bindings.float32: "s",
cutlass_bindings.float64: "d",
cutlass_bindings.dtype.cf32: "c",
cutlass_bindings.dtype.cf64: "z",
}
DataTypeNames = {
cutlass_bindings.dtype.b1: "b1",
cutlass_bindings.dtype.u4: "u4",
cutlass_bindings.dtype.u8: "u8",
cutlass_bindings.dtype.u16: "u16",
cutlass_bindings.dtype.u32: "u32",
cutlass_bindings.dtype.u64: "u64",
cutlass_bindings.dtype.s4: "s4",
cutlass_bindings.int8: "s8",
cutlass_bindings.dtype.s16: "s16",
cutlass_bindings.int32: "s32",
cutlass_bindings.dtype.s64: "s64",
cutlass_bindings.float16: "f16",
cutlass_bindings.bfloat16: "bf16",
cutlass_bindings.float32: "f32",
cutlass_bindings.tfloat32: "tf32",
cutlass_bindings.float64: "f64",
cutlass_bindings.dtype.cf16: "cf16",
cutlass_bindings.dtype.cbf16: "cbf16",
cutlass_bindings.dtype.cf32: "cf32",
cutlass_bindings.dtype.ctf32: "ctf32",
cutlass_bindings.dtype.cf64: "cf64",
cutlass_bindings.dtype.cu4: "cu4",
cutlass_bindings.dtype.cu8: "cu8",
cutlass_bindings.dtype.cu16: "cu16",
cutlass_bindings.dtype.cu32: "cu32",
cutlass_bindings.dtype.cu64: "cu64",
cutlass_bindings.dtype.cs4: "cs4",
cutlass_bindings.dtype.cs8: "cs8",
cutlass_bindings.dtype.cs16: "cs16",
cutlass_bindings.dtype.cs32: "cs32",
cutlass_bindings.dtype.cs64: "cs64",
}
DataTypeTag = {
cutlass_bindings.dtype.b1: "cutlass::uint1b_t",
cutlass_bindings.dtype.u4: "cutlass::uint4b_t",
cutlass_bindings.dtype.u8: "uint8_t",
cutlass_bindings.dtype.u16: "uint16_t",
cutlass_bindings.dtype.u32: "uint32_t",
cutlass_bindings.dtype.u64: "uint64_t",
cutlass_bindings.dtype.s4: "cutlass::int4b_t",
cutlass_bindings.int8: "int8_t",
cutlass_bindings.dtype.s16: "int16_t",
cutlass_bindings.int32: "int32_t",
cutlass_bindings.dtype.s64: "int64_t",
cutlass_bindings.float16: "cutlass::half_t",
cutlass_bindings.bfloat16: "cutlass::bfloat16_t",
cutlass_bindings.float32: "float",
cutlass_bindings.tfloat32: "cutlass::tfloat32_t",
cutlass_bindings.float64: "double",
cutlass_bindings.dtype.cf16: "cutlass::complex<cutlass::half_t>",
cutlass_bindings.dtype.cbf16: "cutlass::complex<cutlass::bfloat16_t>",
cutlass_bindings.dtype.cf32: "cutlass::complex<float>",
cutlass_bindings.dtype.ctf32: "cutlass::complex<cutlass::tfloat32_t>",
cutlass_bindings.dtype.cf64: "cutlass::complex<double>",
cutlass_bindings.dtype.cu4: "cutlass::complex<cutlass::uint4b_t>",
cutlass_bindings.dtype.cu8: "cutlass::complex<cutlass::uint8_t>",
cutlass_bindings.dtype.cu16: "cutlass::complex<cutlass::uint16_t>",
cutlass_bindings.dtype.cu32: "cutlass::complex<cutlass::uint32_t>",
cutlass_bindings.dtype.cu64: "cutlass::complex<cutlass::uint64_t>",
cutlass_bindings.dtype.cs4: "cutlass::complex<cutlass::int4b_t>",
cutlass_bindings.dtype.cs8: "cutlass::complex<cutlass::int8_t>",
cutlass_bindings.dtype.cs16: "cutlass::complex<cutlass::int16_t>",
cutlass_bindings.dtype.cs32: "cutlass::complex<cutlass::int32_t>",
cutlass_bindings.dtype.cs64: "cutlass::complex<cutlass::int64_t>",
}
DataTypeSize = {
cutlass_bindings.dtype.b1: 1,
cutlass_bindings.dtype.u4: 4,
cutlass_bindings.dtype.u8: 8,
cutlass_bindings.dtype.u16: 16,
cutlass_bindings.dtype.u32: 32,
cutlass_bindings.dtype.u64: 64,
cutlass_bindings.dtype.s4: 4,
cutlass_bindings.int8: 8,
cutlass_bindings.dtype.s16: 16,
cutlass_bindings.int32: 32,
cutlass_bindings.dtype.s64: 64,
cutlass_bindings.float16: 16,
cutlass_bindings.bfloat16: 16,
cutlass_bindings.float32: 32,
cutlass_bindings.tfloat32: 32,
cutlass_bindings.float64: 64,
cutlass_bindings.dtype.cf16: 32,
cutlass_bindings.dtype.cbf16: 32,
cutlass_bindings.dtype.cf32: 64,
cutlass_bindings.dtype.ctf32: 32,
cutlass_bindings.dtype.cf64: 128,
cutlass_bindings.dtype.cu4: 8,
cutlass_bindings.dtype.cu8: 16,
cutlass_bindings.dtype.cu16: 32,
cutlass_bindings.dtype.cu32: 64,
cutlass_bindings.dtype.cu64: 128,
cutlass_bindings.dtype.cs4: 8,
cutlass_bindings.dtype.cs8: 16,
cutlass_bindings.dtype.cs16: 32,
cutlass_bindings.dtype.cs32: 64,
cutlass_bindings.dtype.cs64: 128,
}
class DataTypeSizeBytes:
"""
Static class to mimic the `DataTypeSize` dictionary, but with checks for whether the
@ -193,193 +85,15 @@ class DataTypeSizeBytes:
bits = DataTypeSize[datatype]
if bits < 8:
raise Exception(
"Data type {} is less than one byte in size.".format(datatype)
f"Data type {datatype} is less than one byte in size."
)
elif bits % 8 != 0:
raise Exception(
"Data type {} is not an integer number of bytes.".format(datatype)
f"Data type datatype is not an integer number of bytes."
)
return bits // 8
ComplexTransformTag = {
cutlass_bindings.complex_transform.none: "cutlass::ComplexTransform::kNone",
cutlass_bindings.complex_transform.conj: "cutlass::ComplexTransform::kConjugate",
}
RealComplexBijection = [
(cutlass_bindings.float16, cutlass_bindings.dtype.cf16),
(cutlass_bindings.float32, cutlass_bindings.dtype.cf32),
(cutlass_bindings.float64, cutlass_bindings.dtype.cf64),
]
def is_complex(data_type):
for r, c in RealComplexBijection:
if data_type == c:
return True
return False
def get_complex_from_real(real_type):
for r, c in RealComplexBijection:
if real_type == r:
return c
return cutlass_bindings.dtype.invalid
def get_real_from_complex(complex_type):
for r, c in RealComplexBijection:
if complex_type == c:
return r
return cutlass_bindings.dtype.invalid
class ComplexMultiplyOp(enum.Enum):
multiply_add = enum_auto()
gaussian = enum_auto()
class MathOperation(enum.Enum):
multiply_add = enum_auto()
multiply_add_saturate = enum_auto()
xor_popc = enum_auto()
multiply_add_fast_bf16 = enum_auto()
multiply_add_fast_f16 = enum_auto()
multiply_add_fast_f32 = enum_auto()
multiply_add_complex_fast_f32 = enum_auto()
multiply_add_complex = enum_auto()
multiply_add_complex_gaussian = enum_auto()
MathOperationNames = {
MathOperation.multiply_add: "multiply_add",
MathOperation.multiply_add_saturate: "multiply_add_saturate",
MathOperation.xor_popc: "xor_popc",
MathOperation.multiply_add_fast_bf16: "multiply_add_fast_bf16",
MathOperation.multiply_add_fast_f16: "multiply_add_fast_f16",
MathOperation.multiply_add_fast_f32: "multiply_add_fast_f32",
MathOperation.multiply_add_complex_fast_f32: "multiply_add_complex_fast_f32",
MathOperation.multiply_add_complex: "multiply_add_complex",
MathOperation.multiply_add_complex_gaussian: "multiply_add_complex_gaussian",
}
MathOperationTag = {
MathOperation.multiply_add: "cutlass::arch::OpMultiplyAdd",
MathOperation.multiply_add_saturate: "cutlass::arch::OpMultiplyAddSaturate",
MathOperation.xor_popc: "cutlass::arch::OpXorPopc",
MathOperation.multiply_add_fast_bf16: "cutlass::arch::OpMultiplyAddFastBF16",
MathOperation.multiply_add_fast_f16: "cutlass::arch::OpMultiplyAddFastF16",
MathOperation.multiply_add_fast_f32: "cutlass::arch::OpMultiplyAddFastF32",
MathOperation.multiply_add_complex_fast_f32: "cutlass::arch::OpMultiplyAddComplexFastF32",
MathOperation.multiply_add_complex: "cutlass::arch::OpMultiplyAddComplex",
MathOperation.multiply_add_complex_gaussian: "cutlass::arch::OpMultiplyAddGaussianComplex",
}
LayoutTag = {
cutlass_bindings.ColumnMajor: "cutlass::layout::ColumnMajor",
cutlass_bindings.RowMajor: "cutlass::layout::RowMajor",
cutlass_bindings.layout.ColumnMajorInterleaved2: "cutlass::layout::ColumnMajorInterleaved<2>",
cutlass_bindings.layout.RowMajorInterleaved2: "cutlass::layout::RowMajorInterleaved<2>",
cutlass_bindings.ColumnMajorInterleaved32: "cutlass::layout::ColumnMajorInterleaved<32>",
cutlass_bindings.RowMajorInterleaved32: "cutlass::layout::RowMajorInterleaved<32>",
cutlass_bindings.layout.ColumnMajorInterleaved64: "cutlass::layout::ColumnMajorInterleaved<64>",
cutlass_bindings.layout.RowMajorInterleaved64: "cutlass::layout::RowMajorInterleaved<64>",
cutlass_bindings.TensorNHWC: "cutlass::layout::TensorNHWC",
cutlass_bindings.layout.TensorNDHWC: "cutlass::layout::TensorNDHWC",
cutlass_bindings.layout.TensorNCHW: "cutlass::layout::TensorNCHW",
cutlass_bindings.layout.TensorNGHWC: "cutlass::layout::TensorNGHWC",
cutlass_bindings.TensorNC32HW32: "cutlass::layout::TensorNCxHWx<32>",
cutlass_bindings.TensorC32RSK32: "cutlass::layout::TensorCxRSKx<32>",
cutlass_bindings.layout.TensorNC64HW64: "cutlass::layout::TensorNCxHWx<64>",
cutlass_bindings.layout.TensorC64RSK64: "cutlass::layout::TensorCxRSKx<64>",
}
TransposedLayout = {
cutlass_bindings.ColumnMajor: cutlass_bindings.RowMajor,
cutlass_bindings.RowMajor: cutlass_bindings.ColumnMajor,
cutlass_bindings.layout.ColumnMajorInterleaved2: cutlass_bindings.layout.RowMajorInterleaved2,
cutlass_bindings.layout.RowMajorInterleaved2: cutlass_bindings.layout.ColumnMajorInterleaved2,
cutlass_bindings.ColumnMajorInterleaved32: cutlass_bindings.RowMajorInterleaved32,
cutlass_bindings.RowMajorInterleaved32: cutlass_bindings.ColumnMajorInterleaved32,
cutlass_bindings.layout.ColumnMajorInterleaved64: cutlass_bindings.layout.RowMajorInterleaved64,
cutlass_bindings.layout.RowMajorInterleaved64: cutlass_bindings.layout.ColumnMajorInterleaved64,
cutlass_bindings.TensorNHWC: cutlass_bindings.TensorNHWC,
}
ShortLayoutTypeNames = {
cutlass_bindings.ColumnMajor: "n",
cutlass_bindings.layout.ColumnMajorInterleaved2: "n2",
cutlass_bindings.ColumnMajorInterleaved32: "n32",
cutlass_bindings.layout.ColumnMajorInterleaved64: "n64",
cutlass_bindings.RowMajor: "t",
cutlass_bindings.layout.RowMajorInterleaved2: "t2",
cutlass_bindings.RowMajorInterleaved32: "t32",
cutlass_bindings.layout.RowMajorInterleaved64: "t64",
cutlass_bindings.TensorNHWC: "nhwc",
cutlass_bindings.layout.TensorNDHWC: "ndhwc",
cutlass_bindings.layout.TensorNCHW: "nchw",
cutlass_bindings.layout.TensorNGHWC: "nghwc",
cutlass_bindings.TensorNC32HW32: "nc32hw32",
cutlass_bindings.layout.TensorNC64HW64: "nc64hw64",
cutlass_bindings.TensorC32RSK32: "c32rsk32",
cutlass_bindings.layout.TensorC64RSK64: "c64rsk64",
}
ShortComplexLayoutNames = {
(cutlass_bindings.ColumnMajor, cutlass_bindings.complex_transform.none): "n",
(cutlass_bindings.ColumnMajor, cutlass_bindings.complex_transform.conj): "c",
(cutlass_bindings.RowMajor, cutlass_bindings.complex_transform.none): "t",
(cutlass_bindings.RowMajor, cutlass_bindings.complex_transform.conj): "h",
}
OpcodeClassNames = {
cutlass_bindings.OpClass.Simt: "simt",
cutlass_bindings.OpClass.TensorOp: "tensorop",
cutlass_bindings.OpClass.WmmaTensorOp: "wmma_tensorop",
cutlass_bindings.OpClass.SparseTensorOp: "sptensorop",
}
OpcodeClassTag = {
cutlass_bindings.OpClass.Simt: "cutlass::arch::OpClassSimt",
cutlass_bindings.OpClass.TensorOp: "cutlass::arch::OpClassTensorOp",
cutlass_bindings.OpClass.WmmaTensorOp: "cutlass::arch::OpClassWmmaTensorOp",
cutlass_bindings.OpClass.SparseTensorOp: "cutlass::arch::OpClassSparseTensorOp",
}
class OperationKind(enum.Enum):
Gemm = enum_auto()
Conv2d = enum_auto()
Conv3d = enum_auto()
OperationKindNames = {
OperationKind.Gemm: "gemm",
OperationKind.Conv2d: "conv2d",
OperationKind.Conv3d: "conv3d",
}
ArchitectureNames = {
50: "maxwell",
60: "pascal",
61: "pascal",
70: "volta",
75: "turing",
80: "ampere",
90: "hopper",
}
SharedMemPerCC = {
70: 96 << 10, # 96KB of SMEM
72: 96 << 10, # 96KB of SMEM
@ -392,52 +106,8 @@ SharedMemPerCC = {
}
class GemmKind(enum.Enum):
Gemm = enum_auto()
Sparse = enum_auto()
Universal = enum_auto()
PlanarComplex = enum_auto()
PlanarComplexArray = enum_auto()
Grouped = enum_auto()
GemmKindNames = {
GemmKind.Gemm: "gemm",
GemmKind.Sparse: "spgemm",
GemmKind.Universal: "gemm",
GemmKind.PlanarComplex: "gemm_planar_complex",
GemmKind.PlanarComplexArray: "gemm_planar_complex_array",
GemmKind.Grouped: "gemm_grouped",
}
class SwizzlingFunctor(enum.Enum):
Identity1 = enum_auto()
Identity2 = enum_auto()
Identity4 = enum_auto()
Identity8 = enum_auto()
Horizontal = enum_auto()
BatchedIdentity1 = enum_auto()
StridedDgradIdentity1 = enum_auto()
StridedDgradIdentity4 = enum_auto()
StridedDgradHorizontal = enum_auto()
SwizzlingFunctorTag = {
cutlass_bindings.IdentitySwizzle1: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>",
SwizzlingFunctor.Identity2: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>",
SwizzlingFunctor.Identity4: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>",
SwizzlingFunctor.Identity8: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>",
SwizzlingFunctor.Horizontal: "cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle",
SwizzlingFunctor.BatchedIdentity1: "cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle",
SwizzlingFunctor.StridedDgradIdentity1: "cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>",
SwizzlingFunctor.StridedDgradIdentity4: "cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>",
SwizzlingFunctor.StridedDgradHorizontal: "cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle",
}
class SchedulerMode(enum.Enum):
Device = (enum_auto(),)
Device = enum_auto()
Host = enum_auto()
@ -450,61 +120,98 @@ SchedulerModeTag = {
ShortSchedulerModeNames = {SchedulerMode.Device: "Device", SchedulerMode.Host: "Host"}
ConvKindTag = {
cutlass_bindings.conv.Operator.fprop: "cutlass::conv::Operator::kFprop",
cutlass_bindings.conv.Operator.dgrad: "cutlass::conv::Operator::kDgrad",
cutlass_bindings.conv.Operator.wgrad: "cutlass::conv::Operator::kWgrad",
class FunctionalOp(enum.Enum):
AtomicAdd = enum_auto()
AtomicMaximum = enum_auto()
Divides = enum_auto()
Maximum = enum_auto()
Minimum = enum_auto()
Minus = enum_auto()
Multiplies = enum_auto()
MultiplyAdd = enum_auto()
Plus = enum_auto()
FunctionalOpTag = {
FunctionalOp.AtomicAdd: "cutlass::atomic_add",
FunctionalOp.AtomicMaximum: "cutlass::atomic_maximum",
FunctionalOp.Divides: "cutlass::divides",
FunctionalOp.Maximum: "cutlass::maximum",
FunctionalOp.Minimum: "cutlass::minimum",
FunctionalOp.Minus: "cutlass::minus",
FunctionalOp.Multiplies: "cutlass::multiplies",
FunctionalOp.MultiplyAdd: "cutlass::multiply_add",
FunctionalOp.Plus: "cutlass::plus",
}
ConvKindNames = {
cutlass_bindings.conv.Operator.fprop: "fprop",
cutlass_bindings.conv.Operator.dgrad: "dgrad",
cutlass_bindings.conv.Operator.wgrad: "wgrad",
class ActivationOp(enum.Enum):
DGelu = enum_auto()
Gelu = enum_auto()
GeluTaylor = enum_auto()
HardSwish = enum_auto()
Identity = enum_auto()
LeakyReLU = enum_auto()
ReLU = enum_auto()
Sigmoid = enum_auto()
SiLU = enum_auto()
Tanh = enum_auto()
ActivationOpTag = {
ActivationOp.DGelu: "cutlass::epilogue::thread::dGELU",
ActivationOp.Gelu: "cutlass::epilogue::thread::GELU",
ActivationOp.GeluTaylor: "cutlass::epilogue::thread::GELU_taylor",
ActivationOp.HardSwish: "cutlass::epilogue::thread::HardSwish",
ActivationOp.Identity: "cutlass::epilogue::thread::Identity",
ActivationOp.LeakyReLU: "cutlass::epilogue::thread::LeakyReLU",
ActivationOp.ReLU: "cutlass::epilogue::thread::ReLu",
ActivationOp.Sigmoid: "cutlass::epilogue::thread::Sigmoid",
ActivationOp.SiLU: "cutlass::epilogue::thread::SiLu",
ActivationOp.Tanh: "cutlass::epilogue::thread::Tanh",
}
IteratorAlgorithmTag = {
cutlass_bindings.conv.IteratorAlgorithm.analytic: "cutlass::conv::IteratorAlgorithm::kAnalytic",
cutlass_bindings.conv.IteratorAlgorithm.optimized: "cutlass::conv::IteratorAlgorithm::kOptimized",
cutlass_bindings.conv.IteratorAlgorithm.fixed_channels: "cutlass::conv::IteratorAlgorithm::kFixedChannels",
cutlass_bindings.conv.IteratorAlgorithm.few_channels: "cutlass::conv::IteratorAlgorithm::kFewChannels",
}
def op_tag(op) -> str:
"""
Dispatches `op` to the appropriate *Tag dictionary depending on whether
`op` is an ActivationOp or FunctionalOp. This is useful for cases in which
either type can be used.
:param op: operation to emit a tag for
:type op: ActivationOp | FunctionalOp
:return: tag corresponding to op
:rtype: str
"""
if isinstance(op, ActivationOp):
return ActivationOpTag[op]
elif isinstance(op, FunctionalOp):
return FunctionalOpTag[op]
else:
raise Exception(f"Unexpected op type {op}. Must be one of ActivationOp or FunctionalOp.")
IteratorAlgorithmNames = {
cutlass_bindings.conv.IteratorAlgorithm.analytic: "analytic",
cutlass_bindings.conv.IteratorAlgorithm.optimized: "optimized",
cutlass_bindings.conv.IteratorAlgorithm.fixed_channels: "fixed_channels",
cutlass_bindings.conv.IteratorAlgorithm.few_channels: "few_channels",
}
class FloatRoundStyle(enum.Enum):
ToNearest = enum_auto()
ToNearestSatfinite = enum_auto()
Indeterminate = enum_auto()
TowardZero = enum_auto()
TowardInfinity = enum_auto()
TowardNegInfinity = enum_auto()
HalfUlpTruncDntz = enum_auto()
HalfUlpTruncate = enum_auto()
class StrideSupport(enum.Enum):
Strided = enum_auto()
Unity = enum_auto()
StrideSupportTag = {
StrideSupport.Strided: "cutlass::conv::StrideSupport::kStrided",
StrideSupport.Unity: "cutlass::conv::StrideSupport::kUnity",
}
StrideSupportNames = {
StrideSupport.Strided: "",
StrideSupport.Unity: "unity_stride",
}
class ConvMode(enum.Enum):
CrossCorrelation = enum_auto()
Convolution = enum_auto()
ConvModeTag = {
ConvMode.CrossCorrelation: "cutlass::conv::Mode::kCrossCorrelation",
ConvMode.Convolution: "cutlass::conv::Mode::kConvolution",
FloatRoundStyleTag = {
FloatRoundStyle.ToNearest: "cutlass::FloatRoundStyle::round_to_nearest",
FloatRoundStyle.ToNearestSatfinite: "cutlass::FloatRoundStyle::round_to_nearest_satfinite",
FloatRoundStyle.Indeterminate: "cutlass::FloatRoundStyle::round_indeterminate",
FloatRoundStyle.TowardZero: "cutlass::FloatRoundStyle::round_toward_zero",
FloatRoundStyle.TowardInfinity: "cutlass::FloatRoundStyle::round_toward_infinity",
FloatRoundStyle.TowardNegInfinity: "cutlass::FloatRoundStyle::round_toward_neg_infinity",
FloatRoundStyle.HalfUlpTruncDntz: "cutlass::FloatRoundStyle::round_half_ulp_trunc_dntz",
FloatRoundStyle.HalfUlpTruncate: "cutlass::FloatRoundStyle::round_half_ulp_truncate",
}
@ -519,7 +226,7 @@ class MathInstruction:
element_a,
element_b,
element_accumulator,
opcode_class=cutlass_bindings.OpClass.Simt,
opcode_class=OpcodeClass.Simt,
math_operation=MathOperation.multiply_add,
):
"""
@ -529,7 +236,7 @@ class MathInstruction:
:param element_b: data type of operand B
:param element_accumulator: data type used in accumulation
:param opcode_class: higher-level class of the instruction (e.g., SIMT or Tensor Core)
:type opcode_class: cutlass_bindings.OpClass
:type opcode_class: cutlass_library.library.OpcodeClass
:param math_operation: the type of low-level operation to be performed (e.g., multiply accumulate)
:type math_operation: MathOperation
"""
@ -556,7 +263,7 @@ class TileDescription:
cluster_shape=[1, 1, 1],
kernel_schedule: KernelScheduleType = None,
epilogue_schedule: EpilogueScheduleType = None,
tile_scheduler: TileSchedulerType = None,
tile_scheduler: TileSchedulerType = None
):
"""
:param threadblock_shape: shape of a threadblock tyle
@ -610,7 +317,7 @@ class TileDescription:
else:
attrs[key] = getattr(self, key)
mi = MathInstruction(
attrs["math_instruction"] = MathInstruction(
attrs["instruction_shape"],
self.math_instruction.element_a,
self.math_instruction.element_b,
@ -619,11 +326,10 @@ class TileDescription:
self.math_instruction.math_operation
)
return TileDescription(
attrs["threadblock_shape"], attrs["stages"],
attrs["warp_count"], mi, attrs["cluster_shape"],
attrs["kernel_schedule"], attrs["epilogue_schedule"]
)
# Remove the instruction shape
del attrs["instruction_shape"]
return TileDescription(**attrs)
@property
def num_threads(self):
@ -660,6 +366,15 @@ class TileDescription:
return name
def procedural_name_2x(self):
"""
Returns a name identifying the tile description
:return: name identifying the tile description
:rtype: int
"""
return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages)
def __str__(self):
"""
Returns a string with containing each of the tile description's values
@ -695,8 +410,7 @@ class TileDescription:
class TensorDescription:
def __init__(self, element, layout, alignment=1,
complex_transform=cutlass_bindings.complex_transform.none):
def __init__(self, element, layout, alignment=1, complex_transform=ComplexTransform.none):
self.element = element
self.layout = layout
self.alignment = min(128 // DataTypeSize[self.element], alignment)
@ -751,7 +465,7 @@ class ApiVersion(enum.Enum):
v3x = enum_auto()
def api_version(arch, opclass, datatype):
def api_version(arch, opclass, dtype):
"""
Returns whether the architecture, opcode class, and datatype in question require using CUTLASS 2.x
or 3.x for code emission.
@ -759,15 +473,16 @@ def api_version(arch, opclass, datatype):
:param arch: compute capability of device on which to run
:type arch: int
:param opclass: class of the operation being performed
:type opclass: cutlass_bindings.OpClass
:param datatype: data type to be used in operation (assumes that ElementA and ElementB are the same)
:type opclass: cutlass.OpcodeClass
:param dtype: data type to be used in operation (assumes that ElementA and ElementB are the same)
:type dtype: cutlass.DataType
:return: API version to be used in code emission
:rtype: ApiVersion
"""
if (arch >= 90 and
opclass == cutlass_bindings.OpClass.TensorOp and
(datatype != cutlass_bindings.float64)):
opclass == OpcodeClass.TensorOp and
(dtype != DataType.f64)):
return ApiVersion.v3x
else:
return ApiVersion.v2x

View File

@ -1,877 +0,0 @@
################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
import ast
import ctypes
import inspect
import textwrap
from typing import Generic, TypeVar
from cuda import cuda, cudart
import numpy as np
from treelib import Tree
from cutlass.backend.epilogue import (
AccumulatorOp,
BinaryOp,
ColumnBroadcastOp,
ColumnReductionOp,
RowBroadcastOp,
RowReductionOp,
TensorInputOp,
TensorOutputOp,
UnaryOp,
)
from cutlass.backend.frontend import NumpyFrontend
from cutlass.backend.utils.software import SubstituteTemplate
import cutlass.backend as backend
################################################################################
# Type annotation for input arguments
################################################################################
Ttype = TypeVar("Ttype")
Dtype = TypeVar("Dtype")
class NDArray(np.ndarray, Generic[Ttype, Dtype]):
pass
################################################################################
# Operations
################################################################################
operators = {
ast.Add: "Add",
ast.Div: "Div",
ast.Eq: "Equal",
ast.Mult: "Mult",
}
################################################################################
# AST Node abstractions
################################################################################
class UnaryNode:
cnt = 0
# Concept: this is created by the BinOp Node in python ast
def __init__(
self,
element_accumulator,
element_compute,
elements_per_access,
node,
args,
) -> None:
if isinstance(node, BinOpNode):
self.op = node.op
elif isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
self.op = node.func.id
elif isinstance(node.func, ast.Attribute):
self.op = node.func.value.id
else:
raise TypeError
else:
raise TypeError
self.tag = "Unary" + self.op + str(UnaryNode.cnt)
self.id = self.op + str(UnaryNode.cnt)
self.args = args
UnaryNode.cnt += 1
self.type = "tensor"
self.epilogue_op = getattr(backend, self.op)(element_compute)
# data types
self.element_accumulator = element_accumulator
self.element_compute = element_compute
self.elements_per_access = elements_per_access
def get_epilogue_node(self, visitors):
self.epilogue_node = UnaryOp(
self.element_accumulator,
self.element_compute,
self.elements_per_access,
*visitors,
self.epilogue_op,
)
def get_argument(self, visitor_args, kwargs):
epilogue_ops = []
for arg in self.args:
try:
epilogue_ops.append(kwargs[arg])
except:
epilogue_ops.append(arg) # direct arguments like constant
self.argument = self.epilogue_node.argument_type(
self.epilogue_op.argument_type(*epilogue_ops),
*visitor_args,
)
class BinOpNode:
cnt = 0
# Concept: this is created by the BinOp Node in python ast
def __init__(
self,
element_accumulator,
element_compute,
elements_per_access,
node,
) -> None:
self.op = operators[type(node.op)]
self.tag = "Binary" + self.op + str(BinOpNode.cnt)
self.id = self.op + str(BinOpNode.cnt)
self.args = None
BinOpNode.cnt += 1
self.type = "tensor"
self.epilogue_op = getattr(backend, "Vector" + self.op)(element_compute)
# data types
self.element_accumulator = element_accumulator
self.element_compute = element_compute
self.elements_per_access = elements_per_access
def get_epilogue_node(self, visitors):
self.epilogue_node = BinaryOp(
self.element_accumulator,
self.element_compute,
self.elements_per_access,
*visitors,
self.epilogue_op,
)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type(
self.epilogue_op.argument_type(self.args),
*visitor_args,
)
class NameNode:
# Concept: this is created by the Name Node in python ast
def __init__(self, node) -> None:
try:
self.id = node.id
except:
self.id = node.targets[0].id
self.tag = self.id
class ScalarInputNode(NameNode):
# Concept: scalar
def __init__(self, node) -> None:
super().__init__(node)
self.tag = "Scalar:" + self.tag
self.type = "scalar"
class AccumulatorNode(NameNode):
# Concept: VisitorOpAccumulator
def __init__(
self,
element_accumulator,
elements_per_access,
node,
) -> None:
super().__init__(node)
self.tag = "Accum:" + self.tag
self.type = "tensor"
self.element_accumulator = element_accumulator
self.elements_per_access = elements_per_access
def get_epilogue_node(self, visitors):
self.epilogue_node = AccumulatorOp(
self.element_accumulator,
self.elements_per_access,
)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type()
class TensorInputNode(NameNode):
# Concept: VisitorOpTensorInput
def __init__(self, element_accumulator, node) -> None:
super().__init__(node)
self.tag = "TensorInput:" + self.tag
self.type = "tensor"
self.element_accumulator = element_accumulator
def get_epilogue_node(self, *args):
self.epilogue_node = TensorInputOp(self.element_accumulator)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type(
kwargs[self.id + "_ptr"],
kwargs["problem_size"][1],
kwargs["problem_size"][0] * kwargs["problem_size"][1],
)
class RowBroadcastNode(NameNode):
# Concept: VisitorOpRowBroadcast
def __init__(
self,
element_accumulator,
element_fragment,
node,
) -> None:
super().__init__(node)
#
self.tag = "RowBroadcast:" + self.tag
self.type = "tensor"
self.element_accumulator = element_accumulator
self.element_fragment = element_fragment
def get_epilogue_node(self, *args):
self.epilogue_node = RowBroadcastOp(
self.element_accumulator,
self.element_fragment,
)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type(
kwargs[self.id + "_ptr"],
kwargs["problem_size"][1],
)
class ColumnBroadcastNode(NameNode):
# Concept: VisitorOpColumnBroadcast
def __init__(
self,
element_accumulator,
element_fragment,
node,
) -> None:
super().__init__(node)
self.tag = "ColumnBroadcast:" + self.tag
self.type = "tensor"
self.element_accumulator = element_accumulator
self.element_fragment = element_fragment
def get_epilogue_node(self, *args):
self.epilogue_node = ColumnBroadcastOp(
self.element_accumulator,
self.element_fragment,
)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type(
kwargs[self.id + "_ptr"],
kwargs["problem_size"][0],
)
class TensorOutputNode(NameNode):
# Concept: VisitorOpTensorOutput
def __init__(self, element_accumulator, node) -> None:
super().__init__(node)
self.tag = "TensorOutput:" + self.tag
self.type = "tensor"
self.element_accumulator = element_accumulator
def get_epilogue_node(self, visitors):
self.epilogue_node = TensorOutputOp(self.element_accumulator, *visitors)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type(
kwargs[self.id + "_ptr"],
kwargs["problem_size"][1],
*visitor_args,
kwargs["problem_size"][0] * kwargs["problem_size"][1],
)
class RowReductionNode:
# Concept: RowReductionOp
def __init__(
self,
element_accumulator,
element_reduction,
element_reduction_accumulator,
id,
factor,
) -> None:
#
self.id = id
self.tag = "RowReduction:" + self.id
self.type = "tensor"
self.element_accumulator = element_accumulator
self.element_reduction = element_reduction
self.element_reduction_accumulator = element_reduction_accumulator
self.factor = factor
def get_epilogue_node(self, visitors):
self.epilogue_node = RowReductionOp(
self.element_accumulator,
self.element_reduction,
self.element_reduction_accumulator,
*visitors,
)
def get_batch_stride(self, problem_size):
return problem_size[0] * ((problem_size[1] + self.factor - 1) // self.factor)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type(
kwargs[self.id + "_ptr"],
*visitor_args,
self.get_batch_stride(kwargs["problem_size"]),
)
class ColumnReductionNode:
# Concept: ColumnReductionOp
def __init__(
self,
element_accumulator,
element_reduction,
element_reduction_accumulator,
id,
factor,
) -> None:
#
self.id = id
self.tag = "ColumnReduction:" + self.id
self.type = "tensor"
self.element_accumulator = element_accumulator
self.element_reduction = element_reduction
self.element_reduction_accumulator = element_reduction_accumulator
self.factor = factor
def get_epilogue_node(self, visitors):
self.epilogue_node = ColumnReductionOp(
self.element_accumulator,
self.element_reduction,
self.element_reduction_accumulator,
*visitors,
)
def get_batch_stride(self, problem_size):
return problem_size[1] * ((problem_size[0] + self.factor - 1) // self.factor)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type(
kwargs[self.id + "_ptr"],
*visitor_args,
self.get_batch_stride(kwargs["problem_size"]),
)
################################################################################
# Epilogue parser function
################################################################################
class EpilogueAST(ast.NodeVisitor):
def __init__(
self,
epilogue,
tile_description,
element_accumulator,
elements_per_access,
element_compute,
element_output,
) -> None:
#
self.tile_description = tile_description
self.element_accumulator = element_accumulator
self.elements_per_access = elements_per_access
self.element_compute = element_compute
self.element_output = element_output
self.epilogue = epilogue
self.source = textwrap.dedent(inspect.getsource(epilogue.__call__))
self.ast_tree = ast.parse(self.source)
self.epilogue_tree = Tree()
# print(ast.dump(self.ast_tree, indent=4)) # For Debug purpose
# input arguments
self.input_args = {}
# return nodes
self.returns = []
# reduction source nodes
self.reduction_source = {}
# stack used to keep the parent node id
self.stack = []
# visit the AST
self.visit(self.ast_tree)
# visit the name node
def visit_Name(self, node):
# append the return ids into self.returns
if self.stack[-1] == "return":
self.returns.append(node.id)
else:
# accum is produced from accumulator node
if node.id == "accum":
name_node = AccumulatorNode(
self.element_accumulator,
self.elements_per_access,
node,
)
else:
# for input nodes
if node.id in self.input_args.keys():
type = self.input_args[node.id][0]
if type == "tensor":
name_node = TensorInputNode(
self.element_accumulator,
node,
)
elif type == "row":
name_node = RowBroadcastNode(
self.element_accumulator,
self.element_compute,
node,
)
elif type == "column":
name_node = ColumnBroadcastNode(
self.element_accumulator,
self.element_compute,
node,
)
elif type == "scalar":
name_node = ScalarInputNode(node)
else:
raise ValueError(type)
# for output nodes
else:
name_node = TensorOutputNode(
self.element_accumulator,
node,
)
self.epilogue_tree.create_node(
name_node.tag,
name_node.id,
data=name_node,
parent=self.stack[-1],
)
def visit_Assign(self, node):
pre_assign_node = self.epilogue_tree.get_node(node.targets[0].id)
if pre_assign_node is None:
# The assign is to a root node
# skip the reduction nodes
if isinstance(node.value, ast.Call):
if isinstance(node.value.func, ast.Name):
func_type = node.value.func.id
elif isinstance(node.value.func, ast.Attribute):
func_type = node.value.func.value.id
else:
raise TypeError
if func_type == "reduction_op":
self.reduction_source[node.value.args[0].id] = [
node.value.args[1].value,
node.value.args[2].value,
node.targets[0].id,
]
return
name_node = TensorOutputNode(self.element_accumulator, node)
self.epilogue_tree.create_node(
name_node.tag,
name_node.id,
data=name_node,
)
self.stack.append(name_node.id)
else:
if (
node.targets[0].id in self.returns
or node.targets[0].id in self.reduction_source.keys()
):
self.stack.append(node.targets[0].id)
else:
self.stack.append(
pre_assign_node.predecessor(self.epilogue_tree.identifier)
)
self.epilogue_tree.remove_node(node.targets[0].id)
# get child tag
self.visit(node.value)
self.stack.pop()
def visit_Call(self, node):
if isinstance(node.func, ast.Name):
func_type = node.func.id
elif isinstance(node.func, ast.Attribute):
func_type = node.func.value.id
else:
raise TypeError
if func_type == "reduction_op":
self.visit(node.args[0])
else:
arg_list = []
for idx, arg in enumerate(node.args):
if idx == 0:
continue
if isinstance(arg, ast.Constant):
arg_list.append(arg.value)
elif isinstance(arg, ast.Name):
arg_list.append(arg.id)
else:
raise TypeError
unary_node = UnaryNode(
self.element_accumulator,
self.element_compute,
self.elements_per_access,
node,
arg_list,
)
self.epilogue_tree.create_node(
unary_node.tag,
unary_node.id,
parent=self.stack[-1],
data=unary_node,
)
self.stack.append(unary_node.id)
self.visit(node.args[0])
self.stack.pop()
def visit_BinOp(self, node):
binop = BinOpNode(
self.element_accumulator,
self.element_compute,
self.elements_per_access,
node,
)
self.epilogue_tree.create_node(
binop.tag,
binop.id,
data=binop,
parent=self.stack[-1],
)
self.stack.append(binop.id)
self.visit(node.left)
self.visit(node.right)
self.stack.pop()
def visit_Return(self, node):
self.stack.append("return")
self.visit(node.value)
self.stack.pop()
# # A function definition
def visit_FunctionDef(self, node: ast.FunctionDef):
# visit args
for arg in node.args.args:
if arg.arg == "self":
continue
if isinstance(arg.annotation, ast.Constant):
self.input_args[arg.arg] = [
arg.annotation.value,
]
# visit the assign in the reverse order
for idx in range(len(node.body)):
self.visit(node.body[-1 - idx])
#
# Tree optimization pass
#
# pass 1: lower Binary to Unary
def pass_binary_2_unary(self, tree, nid):
node = tree.get_node(nid)
if isinstance(node.data, BinOpNode):
lhs_node = tree.get_node(node.successors(tree.identifier)[0])
left_type = lhs_node.data.type
rhs_node = tree.get_node(node.successors(tree.identifier)[1])
right_type = rhs_node.data.type
if left_type == "scalar" and right_type == "tensor":
node.data = UnaryNode(
self.element_accumulator,
self.element_compute,
self.elements_per_access,
node.data,
[
lhs_node.data.id,
],
)
node.tag = node.data.tag
tree.remove_node(lhs_node.data.id)
self.pass_binary_2_unary(tree, rhs_node.data.id)
elif left_type == "tensor" and right_type == "scalar":
node.data = UnaryNode(
self.element_accumulator,
self.element_compute,
self.elements_per_access,
node.data,
[
rhs_node.id,
],
)
node.tag = node.data.tag
tree.remove_node(rhs_node.data.id)
self.pass_binary_2_unary(tree, lhs_node.data.id)
else:
self.pass_binary_2_unary(tree, lhs_node.data.id)
self.pass_binary_2_unary(tree, rhs_node.data.id)
else:
for child in node.successors(tree.identifier):
self.pass_binary_2_unary(tree, child)
# pass 2: inject reduction nodes
def pass_inject_reduction(self, tree, nid):
node = tree.get_node(nid)
if isinstance(node.data, TensorOutputNode):
if node.data.id in self.reduction_source.keys():
direction = self.reduction_source[node.data.id][0]
target = self.reduction_source[node.data.id][-1]
if direction == "row":
reduction_node = RowReductionNode(
self.element_accumulator,
self.element_output,
self.element_accumulator,
target,
self.tile_description.threadblock_shape[1],
)
elif direction == "column":
reduction_node = ColumnReductionNode(
self.element_accumulator,
self.element_output,
self.element_accumulator,
target,
self.tile_description.threadblock_shape[0],
)
else:
raise ValueError(direction)
child_nid = node.successors(tree.identifier)[0]
# if this output node is injected only for reduction
if node.data.id not in self.returns:
# get reduction config from disc
node.data = reduction_node
node.tag = reduction_node.tag
self.pass_inject_reduction(tree, child_nid)
# if this output node is also a tensor output, inject reduction as its children
else:
# get child node
tree.create_node(
reduction_node.tag,
reduction_node.id,
data=reduction_node,
parent=node.data.id,
)
tree.move_node(
child_nid,
reduction_node.id,
)
child = tree.get_node(child_nid)
for grand_child in child.successors(tree.identifier):
self.pass_inject_reduction(tree, grand_child)
else:
for child in node.successors(tree.identifier):
self.pass_inject_reduction(tree, child)
else:
for child in node.successors(tree.identifier):
self.pass_inject_reduction(tree, child)
def pass_inject_epilogue_op(self, tree, nid):
node = tree.get_node(nid)
visitors = []
for child in node.successors(tree.identifier):
visitors.append(self.pass_inject_epilogue_op(tree, child))
node.data.get_epilogue_node(visitors)
return node.data.epilogue_node
def get_arguments(self, tree, nid, kwargs):
node = tree.get_node(nid)
visitor_args = []
for child in node.successors(tree.identifier):
visitor_args.append(self.get_arguments(tree, child, kwargs))
node.data.get_argument(visitor_args, kwargs)
return node.data.argument
class EpilogueVisitTree:
KernelTemplate = """
${visitor}
using ${operation_name}_EpilogueVisitor = cutlass::epilogue::threadblock::EpilogueVisitorGeneric<${visitor_name}>;
"""
def __init__(
self,
elementwise_functor,
tile_description,
element_accumulator,
elements_per_access,
element_compute,
element_output,
) -> None:
#
# data types
self.tile_description = tile_description
self.element_accumulator = element_accumulator
self.elements_per_access = elements_per_access
self.element_compute = element_compute
self.element_output = element_output
self.elementwise_functor = elementwise_functor
pass
def initialize(self):
function = EpilogueAST(
self,
self.tile_description,
self.element_accumulator,
self.elements_per_access,
self.element_compute,
self.element_output,
)
#
tree = function.epilogue_tree
self.tree = tree
function.pass_binary_2_unary(self.tree, self.tree.root)
function.pass_inject_reduction(self.tree, self.tree.root)
function.pass_inject_epilogue_op(self.tree, self.tree.root)
visitor = self.tree.get_node(self.tree.root).data.epilogue_node
self.visitor = visitor
class _Argument(ctypes.Structure):
_fields_ = [
(
"visitor_arg",
visitor.argument_type,
)
]
def __init__(self, **kwargs) -> None:
# process input args
_kwargs = {}
for input_key in function.input_args.keys():
if input_key == "accum":
continue
if function.input_args[input_key][0] == "scalar":
continue
# tensor input
else:
setattr(
self,
"buffer_tensor_" + input_key,
NumpyFrontend.argument(
kwargs[input_key],
False,
),
)
setattr(
self,
input_key + "_ptr",
int(
getattr(
self,
"buffer_tensor_" + input_key,
).ptr
),
)
_kwargs[input_key + "_ptr"] = getattr(
self,
input_key + "_ptr",
)
# process the return args
for ret in function.returns:
setattr(
self,
"buffer_tensor_" + ret,
NumpyFrontend.argument(kwargs[ret], True),
)
setattr(
self,
ret + "_ptr",
int(
getattr(
self,
"buffer_tensor_" + ret,
).ptr
),
)
_kwargs[ret + "_ptr"] = getattr(self, ret + "_ptr")
setattr(
self,
"host_tensor_" + ret,
kwargs[ret],
)
_kwargs.update(kwargs)
function.get_arguments(tree, tree.root, _kwargs)
self.visitor_arg = tree.get_node(tree.root).data.argument
def sync(self, stream_sync=True):
if stream_sync:
(err,) = cudart.cudaDeviceSynchronize()
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
for ret in function.returns:
(err,) = cuda.cuMemcpyDtoH(
getattr(
self,
"host_tensor_" + ret,
),
cuda.CUdeviceptr(getattr(self, ret + "_ptr")),
getattr(
self,
"host_tensor_" + ret,
).size
* getattr(
self,
"host_tensor_" + ret,
).itemsize,
)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
pass
self.epilogue_type = _Argument
def emit(self, operation):
values = {
"visitor": self.visitor.emit(operation),
"operation_name": operation.procedural_name(),
"visitor_name": self.visitor.instance_name,
}
return SubstituteTemplate(self.KernelTemplate, values)

View File

@ -30,24 +30,24 @@
#
################################################################################
import ctypes
from typing import Union
import ctypes
from cuda import cuda, cudart
import cutlass_bindings
import numpy as np
from cutlass.backend.c_types import MatrixCoord_, TensorRef2D_, get_reduction_params
from cutlass.backend.frontend import NumpyFrontend, TorchFrontend
from cutlass.backend.library import (
from cutlass import (
DataTypeNames,
DataTypeSize,
DataTypeTag,
TensorDescription,
LayoutType
)
from cutlass.backend.c_types import MatrixCoord_, TensorRef2D_, get_reduction_params
from cutlass.backend.frontend import NumpyFrontend, TorchFrontend
from cutlass.backend.library import TensorDescription
from cutlass.backend.operation import ExecutableOperation, LaunchConfiguration
from cutlass.backend.utils.software import CheckPackages, SubstituteTemplate
from cutlass.shape import MatrixCoord
if CheckPackages().check_torch():
import torch
@ -80,10 +80,9 @@ class ReductionArguments:
self.bias = False
self.operation = operation
#: pointer to the workspace
self.ptr_workspace = workspace
#: number of split-k partitions
# number of split-k partitions
self.partitions = partitions
if isinstance(destination, np.ndarray):
@ -112,19 +111,18 @@ class ReductionArguments:
else:
self.output_op = self.operation.epilogue_type(1.0, 0.0)
# get arguments
self.get_arguments()
@staticmethod
def get_tensor_ref(
extent: "tuple[int]",
device_ptr: cuda.CUdeviceptr,
layout: cutlass_bindings.layout,
layout: LayoutType,
):
if layout == cutlass_bindings.RowMajor:
if layout == LayoutType.RowMajor:
return TensorRef2D_(int(device_ptr), extent[1])
else:
raise ValueError("unknown layout type")
raise ValueError(f"Unknown layout type {layout}")
def get_arguments(self):
ref_workspace = ReductionArguments.get_tensor_ref(
@ -133,13 +131,13 @@ class ReductionArguments:
self.problem_size.column,
],
device_ptr=self.ptr_workspace,
layout=cutlass_bindings.RowMajor,
layout=LayoutType.RowMajor,
)
if self.bias:
ref_source = ReductionArguments.get_tensor_ref(
extent=[0, 0],
device_ptr=self.ptr_source,
layout=cutlass_bindings.RowMajor,
layout=LayoutType.RowMajor,
)
else:
ref_source = ReductionArguments.get_tensor_ref(
@ -148,7 +146,7 @@ class ReductionArguments:
self.problem_size.column,
],
device_ptr=self.ptr_source,
layout=cutlass_bindings.RowMajor,
layout=LayoutType.RowMajor,
)
ref_destination = ReductionArguments.get_tensor_ref(
@ -157,7 +155,7 @@ class ReductionArguments:
self.problem_size.column,
],
device_ptr=self.ptr_destination,
layout=cutlass_bindings.RowMajor,
layout=LayoutType.RowMajor,
)
self.c_arguments = self.operation.argument_type(
@ -176,7 +174,7 @@ class ReductionArguments:
def sync(self):
(err,) = cudart.cudaDeviceSynchronize()
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
raise RuntimeError(f"CUDA Error {str(err)}")
if hasattr(self, "host_D"):
(err,) = cuda.cuMemcpyDtoH(
@ -258,15 +256,15 @@ extern "C" {
def plan(self, arguments: ReductionArguments):
block_shape = [
self.operation.shape.column() // self.elements_per_access,
self.operation.shape.row(),
self.operation.shape.column // self.elements_per_access,
self.operation.shape.row,
1,
]
grid_shape = [
(arguments.problem_size.row + self.operation.shape.row() - 1)
// self.operation.shape.row(),
(arguments.problem_size.column + self.operation.shape.column() - 1)
// self.operation.shape.column(),
(arguments.problem_size.row + self.operation.shape.row - 1)
// self.operation.shape.row,
(arguments.problem_size.column + self.operation.shape.column - 1)
// self.operation.shape.column,
1,
]
return LaunchConfiguration(
@ -282,20 +280,17 @@ extern "C" {
value=self.shared_memory_capacity,
)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("Cuda Error: {}".format(err))
raise RuntimeError(f"CUDA Error: {err}")
class ReductionOperation:
"""
CUTLASS Reduction Operation
shape: shape of CTA
outputop: output operator
r
CUTLASS reduction Operation
"""
def __init__(
self,
shape: cutlass_bindings.MatrixCoord,
shape: MatrixCoord,
C: TensorDescription,
element_accumulator,
element_workspace=None,
@ -304,45 +299,33 @@ class ReductionOperation:
count: int = 1,
partitions_per_stage: int = 4,
) -> None:
"""Constructor"""
self.shape = shape
#: epilogue functor (default: LinearCombination)
self.epilogue_functor = epilogue_functor
#: datatype of accumulator
self.element_accumulator = element_accumulator
if element_workspace is None:
#: datatype of workspace
self.element_workspace = element_accumulator
else:
#: datatype of workspace
self.element_workspace = element_workspace
if element_compute is None:
#: datatype of workspace
self.element_compute = element_accumulator
else:
#: datatype of workspace
self.element_compute = element_compute
#: datatype of output
self.element_output = C.element
#: operand C
self.C: TensorDescription = C
#: reduce op processing size
# Reduce op processing size
self.count: int = count
#: number of partitions to reduce per stage
# Number of partitions to reduce per stage
self.partitions_per_stage: int = partitions_per_stage
self.rt_module: ReductionRT = ReductionRT(self)
self.argument_type = self.rt_module.argument_type
self.epilogue_type = self.rt_module.epilogue_type
#
def extended_name(self):
extend_name = "${element_workspace}_${element_accumulator}_${element_compute}_${element_output}"
@ -356,15 +339,14 @@ class ReductionOperation:
},
)
#
def configuration_name(self):
"""The full procedural name indicates architecture, extended name, tile size"""
configuration_name = "cutlass_reduce_split_k_${extended_name}_${threadblock}"
threadblock = "%dx%d" % (
self.shape.row(),
self.shape.column(),
self.shape.row,
self.shape.column,
)
return SubstituteTemplate(
@ -375,7 +357,6 @@ class ReductionOperation:
},
)
#
def procedural_name(self):
"""The full procedural name indicates architeture, extended name, tile size"""
return self.configuration_name()
@ -384,14 +365,11 @@ class ReductionOperation:
"""
Configure and launch the cuda kernel with input arguments
"""
# get launch configuration
launch_config = self.rt_module.plan(arguments)
# get the host and device workspace
host_workspace = arguments.host_workspace
device_workspace = None
# launch the kernel
err = self.rt_module.run(
host_workspace,
device_workspace,
@ -399,7 +377,7 @@ class ReductionOperation:
)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
raise RuntimeError(f"CUDA Error {str(err)}")
return err
@ -421,7 +399,7 @@ class EmitReductionInstance:
]
self.template = """
// Reduction kernel instance
using ${operation_name}_base =
using ${operation_name}_base =
typename cutlass::reduction::kernel::ReduceSplitK<
cutlass::MatrixShape<${shape_row}, ${shape_column}>,
${epilogue_functor},
@ -436,19 +414,14 @@ struct ${operation_name}${operation_suffix}:
"""
def emit(self, operation: ReductionOperation):
epilogue_vector_length = int(
min(
operation.C.alignment * DataTypeSize[operation.C.element],
128,
)
/ DataTypeSize[operation.C.element]
)
vector_length_bits = min(operation.C.alignment * DataTypeSize[operation.C.element], 128)
epilogue_vector_length = vector_length_bits // DataTypeSize[operation.C.element]
values = {
"operation_name": operation.configuration_name(),
"operation_suffix": self.operation_suffix,
"shape_row": str(operation.shape.row()),
"shape_column": str(operation.shape.column()),
"shape_row": str(operation.shape.row),
"shape_column": str(operation.shape.column),
"epilogue_functor": operation.epilogue_functor.emit(),
"element_output": DataTypeTag[operation.element_output],
"epilogue_vector_length": str(epilogue_vector_length),

View File

@ -1,807 +0,0 @@
#################################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import re
import subprocess
from time import sleep
from bfloat16 import bfloat16
import cutlass_bindings
import numpy as np
from cutlass.backend import compiler
from cutlass.backend.conv2d_operation import Conv2dArguments, Conv2dOperation
from cutlass.backend.library import DataTypeSize, ShortDataTypeNames, StrideSupport
from cutlass.backend.memory_manager import get_allocated_size
from cutlass.backend.reduction_operation import ReductionArguments, ReductionOperation
from cutlass.backend.test.profiler import GpuTimer
from cutlass.backend.utils.software import SubstituteTemplate
def getTensorRef(tensor, tensor_layout, conv_kind, problem_size, operand):
ptr = tensor.__array_interface__["data"][0]
if operand == "a":
tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_a_extent(
conv_kind, problem_size
)
elif operand == "b":
tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_b_extent(
conv_kind, problem_size
)
elif operand in ["c", "d"]:
tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_c_extent(
conv_kind, problem_size
)
else:
raise ValueError("unknown operand: " + operand)
layout = tensor_layout.packed(tensor_coord)
if tensor.dtype == np.float64:
return cutlass_bindings.TensorRefF64NHWC(ptr, layout)
elif tensor.dtype == np.float32:
return cutlass_bindings.TensorRefF32NHWC(ptr, layout)
elif tensor.dtype == np.float16:
return cutlass_bindings.TensorRefF16NHWC(ptr, layout)
if tensor.dtype == bfloat16:
return cutlass_bindings.TensorRefBF16NHWC(ptr, layout)
elif tensor.dtype == np.int32:
return cutlass_bindings.TensorRefS32NHWC(ptr, layout)
elif tensor.dtype == np.int8:
if tensor_layout == cutlass_bindings.TensorNC32HW32:
return cutlass_bindings.TensorRefS8NC32HW32(ptr, layout)
elif tensor_layout == cutlass_bindings.TensorC32RSK32:
return cutlass_bindings.TensorRefS8C32RSK32(ptr, layout)
else:
return cutlass_bindings.TensorRefS8NHWC(ptr, layout)
else:
raise ValueError("unsupported data type")
def getTensorView(tensor, tensor_layout, conv_kind, problem_size, operand):
tensor_ref = getTensorRef(tensor, tensor_layout, conv_kind, problem_size, operand)
if operand == "a":
tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_a_extent(
conv_kind, problem_size
)
elif operand == "b":
tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_b_extent(
conv_kind, problem_size
)
elif operand in ["c", "d"]:
tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_c_extent(
conv_kind, problem_size
)
else:
raise ValueError("unknown operand: " + operand)
if tensor.dtype == np.float64:
return cutlass_bindings.TensorViewF64NHWC(tensor_ref, tensor_coord)
elif tensor.dtype == np.float32:
return cutlass_bindings.TensorViewF32NHWC(tensor_ref, tensor_coord)
elif tensor.dtype == np.float16:
return cutlass_bindings.TensorViewF16NHWC(tensor_ref, tensor_coord)
elif tensor.dtype == bfloat16:
return cutlass_bindings.TensorViewBF16NHWC(tensor_ref, tensor_coord)
elif tensor.dtype == np.int32:
return cutlass_bindings.TensorViewS32NHWC(tensor_ref, tensor_coord)
elif tensor.dtype == np.int8:
if tensor_layout == cutlass_bindings.TensorNC32HW32:
return cutlass_bindings.TensorViewS8NC32HW32(tensor_ref, tensor_coord)
elif tensor_layout == cutlass_bindings.TensorC32RSK32:
return cutlass_bindings.TensorViewS8C32RSK32(tensor_ref, tensor_coord)
else:
return cutlass_bindings.TensorViewS8NHWC(tensor_ref, tensor_coord)
else:
raise ValueError("unsupported data type")
class Conv2dLauncher:
"""
Launcher that runs the operation on given problem size
"""
def __init__(
self,
operation: "Conv2dOperation",
seed: int = 2080,
interleaved=False,
verification=True,
profiling=False,
warmup_iterations=500,
iterations=500,
compilation_mode="nvcc",
**kwargs,
) -> None:
self.enable_cached_results = True
self.interleaved = interleaved
# create the reduction kernel
self.reduction_operation = ReductionOperation(
shape=cutlass_bindings.MatrixCoord(4, 32 * operation.C.alignment),
C=operation.C,
element_accumulator=operation.tile_description.math_instruction.element_accumulator,
element_compute=operation.epilogue_functor.element_epilogue,
epilogue_functor=operation.epilogue_functor,
count=operation.C.alignment,
)
#: verify the output result
self.verification = verification
#: profile the kernel's runtime
self.profiling = profiling
self.timer = GpuTimer()
self.warmup_iterations = warmup_iterations
self.iterations = iterations
if "sleep" in kwargs.keys():
self.sleep_time = kwargs["sleep"]
else:
self.sleep_time = 0
#
# Compile the operator
#
if compilation_mode == "nvcc":
compiler.nvcc()
elif compilation_mode == "nvrtc":
compiler.nvrtc()
else:
raise Exception(f"Unexpected compilation mode {compilation_mode}")
compiler.add_module([operation, self.reduction_operation])
self.operation = operation
self.dtype_A = Conv2dLauncher.numpy_type(operation.A.element)
self.layout_A = operation.A.layout
self.dtype_B = Conv2dLauncher.numpy_type(operation.B.element)
self.layout_B = operation.B.layout
self.dtype_C = Conv2dLauncher.numpy_type(operation.C.element)
self.layout_C = operation.C.layout
self.dtype_D = Conv2dLauncher.numpy_type(operation.C.element)
self.layout_D = operation.C.layout
accumulator_size = DataTypeSize[
operation.tile_description.math_instruction.element_accumulator
]
element_size = DataTypeSize[operation.A.element]
if element_size <= 8:
self.randomization_max = 1
elif element_size == 16:
if accumulator_size <= 16:
self.randomization_max = 2
else:
self.randomization_max = 4
else:
self.randomization_max = 7
# Seed
self.seed = seed
self.conv_kind = operation.conv_kind
#
# Get the host reference function
#
self.element_compute = operation.epilogue_functor.element_epilogue
self.host_conv2d = cutlass_bindings.test.conv.host.conv2d
self.timer = GpuTimer()
@staticmethod
def numpy_type(type):
if type == cutlass_bindings.float64:
return np.float64
elif type == cutlass_bindings.float32:
return np.float32
elif type == cutlass_bindings.float16:
return np.float16
elif type == cutlass_bindings.bfloat16:
return bfloat16
elif type == cutlass_bindings.int32:
return np.int32
elif type == cutlass_bindings.int8:
return np.int8
else:
raise ValueError("unsupported type: %s" % ShortDataTypeNames[type])
def print_problem_size(self, p, split_k_mode=1):
print(
"nhwc_%dx%dx%dx%d_krsc_%dx%dx%dx%d_padding_%dx%d_stride_%dx%d_dilation_%dx%d_splitkslices_%d_splitkmode_%d"
% (
p.N,
p.H,
p.W,
p.C,
p.K,
p.R,
p.S,
p.C,
p.pad_h,
p.pad_w,
p.stride_h,
p.stride_w,
p.dilation_h,
p.dilation_w,
p.split_k_slices,
split_k_mode,
)
)
def uniform_init(self, size, dtype):
if dtype in [np.float32, np.float16, bfloat16, np.float64]:
return np.ceil(
np.random.uniform(
low=-self.randomization_max - 0.5, high=self.randomization_max - 0.5, size=size
).astype(dtype)
)
else:
return np.random.uniform(
low=-self.randomization_max - 1, high=self.randomization_max + 1, size=size
).astype(dtype)
def eq_gemm_size(self, problem_size):
n = problem_size.N
p = problem_size.P
q = problem_size.Q
k = problem_size.K
r = problem_size.R
s = problem_size.S
c = problem_size.C
h = problem_size.H
w = problem_size.W
if self.conv_kind == cutlass_bindings.conv.Operator.fprop:
return cutlass_bindings.gemm.GemmCoord(n * p * q, k, r * s * c)
elif self.conv_kind == cutlass_bindings.conv.Operator.dgrad:
return cutlass_bindings.gemm.GemmCoord(n * h * w, c, k * r * s)
else:
return cutlass_bindings.gemm.GemmCoord(k, r * s * c, n * p * q)
def bytes(self, problem_size, alpha, beta):
mnk = self.eq_gemm_size(problem_size)
bytes_ = (
(DataTypeSize[self.operation.A.element] * mnk.m() // 8) * mnk.k()
+ (DataTypeSize[self.operation.B.element] * mnk.n() // 8) * mnk.k()
+ (DataTypeSize[self.operation.C.element] * mnk.m() // 8) * mnk.n()
)
if beta != 0:
bytes_ += (DataTypeSize[self.operation.C.element] * mnk.m() // 8) * mnk.n()
return bytes_
def flops(self, problem_size):
mnk = self.eq_gemm_size(problem_size)
flops_mainloop_ = mnk.m() * mnk.n() * mnk.k() * 2
flops_epilogue_ = mnk.m() * mnk.n() * 2
# Adjust mainloop flop for dgrad stride
if self.conv_kind == cutlass_bindings.conv.Operator.dgrad:
flops_mainloop_ = flops_mainloop_ // (
problem_size.stride_h * problem_size.stride_w
)
flops_total_ = flops_mainloop_ + flops_epilogue_
return flops_total_
def host_reference(self, problem_size, tensor_A, tensor_B, tensor_C, alpha, beta):
if self.element_compute == cutlass_bindings.float16:
alpha = cutlass_bindings.float16(alpha)
beta = cutlass_bindings.float16(beta)
elif self.element_compute == cutlass_bindings.int32:
alpha = int(alpha)
beta = int(beta)
else:
alpha = alpha
beta = beta
# if cached result is loaded
cached_result_loaded = False
if self.enable_cached_results:
# get problem key
cached_test_key = cutlass_bindings.test.conv.host.CreateCachedConv2dTestKey(
self.conv_kind,
problem_size,
alpha,
beta,
getTensorView(
tensor_A, self.layout_A, self.conv_kind, problem_size, "a"
),
getTensorView(
tensor_B, self.layout_B, self.conv_kind, problem_size, "b"
),
getTensorView(
tensor_C, self.layout_C, self.conv_kind, problem_size, "c"
),
)
cached_test_result = cutlass_bindings.test.conv.host.CachedTestResult()
conv2d_result_cache_name = "cached_results_SM%d_%d.txt" % (
self.operation.arch,
self.seed,
)
cached_results = cutlass_bindings.test.conv.host.CachedTestResultListing(
conv2d_result_cache_name
)
# CachedTestResultListing cached_results(conv2d_result_cache_name);
cached = cached_results.find(cached_test_key)
cached_result_loaded = cached[0]
if cached_result_loaded:
cached_test_result = cached[1]
if not cached_result_loaded:
# compute the conv2d on host
tensor_D_ref = np.ones_like(tensor_C)
tensor_ref_A = getTensorRef(
tensor_A, self.layout_A, self.conv_kind, problem_size, "a"
)
tensor_ref_B = getTensorRef(
tensor_B, self.layout_B, self.conv_kind, problem_size, "b"
)
tensor_ref_C = getTensorRef(
tensor_C, self.layout_C, self.conv_kind, problem_size, "c"
)
tensor_ref_D_ref = getTensorRef(
tensor_D_ref, self.layout_D, self.conv_kind, problem_size, "d"
)
self.host_conv2d(
self.conv_kind,
problem_size,
tensor_ref_A,
tensor_ref_B,
tensor_ref_C,
tensor_ref_D_ref,
alpha,
beta,
)
tensor_view_D_ref = getTensorView(
tensor_D_ref, self.layout_D, self.conv_kind, problem_size, "d"
)
if self.enable_cached_results:
cached_test_result.D = cutlass_bindings.test.conv.host.TensorHash(
tensor_view_D_ref
)
cached_results = (
cutlass_bindings.test.conv.host.CachedTestResultListing(
conv2d_result_cache_name
)
)
cached_results.append(cached_test_key, cached_test_result)
cached_results.write(conv2d_result_cache_name)
else:
return tensor_D_ref
return cached_test_result.D
def equal(self, tensor_D, tensor_D_ref, problem_size):
if self.enable_cached_results:
tensor_view_D = getTensorView(
tensor_D, self.layout_D, self.conv_kind, problem_size, "d"
)
tensor_D_hash = cutlass_bindings.test.conv.host.TensorHash(tensor_view_D)
return tensor_D_hash == tensor_D_ref
else:
tensor_view_D = getTensorView(
tensor_D, self.layout_D, self.conv_kind, problem_size, "d"
)
tensor_view_D_ref = getTensorView(
tensor_D_ref, self.layout_D, self.conv_kind, problem_size, "d"
)
return cutlass_bindings.test.conv.host.equals(
tensor_view_D, tensor_view_D_ref
)
def run_cutlass_profiler(
self,
problem_size,
split_k_mode=cutlass_bindings.conv.SplitKMode.Serial,
alpha=1.0,
beta=0.0,
):
if split_k_mode == cutlass_bindings.conv.SplitKMode.Serial:
split_k_mode_ = "serial"
else:
split_k_mode_ = "parallel"
cutlass_path = os.getenv("CUTLASS_PATH")
assert (
cutlass_path is not None
), "Environment variable 'CUTLASS_PATH' is not defined."
values = {
"profiler_path": cutlass_path + "/build/tools/profiler/cutlass_profiler",
"kernel_name": self.operation.procedural_name(),
"verification_providers": "device",
"provider": "cutlass",
"n": str(problem_size.N),
"h": str(problem_size.H),
"w": str(problem_size.W),
"c": str(problem_size.C),
"k": str(problem_size.K),
"r": str(problem_size.R),
"s": str(problem_size.S),
"p": str(problem_size.P),
"q": str(problem_size.Q),
"pad_h": str(problem_size.pad_h),
"pad_w": str(problem_size.pad_w),
"stride_h": str(problem_size.stride_h),
"stride_w": str(problem_size.stride_w),
"dilation_h": str(problem_size.dilation_h),
"dilation_w": str(problem_size.dilation_w),
"split_k_slices": str(problem_size.split_k_slices),
"split_k_mode": split_k_mode_,
"alpha": str(alpha),
"beta": str(beta),
"warmup": str(self.warmup_iterations),
"profile": str(self.iterations),
}
cmd_template = (
"${profiler_path} --kernels=${kernel_name} --verification-providers=${verification_providers}"
" --providers=${provider} --n=${n} --h=${h} --w=${w} --c=${c} --k=${k} --r=${r} --s=${s} --p=${p}"
" --q=${q} --pad_h=${pad_h} --pad_w=${pad_w} --stride_h={stride_h} --stride_w=${stride_w}"
" --dilation_h=${dilation_h} --dilation_w=${dilation_w} --warmup-iterations=${warmup} --profiling-iterations=${profile}"
" --split_k_slices=${split_k_slices} --alpha=${alpha} --beta=${beta} --split_k_mode=${split_k_mode}"
)
cmd = SubstituteTemplate(cmd_template, values)
result = subprocess.getoutput(cmd)
m = re.search(r"Runtime:\s+(?P<runtime>\d+.\d+)", result)
runtime = float(m.group("runtime"))
m = re.search(r"Bytes:\s+(?P<bytes>\d+)", result)
bytes = int(m.group("bytes"))
m = re.search(r"FLOPs:\s+(?P<flops>\d+)", result)
flops = int(m.group("flops"))
# check if the problem size matches
assert bytes == self.bytes(problem_size, alpha, beta)
assert flops == self.flops(problem_size)
return runtime
def run(
self,
problem_size,
split_k_mode=cutlass_bindings.conv.SplitKMode.Serial,
alpha=1.0,
beta=0.0,
):
assert get_allocated_size() == 0, (
"%d byte of pool memory is not released in previous run"
% get_allocated_size()
)
#
# Initialize input and output tensors
#
tensor_A_size = cutlass_bindings.conv.implicit_gemm_tensor_a_size(
self.conv_kind, problem_size
)
tensor_B_size = cutlass_bindings.conv.implicit_gemm_tensor_b_size(
self.conv_kind, problem_size
)
tensor_C_size = cutlass_bindings.conv.implicit_gemm_tensor_c_size(
self.conv_kind, problem_size
)
np.random.seed(self.seed)
tensor_A = self.uniform_init(size=(tensor_A_size,), dtype=self.dtype_A)
tensor_B = self.uniform_init(size=(tensor_B_size,), dtype=self.dtype_B)
tensor_C = self.uniform_init(size=(tensor_C_size,), dtype=self.dtype_C)
tensor_D = np.zeros(shape=(tensor_C_size,), dtype=self.dtype_D)
#
# Launch kernel
#
arguments = Conv2dArguments(
operation=self.operation,
problem_size=problem_size,
A=tensor_A,
B=tensor_B,
C=tensor_C,
D=tensor_D,
output_op=self.operation.epilogue_type(alpha, beta),
split_k_slices=problem_size.split_k_slices,
split_k_mode=split_k_mode,
)
if split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel:
implicit_gemm_size = cutlass_bindings.conv.implicit_gemm_problem_size(
self.operation.conv_kind, arguments.problem_size
)
reduction_arguments = ReductionArguments(
self.reduction_operation,
problem_size=[implicit_gemm_size.m(), implicit_gemm_size.n()],
partitions=problem_size.split_k_slices,
workspace=arguments.ptr_D,
destination=tensor_D,
source=tensor_C,
output_op=self.reduction_operation.epilogue_type(alpha, beta),
)
self.operation.run(arguments)
if split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel:
self.reduction_operation.run(reduction_arguments)
passed = True
if self.verification:
if split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel:
reduction_arguments.sync()
else:
arguments.sync()
tensor_D_ref = self.host_reference(
problem_size, tensor_A, tensor_B, tensor_C, alpha, beta
)
passed = self.equal(tensor_D, tensor_D_ref, problem_size)
try:
assert passed
except AssertionError:
self.print_problem_size(problem_size, split_k_mode)
if self.profiling:
sleep(self.sleep_time)
for _ in range(self.warmup_iterations):
self.operation.run(arguments)
if split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel:
self.reduction_operation.run(reduction_arguments)
self.timer.start()
for _ in range(self.warmup_iterations):
self.operation.run(arguments)
if split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel:
self.reduction_operation.run(reduction_arguments)
self.timer.stop_and_wait()
runtime = self.timer.duration(self.iterations)
# free memory
del arguments
if split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel:
del reduction_arguments
assert get_allocated_size() == 0, (
"%d byte of pool memory is not released after current run"
% get_allocated_size()
)
if self.profiling:
return runtime
return passed
########################################################################################################
# TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference
# TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes
# Additionaly, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes
# (conv_blacklist_sizes)
############################################################################################################
def test_all_conv2d_from_compilation_mode(
operation: Conv2dOperation,
conv_test_sizes,
interleaved,
compilation_mode):
passed = True
testbed = Conv2dLauncher(operation, interleaved=interleaved, compilation_mode=compilation_mode)
#
# Get conv problem sizes to run conv operator
#
conv_problems = cutlass_bindings.test.conv.TestbedConv2dProblemSizes(64)
# Vector of conv2d problem sizes to avoid duplicate runs
conv_tested_sizes = []
# Flatten 2D problem_vectors into a 1D problem sizes
problem_sizes = conv_problems.conv2d_default_sizes
problem_sizes = [conv_problem for conv_problem in problem_sizes] + conv_test_sizes
# Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slices=1, alpha=1.0, beta=0.0)
for conv_problem in problem_sizes:
if conv_problem in conv_tested_sizes:
continue
# skip channel dimension % 32 != 0 for interleaved case
if interleaved:
if conv_problem.K % 32 != 0 or conv_problem.C % 32 != 0:
continue
#
# Procedurally disable certain cases
#
# CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1}
if (
operation.conv_kind == cutlass_bindings.conv.Operator.dgrad
and operation.stride_support == StrideSupport.Unity
):
if not ((conv_problem.stride_h == 1) and (conv_problem.stride_w == 1)):
continue
if not interleaved:
# Fixed channels algorithm requires channel count to match access size
if (
operation.iterator_algorithm
== cutlass_bindings.conv.IteratorAlgorithm.fixed_channels
):
if conv_problem.C != operation.A.alignment:
continue
# Few channels algorithm requires channel count to match access size
if (
operation.iterator_algorithm
== cutlass_bindings.conv.IteratorAlgorithm.few_channels
):
if conv_problem.C % operation.A.alignment:
continue
# CUTLASS DGRAD's *strided* stride specialization supports all stride {stride_h, stride_w}
# Although strided dgrad works for all stride combinations, we are only going
# to run strided dgrad for non-unity strides
if (
operation.conv_kind == cutlass_bindings.conv.Operator.dgrad
and operation.stride_support == StrideSupport.Strided
):
if (conv_problem.stride_h == 1) and (conv_problem.stride_w == 1):
continue
#
# Test
#
# push back tested problem size to avoid re-running duplicates
conv_tested_sizes.append(conv_problem)
passed = testbed.run(conv_problem)
if not passed:
return False
if interleaved:
return True
#
# filter the cases for split K
#
# Small-channels convolution can't run here.
if operation.iterator_algorithm in [
cutlass_bindings.conv.IteratorAlgorithm.fixed_channels,
cutlass_bindings.conv.IteratorAlgorithm.few_channels,
]:
return True
# CUTLASS DGRAD's *stride* specialization does not support split-k mode
if (
operation.conv_kind == cutlass_bindings.conv.Operator.dgrad
and operation.stride_support == StrideSupport.Strided
):
conv_problem = cutlass_bindings.conv.Conv2dProblemSize(
cutlass_bindings.Tensor4DCoord(1, 56, 56, 8),
cutlass_bindings.Tensor4DCoord(8, 1, 1, 8),
cutlass_bindings.Tensor4DCoord(0, 0, 0, 0),
cutlass_bindings.MatrixCoord(2, 2),
cutlass_bindings.MatrixCoord(1, 1),
cutlass_bindings.conv.Mode.cross_correlation,
1,
1,
)
passed = testbed.run(conv_problem)
return passed
# Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for
# a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters
# which are abolutely neccessary to catch functional bugs. The below code does provide option to sweep
# alpha and beta for local testing, but only runs one value for alpha and beta.
conv2d_split_k_test_size = cutlass_bindings.conv.Conv2dProblemSize(
cutlass_bindings.Tensor4DCoord(1, 17, 11, 288),
cutlass_bindings.Tensor4DCoord(160, 3, 3, 288),
cutlass_bindings.Tensor4DCoord(1, 1, 1, 1),
cutlass_bindings.MatrixCoord(1, 1),
cutlass_bindings.MatrixCoord(1, 1),
cutlass_bindings.conv.Mode.cross_correlation,
1,
1,
)
split_k_modes = [
cutlass_bindings.conv.SplitKMode.Parallel,
cutlass_bindings.conv.SplitKMode.Serial,
]
split_k_slices = [1, 2, 3, 4, 201]
problem_alpha = [
2.0,
]
problem_beta = [
2.0,
]
for split_k_mode in split_k_modes:
for split_k_slice in split_k_slices:
for alpha in problem_alpha:
for beta in problem_beta:
passed = testbed.run(
conv2d_split_k_test_size.reset_split_k_slices(split_k_slice),
split_k_mode,
alpha,
beta,
)
return passed
def test_all_conv2d(
operation: Conv2dOperation,
conv_test_sizes=[],
interleaved=False,
compilation_modes=["nvcc", "nvrtc"]):
for compilation_mode in compilation_modes:
passed = test_all_conv2d_from_compilation_mode(operation, conv_test_sizes, interleaved, compilation_mode)
if not passed:
return False
return True

View File

@ -1,276 +0,0 @@
#################################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from bfloat16 import bfloat16
import cutlass_bindings
import numpy as np
from cutlass.backend import compiler
from cutlass.backend.gemm_operation import GemmGroupedArguments, GemmOperationGrouped
from cutlass.backend.library import DataTypeSize, ShortDataTypeNames
from cutlass.backend.memory_manager import get_allocated_size
from cutlass.backend.test.gemm_testbed import getTensorRef, getTensorView, transpose
class TestbedGrouped:
def __init__(self, operation: GemmOperationGrouped, seed: int = 2080) -> None:
compiler.add_module([operation])
self.seed = seed
self.operation = operation
element_size = DataTypeSize[operation.A.element]
self.dtype_A = self.numpy_type(operation.A.element)
self.dtype_B = self.numpy_type(operation.B.element)
self.dtype_C = self.numpy_type(operation.C.element)
self.dtype_D = self.numpy_type(operation.C.element)
if element_size == 1:
self.scope_max = 1
self.scope_min = 0
elif element_size <= 8:
self.scope_max = 1
self.scope_min = -1
elif element_size == 16:
self.scope_max = 4
self.scope_min = -4
else:
self.scope_max = 8
self.scope_min = -8
#: compute type
self.compute_type = operation.epilogue_functor.element_epilogue
self.accumulator_type = (
operation.tile_description.math_instruction.element_accumulator
)
@staticmethod
def numpy_type(type):
if type == cutlass_bindings.float64:
return np.float64
elif type == cutlass_bindings.float32:
return np.float32
elif type == cutlass_bindings.float16:
return np.float16
elif type == cutlass_bindings.bfloat16:
return bfloat16
elif type == cutlass_bindings.int32:
return np.int32
elif type == cutlass_bindings.int8:
return np.int8
else:
raise ValueError("unsupported type: %s" % ShortDataTypeNames[type])
def uniform_init(self, size, dtype):
if dtype in [np.float32, np.float16, bfloat16, np.float64]:
return np.ceil(
np.random.uniform(
low=self.scope_min - 0.5, high=self.scope_max - 0.5, size=size
).astype(dtype)
)
else:
return np.random.uniform(
low=self.scope_min - 1, high=self.scope_max + 1, size=size
).astype(dtype)
def print_problem_size(self, p):
problem_size = "problem: %d, %d, %d\n" % (p.m(), p.n(), p.k())
print(problem_size)
def run(self, problem_count: int, alpha: float = 1.0, beta: float = 0.0) -> bool:
assert get_allocated_size() == 0, (
"%d byte of pool memory is not released in previous run"
% get_allocated_size()
)
# initialize
passed = False
np.random.seed(self.seed)
# generate the problem sizes
problem_sizes = []
tensor_As = []
tensor_Bs = []
tensor_Cs = []
tensor_Ds = []
tensor_D_refs = []
for i in range(problem_count):
if self.dtype_A == np.int8:
if i == 0:
problem_size = cutlass_bindings.gemm.GemmCoord(48, 16, 32)
else:
problem_size = cutlass_bindings.gemm.GemmCoord(
16 * np.random.randint(0, 64) + 48,
16 * np.random.randint(0, 64) + 48,
16 * np.random.randint(0, 64) + 48,
)
else:
if i == 0:
problem_size = cutlass_bindings.gemm.GemmCoord(48, 16, 8)
else:
problem_size = cutlass_bindings.gemm.GemmCoord(
8 * np.random.randint(0, 64) + 24,
8 * np.random.randint(0, 64) + 24,
8 * np.random.randint(0, 64) + 24,
)
tensor_As.append(
self.uniform_init(
size=(problem_size.m() * problem_size.k(),), dtype=self.dtype_A
)
)
tensor_Bs.append(
self.uniform_init(
size=(problem_size.n() * problem_size.k(),), dtype=self.dtype_B
)
)
tensor_Cs.append(
self.uniform_init(
size=(problem_size.m() * problem_size.n(),), dtype=self.dtype_C
)
)
tensor_Ds.append(
np.zeros(
shape=(problem_size.m() * problem_size.n(),), dtype=self.dtype_D
)
)
tensor_D_refs.append(
np.ones(
shape=(problem_size.m() * problem_size.n(),), dtype=self.dtype_D
)
)
problem_sizes.append(problem_size)
arguments = GemmGroupedArguments(
operation=self.operation,
problem_sizes=problem_sizes,
A=tensor_As,
B=tensor_Bs,
C=tensor_Cs,
D=tensor_Ds,
output_op=self.operation.epilogue_type(alpha, beta),
)
self.operation.run(arguments)
arguments.sync()
#
# Reference check
#
alpha = self.compute_type(alpha).value()
beta = self.compute_type(beta).value()
init_acc = self.accumulator_type(0).value()
for idx, problem_size in enumerate(problem_sizes):
if self.operation.switched:
tensor_ref_A = getTensorRef(
tensor_As[idx],
problem_size,
"a",
transpose(self.operation.B.layout),
)
tensor_ref_B = getTensorRef(
tensor_Bs[idx],
problem_size,
"b",
transpose(self.operation.A.layout),
)
tensor_ref_C = getTensorRef(
tensor_Cs[idx],
problem_size,
"c",
transpose(self.operation.C.layout),
)
tensor_ref_D_ref = getTensorRef(
tensor_D_refs[idx],
problem_size,
"d",
transpose(self.operation.C.layout),
)
else:
tensor_ref_A = getTensorRef(
tensor_As[idx], problem_size, "a", self.operation.A.layout
)
tensor_ref_B = getTensorRef(
tensor_Bs[idx], problem_size, "b", self.operation.B.layout
)
tensor_ref_C = getTensorRef(
tensor_Cs[idx], problem_size, "c", self.operation.C.layout
)
tensor_ref_D_ref = getTensorRef(
tensor_D_refs[idx], problem_size, "d", self.operation.C.layout
)
tensor_view_D_ref = getTensorView(
tensor_D_refs[idx], problem_size, "d", self.operation.C.layout
)
cutlass_bindings.test.gemm.host.gemm(
problem_size,
alpha,
tensor_ref_A,
tensor_ref_B,
beta,
tensor_ref_C,
tensor_ref_D_ref,
init_acc,
)
tensor_view_D = getTensorView(
tensor_Ds[idx], problem_size, "d", self.operation.C.layout
)
passed = cutlass_bindings.test.gemm.host.equals(
tensor_view_D, tensor_view_D_ref
)
try:
assert passed
except AssertionError:
self.print_problem_size(problem_size)
del arguments
assert get_allocated_size() == 0, (
"%d byte of pool memory is not released after current run"
% get_allocated_size()
)
return passed

View File

@ -1,765 +0,0 @@
#################################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import os
import re
import subprocess
from time import sleep
from bfloat16 import bfloat16
from cuda import cuda, cudart
import cutlass_bindings
import numpy as np
from cutlass.backend import compiler
from cutlass.backend.gemm_operation import GemmArguments, GemmOperationUniversal
from cutlass.backend.library import (
DataTypeSize,
DataTypeSizeBytes,
MathOperation,
ShortDataTypeNames,
)
from cutlass.backend.memory_manager import get_allocated_size
from cutlass.backend.reduction_operation import ReductionArguments, ReductionOperation
from cutlass.backend.test.profiler import GpuTimer
from cutlass.backend.utils.datatypes import to_cutlass
from cutlass.backend.utils.software import SubstituteTemplate
def transpose(layout):
if layout == cutlass_bindings.RowMajor:
return cutlass_bindings.ColumnMajor
elif layout == cutlass_bindings.ColumnMajor:
return cutlass_bindings.RowMajor
elif layout == cutlass_bindings.ColumnMajorInterleaved32:
return cutlass_bindings.RowMajorInterleaved32
elif layout == cutlass_bindings.RowMajorInterleaved32:
return cutlass_bindings.ColumnMajorInterleaved32
def getTensorRef(
tensor: np.ndarray,
problem_size: cutlass_bindings.gemm.GemmCoord,
operand: str,
layout: cutlass_bindings.layout,
batch_offset: int = 0,
):
ptr = tensor.__array_interface__["data"][0]
if operand == "a":
tensor_coord = problem_size.mk()
batch_stride = problem_size.m() * problem_size.k()
elif operand == "b":
tensor_coord = problem_size.kn()
batch_stride = problem_size.k() * problem_size.n()
elif operand in ["c", "d"]:
tensor_coord = problem_size.mn()
batch_stride = problem_size.m() * problem_size.n()
else:
raise ValueError("Unknown operand: " + operand)
elt_size = DataTypeSizeBytes[to_cutlass(tensor.dtype)]
ptr += batch_offset * batch_stride * elt_size
if layout == cutlass_bindings.RowMajor:
layout = cutlass_bindings.RowMajor.packed(tensor_coord)
layout_tag = "RowMajor"
elif layout == cutlass_bindings.ColumnMajor:
layout = cutlass_bindings.ColumnMajor.packed(tensor_coord)
layout_tag = "ColumnMajor"
elif layout == cutlass_bindings.ColumnMajorInterleaved32:
layout = cutlass_bindings.ColumnMajorInterleaved32.packed(tensor_coord)
layout_tag = "ColumnMajorInterleaved32"
elif layout == cutlass_bindings.RowMajorInterleaved32:
layout = cutlass_bindings.RowMajorInterleaved32.packed(tensor_coord)
layout_tag = "RowMajorInterleaved32"
else:
raise ValueError("unsupported layout")
if tensor.dtype == np.float32:
ref_name = "TensorRefF32" + layout_tag
elif tensor.dtype == np.float64:
ref_name = "TensorRefF64" + layout_tag
elif tensor.dtype == np.float16:
ref_name = "TensorRefF16" + layout_tag
elif tensor.dtype == bfloat16:
ref_name = "TensorRefBF16" + layout_tag
elif tensor.dtype == np.int8:
ref_name = "TensorRefS8" + layout_tag
elif tensor.dtype == np.int32:
ref_name = "TensorRefS32" + layout_tag
else:
raise ValueError("unsupported datatype %s" % ShortDataTypeNames[tensor.dtype])
return getattr(cutlass_bindings, ref_name)(ptr, layout)
def getTensorView(
tensor: np.ndarray,
problem_size: cutlass_bindings.gemm.GemmCoord,
operand: str,
layout: str,
batch_offset: int = 0,
):
tensor_ref = getTensorRef(tensor, problem_size, operand, layout, batch_offset)
if operand == "a":
tensor_coord = problem_size.mk()
elif operand == "b":
tensor_coord = problem_size.kn()
elif operand in ["c", "d"]:
tensor_coord = problem_size.mn()
else:
raise ValueError("Unknown operand: " + operand)
if layout == cutlass_bindings.RowMajor:
layout_tag = "RowMajor"
elif layout == cutlass_bindings.ColumnMajor:
layout_tag = "ColumnMajor"
elif layout == cutlass_bindings.ColumnMajorInterleaved32:
layout_tag = "ColumnMajorInterleaved32"
elif layout == cutlass_bindings.RowMajorInterleaved32:
layout_tag = "RowMajorInterleaved32"
else:
raise ValueError("unsupported layout")
if tensor.dtype == np.float32:
ref_name = "TensorViewF32" + layout_tag
elif tensor.dtype == np.float64:
ref_name = "TensorViewF64" + layout_tag
elif tensor.dtype == np.float16:
ref_name = "TensorViewF16" + layout_tag
elif tensor.dtype == bfloat16:
ref_name = "TensorViewBF16" + layout_tag
elif tensor.dtype == np.int32:
ref_name = "TensorViewS32" + layout_tag
elif tensor.dtype == np.int8:
ref_name = "TensorViewS8" + layout_tag
else:
raise ValueError("unsupported datatype")
return getattr(cutlass_bindings, ref_name)(tensor_ref, tensor_coord)
class GemmUniversalLauncher:
def __init__(
self,
operation: "GemmOperationUniversal",
seed: int = 2080,
interleaved=False,
verification=True,
profiling=False,
warmup_iterations=500,
iterations=500,
compiler_mode: str = "nvcc",
**kwargs,
) -> None:
# create the reduction kernel
self.reduction_operation: ReductionOperation = ReductionOperation(
shape=cutlass_bindings.MatrixCoord(4, 32 * operation.C.alignment),
C=operation.C,
element_accumulator=operation.tile_description.math_instruction.element_accumulator,
element_compute=operation.epilogue_functor.element_epilogue,
epilogue_functor=operation.epilogue_functor,
count=operation.C.alignment,
)
self.math_operation = operation.tile_description.math_instruction.math_operation
#: verify the output result
self.verification = verification
#: profile the kernel's runtime
self.profiling = profiling
self.timer = GpuTimer()
self.warmup_iterations = warmup_iterations
self.iterations = iterations
if "sleep" in kwargs.keys():
self.sleep_time = kwargs["sleep"]
else:
self.sleep_time = 0
#
# Compile the operator
#
if compiler_mode == "nvcc":
compiler.nvcc()
elif compiler_mode == "nvrtc":
compiler.nvrtc()
else:
raise Exception(f"Unexpected compiler string {compiler_mode}")
op_list = [operation]
if operation.arch < 90:
# Split K via Python is currently only supported for pre-SM90 kernels
op_list.append(self.reduction_operation)
compiler.add_module(op_list, bypass_cache=True)
self.operation = operation
self.dtype_A = GemmUniversalLauncher.numpy_type(operation.A.element)
self.dtype_B = GemmUniversalLauncher.numpy_type(operation.B.element)
self.dtype_C = GemmUniversalLauncher.numpy_type(operation.C.element)
self.dtype_D = GemmUniversalLauncher.numpy_type(operation.C.element)
accumulator_size = DataTypeSize[
operation.tile_description.math_instruction.element_accumulator
]
element_size = DataTypeSize[operation.A.element]
if element_size == 1:
self.scope_max = 1
self.scope_min = 0
elif element_size <= 8:
self.scope_max = 1
self.scope_min = -1
elif element_size == 16:
self.scope_max = 4
self.scope_min = -4
else:
self.scope_max = 8
self.scope_min = -8
#: seed
self.seed: int = seed
#: whether the layout is interleaved
self.interleaved = interleaved
#: compute type
self.compute_type = operation.epilogue_functor.element_epilogue
self.accumulator_type = (
operation.tile_description.math_instruction.element_accumulator
)
def print_problem_size(self, p, mode, batch_count):
if mode == cutlass_bindings.gemm.Mode.Gemm:
mode = "Gemm"
elif mode == cutlass_bindings.gemm.Mode.Batched:
mode = "GemmBatched"
elif mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel:
mode = "GemmSplitKParallel"
problem_size = "problem: %d, %d, %d\n batch_count: %d\n mode: %s" % (
p.m(),
p.n(),
p.k(),
batch_count,
mode,
)
print(problem_size)
@staticmethod
def numpy_type(type):
if type == cutlass_bindings.float64:
return np.float64
elif type == cutlass_bindings.float32:
return np.float32
elif type == cutlass_bindings.float16:
return np.float16
elif type == cutlass_bindings.bfloat16:
return bfloat16
elif type == cutlass_bindings.int32:
return np.int32
elif type == cutlass_bindings.int8:
return np.int8
else:
raise ValueError("unsupported type: %s" % ShortDataTypeNames[type])
def uniform_init(self, size, dtype):
if dtype in [np.float32, np.float16, bfloat16, np.float64]:
return np.ceil(
np.random.uniform(
low=self.scope_min - 0.5, high=self.scope_max - 0.5, size=size
).astype(dtype)
)
else:
return np.random.uniform(
low=self.scope_min - 1, high=self.scope_max + 1, size=size
).astype(dtype)
def reorder_tensor_B(self, tensor_B, problem_size):
reordered_tensor_B = np.empty_like(tensor_B)
tensor_ref_B = getTensorRef(
tensor_B, problem_size, "b", self.operation.B.layout
)
reordered_tensor_ref_B = getTensorRef(
reordered_tensor_B, problem_size, "b", self.operation.B.layout
)
cutlass_bindings.gemm.host.reorder_column(
tensor_ref_B, reordered_tensor_ref_B, problem_size
)
return reordered_tensor_B
def host_reference(self, problem_size, batch_count, tensor_A, tensor_B, tensor_C, alpha, beta):
tensor_D_ref = np.ones_like(tensor_C)
alpha = self.numpy_type(self.compute_type)(alpha)
beta = self.numpy_type(self.compute_type)(beta)
init_acc = 0
alpha = self.compute_type(alpha).value()
beta = self.compute_type(beta).value()
init_acc = self.accumulator_type(init_acc).value()
for i in range(batch_count):
if self.operation.switched:
tensor_ref_A = getTensorRef(
tensor_A,
problem_size,
"a",
transpose(self.operation.B.layout),
batch_offset=i,
)
tensor_ref_B = getTensorRef(
tensor_B,
problem_size,
"b",
transpose(self.operation.A.layout),
batch_offset=i,
)
tensor_ref_C = getTensorRef(
tensor_C,
problem_size,
"c",
transpose(self.operation.C.layout),
batch_offset=i,
)
tensor_ref_D_ref = getTensorRef(
tensor_D_ref,
problem_size,
"d",
transpose(self.operation.C.layout),
batch_offset=i,
)
else:
tensor_ref_A = getTensorRef(
tensor_A, problem_size, "a", self.operation.A.layout, batch_offset=i
)
tensor_ref_B = getTensorRef(
tensor_B, problem_size, "b", self.operation.B.layout, batch_offset=i
)
tensor_ref_C = getTensorRef(
tensor_C, problem_size, "c", self.operation.C.layout, batch_offset=i
)
tensor_ref_D_ref = getTensorRef(
tensor_D_ref,
problem_size,
"d",
self.operation.C.layout,
batch_offset=i,
)
if self.math_operation in [MathOperation.multiply_add_saturate]:
cutlass_bindings.test.gemm.host.gemm_saturate(
problem_size,
alpha,
tensor_ref_A,
tensor_ref_B,
beta,
tensor_ref_C,
tensor_ref_D_ref,
init_acc,
)
else:
cutlass_bindings.test.gemm.host.gemm(
problem_size,
alpha,
tensor_ref_A,
tensor_ref_B,
beta,
tensor_ref_C,
tensor_ref_D_ref,
init_acc,
)
return tensor_D_ref
def equal(self, tensor_D, tensor_D_ref, problem_size, batch_count):
for i in range(batch_count):
tensor_view_D = getTensorView(
tensor_D, problem_size, "d", self.operation.C.layout, batch_offset=i
)
tensor_view_D_ref = getTensorView(
tensor_D_ref, problem_size, "d", self.operation.C.layout, batch_offset=i
)
if not cutlass_bindings.test.gemm.host.equals(
tensor_view_D, tensor_view_D_ref
):
return False
return True
def bytes(self, problem_size, batch_count=1, alpha=1.0, beta=0.0):
m = problem_size.m()
n = problem_size.n()
k = problem_size.k()
bytes = (
(DataTypeSize[self.operation.A.element] * m // 8) * k
+ (DataTypeSize[self.operation.B.element] * n // 8) * k
+ (DataTypeSize[self.operation.C.element] * m // 8) * n
)
if beta != 0:
bytes += (DataTypeSize[self.operation.C.element] * m // 8) * n
bytes *= batch_count
return bytes
def flops(self, problem_size, batch_count=1):
m = problem_size.m()
n = problem_size.n()
k = problem_size.k()
flops_ = (m * n * k) * 2 * batch_count
return flops_
def run_cutlass_profiler(
self, mode, problem_size, batch_count=1, alpha=1.0, beta=0.0
):
cutlass_path = os.getenv("CUTLASS_PATH")
assert (
cutlass_path is not None
), "Environment variable 'CUTLASS_PATH' is not defined."
values = {
"profiler_path": cutlass_path + "/build/tools/profiler/cutlass_profiler",
"kernel_name": self.operation.procedural_name(),
"verification_providers": "device",
"provider": "cutlass",
"m": str(problem_size.m()),
"n": str(problem_size.n()),
"k": str(problem_size.k()),
"split_k_slices": str(batch_count),
"alpha": str(alpha),
"beta": str(beta),
"warmup": str(self.warmup_iterations),
"profile": str(self.iterations),
}
cmd_template = (
"${profiler_path} --kernels=${kernel_name} --verification-providers=${verification_providers}"
" --providers=${provider} --m=${m} --n=${n} --k=${k}"
)
cmd = SubstituteTemplate(cmd_template, values)
result = subprocess.getoutput(cmd)
m = re.search(r"Runtime:\s+(?P<runtime>\d+.\d+)", result)
runtime = float(m.group("runtime"))
m = re.search(r"Bytes:\s+(?P<bytes>\d+)", result)
bytes = int(m.group("bytes"))
m = re.search(r"FLOPs:\s+(?P<flops>\d+)", result)
flops = int(m.group("flops"))
# check if the problem size matches
assert bytes == self.bytes(problem_size, alpha, beta)
assert flops == self.flops(problem_size)
return runtime
def run(self, mode, problem_size, batch_count=1, split_k_slices=1, alpha=1.0, beta=0.0):
assert get_allocated_size() == 0, (
"%d byte of pool memory is not released in previous run"
% get_allocated_size()
)
np.random.seed(self.seed)
# Assign an actual batch count in cases where we are not running in batched mode.
# This is to differentiate between the number of split K slices and the batch count,
# which are overloaded within the single `batch_count` variable.
true_batch_count = (
batch_count if mode == cutlass_bindings.gemm.Mode.Batched else 1
)
tensor_A = self.uniform_init(
size=(problem_size.m() * problem_size.k() * true_batch_count,),
dtype=self.dtype_A,
)
tensor_B = self.uniform_init(
size=(problem_size.n() * problem_size.k() * true_batch_count,),
dtype=self.dtype_B,
)
tensor_C = self.uniform_init(
size=(problem_size.m() * problem_size.n() * true_batch_count,),
dtype=self.dtype_C,
)
tensor_D = np.zeros(
shape=(problem_size.m() * problem_size.n() * true_batch_count,),
dtype=self.dtype_D,
)
#
# Launch kernel
#
arguments = GemmArguments(
operation=self.operation,
problem_size=problem_size,
A=tensor_A,
B=tensor_B,
C=tensor_C,
D=tensor_D,
output_op=self.operation.epilogue_type(alpha, beta),
gemm_mode=mode,
split_k_slices=split_k_slices,
batch=batch_count,
)
if mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel:
reduction_arguments = ReductionArguments(
self.reduction_operation,
problem_size=[problem_size.m(), problem_size.n()],
partitions=split_k_slices,
workspace=arguments.ptr_D,
destination=tensor_D,
source=tensor_C,
output_op=self.reduction_operation.epilogue_type(alpha, beta),
)
self.operation.run(arguments)
if mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel:
self.reduction_operation.run(reduction_arguments)
passed = True
if self.verification:
if mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel:
reduction_arguments.sync()
else:
arguments.sync()
tensor_D_ref = self.host_reference(
problem_size,
true_batch_count,
tensor_A,
tensor_B,
tensor_C,
alpha,
beta,
)
passed = self.equal(tensor_D, tensor_D_ref, problem_size, true_batch_count)
try:
assert passed
except AssertionError:
self.print_problem_size(problem_size, mode, batch_count)
if self.profiling:
sleep(self.sleep_time)
for _ in range(self.warmup_iterations):
self.operation.run(arguments)
if mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel:
self.reduction_operation.run(reduction_arguments)
self.timer.start()
for _ in range(self.iterations):
self.operation.run(arguments)
if mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel:
self.reduction_operation.run(reduction_arguments)
self.timer.stop_and_wait()
runtime = self.timer.duration(self.iterations)
# free memory and clear buffers
del arguments
if mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel:
del reduction_arguments
assert get_allocated_size() == 0, (
"%d byte of pool memory is not released after current run"
% get_allocated_size()
)
if self.profiling:
return runtime
return passed
def test_all_gemm(operation: "GemmOperationUniversal", testcase="universal", compilation_mode="nvcc"):
passed = True
minimum_operand_element_size = min(
DataTypeSize[operation.A.element], DataTypeSize[operation.B.element]
)
opcode_class = operation.tile_description.math_instruction.opcode_class
if opcode_class == cutlass_bindings.OpClass.Simt:
alignment = 1
else:
alignment = 128 // minimum_operand_element_size
# int8_t gemm alignment constraints
if opcode_class == cutlass_bindings.OpClass.Simt and operation.A.element == cutlass_bindings.int8 and operation.A.layout == cutlass_bindings.ColumnMajor:
alignment_m = 4
else:
alignment_m = alignment
if (
opcode_class == cutlass_bindings.OpClass.Simt
and operation.B.element == cutlass_bindings.int8
and operation.A.layout == cutlass_bindings.RowMajor
):
alignment_n = 4
else:
alignment_n = alignment
if (
opcode_class == cutlass_bindings.OpClass.Simt
and operation.A.element == cutlass_bindings.int8
and operation.B.element == cutlass_bindings.int8
and (
operation.A.layout == cutlass_bindings.RowMajor
or operation.B.layout == cutlass_bindings.ColumnMajor
)
):
alignment_k = 4
else:
alignment_k = alignment
threadblock_k = operation.tile_description.threadblock_shape[2]
if testcase == "interleaved":
if operation.A.layout in [
cutlass_bindings.ColumnMajorInterleaved32,
cutlass_bindings.RowMajorInterleaved32,
]:
interleavedk = 32
else:
raise ValueError("Unknown layout")
# Split K mode via Python is currently only supported pre-SM90, and when stream K is not used.
# Stream K enables split-k functionality with mode `Gemm` and a non-unit batch count.
supports_split_k = operation.arch < 90 and not isinstance(
operation.swizzling_functor, cutlass_bindings.ThreadblockSwizzleStreamK
)
if testcase == "interleaved":
modes = [
cutlass_bindings.gemm.Mode.Gemm,
]
problem_size_m = [interleavedk, 512 + interleavedk]
problem_size_n = [interleavedk, 512 + interleavedk]
problem_size_k = [
interleavedk,
threadblock_k * operation.tile_description.stages + interleavedk,
]
problem_alpha = [1.0]
problem_beta = [0.0]
batch_counts = [
1,
]
elif testcase == "multistage":
modes = [
cutlass_bindings.gemm.Mode.Gemm,
]
problem_size_m = [16, 528]
problem_size_n = [16, 528]
problem_size_k = [
threadblock_k,
threadblock_k * operation.tile_description.stages
+ operation.tile_description.math_instruction.instruction_shape[2],
]
problem_alpha = [1.0]
problem_beta = [0.0]
batch_counts = [
1,
]
else: # universal
modes = [cutlass_bindings.gemm.Mode.Gemm]
batch_counts = [1, 2, 3, 5, 7]
if supports_split_k:
modes.append(cutlass_bindings.gemm.Mode.GemmSplitKParallel)
problem_size_m = [alignment_m, 512 - 3 * alignment_m]
problem_size_n = [alignment_n, 512 - 2 * alignment_n]
if operation.tile_description.stages is None:
stages_for_k_calc = 7
else:
stages_for_k_calc = operation.tile_description.stages
problem_size_k = [
alignment_k,
threadblock_k * stages_for_k_calc - alignment_k,
threadblock_k * stages_for_k_calc * 3 - alignment_k,
]
problem_alpha = [1.0]
problem_beta = [2.0]
testbed = GemmUniversalLauncher(operation, interleaved=(testcase == "interleaved"), compiler_mode=compilation_mode)
for mode in modes:
for m in problem_size_m:
for n in problem_size_n:
for k in problem_size_k:
for batch_count in batch_counts:
for alpha in problem_alpha:
for beta in problem_beta:
# skip very small K problems
if testcase == "universal":
if k // batch_count < 2 * threadblock_k:
continue
problem_size = cutlass_bindings.gemm.GemmCoord(m, n, k)
if supports_split_k:
split_k_slices = batch_count
else:
split_k_slices = 1
overridden_mode = mode
if (
mode == cutlass_bindings.gemm.Mode.Gemm
and batch_count > 1
):
overridden_mode = cutlass_bindings.gemm.Mode.Batched
passed = testbed.run(
overridden_mode,
problem_size,
batch_count,
split_k_slices,
alpha,
beta,
)
(err,) = cudart.cudaDeviceSynchronize()
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
if not passed:
return False
return passed

View File

@ -1,305 +0,0 @@
#################################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import cutlass
import cutlass_bindings
from cutlass import EpilogueScheduleSuffixes, KernelScheduleSuffixes
from cutlass.utils.datatypes import binding_opclass, binding_type
from cutlass.backend import library
from cutlass.backend.test.gemm_testbed import test_all_gemm
from cutlass.backend.utils.software import SubstituteTemplate
class Layout:
"""
Utility class to map transpose and non-transpose terminology to row- and column-major terminology
"""
T = cutlass_bindings.RowMajor
N = cutlass_bindings.ColumnMajor
class LayoutCombination:
"""
Utility class defining all combinations of row- and column-major layouts for operands to a GEMMs
"""
NNN = (Layout.N, Layout.N, Layout.N)
NNT = (Layout.N, Layout.N, Layout.T)
NTN = (Layout.N, Layout.T, Layout.N)
NTT = (Layout.N, Layout.T, Layout.T)
TNN = (Layout.T, Layout.N, Layout.N)
TNT = (Layout.T, Layout.N, Layout.T)
TTN = (Layout.T, Layout.T, Layout.N)
TTT = (Layout.T, Layout.T, Layout.T)
def get_name(
layouts,
alignments,
element_output,
element_accumulator,
element_epilogue,
cluster_shape,
threadblock_shape,
stages,
element_a,
element_b,
arch,
opclass,
kernel_schedule=None,
epilogue_schedule=None,
suffix="",
):
"""
Generates a procedural name for a test case.
:param layouts: indexable container of layouts of A, B, and C operands
:param alignments: indexable container of alignments of A, B, and C operands
:param element_output: data type of the output element
:param element_accumulator: data type used in accumulation
:param element_epilogue: data type used in computing the epilogue
:param cluster_shape: indexable container of dimensions of threadblock cluster to be launched
:param threadblock_shape: indexable container of dimensions of threadblock tiles
:param stages: number of pipeline stages to use in the kernel
:type stages: int
:param element_a: data type of operand A
:param element_b: data type of operand B
:param arch: compute capability of kernel being generated
:type arch: int
:param opclass: class of operation being performed (e.g., SIMT, Tensor Core)
:type opclass: cutlass_bindings.OpClass
:param kernel_schedule: kernel_schedule type
:type kernel_schedule: cutlass.KernelScheduleType
:param epilogue_schedule: epilogue_schedule type
:type epilogue_schedule: cutlass.EpilogueScheduleType
:param suffix: additional string to add to the suffix of the name
:type suffix: str
:return: str
"""
name_format = "test_SM${arch}_Device_Gemm_${eA}${lA}_${eB}${lB}_${eC}${lC}_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${cM}x${cN}x${cK}_${stages}_align${aA}-${aB}-${aC}${k}${e}${suffix}"
return SubstituteTemplate(
name_format,
{
"arch": str(arch),
"eA": library.DataTypeNames[binding_type(element_a)],
"eB": library.DataTypeNames[binding_type(element_b)],
"eC": library.DataTypeNames[binding_type(element_output)],
"lA": library.ShortLayoutTypeNames[layouts[0]],
"lB": library.ShortLayoutTypeNames[layouts[1]],
"lC": library.ShortLayoutTypeNames[layouts[2]],
"opclass": library.OpcodeClassNames[binding_opclass(opclass)],
"acc": library.DataTypeNames[binding_type(element_accumulator)],
"cM": str(cluster_shape[0]),
"cN": str(cluster_shape[1]),
"cK": str(cluster_shape[2]),
"tbM": str(threadblock_shape[0]),
"tbN": str(threadblock_shape[1]),
"tbK": str(threadblock_shape[2]),
"stages": str(stages) if stages is not None else "auto",
"aA": str(alignments[0]),
"aB": str(alignments[1]),
"aC": str(alignments[2]),
"k": "" if kernel_schedule is None else KernelScheduleSuffixes[kernel_schedule],
"e": "" if epilogue_schedule is None else EpilogueScheduleSuffixes[epilogue_schedule],
"suffix": "" if suffix is None else suffix,
},
)
def get_name_conv2d(
arch,
conv_kind,
element,
element_accumulator,
element_output,
opclass,
threadblock_shape,
warp_count,
instruction_shape,
stages,
iterator_algorithm,
swizzle,
split_k_mode,
split_k_slices,
activation
):
"""
Generates a procedural name for a test case for conv2d
:param arch: compute capability of kernel being generated
:type arch: int
:param conv_kind: the convolution type (i.e. fprop, dgrad, wgrad)
:type conv_kind: str
:param iterator_algorithm: the iterator algorithm applied
:type iterator_algorithm: cutlass_bindings.conv.IteratorAlgorithm
:param element_a: data type of operand A
:param element_b: data type of operand B
:param element_c: data type of operand C
:param element_accumulator: data type used in accumulation
:param opclass: class of operation being performed (e.g., SIMT, Tensor Core)
:type opclass: cutlass_bindings.OpClass
:param threadblock_shape: indexable container of dimensions of threadblock tiles
:param stages: number of pipeline stages to use in the kernel
:type stages: int
:param stride_support: stride support of dgrad
:param alignment: int
:type alignment: int
:return: str
"""
if iterator_algorithm is None:
iterator_algorithm = "AUTO"
if swizzle is None:
swizzle = 1
name_format = "test_SM${arch}_Device_Conv2d_${conv_kind}_${iter_alg}_ImplicitGemm_${eA}nhwc_${eB}nhwc_${eC}nhwc_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${wM}x${wN}x${wK}_${IM}${IN}${IK}_stage${stages}_swizzle${swizzle}_${split_k_mode}${split_k_slices}_${activation}"
return SubstituteTemplate(
name_format,
{
"arch": str(arch),
"conv_kind": conv_kind,
"iter_alg": iterator_algorithm,
"eA": library.DataTypeNames[binding_type(element)],
"eB": library.DataTypeNames[binding_type(element)],
"eC": library.DataTypeNames[binding_type(element_output)],
"opclass": opclass,
"acc": library.DataTypeNames[binding_type(element_accumulator)],
"tbM": str(threadblock_shape[0]),
"tbN": str(threadblock_shape[1]),
"tbK": str(threadblock_shape[2]),
"wM": str(threadblock_shape[0] // warp_count[0]),
"wN": str(threadblock_shape[1] // warp_count[1]),
"wK": str(threadblock_shape[2] // warp_count[2]),
"IM": str(instruction_shape[0]),
"IN": str(instruction_shape[1]),
"IK": str(instruction_shape[2]),
"stages": str(stages),
"swizzle": str(swizzle),
"split_k_mode": split_k_mode,
"split_k_slices": str(split_k_slices),
"activation": activation
}
)
def add_test_gemm(
cls=None,
cc=None,
element=None,
layouts=None,
alignments=None,
element_output=None,
element_accumulator=None,
cluster_shape=None,
threadblock_shape=None,
warp_count=None,
stages=None,
opclass=None,
swizzle=None,
kernel_schedule=None,
epilogue_schedule=None,
compilation_modes=['nvcc', 'nvrtc']):
"""
Create test-running functions with the given specification and set it as a method of ``cls``.
:param cls: class to which the generated method will be added
:type cls: type
:param cc: compute capability to compile for
:type cc: int
:param element: data type of A and B operands
:type element: cutlass.DataType.f16
:param layouts: layouts of A, B, and C operands
:type layouts: list or tuple
:param alignments: alingments of A, B, and C operands
:type alignments: list or tuple
:param element_output: data type of the output element
:type element_output: cutlass.DataType
:param element_accumulator: data type used in accumulation
:type element_accumulator: cutlass.DataType
:param cluster_shape: dimensions of clusters
:type cluster_shape: list or tuple
:param threadblock_shape: dimensions of threadblock tiles
:type threadblock_shape: list or tuple
:param warp_count: warps to be launched per threadblock dimension
:type warp_count: list or tuple
:param stages: number of pipeline stages to use in the kernel
:type stages: int
:param opclass: class of operation being performed (e.g., SIMT, Tensor Core)
:type opclass: cutlass.OpClass
:param swizzle: threadblock swizzling functor
:param kernel_schedule: kernel schedule to use
:type kernel_schedule: cutlass.KernelScheduleType
:param epilogue_schedule: epilogue schedule to use
:type epilogue_schedule: cutlass.EpilogueScheduleType
:param compilation_modes: list of compilers to used in testing the kernel (options: 'nvrtc', 'nvcc')
:type compilation_modes: list
"""
for compilation_mode in compilation_modes:
def run(self):
"""
Dynamically-generated function that constructs a GEMM operation and verifies it against
multiple test cases.
"""
element_A = element
element_B = element
layout_A, layout_B, layout_C = layouts
alignment_A, alignment_B, alignment_C = alignments
plan = cutlass.op.Gemm(element_A=element_A, element_B=element_B,
element_C=element_output, element_D=element_output,
layout_A=layout_A, layout_B=layout_B, layout_C=layout_C,
element_accumulator=element_accumulator,
kernel_cc=cc)
plan.opclass = opclass
if swizzle is not None:
plan.swizzling_functor = swizzle
td = plan.tile_descriptions()[0]
td.threadblock_shape = threadblock_shape
td.stages = stages
if warp_count is not None:
td.warp_count = warp_count
td.cluster_shape = cluster_shape
op = plan.construct(tile_description=td, alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C)
self.assertTrue(test_all_gemm(op, 'universal', compilation_mode=compilation_mode))
element_epilogue = element_accumulator
name = get_name(
layouts=layouts, alignments=alignments, element_output=element_output, element_accumulator=element_accumulator,
element_epilogue=element_epilogue, cluster_shape=cluster_shape, threadblock_shape=threadblock_shape,
stages=stages, element_a=element, element_b=element, arch=cc, opclass=opclass,
kernel_schedule=kernel_schedule, epilogue_schedule=epilogue_schedule, suffix=f'_{compilation_mode}')
setattr(cls, name, run)

View File

@ -32,7 +32,6 @@
from cutlass.backend.utils.datatypes import *
from cutlass.backend.utils.device import check_cuda_errors, device_cc
from cutlass.backend.utils.reference_model import ReferenceModule
from cutlass.backend.utils.software import (
CheckPackages,
SubstituteTemplate,

View File

@ -34,8 +34,9 @@
Utility functions for converting between frontend datatypes and CUTLASS datatypes
"""
import cutlass_bindings
from cuda import cuda
from cutlass import DataType
from cutlass.backend.utils.software import CheckPackages
numpy_available = CheckPackages().check_numpy()
@ -43,16 +44,16 @@ if numpy_available:
import numpy as np
numpy_to_cutlass_dict = {
np.float16: cutlass_bindings.float16,
np.float32: cutlass_bindings.float32,
np.float64: cutlass_bindings.float64,
np.int8: cutlass_bindings.int8,
np.int32: cutlass_bindings.int32,
np.dtype('float16'): cutlass_bindings.float16,
np.dtype('float32'): cutlass_bindings.float32,
np.dtype('float64'): cutlass_bindings.float64,
np.dtype('int8'): cutlass_bindings.int8,
np.dtype('int32'): cutlass_bindings.int32,
np.float16: DataType.f16,
np.float32: DataType.f32,
np.float64: DataType.f64,
np.int8: DataType.s8,
np.int32: DataType.s32,
np.dtype('float16'): DataType.f16,
np.dtype('float32'): DataType.f32,
np.dtype('float64'): DataType.f64,
np.dtype('int8'): DataType.s8,
np.dtype('int32'): DataType.s32,
}
@ -67,9 +68,9 @@ if cupy_available:
import cupy as cp
cupy_to_cutlass_dict = {
cp.float16: cutlass_bindings.float16,
cp.float32: cutlass_bindings.float32,
cp.float64: cutlass_bindings.float64,
cp.float16: DataType.f16,
cp.float32: DataType.f32,
cp.float64: DataType.f64,
}
@ -84,12 +85,12 @@ if torch_available:
import torch
torch_to_cutlass_dict = {
torch.half: cutlass_bindings.float16,
torch.float16: cutlass_bindings.float16,
torch.float: cutlass_bindings.float32,
torch.float32: cutlass_bindings.float32,
torch.double: cutlass_bindings.float64,
torch.float64: cutlass_bindings.float64,
torch.half: DataType.f16,
torch.float16: DataType.f16,
torch.float: DataType.f32,
torch.float32: DataType.f32,
torch.double: DataType.f64,
torch.float64: DataType.f64,
}
@ -102,7 +103,7 @@ try:
import bfloat16
bfloat16_available = True
numpy_to_cutlass_dict[np.dtype(bfloat16.bfloat16)] = cutlass_bindings.bfloat16
numpy_to_cutlass_dict[np.dtype(bfloat16.bfloat16)] = DataType.bf16
except ImportError:
bfloat16_available = False
@ -110,7 +111,7 @@ except ImportError:
def bfloat16_to_cutlass(inp):
if bfloat16_available:
if inp == bfloat16.bfloat16:
return cutlass_bindings.bfloat16
return DataType.bf16
def to_cutlass(inp):
@ -127,3 +128,29 @@ def to_cutlass(inp):
raise Exception(
"No available conversion from type {} to a CUTLASS type.".format(inp)
)
def to_device_ptr(tensor) -> cuda.CUdeviceptr:
"""
Converts a tensor to a CUdeviceptr
:param tensor: tensor to convert
:type tensor: np.ndarray | torch.Tensor | cp.ndarray | int
:return: device pointer
:rtype: cuda.CUdeviceptr
"""
if isinstance(tensor, np.ndarray):
ptr = cuda.CUdeviceptr(tensor.__array_interface__["data"][0])
elif torch_available and isinstance(tensor, torch.Tensor):
ptr = cuda.CUdeviceptr(tensor.data_ptr())
elif cupy_available and isinstance(tensor, cp.ndarray):
ptr = cuda.CUdeviceptr(int(tensor.data.ptr))
elif isinstance(tensor, cuda.CUdeviceptr):
ptr = tensor
elif isinstance(tensor, int):
ptr = cuda.CUdeviceptr(tensor)
else:
raise NotImplementedError(tensor)
return ptr

View File

@ -1,317 +0,0 @@
#################################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from typing import Union
from bfloat16 import bfloat16
import cutlass_bindings
import numpy as np
from cutlass.backend.library import TensorDescription
from cutlass.backend.utils.software import CheckPackages
torch_available = CheckPackages().check_torch()
if torch_available:
import torch
class ReferenceModule:
def __init__(
self, A: TensorDescription, B: TensorDescription, C: TensorDescription
) -> None:
self.layout_A = A.layout
self.layout_B = B.layout
self.layout_C = C.layout
def run(
self,
A: np.ndarray,
B: np.ndarray,
C: np.ndarray,
problem_size: cutlass_bindings.gemm.GemmCoord,
alpha: float = 1.0,
beta: float = 0.0,
bias=False,
batch=1,
):
"""
Compute the reference result on CPU
Args:
A: dense operator with shape (M, K) in row-major and (K, M) in column-major
B: dense operator with shape (K, N) in row-major and (N, K) in column-major
C: dense operator with shape (M, N) in row-major and (N, M) in column-major
"""
M, N, K = problem_size.m(), problem_size.n(), problem_size.k()
if isinstance(A, np.ndarray):
if self.layout_A == cutlass_bindings.RowMajor:
A_row = np.reshape(A, newshape=(batch, M, K))
else:
A_col = np.reshape(A, newshape=(batch, K, M))
A_row = np.transpose(A_col, axes=(0, 2, 1))
if self.layout_B == cutlass_bindings.RowMajor:
B_row = np.reshape(B, newshape=(batch, K, N))
else:
B_col = np.reshape(B, newshape=(batch, N, K))
B_row = np.transpose(B_col, axes=(0, 2, 1))
if self.layout_C == cutlass_bindings.RowMajor:
if bias:
C_row = np.reshape(C, newshape=(batch, 1, N))
else:
C_row = np.reshape(C, newshape=(batch, M, N))
else:
if bias:
C_row = np.reshape(C, newshape=(batch, M, 1))
else:
C_col = np.reshape(C, newshape=(batch, N, M))
C_row = np.transpose(C_col, axes=(0, 2, 1))
if A_row.dtype == bfloat16:
# numpy's einsum doesn't support bfloat16
out_row = (
np.einsum(
"bik,bkj->bij",
A_row.astype(np.float32),
B_row.astype(np.float32),
)
* alpha
+ C_row * beta
)
out_row = out_row.astype(C_row.dtype)
else:
out_row = np.einsum("bik,bkj->bij", A_row, B_row) * alpha + C_row * beta
if self.layout_C == cutlass_bindings.ColumnMajor:
out = np.transpose(out_row, axes=(0, 2, 1))
else:
out = out_row
return out.ravel()
elif isinstance(A, torch.Tensor):
if self.layout_A == cutlass_bindings.RowMajor:
A_row = A.view((M, K))
else:
A_col = A.view((K, M))
A_row = torch.permute(A_col, (1, 0))
if self.layout_B == cutlass_bindings.RowMajor:
B_row = B.view((K, N))
else:
B_col = B.view((N, K))
B_row = torch.permute(B_col, (1, 0))
if self.layout_C == cutlass_bindings.RowMajor:
C_row = C.view((M, N))
else:
C_col = C.view((N, M))
C_row = torch.permute(C_col, (1, 0))
out_row = torch.matmul(A_row, B_row) * alpha + C_row * beta
if self.layout_C == cutlass_bindings.ColumnMajor:
out = torch.permute(out_row, (1, 0))
else:
out = out_row
return torch.flatten(out)
#####################################################################################################
# Conv2d
#####################################################################################################
if torch_available:
import torch
class Conv2dReferenceModule:
def __init__(
self,
A: TensorDescription,
B: TensorDescription,
C: TensorDescription,
kind: cutlass_bindings.conv.Operator.fprop,
) -> None:
self.layout_A = A.layout
self.layout_B = B.layout
self.layout_C = C.layout
self.kind = kind
def run(
self,
A: Union[np.ndarray, torch.Tensor],
B: Union[np.ndarray, torch.Tensor],
C: Union[np.ndarray, torch.Tensor],
problem_size,
alpha=1.0,
beta=0.0,
bias=False,
) -> np.ndarray:
"""
Compute the reference result on CPU
"""
n = problem_size.N
h = problem_size.H
w = problem_size.W
c = problem_size.C
k = problem_size.K
r = problem_size.R
s = problem_size.S
p = problem_size.P
q = problem_size.Q
stride_h = problem_size.stride_h
stride_w = problem_size.stride_w
pad_h = problem_size.pad_h
pad_w = problem_size.pad_w
dilation_h = problem_size.dilation_h
dilation_w = problem_size.dilation_w
groups = problem_size.groups
if isinstance(A, np.ndarray):
# the pytorch activation layout is NCHW
# weight layout is Cout Cin Kh Kw (also NCHW)
if self.layout_A == cutlass_bindings.TensorNHWC:
A_nhwc = np.reshape(A, newshape=(n, h, w, c))
A_torch_nhwc = torch.from_numpy(A_nhwc).to("cuda")
A_torch_nchw = torch.permute(A_torch_nhwc, (0, 3, 1, 2))
if self.layout_B == cutlass_bindings.TensorNHWC:
B_nhwc = np.reshape(B, newshape=(k, r, s, c))
B_torch_nhwc = torch.from_numpy(B_nhwc).to("cuda")
B_torch_nchw = torch.permute(B_torch_nhwc, (0, 3, 1, 2))
if self.layout_C == cutlass_bindings.TensorNHWC:
C_nhwc = np.reshape(C, newshape=(n, p, q, k))
C_torch_nhwc = torch.from_numpy(C_nhwc).to("cuda")
C_torch_nchw = torch.permute(C_torch_nhwc, (0, 3, 1, 2))
elif isinstance(A, torch.Tensor):
if self.kind == cutlass_bindings.conv.Operator.wgrad:
if self.layout_A == cutlass_bindings.TensorNHWC:
A_nhwc = A.view((n, p, q, k))
A_torch_nchw = torch.permute(A_nhwc, (0, 3, 1, 2))
if self.layout_B == cutlass_bindings.TensorNHWC:
B_nhwc = B.view((n, h, w, c))
B_torch_nchw = torch.permute(B_nhwc, (0, 3, 1, 2))
if self.layout_C == cutlass_bindings.TensorNHWC:
if bias:
C_nhwc = C.view((1, 1, 1, c))
else:
C_nhwc = C.view((k, r, s, c))
C_torch_nchw = torch.permute(C_nhwc, (0, 3, 1, 2))
elif self.kind == cutlass_bindings.conv.Operator.dgrad:
if self.layout_A == cutlass_bindings.TensorNHWC:
A_nhwc = A.view((n, p, q, k))
A_torch_nchw = torch.permute(A_nhwc, (0, 3, 1, 2))
if self.layout_B == cutlass_bindings.TensorNHWC:
B_nhwc = B.view((k, r, s, c))
B_torch_nchw = torch.permute(B_nhwc, (0, 3, 1, 2))
if self.layout_C == cutlass_bindings.TensorNHWC:
if bias:
C_nhwc = C.view((1, 1, 1, c))
else:
C_nhwc = C.view((n, h, w, c))
C_torch_nchw = torch.permute(C_nhwc, (0, 3, 1, 2))
else:
if self.layout_A == cutlass_bindings.TensorNHWC:
A_nhwc = A.view((n, h, w, c))
A_torch_nchw = torch.permute(A_nhwc, (0, 3, 1, 2))
if self.layout_B == cutlass_bindings.TensorNHWC:
B_nhwc = B.view((k, r, s, c))
B_torch_nchw = torch.permute(B_nhwc, (0, 3, 1, 2))
if self.layout_C == cutlass_bindings.TensorNHWC:
if bias:
C_nhwc = C.view((1, 1, 1, k))
else:
C_nhwc = C.view((n, p, q, k))
C_torch_nchw = torch.permute(C_nhwc, (0, 3, 1, 2))
if self.kind == cutlass_bindings.conv.Operator.fprop:
D_torch_nchw = (
alpha
* torch.nn.functional.conv2d(
A_torch_nchw,
B_torch_nchw,
stride=(stride_h, stride_w),
padding=(pad_h, pad_w),
dilation=(dilation_h, dilation_w),
groups=groups,
)
+ beta * C_torch_nchw
)
elif self.kind == cutlass_bindings.conv.Operator.dgrad:
D_torch_nchw = (
alpha
* torch.nn.grad.conv2d_input(
(n, c, h, w),
B_torch_nchw,
A_torch_nchw,
padding=(pad_h, pad_w),
stride=(stride_h, stride_w),
).to(torch.float32)
+ beta * C_torch_nchw
)
elif self.kind == cutlass_bindings.conv.Operator.wgrad:
D_torch_nchw = (
alpha
* torch.nn.grad.conv2d_weight(
B_torch_nchw,
(k, c, r, s),
A_torch_nchw,
padding=(pad_h, pad_w),
stride=(stride_h, stride_w),
).to(torch.float32)
+ beta * C_torch_nchw
)
if self.layout_C == cutlass_bindings.TensorNHWC:
if isinstance(A, np.ndarray):
D_torch_out = (
torch.permute(D_torch_nchw, (0, 2, 3, 1)).detach().cpu().numpy()
)
elif isinstance(A, torch.Tensor):
D_torch_out = torch.permute(D_torch_nchw, (0, 2, 3, 1))
return D_torch_out.flatten()

View File

@ -1,3 +1,6 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
@ -85,10 +88,7 @@ def SubstituteTemplate(template, values):
return text
# this._device_sm_count = None
def device_sm_count():
# Query the number of SMs, if needed
# if this._device_sm_count is None:
from cuda import cuda
_device = 0

View File

@ -1,75 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief In-memory compiled artifact cache
*/
#include <pybind11/pybind11.h>
#include <string>
#include <unordered_map>
namespace py = pybind11;
namespace cutlass {
struct CompileCache {
public:
CompileCache() = default;
~CompileCache() = default;
using Cache = std::unordered_map<std::string, py::object>;
/// Check if the kernel has already been compiled
py::object at(const std::string &kernel) {
auto item = cache_.find(kernel);
if (item != cache_.end()) {
return item->second;
}
return py::none();
}
/// Insert a new compiled kernel for new configuration
void insert(const std::string &kernel, const py::object &compiled_kernel){
cache_.emplace(kernel, compiled_kernel);
}
const int64_t size() const { return cache_.size(); }
/// Clear the cache
void clear() { cache_.clear(); }
private:
Cache cache_;
};
} // namespace cutlass

View File

@ -1,182 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief binding CUTLASS C++ APIs to Python
*/
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "builtin_types.h"
#include "device_launch_parameters.h"
#include "stddef.h"
#include "cutlass/cutlass.h"
#include "include/conv/convolution.h"
#include "include/gemm/gemm.h"
#include "include/types.h"
#include "include/layout/layout.h"
#include "include/tensor_coord.h"
#include "include/arch.h"
#include "include/tensor_ref_view.h"
#include "include/swizzling.h"
#include "test/conv/convolution.h"
#include "test/gemm/gemm.h"
// Data Types
#include "library.h"
// compiler
#include "compiler.h"
namespace py = pybind11;
PYBIND11_MODULE(cutlass_bindings, m) {
// module doc
m.doc() = "CUTLASS C++ binding";
//
// Bind data type
//
bind_cutlass_types(m);
//
// Bind layout
//
bind_layout(m);
//
// Bind tensor coord
//
bind_tensor_coord(m);
//
// Bind tensor ref
//
bind_tensor_refs_and_views(m);
//
// Bind opcode
//
bind_opcode(m);
//
// Bind convolution
//
py::module_ conv_submodule = m.def_submodule("conv");
bind_convolution(conv_submodule);
//
// Bind gemm
//
py::module_ gemm_submodule = m.def_submodule("gemm");
bind_gemm(gemm_submodule);
//
// Bind swizzling
//
bind_threadblock_swizzle(m);
//
// Bind test units
//
py::module_ test = m.def_submodule("test");
py::module_ test_conv = test.def_submodule("conv");
bind_convolution_test(test_conv);
py::module_ test_gemm = test.def_submodule("gemm");
bind_gemm_test(test_gemm);
// data types
py::enum_<cutlass::DataType>(m, "dtype")
.value("b1", cutlass::DataType::kB1)
.value("u2", cutlass::DataType::kU2)
.value("u4", cutlass::DataType::kU4)
.value("u8", cutlass::DataType::kU8)
.value("u16", cutlass::DataType::kU16)
.value("u32", cutlass::DataType::kU32)
.value("u64", cutlass::DataType::kU64)
.value("s2", cutlass::DataType::kS2)
.value("s4", cutlass::DataType::kS4)
.value("s16", cutlass::DataType::kS16)
.value("s64", cutlass::DataType::kS64)
.value("cf16", cutlass::DataType::kCF16)
.value("cbf16", cutlass::DataType::kCBF16)
.value("cf32", cutlass::DataType::kCF32)
.value("ctf32", cutlass::DataType::kCTF32)
.value("cf64", cutlass::DataType::kCF64)
.value("cs2", cutlass::DataType::kCS2)
.value("cs4", cutlass::DataType::kCS4)
.value("cs8", cutlass::DataType::kCS8)
.value("cs16", cutlass::DataType::kCS16)
.value("cs32", cutlass::DataType::kCS32)
.value("cs64", cutlass::DataType::kCS64)
.value("cu2", cutlass::DataType::kCU2)
.value("cu4", cutlass::DataType::kCU4)
.value("cu8", cutlass::DataType::kCU8)
.value("cu16", cutlass::DataType::kCU16)
.value("cu32", cutlass::DataType::kCU32)
.value("cu64", cutlass::DataType::kCU64)
.value("invalid", cutlass::DataType::kInvalid);
// layout types
py::enum_<cutlass::LayoutType>(m, "layout")
.value("ColumnMajorInterleaved2", cutlass::LayoutType::kColumnMajorInterleaved2)
.value("RowMajorInterleaved2", cutlass::LayoutType::kRowMajorInterleaved2)
.value("ColumnMajorInterleaved64", cutlass::LayoutType::kColumnMajorInterleaved64)
.value("RowMajorInterleaved64", cutlass::LayoutType::kRowMajorInterleaved64)
.value("TensorNDHWC", cutlass::LayoutType::kTensorNDHWC)
.value("TensorNCHW", cutlass::LayoutType::kTensorNCHW)
.value("TensorNGHWC", cutlass::LayoutType::kTensorNGHWC)
.value("TensorNC64HW64", cutlass::LayoutType::kTensorNC64HW64)
.value("TensorC64RSK64", cutlass::LayoutType::kTensorC64RSK64);
// transform types
py::enum_<cutlass::ComplexTransform>(m, "complex_transform")
.value("none", cutlass::ComplexTransform::kNone)
.value("conj", cutlass::ComplexTransform::kConjugate);
//
// Compiler
//
py::class_<cutlass::CompileCache>(m, "CompileCache")
.def(py::init<>())
.def("at", &cutlass::CompileCache::at)
.def("insert", &cutlass::CompileCache::insert)
.def("size", &cutlass::CompileCache::size)
.def("clear", &cutlass::CompileCache::clear);
}

View File

@ -1,59 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind opcode classes to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/arch/mma.h"
namespace py = pybind11;
namespace cutlass {
enum class OpcodeClass {
kSimt, kTensorOp, kWmmaTensorOp, kSparseTensorOp
};
}
void bind_opcode(py::module &m) {
py::enum_<cutlass::OpcodeClass>(m, "OpClass",
R"pbdoc(classification of math operators)pbdoc")
.value("Simt", cutlass::OpcodeClass::kSimt,
R"pbdoc(Tag classifying math operators as thread-level operations)pbdoc")
.value("TensorOp", cutlass::OpcodeClass::kTensorOp,
R"pbdoc(Tag classifying operators as Tensor Core operations)pbdoc")
.value("WmmaTensorOp", cutlass::OpcodeClass::kWmmaTensorOp,
R"pbdoc(Tag classifying operators as WMMA Tensor Core operations)pbdoc")
.value("SparseTensorOp", cutlass::OpcodeClass::kSparseTensorOp,
R"pbdoc(Tag classifying operators as sparseTensor Core operations)pbdoc");
}

View File

@ -1,102 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind Convolution problem sizes to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/conv/conv2d_problem_size.h"
namespace py = pybind11;
void bind_conv_problem_size(py::module &m) {
//
// Conv2d Problem Size:
// include/cutlass/conv/conv2d_problem_size.h
//
py::class_<cutlass::conv::Conv2dProblemSize>(m, "Conv2dProblemSize")
// constructors
.def(py::init<int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, cutlass::conv::Mode, int, int>())
.def(py::init<cutlass::Tensor4DCoord, cutlass::Tensor4DCoord, cutlass::Tensor4DCoord, cutlass::MatrixCoord, cutlass::MatrixCoord, cutlass::conv::Mode, int, int>())
// attribute accessors
.def_readwrite("N", &cutlass::conv::Conv2dProblemSize::N)
.def_readwrite("H", &cutlass::conv::Conv2dProblemSize::H)
.def_readwrite("W", &cutlass::conv::Conv2dProblemSize::W)
.def_readwrite("C", &cutlass::conv::Conv2dProblemSize::C)
.def_readwrite("P", &cutlass::conv::Conv2dProblemSize::P)
.def_readwrite("Q", &cutlass::conv::Conv2dProblemSize::Q)
.def_readwrite("K", &cutlass::conv::Conv2dProblemSize::K)
.def_readwrite("R", &cutlass::conv::Conv2dProblemSize::R)
.def_readwrite("S", &cutlass::conv::Conv2dProblemSize::S)
.def_readwrite("pad_h", &cutlass::conv::Conv2dProblemSize::pad_h)
.def_readwrite("pad_w", &cutlass::conv::Conv2dProblemSize::pad_w)
.def_readwrite("stride_h", &cutlass::conv::Conv2dProblemSize::stride_h)
.def_readwrite("stride_w", &cutlass::conv::Conv2dProblemSize::stride_w)
.def_readwrite("dilation_h", &cutlass::conv::Conv2dProblemSize::dilation_h)
.def_readwrite("dilation_w", &cutlass::conv::Conv2dProblemSize::dilation_w)
.def_readwrite("mode", &cutlass::conv::Conv2dProblemSize::mode)
.def_readwrite("split_k_slices", &cutlass::conv::Conv2dProblemSize::split_k_slices)
.def_readwrite("groups", &cutlass::conv::Conv2dProblemSize::groups)
// functions
.def("reset_split_k_slices", &cutlass::conv::Conv2dProblemSize::reset_split_k_slices)
.def("activation_extent", &cutlass::conv::Conv2dProblemSize::activation_extent)
.def("filter_extent", &cutlass::conv::Conv2dProblemSize::filter_extent)
.def("output_extent", &cutlass::conv::Conv2dProblemSize::output_extent)
.def("activation_size", &cutlass::conv::Conv2dProblemSize::activation_size)
.def("filter_size", &cutlass::conv::Conv2dProblemSize::filter_size)
.def("output_size", &cutlass::conv::Conv2dProblemSize::output_size);
// Get tensor size
m.def("implicit_gemm_tensor_a_size", py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&>(&cutlass::conv::implicit_gemm_tensor_a_size));
m.def("implicit_gemm_tensor_b_size", py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&>(&cutlass::conv::implicit_gemm_tensor_b_size));
m.def("implicit_gemm_tensor_c_size", py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&>(&cutlass::conv::implicit_gemm_tensor_c_size));
// Get tensor extent
m.def("implicit_gemm_tensor_a_extent",
py::overload_cast<
cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&
>(&cutlass::conv::implicit_gemm_tensor_a_extent));
m.def("implicit_gemm_tensor_b_extent",
py::overload_cast<
cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&
>(&cutlass::conv::implicit_gemm_tensor_b_extent));
m.def("implicit_gemm_tensor_c_extent",
py::overload_cast<
cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&
>(&cutlass::conv::implicit_gemm_tensor_c_extent));
m.def("implicit_gemm_problem_size", py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize &>(&cutlass::conv::implicit_gemm_problem_size));
}

View File

@ -1,91 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind convolution related enum types to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "conv_problem_size.h"
#include "host.h"
#include "cutlass/conv/convolution.h"
namespace py = pybind11;
void bind_convolution(py::module &m) {
//
// Enumerate types
// cutlass/include/cutlass/conv/convolution.h
//
/// Convolutional operator
py::enum_<cutlass::conv::Operator>(m, "Operator", R"pbdoc(Convolutional operator)pbdoc")
.value("fprop", cutlass::conv::Operator::kFprop, "Forward propagation")
.value("dgrad", cutlass::conv::Operator::kDgrad, "Activation grad")
.value("wgrad", cutlass::conv::Operator::kWgrad, "Weight grad");
/// Distinguishes convolution from cross correlation
py::enum_<cutlass::conv::Mode>(m, "Mode")
.value("cross_correlation", cutlass::conv::Mode::kCrossCorrelation)
.value("convolution", cutlass::conv::Mode::kConvolution);
/// Selects among several implementation variants trading off performance with simplicity
py::enum_<cutlass::conv::IteratorAlgorithm>(m, "IteratorAlgorithm",
R"pbdoc(Selects among several implementation variants trading off performance with simplicity)pbdoc")
.value("analytic", cutlass::conv::IteratorAlgorithm::kAnalytic, R"pbdoc(functionally correct in all cases but lower performance)pbdoc")
.value("optimized", cutlass::conv::IteratorAlgorithm::kOptimized, R"pbdoc(optimized for R <= 32, S <= 32 and unity-stride dgrad)pbdoc")
.value("fixed_channels", cutlass::conv::IteratorAlgorithm::kFixedChannels, R"pbdoc(Analytic algorithm optimized for fixed channel count (C == AccessSize))pbdoc")
.value("few_channels", cutlass::conv::IteratorAlgorithm::kFewChannels, R"pbdoc(Analytic algorithm optimized for few channels (C divisible by AccessSize))pbdoc");
/// Distinguishes among partial specializations that accelerate certain problems where convolution
/// stride is unit.
py::enum_<cutlass::conv::StrideSupport>(m, "StrideSupport",
R"pbdoc(Distinguishes among partial specializations that accelerate certain problems where convolution
stride is unit.)pbdoc")
.value("strided", cutlass::conv::StrideSupport::kStrided, R"pbdoc(arbitrary convolution stride)pbdoc")
.value("unity", cutlass::conv::StrideSupport::kUnity, R"pbdoc(unit convolution stride)pbdoc");
/// Identifies split-K mode
py::enum_<cutlass::conv::SplitKMode>(m, "SplitKMode")
.value("None", cutlass::conv::SplitKMode::kNone)
.value("Serial", cutlass::conv::SplitKMode::kSerial)
.value("Parallel", cutlass::conv::SplitKMode::kParallel);
// Conv problem sizes
bind_conv_problem_size(m);
//
// host helper functions
//
py::module_ host_submodule = m.def_submodule("host");
bind_conv_host_helper(host_submodule);
}

View File

@ -1,54 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind conv host helpers to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/util/host_reorder.h"
#include "cutlass/layout/tensor.h"
namespace py = pybind11;
void bind_conv_host_helper(py::module &m) {
/// reorder operand B for interleaved layout
m.def("reorder_convK", [](
cutlass::TensorRef<int8_t, cutlass::layout::TensorCxRSKx<32>> dest,
cutlass::TensorRef<int8_t, cutlass::layout::TensorCxRSKx<32>> src,
cutlass::conv::Operator conv_op, const cutlass::conv::Conv2dProblemSize & problem_size) {
cutlass::gemm::GemmCoord implicit_problem_size = cutlass::conv::implicit_gemm_problem_size(conv_op, problem_size);
cutlass::reorder_convK<32>(dest, src, implicit_problem_size);
});
}

View File

@ -1,222 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A generic wrapper around an epilogue visitor operation
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/memory.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/gemm/kernel/default_gemm.h"
#include "cutlass/gemm/kernel/default_gemm_complex.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h"
#include "epilogue_visitor_op/visitor_op_linear_combination.h"
#include "epilogue_visitor_op/visitor_op_tensor_input.h"
#include "epilogue_visitor_op/visitor_op_accumulator.h"
#include "epilogue_visitor_op/visitor_op_row_broadcast.h"
#include "epilogue_visitor_op/visitor_op_tensor_output.h"
#include "epilogue_visitor_op/visitor_op_column_reduction.h"
#include "epilogue_visitor_op/visitor_op_row_reduction.h"
#include "epilogue_visitor_op/visitor_op_column_broadcast.h"
#include "epilogue_visitor_op/visitor_op_unary.h"
#include "epilogue_visitor_op/visitor_op_binary.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Generic Epilogue Visitor.
template <
typename OutputOp_
>
class EpilogueVisitorGeneric {
public:
using OutputOp = OutputOp_;
using AccumulatorAccessType = typename OutputOp::AccumulatorAccessType;
static int const kElementsPerAccess = OutputOp::kElementsPerAccess;
using ElementOutput = typename OutputOp::ElementOutput;
using OutputTileIterator = typename OutputOp::OutputTileIterator;
static int const kIterations = OutputTileIterator::kIterations;
///
/// End Epilogue Tree
///
/// Additional SMEM bufer is not required in the broadcast epilogue visitor
struct SharedStorage {
typename OutputOp::SharedStorage output_smem;
CUTLASS_HOST_DEVICE
SharedStorage() { }
};
public:
/// Argument structure
struct Arguments {
typename OutputOp::Arguments output_op_args;
//
// Methods
//
Arguments() { }
Arguments(
typename OutputOp::Arguments output_op_args
):
output_op_args(output_op_args)
{
}
};
struct Params {
typename OutputOp::Params output_op_params;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params() { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
output_op_params(args.output_op_args)
{
}
};
private:
OutputOp output_op;
public:
/// Constructor
CUTLASS_DEVICE
EpilogueVisitorGeneric(
Params const &params, ///< Parameters routed to the epilogue
SharedStorage &shared_storage, ///< Shared storage needed by the functors here
MatrixCoord threadblock_offset,
gemm::GemmCoord threadblock_tile_offset,
int thread_idx,
MatrixCoord problem_size
):
output_op(params.output_op_params, shared_storage.output_smem, thread_idx, threadblock_offset, problem_size)
{ }
/// Helper to indicate split-K behavior
CUTLASS_DEVICE
void set_k_partition(
int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
int split_k_slices) { ///< Total number of split-K slices
}
/// Called to set the batch index
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
output_op.set_batch_index(batch_idx);
}
/// Called at the start of the epilogue just before iterating over accumulator slices
CUTLASS_DEVICE
void begin_epilogue() {
output_op.begin_epilogue();
}
/// Called at the start of one step before starting accumulator exchange
CUTLASS_DEVICE
void begin_step(int step_idx) {
output_op.begin_step(step_idx);
}
/// Called at the start of a row
CUTLASS_DEVICE
void begin_row(int row_idx) {
output_op.begin_row(row_idx);
}
/// Called after accumulators have been exchanged for each accumulator vector
CUTLASS_DEVICE
void visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum) {
output_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
}
/// Called at the start of a row
CUTLASS_DEVICE
void end_row(int row_idx) {
output_op.end_row(row_idx);
}
/// Called after all accumulator elements have been visited
CUTLASS_DEVICE
void end_step(int step_idx) {
output_op.end_step(step_idx);
}
/// Called after all steps have been completed
CUTLASS_DEVICE
void end_epilogue() {
output_op.end_epilogue();
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -1,84 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the binary ops
*/
#pragma once
#include "cutlass/cutlass.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Scalar multiplication
template <typename T, int N>
struct VectorAdd {
struct Arguments {
int tmp;
CUTLASS_HOST_DEVICE
Arguments():tmp(0){ }
CUTLASS_HOST_DEVICE
Arguments(int tmp): tmp(tmp) { }
};
struct Params {
CUTLASS_HOST_DEVICE
Params(Arguments const &args) { }
};
CUTLASS_HOST_DEVICE
VectorAdd(
Params const &params
) { }
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
cutlass::plus<Array<T, N>> add_op;
return add_op(lhs, rhs);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,233 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the unary ops
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/thread/activation.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Scalar multiplication
template <typename T, int N>
struct Mult {
struct Arguments {
T alpha;
CUTLASS_HOST_DEVICE
Arguments():alpha(T(1.0)){ }
CUTLASS_HOST_DEVICE
Arguments(T alpha): alpha(alpha) { }
};
struct Params {
T alpha; ///< scales accumulators
CUTLASS_HOST_DEVICE
Params():alpha(T(1.0)){ }
CUTLASS_HOST_DEVICE
Params(Arguments const &args): alpha(args.alpha) { }
};
T alpha_;
CUTLASS_HOST_DEVICE
Mult(
Params const &params
):
alpha_(params.alpha)
{ }
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &source) const {
cutlass::multiplies<Array<T, N>> multiply_op;
return multiply_op(source, alpha_);
}
CUTLASS_HOST_DEVICE
bool guard() {
return alpha_ != T(0);
}
};
/// ReLU
template <typename T, int N>
struct ReLUVisitor {
struct Arguments {
T threshold;
CUTLASS_HOST_DEVICE
Arguments():threshold(T(0.0)) { }
CUTLASS_HOST_DEVICE
Arguments(T threshold): threshold(threshold) { }
};
struct Params {
T threshold;
CUTLASS_HOST_DEVICE
Params():threshold(T(0.0)) { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args): threshold(args.threshold) { }
};
T threshold_;
CUTLASS_HOST_DEVICE
ReLUVisitor(Params const &params):
threshold_(params.threshold) { }
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &frag) const {
maximum<Array<T, N>> mx;
return mx(frag, threshold_);
}
CUTLASS_HOST_DEVICE
bool guard() {
return true;
}
};
/// leakyReLU
template <typename T, int N>
struct LeakyReLUVisitor {
struct Arguments {
T leaky_alpha;
CUTLASS_HOST_DEVICE
Arguments():leaky_alpha(T(0.0)) { }
CUTLASS_HOST_DEVICE
Arguments(T leaky_alpha): leaky_alpha(leaky_alpha) { }
};
struct Params {
T leaky_alpha;
CUTLASS_HOST_DEVICE
Params():leaky_alpha(T(0.0)) { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args): leaky_alpha(args.leaky_alpha) { }
};
T leaky_alpha_;
CUTLASS_HOST_DEVICE
LeakyReLUVisitor(Params const &params):
leaky_alpha_(params.leaky_alpha) { }
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &frag) const {
cutlass::epilogue::thread::LeakyReLU<Array<T, N>> leaky_op;
return leaky_op(frag, leaky_alpha_);
}
CUTLASS_HOST_DEVICE
bool guard() {
return true;
}
};
/// Tanh
template <typename T, int N>
struct TanhVisitor {
/// Argument
struct Arguments {
// a placeholder argument to ensure correctness of ctypes
int tmp;
CUTLASS_HOST_DEVICE
Arguments(): tmp(0) { };
CUTLASS_HOST_DEVICE
Arguments(int tmp): tmp(tmp) { };
};
/// Param
struct Params {
CUTLASS_HOST_DEVICE
Params(){ };
Params(Arguments const &args) { }
};
/// Constructor
CUTLASS_HOST_DEVICE
TanhVisitor(Params const &params) { }
// scalar operator
CUTLASS_HOST_DEVICE
T tanh_op(T const &scalar) const {
return fast_tanh(scalar);
}
/// vector operator
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &frag) const {
Array<T, N> y;
CUTLASS_PRAGMA_UNROLL
for (int i=0; i < N; ++i) {
y[i] = tanh_op(frag[i]);
}
return y;
}
CUTLASS_HOST_DEVICE
bool guard() {
return true;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,148 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor Op with accumulator
*/
#pragma once
#include "cutlass/cutlass.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue Visitor operator for the following Computation
///
/// ElementAccumulator accum;
/// return accum;
///
/// It can only be the leaf node of the epilogue tree
template <
typename ElementAccumulator_, ///< Data type of the Accumulator
int kElementsPerAccess_ ///< Number of elements computed per operation
>
class VisitorOpAccumulator{
public:
using ElementAccumulator = ElementAccumulator_;
static int const kElementsPerAccess = kElementsPerAccess_;
/// Fragment type for Accumulator
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
/// Fragment type returned by this visitor
using VisitAccessType = AccumulatorAccessType;
/// SMEM buffer class required in the epilogue visitor
struct SharedStorage {
CUTLASS_HOST_DEVICE
SharedStorage() {}
};
/// Host-constructable Arguments structure
struct Arguments {
// Note: it is strange that ctypes will return issue with empty arguments
int tmp;
CUTLASS_HOST_DEVICE
Arguments() { }
CUTLASS_HOST_DEVICE
Arguments(int tmp): tmp(tmp) { }
};
/// Parameter structure
struct Params {
CUTLASS_HOST_DEVICE
Params() { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args) { }
};
public:
/// Constructs the function object
CUTLASS_HOST_DEVICE
VisitorOpAccumulator(
Params const &params,
SharedStorage &shared_storage,
int thread_idx,
MatrixCoord threadblock_offset,
MatrixCoord problem_size
) { }
CUTLASS_DEVICE
void set_batch_index(int batch_idx) { }
CUTLASS_DEVICE
void begin_epilogue() { }
CUTLASS_DEVICE
void begin_step(int step_idx) { }
CUTLASS_DEVICE
void begin_row(int row_idx) { }
CUTLASS_DEVICE
VisitAccessType visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum
) {
return accum;
}
CUTLASS_DEVICE
void end_row(int row_idx) { }
CUTLASS_DEVICE
void end_step(int step_idx) { }
CUTLASS_DEVICE
void end_epilogue() { }
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,245 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor Op with Binary op
*/
#pragma once
#include "cutlass/cutlass.h"
#include "binary_ops.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue Visitor operator for the following computation:
///
/// ElementCompute alpha;
/// ElementCompute beta;
/// ElementCompute C = BinaryOp(alpha * ElementCompute(Visitor_A), beta * ElementCompute(Visitor_B)
/// Return C;
///
template <
typename ElementAccumulator_, ///< Data type of the Accumulator
typename ElementCompute_, ///< Data type used to compute linear combination
int kElementsPerAccess_, ///< Number of elements computed per operation
typename VisitorA_, ///< Child node A
typename VisitorB_, ///< Child node B
template<typename T, int N> typename BinaryOp_
>
class VisitorOpBinary{
public:
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
static int const kElementsPerAccess = kElementsPerAccess_;
using VisitorA = VisitorA_;
using VisitorB = VisitorB_;
/// Fragment type returned from VisitorA.visit
using VisitAccessTypeA = typename VisitorA::VisitAccessType;
using ElementA = typename VisitAccessTypeA::Element;
/// Fragment type returned from VisitorB.visit
using VisitAccessTypeB = typename VisitorB::VisitAccessType;
using ElementB = typename VisitAccessTypeB::Element;
/// Fragment type returned by this visitor
using VisitAccessType = Array<ElementCompute, kElementsPerAccess>;
/// Fragment type of accumulator
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
using BinaryOp = BinaryOp_<ElementCompute, kElementsPerAccess>;
static_assert(kElementsPerAccess==VisitAccessTypeA::kElements, "kElementsPerAccess mismatches with Visitor A");
static_assert(kElementsPerAccess==VisitAccessTypeB::kElements, "kElementsPerAccess mismatches with Visitor B");
/// SMEM buffer class required in the epilogue visitor
struct SharedStorage {
typename VisitorA::SharedStorage storage_a;
typename VisitorB::SharedStorage storage_b;
CUTLASS_HOST_DEVICE
SharedStorage() {}
};
/// Host-constructable Arguments structure
struct Arguments {
typename BinaryOp::Arguments binary_arg;
typename VisitorA::Arguments visitor_a_arg; ///< Argument type for visitor_a
typename VisitorB::Arguments visitor_b_arg; ///< Argument type for visitor_b
//
// Methods
//
CUTLASS_HOST_DEVICE
Arguments():binary_arg() { }
CUTLASS_HOST_DEVICE
Arguments(
typename BinaryOp::Arguments binary_arg,
typename VisitorA::Arguments visitor_a_arg,
typename VisitorB::Arguments visitor_b_arg
):
binary_arg(binary_arg),
visitor_a_arg(visitor_a_arg),
visitor_b_arg(visitor_b_arg)
{ }
};
/// Parameter structure
struct Params {
typename BinaryOp::Params binary_param;
typename VisitorA::Params visitor_a_param; ///< Argument type for visitor_a
typename VisitorB::Params visitor_b_param; ///< Argument type for visitor_b
//
// Methods
//
CUTLASS_HOST_DEVICE
Params() { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
binary_param(args.binary_arg),
visitor_a_param(args.visitor_a_arg),
visitor_b_param(args.visitor_b_arg)
{ }
};
private:
//
// Data members
//
BinaryOp binary_op;
VisitorA visitor_a_op;
VisitorB visitor_b_op;
public:
/// Constructs the function object
CUTLASS_HOST_DEVICE
VisitorOpBinary(
Params const &params,
SharedStorage &shared_storage,
int thread_idx,
MatrixCoord threadblock_offset,
MatrixCoord problem_size
):
binary_op(params.binary_param),
visitor_a_op(params.visitor_a_param, shared_storage.storage_a, thread_idx, threadblock_offset, problem_size),
visitor_b_op(params.visitor_b_param, shared_storage.storage_b, thread_idx, threadblock_offset, problem_size)
{ }
CUTLASS_DEVICE
void begin_epilogue() {
visitor_a_op.begin_epilogue();
visitor_b_op.begin_epilogue();
}
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
visitor_a_op.set_batch_index(batch_idx);
visitor_b_op.set_batch_index(batch_idx);
}
CUTLASS_DEVICE
void begin_step(int step_idx) {
visitor_a_op.begin_step(step_idx);
visitor_b_op.begin_step(step_idx);
}
CUTLASS_DEVICE
void begin_row(int row_idx) {
visitor_a_op.begin_row(row_idx);
visitor_b_op.begin_row(row_idx);
}
CUTLASS_DEVICE
VisitAccessType visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum
) {
/// Get result from visitor A and visitor B
VisitAccessTypeA result_A = visitor_a_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
VisitAccessTypeB result_B = visitor_b_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
/// Type conversion
NumericArrayConverter<ElementCompute, ElementA, kElementsPerAccess> source_converter_A;
NumericArrayConverter<ElementCompute, ElementB, kElementsPerAccess> source_converter_B;
return binary_op(
source_converter_A(result_A),
source_converter_B(result_B)
);
}
CUTLASS_DEVICE
void end_row(int row_idx) {
visitor_a_op.end_row(row_idx);
visitor_b_op.end_row(row_idx);
}
CUTLASS_DEVICE
void end_step(int step_idx) {
visitor_a_op.end_step(step_idx);
visitor_b_op.end_step(step_idx);
}
CUTLASS_DEVICE
void end_epilogue() {
visitor_a_op.end_epilogue();
visitor_b_op.end_epilogue();
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,250 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor Op with broadcasting vector to all columns
*/
#pragma once
#include "cutlass/cutlass.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue Visitor operator for the following computation:
///
/// ElementVector T[i][j] <- device-memory Td[i]
///
/// It can only be a leaf node in the epilogue tree
template <
typename ElementAccumulator_, ///< Data type of the Accumulator
typename ElementFragment_, ///< Data type used to cache vector in register
typename InputTileIterator_ ///< Tile iterator type to read the broadcasted tensor
>
class VisitorOpColumnBroadcast {
public:
using InputTileIterator = InputTileIterator_;
static int const kElementsPerAccess = InputTileIterator::kElementsPerAccess;
using ElementAccumulator = ElementAccumulator_;
using ElementVector = typename InputTileIterator::Element;
using ElementFragment = ElementFragment_;
using VisitAccessType = Array<ElementFragment, kElementsPerAccess>;
/// Thread map used by input tile iterators
using ThreadMap = typename InputTileIterator::ThreadMap;
/// Fragment object used to store the broadcast values
using BroadcastFragment = Array<
ElementFragment, kElementsPerAccess>;
/// Fragment type of accumulator
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
/// Used for the broadcast
struct BroadcastDetail {
/// Number of threads per warp
static int const kWarpSize = 32;
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
/// Number of distinct scalar column indices handled by each thread
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
/// Number of distinct scalar row indices handled by each thread
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
/// Number of threads per threadblock
static int const kThreadCount = ThreadMap::kThreads;
/// Number of distinct threads per row of output tile
static int const kThreadsPerRow = (InputTileIterator::Shape::kN / kColumnsPerThread);
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock.
static int const kThreadRows = kThreadCount / kThreadsPerRow;
// /// Number of iterations (accesses) the threadblock takes to reduce a row
// static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount);
};
// using ComputeFragmentType = Array<ElementVector, BroadcastDetail::kElementsPerAccess>;
struct SharedStorage {
CUTLASS_HOST_DEVICE
SharedStorage() { }
};
/// Host-constructable Argument structure
struct Arguments {
ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand
int64_t batch_stride;
/// Methods
CUTLASS_HOST_DEVICE
Arguments():
broadcast_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Arguments(
ElementVector *broadcast_ptr,
int64_t batch_stride
):
broadcast_ptr(broadcast_ptr),
batch_stride(batch_stride) { }
};
/// Param structure
struct Params {
ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand
int64_t batch_stride;
/// Method
CUTLASS_HOST_DEVICE
Params():
broadcast_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
broadcast_ptr(args.broadcast_ptr),
batch_stride(args.batch_stride) { }
};
private:
ElementVector *broadcast_ptr;
BroadcastFragment broadcast_fragment; ///< Array holds the loaded broadcast fragment
MatrixCoord threadblock_offset_;
int thread_idx_;
MatrixCoord problem_size;
int thread_start_row_;
int state_[3];
int thread_offset_row_;
int64_t batch_stride_;
public:
/// Constructs the function object
CUTLASS_HOST_DEVICE
VisitorOpColumnBroadcast(
Params const &params,
SharedStorage &shared_storage,
int thread_idx,
MatrixCoord threadblock_offset,
MatrixCoord problem_size
):
broadcast_ptr(params.broadcast_ptr),
threadblock_offset_(threadblock_offset),
thread_idx_(thread_idx),
problem_size(problem_size),
thread_start_row_(ThreadMap::initial_offset(thread_idx).row() + threadblock_offset.row()),
batch_stride_(params.batch_stride)
{
state_[0] = state_[1] = state_[2] = 0;
}
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
broadcast_ptr += batch_idx * batch_stride_;
}
CUTLASS_DEVICE
void begin_epilogue() { }
CUTLASS_DEVICE
void begin_step(int step_idx) {}
CUTLASS_DEVICE
void begin_row(int row_idx) {}
CUTLASS_DEVICE
VisitAccessType visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum
) {
// get pointer
thread_offset_row_ = thread_start_row_ + ThreadMap::iteration_offset(frag_idx).row();
ElementFragment broadcast_data = ElementFragment(*(broadcast_ptr + thread_offset_row_));
broadcast_fragment.fill(broadcast_data);
return broadcast_fragment;
}
CUTLASS_DEVICE
void end_row(int row_idx) { }
CUTLASS_DEVICE
void end_step(int step_idx) {
// run operator ++
++state_[0];
thread_start_row_ += ThreadMap::Shape::kRow;
if (state_[0] == ThreadMap::Count::kRow) {
state_[0] = 0;
++state_[1];
thread_start_row_ += (ThreadMap::Shape::kGroup - 1) *
ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
if (state_[1] == ThreadMap::Count::kGroup) {
state_[1] = 0;
++state_[2];
thread_start_row_ += ThreadMap::Count::kGroup *
ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;
if (state_[2] == ThreadMap::Count::kCluster) {
state_[2] = 0;
}
}
}
}
CUTLASS_DEVICE
void end_epilogue() { }
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,341 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor Op with reduction over columns in CTA
*/
#pragma once
#include "cutlass/cutlass.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue Visitor operator for the following computation:
///
/// ElementReductionAccumulator R[j] = \sum_i ElementReductionAccumulator(T[i][j])
/// device memory <- ElementReduction(R[j])
///
template <
typename ThreadblockShape_, /// Threadblock shape
typename ElementAccumulator_, ///< Data type of the Accumulator
typename ElementReduction_, ///< Data type of the output reduction in device memory
typename ElementReductionAccumulator_ , ///< Data type to accumulate reduction in smem and register
typename OutputTileIterator_, ///< Tile Iterator type
typename Visitor_ ///< preceding visitor op
>
class VisitorOpColumnReduction {
public:
using ElementAccumulator = ElementAccumulator_;
using ElementReductionAccumulator = ElementReductionAccumulator_;
using ElementReduction = ElementReduction_;
using OutputTileIterator = OutputTileIterator_;
using ThreadblockShape = ThreadblockShape_;
using Visitor = Visitor_;
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
using ReductionOp = cutlass::plus<Array<ElementReductionAccumulator, kElementsPerAccess>>;
using ReductionOpScalar = cutlass::plus<ElementReductionAccumulator>;
using ElementOutput = typename OutputTileIterator::Element;
/// Fragment type returned from Visitor
using VisitAccessTypeVisitor = typename Visitor::VisitAccessType;
using ElementVisitor = typename VisitAccessTypeVisitor::Element;
using VisitAccessType = VisitAccessTypeVisitor;
/// Fragment type of accumulator
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
/// Fragment type of reduction
using ReductionAccumulatorAccessType = Array<ElementReductionAccumulator, kElementsPerAccess>;
/// Thread map used by output tile iterators
using ThreadMap = typename OutputTileIterator::ThreadMap;
/// Used for the reduction
struct ReductionDetail {
/// Number of threads per warp
static int const kWarpSize = 32;
/// Number of distinct scalar column indices handled by each thread
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
/// Number of distinct scalar row indices handled by each thread
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
/// Number of threads per threadblock
static int const kThreadCount = ThreadMap::kThreads;
/// Number of distinct threads per row of output tile
static int const kThreadsPerRow = ThreadblockShape::kN / kColumnsPerThread;
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock
static int const kThreadRows = kThreadCount / kThreadsPerRow;
/// Number of iterations (accesses) the threadblock takes to reduce a row
static int const kThreadAccessesPerRow = const_max(1, (ThreadblockShape::kN + kThreadCount - 1) / kThreadCount);
using StorageShape = MatrixShape<
kThreadRows,
ThreadblockShape::kN
>;
};
using ReductionFragment = Array<ElementReductionAccumulator, ReductionDetail::kColumnsPerThread>;
/// Shared storage
struct SharedStorage {
typename Visitor::SharedStorage storage_visitor;
AlignedArray<ElementReductionAccumulator, ReductionDetail::StorageShape::kCount, 16> reduction;
CUTLASS_HOST_DEVICE
SharedStorage() {}
};
/// Host-constructable Argument structure
struct Arguments {
ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory
int64_t batch_stride;
typename Visitor::Arguments visitor_arg; ///< Argument type of visitor
/// Method
CUTLASS_HOST_DEVICE
Arguments(): reduction_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Arguments(
ElementReduction *reduction_ptr,
int64_t batch_stride,
typename Visitor::Arguments visitor_arg
):
reduction_ptr(reduction_ptr),
batch_stride(batch_stride),
visitor_arg(visitor_arg)
{ }
};
/// Param structure
struct Params {
ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory
int64_t batch_stride;
typename Visitor::Params visitor_param; ///< Argument type of visitor
/// Method
CUTLASS_HOST_DEVICE
Params(): reduction_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
reduction_ptr(args.reduction_ptr),
batch_stride(args.batch_stride),
visitor_param(args.visitor_arg)
{ }
};
private:
ElementReduction *reduction_output_ptr_; ///< Pointer to the reduction tensor in device memory
ElementReductionAccumulator *reduction_smem_ptr_; ///< Pointer to the partial reductions in shared memory
ReductionFragment reduction_fragment; ///< register fragments that hold the partial reduction
Visitor visitor_; ///< visitor
int thread_idx_;
MatrixCoord threadblock_offset;
MatrixCoord problem_size_;
int64_t batch_stride_;
public:
/// Constructs the function object
CUTLASS_HOST_DEVICE
VisitorOpColumnReduction(
Params const &params,
SharedStorage &shared_storage,
int thread_idx,
MatrixCoord threadblock_offset,
MatrixCoord problem_size
):
visitor_(params.visitor_param, shared_storage.storage_visitor,
thread_idx, threadblock_offset, problem_size),
reduction_smem_ptr_(shared_storage.reduction.data()),
reduction_output_ptr_(params.reduction_ptr),
thread_idx_(thread_idx),
threadblock_offset(threadblock_offset),
problem_size_(problem_size),
batch_stride_(params.batch_stride)
{ }
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
reduction_output_ptr_ += batch_idx * batch_stride_;
visitor_.set_batch_index(batch_idx);
}
CUTLASS_DEVICE
void begin_epilogue() {
visitor_.begin_epilogue();
// clear the reduction fragment
reduction_fragment.clear();
}
CUTLASS_DEVICE
void begin_step(int step_idx) {
visitor_.begin_step(step_idx);
}
CUTLASS_DEVICE
void begin_row(int row_idx) {
visitor_.begin_row(row_idx);
}
CUTLASS_DEVICE
VisitAccessType visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum
) {
/// Get result from visitor
VisitAccessTypeVisitor result = visitor_.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
NumericArrayConverter<ElementReductionAccumulator, ElementVisitor, kElementsPerAccess> reduction_converter;
ReductionOp reduction_op;
ReductionAccumulatorAccessType* reduction_fragment_ = reinterpret_cast<ReductionAccumulatorAccessType*>(&reduction_fragment);
reduction_fragment_[column_idx] = reduction_op(reduction_fragment_[column_idx], reduction_converter(result));
return result;
}
CUTLASS_DEVICE
void end_row(int row_idx) {
visitor_.end_row(row_idx);
}
CUTLASS_DEVICE
void end_step(int step_idx) {
visitor_.end_step(step_idx);
}
CUTLASS_DEVICE
void end_epilogue() {
visitor_.end_epilogue();
//
// Store the partially reduced value to SMEM
//
// Guard against uses of the existing SMEM tile
__syncthreads();
using AccessType = AlignedArray<ElementReductionAccumulator, ThreadMap::kElementsPerAccess>;
//
// Determine a compact thread arrangement to store to SMEM
//
MatrixCoord thread_offset(
thread_idx_ / ReductionDetail::kThreadsPerRow,
(thread_idx_ % ReductionDetail::kThreadsPerRow) * ThreadMap::kElementsPerAccess
);
//
// Each thread store its fragment to a SMEM
//
AccessType *aligned_reduction_ptr = reinterpret_cast<AccessType *>(
&reduction_smem_ptr_[thread_offset.row() * ThreadblockShape::kN + thread_offset.column()]
);
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(
&reduction_fragment
);
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
int col_idx = column * ThreadMap::Delta::kColumn / ThreadMap::kElementsPerAccess;
aligned_reduction_ptr[col_idx] = frag_ptr[column];
}
__syncthreads();
//
// Now, threads are assigned several columns of the output. The fetch over all rows from
// the compacted SMEM tile and perform a reduction.
//
NumericConverter<ElementReduction, ElementReductionAccumulator> output_converter;
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < ReductionDetail::kThreadAccessesPerRow; ++j) {
int column_idx = thread_idx_ + j * ReductionDetail::kThreadCount;
ReductionOpScalar reduction_op;
ElementReductionAccumulator reduction_element = ElementReductionAccumulator();
int output_column_idx = threadblock_offset.column() + column_idx;
if (column_idx < ThreadblockShape::kN && output_column_idx < problem_size_.column()) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ReductionDetail::kThreadRows; ++row) {
if (row) {
auto frag = reduction_smem_ptr_[row * ThreadblockShape::kN + column_idx];
reduction_element = reduction_op(reduction_element, frag);
}
else {
reduction_element = reduction_smem_ptr_[column_idx];
}
}
// Store
reduction_output_ptr_[column_idx + threadblock_offset.column() + threadblock_offset.row() / ThreadblockShape::kM * problem_size_.column()] = output_converter(reduction_element);
}
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,266 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor Op with Linear Combination
*/
#pragma once
#include "cutlass/cutlass.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue Visitor operator for the following computation:
///
/// ElementCompute alpha;
/// ElementCompute beta;
/// ElementCompute C = BinaryOp(alpha * ElementCompute(Visitor_A), beta * ElementCompute(Visitor_B)
/// Return C;
///
template <
typename ElementAccumulator_, ///< Data type of the Accumulator
typename ElementCompute_, ///< Data type used to compute linear combination
int kElementsPerAccess_, ///< Number of elements computed per operation
typename VisitorA_, ///< Child node A
typename VisitorB_ ///< Child node B
>
class VisitorOpLinearCombination{
public:
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
static int const kElementsPerAccess = kElementsPerAccess_;
using VisitorA = VisitorA_;
using VisitorB = VisitorB_;
/// Fragment type returned from VisitorA.visit
using VisitAccessTypeA = typename VisitorA::VisitAccessType;
using ElementA = typename VisitAccessTypeA::Element;
/// Fragment type returned from VisitorB.visit
using VisitAccessTypeB = typename VisitorB::VisitAccessType;
using ElementB = typename VisitAccessTypeB::Element;
/// Fragment type returned by this visitor
using VisitAccessType = Array<ElementCompute, kElementsPerAccess>;
/// Fragment type of accumulator
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
/// Combination Op
using CombinationOp = cutlass::plus<VisitAccessType>;
static_assert(kElementsPerAccess==VisitAccessTypeA::kElements, "kElementsPerAccess mismatches with Visitor A");
static_assert(kElementsPerAccess==VisitAccessTypeB::kElements, "kElementsPerAccess mismatches with Visitor B");
/// SMEM buffer class required in the epilogue visitor
struct SharedStorage {
typename VisitorA::SharedStorage storage_a;
typename VisitorB::SharedStorage storage_b;
CUTLASS_HOST_DEVICE
SharedStorage() {}
};
/// Host-constructable Arguments structure
struct Arguments {
ElementCompute alpha; ///< scales accumulators
ElementCompute beta; ///< scales source tensor
typename VisitorA::Arguments visitor_a_arg; ///< Argument type for visitor_a
typename VisitorB::Arguments visitor_b_arg; ///< Argument type for visitor_b
//
// Methods
//
CUTLASS_HOST_DEVICE
Arguments():
alpha(ElementCompute(1)),
beta(ElementCompute(0))
{ }
CUTLASS_HOST_DEVICE
Arguments(
ElementCompute alpha,
ElementCompute beta,
typename VisitorA::Arguments visitor_a_arg,
typename VisitorB::Arguments visitor_b_arg
):
alpha(alpha),
beta(beta),
visitor_a_arg(visitor_a_arg),
visitor_b_arg(visitor_b_arg)
{ }
};
/// Parameter structure
struct Params {
ElementCompute alpha; ///< scales accumulators
ElementCompute beta; ///< scales source tensor
typename VisitorA::Params visitor_a_param; ///< Argument type for visitor_a
typename VisitorB::Params visitor_b_param; ///< Argument type for visitor_b
//
// Methods
//
CUTLASS_HOST_DEVICE
Params() { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
alpha(args.alpha),
beta(args.beta),
visitor_a_param(args.visitor_a_arg),
visitor_b_param(args.visitor_b_arg)
{ }
};
private:
//
// Data members
//
ElementCompute alpha_;
ElementCompute beta_;
VisitorA visitor_a_op;
VisitorB visitor_b_op;
public:
/// Constructs the function object
CUTLASS_HOST_DEVICE
VisitorOpLinearCombination(
Params const &params,
SharedStorage &shared_storage,
int thread_idx,
MatrixCoord threadblock_offset,
MatrixCoord problem_size
):
alpha_(params.alpha),
beta_(params.beta),
visitor_a_op(params.visitor_a_param, shared_storage.storage_a, thread_idx, threadblock_offset, problem_size),
visitor_b_op(params.visitor_b_param, shared_storage.storage_b, thread_idx, threadblock_offset, problem_size)
{ }
CUTLASS_DEVICE
void begin_epilogue() {
if (alpha_ != ElementCompute(0)) visitor_a_op.begin_epilogue();
if (beta_ != ElementCompute(0)) visitor_b_op.begin_epilogue();
}
CUTLASS_DEVICE
void begin_step(int step_idx) {
if (alpha_ != ElementCompute(0)) visitor_a_op.begin_step(step_idx);
if (beta_ != ElementCompute(0)) visitor_b_op.begin_step(step_idx);
}
CUTLASS_DEVICE
void begin_row(int row_idx) {
if (alpha_ != ElementCompute(0)) visitor_a_op.begin_row(row_idx);
if (beta_ != ElementCompute(0)) visitor_b_op.begin_row(row_idx);
}
CUTLASS_DEVICE
VisitAccessType visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum
) {
/// Get result from visitor A and visitor B
VisitAccessTypeA result_A;
VisitAccessTypeB result_B;
if (alpha_ != ElementCompute(0)) {
result_A = visitor_a_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
} else {
// Fill the result A with zeros
result_A.clear();
}
if (beta_ != ElementCompute(0)) {
result_B = visitor_b_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
} else {
// Fill the result B with zeros
result_B.clear();
}
/// Type conversion
NumericArrayConverter<ElementCompute, ElementA, kElementsPerAccess> source_converter_A;
NumericArrayConverter<ElementCompute, ElementB, kElementsPerAccess> source_converter_B;
CombinationOp combination_op;
cutlass::multiplies<VisitAccessType> multiply_op;
return combination_op(
multiply_op(alpha_, source_converter_A(result_A)),
multiply_op(beta_, source_converter_B(result_B))
);
}
CUTLASS_DEVICE
void end_row(int row_idx) {
if (alpha_ != ElementCompute(0)) visitor_a_op.end_row(row_idx);
if (beta_ != ElementCompute(0)) visitor_b_op.end_row(row_idx);
}
CUTLASS_DEVICE
void end_step(int step_idx) {
if (alpha_ != ElementCompute(0)) visitor_a_op.end_step(step_idx);
if (beta_ != ElementCompute(0)) visitor_b_op.end_step(step_idx);
}
CUTLASS_DEVICE
void end_epilogue() {
if (alpha_ != ElementCompute(0)) visitor_a_op.end_epilogue();
if (beta_ != ElementCompute(0)) visitor_b_op.end_epilogue();
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,258 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor Op with broadcasting vector to all rows
*/
#pragma once
#include "cutlass/cutlass.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue Visitor operator for the following computation:
///
/// ElementVector T[i][j] <- device-memory Td[j]
///
/// It can only be a leaf node in the epilogue tree
template <
typename ElementAccumulator_, ///< Data type of the Accumulator
typename ElementFragment_, ///< Data type used to cache vector in register
typename InputTileIterator_ ///< Tile iterator type to read the broadcasted tensor
>
class VisitorOpRowBroadcast {
public:
using InputTileIterator = InputTileIterator_;
static int const kElementsPerAccess = InputTileIterator::kElementsPerAccess;
using ElementAccumulator = ElementAccumulator_;
using ElementVector = typename InputTileIterator::Element;
using ElementFragment = ElementFragment_;
using VisitAccessType = Array<ElementFragment, kElementsPerAccess>;
/// Thread map used by input tile iterators
using ThreadMap = typename InputTileIterator::ThreadMap;
/// Fragment object used to store the broadcast values
using BroadcastFragment = Array<
ElementFragment,
ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>;
/// Fragment type of accumulator
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
/// Used for the broadcast
struct BroadcastDetail {
/// Number of threads per warp
static int const kWarpSize = 32;
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
/// Number of distinct scalar column indices handled by each thread
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
/// Number of distinct scalar row indices handled by each thread
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
/// Number of threads per threadblock
static int const kThreadCount = ThreadMap::kThreads;
/// Number of distinct threads per row of output tile
static int const kThreadsPerRow = (InputTileIterator::Shape::kN / kColumnsPerThread);
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock.
static int const kThreadRows = kThreadCount / kThreadsPerRow;
// /// Number of iterations (accesses) the threadblock takes to reduce a row
// static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount);
};
// using ComputeFragmentType = Array<ElementVector, BroadcastDetail::kElementsPerAccess>;
struct SharedStorage {
CUTLASS_HOST_DEVICE
SharedStorage() { }
};
/// Host-constructable Argument structure
struct Arguments {
ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand
int64_t batch_stride;
/// Methods
CUTLASS_HOST_DEVICE
Arguments():
broadcast_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Arguments(
ElementVector *broadcast_ptr,
int64_t batch_stride
):
broadcast_ptr(broadcast_ptr),
batch_stride(batch_stride) { }
};
/// Param structure
struct Params {
ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand
int64_t batch_stride;
/// Method
CUTLASS_HOST_DEVICE
Params():
broadcast_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
broadcast_ptr(args.broadcast_ptr),
batch_stride(args.batch_stride) { }
};
private:
ElementVector *broadcast_ptr;
BroadcastFragment broadcast_fragment; ///< Array holds the loaded broadcast fragment
MatrixCoord threadblock_offset_;
int thread_idx_;
MatrixCoord problem_size;
int64_t batch_stride_;
public:
/// Constructs the function object
CUTLASS_HOST_DEVICE
VisitorOpRowBroadcast(
Params const &params,
SharedStorage &shared_storage,
int thread_idx,
MatrixCoord threadblock_offset,
MatrixCoord problem_size
):
broadcast_ptr(params.broadcast_ptr + threadblock_offset.column()),
threadblock_offset_(threadblock_offset),
thread_idx_(thread_idx),
problem_size(problem_size),
batch_stride_(params.batch_stride) { }
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
broadcast_ptr += batch_idx * batch_stride_;
}
CUTLASS_DEVICE
void begin_epilogue() {
// load broadcast fragment
load_broadcast_fragment_();
}
CUTLASS_DEVICE
void begin_step(int step_idx) {}
CUTLASS_DEVICE
void begin_row(int row_idx) {}
CUTLASS_DEVICE
VisitAccessType visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum
) {
VisitAccessType* broadcast_fragment_ = reinterpret_cast<VisitAccessType*>(&broadcast_fragment);
return broadcast_fragment_[column_idx];
}
CUTLASS_DEVICE
void end_row(int row_idx) { }
CUTLASS_DEVICE
void end_step(int step_idx) { }
CUTLASS_DEVICE
void end_epilogue() { }
private:
CUTLASS_DEVICE
void load_broadcast_fragment_() {
broadcast_fragment.clear();
// If no pointer is supplied, set with all zeros and avoid memory accesses
if (!broadcast_ptr) {
return;
}
int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column();
int thread_column_idx = threadblock_offset_.column() + thread_initial_column;
broadcast_ptr += thread_initial_column;
NumericArrayConverter<ElementFragment, ElementVector, BroadcastDetail::kElementsPerAccess> converter;
using AccessType = AlignedArray<ElementVector, BroadcastDetail::kElementsPerAccess>;
using AccessFragmentType = Array<ElementFragment, BroadcastDetail::kElementsPerAccess>;
AccessFragmentType *frag_ptr = reinterpret_cast<AccessFragmentType *>(&broadcast_fragment);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) {
AccessType loaded;
loaded.clear();
if (thread_column_idx < problem_size.column()) {
loaded = *reinterpret_cast<AccessType const *>(broadcast_ptr);
}
AccessFragmentType cvt = converter(loaded);
frag_ptr[j] = cvt;
thread_column_idx += ThreadMap::Delta::kColumn;
broadcast_ptr += ThreadMap::Delta::kColumn;
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,319 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor Op with reduction over rows in CTA
*/
#pragma once
#include "cutlass/cutlass.h"
#include "stdio.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue Visitor operator for the following computation:
///
/// ElementReductionAccumulator R[i] = \sum_i ElementReductionAccumulator(T[i][j])
/// device memory <- ElementReduction(R[i])
///
template <
typename ThreadblockShape_, /// Threadblock shape
typename ElementAccumulator_, ///< Data type of the Accumulator
typename ElementReduction_, ///< Data type of the output reduction in device memory
typename ElementReductionAccumulator_ , ///< Data type to accumulate reduction in smem and register
typename OutputTileIterator_, ///< Tile Iterator type
typename Visitor_ ///< preceding visitor op
>
class VisitorOpRowReduction {
public:
using ElementAccumulator = ElementAccumulator_;
using ElementReductionAccumulator = ElementReductionAccumulator_;
using ElementReduction = ElementReduction_;
using OutputTileIterator = OutputTileIterator_;
using ThreadblockShape = ThreadblockShape_;
using Visitor = Visitor_;
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
using ReductionOp = cutlass::plus<Array<ElementReductionAccumulator, kElementsPerAccess>>;
using ReductionOpScalar = cutlass::plus<ElementReductionAccumulator>;
using ElementOutput = typename OutputTileIterator::Element;
/// Fragment type returned from Visitor
using VisitAccessTypeVisitor = typename Visitor::VisitAccessType;
using ElementVisitor = typename VisitAccessTypeVisitor::Element;
using VisitAccessType = VisitAccessTypeVisitor;
/// Fragment type of accumulator
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
/// Fragment type of reduction
using ReductionAccumulatorAccessType = Array<ElementReductionAccumulator, kElementsPerAccess>;
/// Thread map used by output tile iterators
using ThreadMap = typename OutputTileIterator::ThreadMap;
/// Used for the reduction
struct ReductionDetail {
/// Number of threads per warp
static int const kWarpSize = 32;
/// Number of distinct scalar column indices handled by each thread
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
/// Number of distinct scalar row indices handled by each thread
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
/// Number of threads per threadblock
static int const kThreadCount = ThreadMap::kThreads;
/// Number of distinct threads per row of output tile
static int const kThreadsPerRow = ThreadblockShape::kN / kColumnsPerThread;
/// Half number of threads per row used for cross-thread reduction
static int const kHalfThreadsPerRow = (kThreadsPerRow >> 1);
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock
static int const kThreadRows = kThreadCount / kThreadsPerRow;
};
/// Shared storage
struct SharedStorage {
typename Visitor::SharedStorage storage_visitor;
CUTLASS_HOST_DEVICE
SharedStorage() { }
};
/// Host-constructable Argument structure
struct Arguments {
ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory
int64_t batch_stride;
typename Visitor::Arguments visitor_arg; ///< Argument type of visitor
/// Method
CUTLASS_HOST_DEVICE
Arguments(): reduction_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Arguments(
ElementReduction *reduction_ptr,
int64_t batch_stride,
typename Visitor::Arguments visitor_arg
):
reduction_ptr(reduction_ptr),
batch_stride(batch_stride),
visitor_arg(visitor_arg)
{ }
};
/// Param structure
struct Params {
ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory
int64_t batch_stride;
typename Visitor::Params visitor_param; ///< Argument type of visitor
/// Method
CUTLASS_HOST_DEVICE
Params(): reduction_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
reduction_ptr(args.reduction_ptr),
batch_stride(args.batch_stride),
visitor_param(args.visitor_arg)
{ }
};
private:
ElementReduction *reduction_output_ptr_; ///< Pointer to the reduction tensor in device memory
ElementReductionAccumulator reduction_accum;
Visitor visitor_; ///< visitor
int thread_idx_;
MatrixCoord threadblock_offset;
MatrixCoord problem_size_;
int thread_start_row_; /// used to identify
int state_[3]; /// used to track row iterator
int thread_offset_row_;
int64_t batch_stride_;
public:
/// Constructs the function object
CUTLASS_HOST_DEVICE
VisitorOpRowReduction(
Params const &params,
SharedStorage &shared_storage,
int thread_idx,
MatrixCoord threadblock_offset,
MatrixCoord problem_size
):
visitor_(params.visitor_param, shared_storage.storage_visitor,
thread_idx, threadblock_offset, problem_size),
reduction_output_ptr_(params.reduction_ptr),
thread_idx_(thread_idx),
threadblock_offset(threadblock_offset),
problem_size_(problem_size),
thread_start_row_(ThreadMap::initial_offset(thread_idx).row() + threadblock_offset.row()),
batch_stride_(params.batch_stride)
{
state_[0] = state_[1] = state_[2] = 0;
}
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
reduction_output_ptr_ += batch_idx * batch_stride_;
visitor_.set_batch_index(batch_idx);
}
CUTLASS_DEVICE
void begin_epilogue() {
visitor_.begin_epilogue();
}
CUTLASS_DEVICE
void begin_step(int step_idx) {
visitor_.begin_step(step_idx);
}
CUTLASS_DEVICE
void begin_row(int row_idx) {
visitor_.begin_row(row_idx);
reduction_accum = ElementReductionAccumulator(0);
}
CUTLASS_DEVICE
VisitAccessType visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum
) {
/// Get result from visitor
VisitAccessTypeVisitor result = visitor_.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
thread_offset_row_ = thread_start_row_ + ThreadMap::iteration_offset(frag_idx).row();
ReductionOpScalar reduction_op;
ElementReductionAccumulator reduction_accum_ = reduction(result);
// After performing the in-thread reduction, we then perform cross-thread / in-warp reduction
CUTLASS_PRAGMA_UNROLL
for (int i = ReductionDetail::kHalfThreadsPerRow; i > 0; i >>= 1) {
reduction_accum_ = reduction_op(reduction_accum_, __shfl_xor_sync(0xFFFFFFFF, reduction_accum_, i));
}
reduction_accum = reduction_op(reduction_accum, reduction_accum_);
return result;
}
CUTLASS_DEVICE
void end_row(int row_idx) {
visitor_.end_row(row_idx);
NumericConverter<ElementReduction, ElementReductionAccumulator> output_converter;
bool is_write_thread = (thread_offset_row_ < problem_size_.row() && (thread_idx_ % ReductionDetail::kThreadsPerRow) == 0);
int row_offset = thread_offset_row_ + threadblock_offset.column() / ThreadblockShape::kN * problem_size_.row();
ElementReduction *curr_ptr_reduction = reduction_output_ptr_ + row_offset;
arch::global_store<ElementReduction, sizeof(ElementReduction)>(
output_converter(reduction_accum),
(void *)curr_ptr_reduction,
is_write_thread);
}
CUTLASS_DEVICE
void end_step(int step_idx) {
visitor_.end_step(step_idx);
// run operator ++
++state_[0];
thread_start_row_ += ThreadMap::Shape::kRow;
if (state_[0] == ThreadMap::Count::kRow) {
state_[0] = 0;
++state_[1];
thread_start_row_ += (ThreadMap::Shape::kGroup - 1) *
ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
if (state_[1] == ThreadMap::Count::kGroup) {
state_[1] = 0;
++state_[2];
thread_start_row_ += ThreadMap::Count::kGroup *
ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;
if (state_[2] == ThreadMap::Count::kCluster) {
state_[2] = 0;
}
}
}
}
CUTLASS_DEVICE
void end_epilogue() {
visitor_.end_epilogue();
}
private:
CUTLASS_DEVICE
ElementReductionAccumulator reduction(VisitAccessTypeVisitor const& result) {
ElementReductionAccumulator sum_ = ElementReductionAccumulator(0);
ReductionOpScalar reduction_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < VisitAccessTypeVisitor::kElements; ++i) {
sum_ = reduction_op(sum_, result[i]);
}
return sum_;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,188 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor Op with Tensor Output
*/
#pragma once
#include "cutlass/cutlass.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue Visitor operator for the following computation:
///
/// ElementInput C <- device memory
///
/// It can only be a leaf node in the epilogue tree
template <
typename ElementAccumulator_, ///< Data type of the Accumulator
typename InputTileIterator_ ///< Tile iterator type to read the tensor
>
class VisitorOpTensorInput {
public:
using ElementAccumulator = ElementAccumulator_;
using InputTileIterator = InputTileIterator_;
static int const kElementsPerAccess = InputTileIterator::kElementsPerAccess;
using ElementInput = typename InputTileIterator::Element;
using VisitAccessType = Array<ElementInput, kElementsPerAccess>;
/// Fragment type of accumulator
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
struct SharedStorage {
CUTLASS_HOST_DEVICE
SharedStorage() { }
};
/// Host-constructable Argument structure
struct Arguments {
ElementInput *input_ptr; ///< Pointer to the input tensor in device memory
int ldt; ///< Leading dimension of the input tensor operand
int64_t batch_stride; ///< batch stride for batched GEMM
/// Methods
CUTLASS_HOST_DEVICE
Arguments(): input_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Arguments(
ElementInput *input_ptr,
int ldt, int64_t batch_stride
):
input_ptr(input_ptr),
ldt(ldt),
batch_stride(batch_stride)
{ }
};
/// Param structure
struct Params {
typename InputTileIterator::Params params_input;
ElementInput *input_ptr;
int64_t batch_stride;
/// Method
CUTLASS_HOST_DEVICE
Params():
input_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
params_input(args.ldt),
input_ptr(args.input_ptr),
batch_stride(args.batch_stride)
{ }
};
private:
InputTileIterator iterator_T_;
typename InputTileIterator::Fragment fragment_T_;
MatrixCoord problem_size;
int64_t batch_stride_;
public:
/// Constructs the function object
CUTLASS_HOST_DEVICE
VisitorOpTensorInput(
Params const &params,
SharedStorage &shared_storage,
int thread_idx,
MatrixCoord threadblock_offset,
MatrixCoord problem_size
):
iterator_T_(
InputTileIterator(
params.params_input,
params.input_ptr,
problem_size,
thread_idx,
threadblock_offset
)
),
problem_size(problem_size),
batch_stride_(params.batch_stride) { }
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
iterator_T_.add_pointer_offset(batch_idx * batch_stride_);
}
CUTLASS_DEVICE
void begin_epilogue() { }
CUTLASS_DEVICE
void begin_step(int step_idx) {
fragment_T_.clear();
iterator_T_.load(fragment_T_);
++iterator_T_;
}
CUTLASS_DEVICE
void begin_row(int row_idx) { }
CUTLASS_DEVICE
VisitAccessType visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum
) {
VisitAccessType source = reinterpret_cast<VisitAccessType *>(&fragment_T_)[frag_idx];
return source;
}
CUTLASS_DEVICE
void end_row(int row_idx) { }
CUTLASS_DEVICE
void end_step(int step_idx) { }
CUTLASS_DEVICE
void end_epilogue() { }
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,240 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor Op with Tensor Output
*/
#pragma once
#include "cutlass/cutlass.h"
#include "stdio.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue Visitor operator for the following computation:
///
/// ElementOutput T = ElementOutput(Visitor)
/// T-> device memory
///
template <
typename ElementAccumulator_, ///< Data type of the Accumulator
typename OutputTileIterator_, ///< Tile iterator type to write the tensor
typename Visitor_ ///< Child visitor that produces the output tensor
>
class VisitorOpTensorOutput {
public:
using ElementAccumulator = ElementAccumulator_;
using OutputTileIterator = OutputTileIterator_;
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
using ElementOutput = typename OutputTileIterator::Element;
using Visitor = Visitor_;
/// Fragment type returned from Visitor
using VisitAccessTypeVisitor = typename Visitor::VisitAccessType;
using ElementVisitor = typename VisitAccessTypeVisitor::Element;
using VisitAccessType = VisitAccessTypeVisitor;
/// Fragment type of accumulator
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
/// Fragment type of output
using OutputAccessType = Array<ElementOutput, kElementsPerAccess>;
static_assert(kElementsPerAccess==VisitAccessTypeVisitor::kElements, "kElementsPerAccess mismatches with Visitor");
struct SharedStorage {
typename Visitor::SharedStorage storage_visitor;
CUTLASS_HOST_DEVICE
SharedStorage() { }
};
/// Host-constructable Argument structure
struct Arguments {
ElementOutput *output_ptr; ///< Pointer to the output tensor in device memory
int ldt; ///< Leading dimension of the output tensor operand
int64_t batch_stride; ///< batch stride
typename Visitor::Arguments visitor_arg; ///< Argument type of visitor
/// Methods
CUTLASS_HOST_DEVICE
Arguments(): output_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Arguments(
ElementOutput *output_ptr,
int ldt,
int64_t batch_stride,
typename Visitor::Arguments visitor_arg
):
output_ptr(output_ptr),
ldt(ldt),
batch_stride(batch_stride),
visitor_arg(visitor_arg)
{ }
};
/// Param structure
struct Params {
typename OutputTileIterator::Params params_output;
ElementOutput *output_ptr;
int64_t batch_stride;
typename Visitor::Params visitor_param;
/// Method
CUTLASS_HOST_DEVICE
Params():
output_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
params_output(args.ldt),
output_ptr(args.output_ptr),
batch_stride(args.batch_stride),
visitor_param(args.visitor_arg)
{ }
};
private:
OutputTileIterator iterator_T_;
typename OutputTileIterator::Fragment fragment_T_;
MatrixCoord problem_size;
Visitor visitor_;
int64_t batch_stride_;
public:
/// Constructs the function object
CUTLASS_HOST_DEVICE
VisitorOpTensorOutput(
Params const &params,
SharedStorage &shared_storage,
int thread_idx,
MatrixCoord threadblock_offset,
MatrixCoord problem_size
):
visitor_(params.visitor_param, shared_storage.storage_visitor, thread_idx, threadblock_offset, problem_size),
iterator_T_(
OutputTileIterator(
params.params_output,
params.output_ptr,
problem_size,
thread_idx,
threadblock_offset
)
),
problem_size(problem_size),
batch_stride_(params.batch_stride) { }
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
iterator_T_.add_pointer_offset(batch_idx * batch_stride_);
visitor_.set_batch_index(batch_idx);
}
CUTLASS_DEVICE
void begin_epilogue() {
visitor_.begin_epilogue();
}
CUTLASS_DEVICE
void begin_step(int step_idx) {
fragment_T_.clear();
visitor_.begin_step(step_idx);
}
CUTLASS_DEVICE
void begin_row(int row_idx) {
visitor_.begin_row(row_idx);
}
CUTLASS_DEVICE
VisitAccessType visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum
) {
/// Get result from visitor
VisitAccessTypeVisitor result = visitor_.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
// Column guard
MatrixCoord thread_offset_ = iterator_T_.thread_start() + OutputTileIterator::ThreadMap::iteration_offset(frag_idx);
bool column_guard = (thread_offset_.column() < problem_size.column());
if (column_guard) {
NumericArrayConverter<ElementOutput, ElementVisitor, kElementsPerAccess> output_converter;
OutputAccessType &output = reinterpret_cast<OutputAccessType *>(&fragment_T_)[frag_idx];
output = output_converter(result);
}
return result;
}
CUTLASS_DEVICE
void end_row(int row_idx) {
visitor_.end_row(row_idx);
}
CUTLASS_DEVICE
void end_step(int step_idx) {
visitor_.end_step(step_idx);
iterator_T_.store(fragment_T_);
++iterator_T_;
}
CUTLASS_DEVICE
void end_epilogue() {
visitor_.end_epilogue();
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,226 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor Op with Unary operation
*/
#pragma once
#include "cutlass/cutlass.h"
#include "unary_ops.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue Visitor operator for the following computation:
///
/// ElementCompute alpha;
/// ElementCompute beta;
/// ElementCompute C = UnaryOp(ElementCompute(Visitor))
/// Return C;
///
template <
typename ElementAccumulator_, ///< Data type of the Accumulator
typename ElementCompute_, ///< Data type used to compute linear combination
int kElementsPerAccess_, ///< Number of elements computed per operation
typename Visitor_, ///< Child node
template<typename T, int N> typename UnaryOp_
>
class VisitorOpUnary{
public:
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
static int const kElementsPerAccess = kElementsPerAccess_;
using Visitor = Visitor_;
/// Fragment type returned from Visitor.visit
using VisitAccessTypeVisitor = typename Visitor::VisitAccessType;
using ElementVisit = typename VisitAccessTypeVisitor::Element;
/// Fragment type returned by this visitor
using VisitAccessType = Array<ElementCompute, kElementsPerAccess>;
/// Fragment type of accumulator
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
/// Combination Op
using UnaryOp = UnaryOp_<ElementCompute, kElementsPerAccess>;
static_assert(kElementsPerAccess==VisitAccessTypeVisitor::kElements, "kElementsPerAccess mismatches with Visitor");
/// SMEM buffer class required in the epilogue visitor
struct SharedStorage {
typename Visitor::SharedStorage storage_visitor;
CUTLASS_HOST_DEVICE
SharedStorage() {}
};
/// Host-constructable Arguments structure
struct Arguments {
typename UnaryOp::Arguments unary_arg;
typename Visitor::Arguments visitor_arg; ///< Argument type for visitor
//
// Methods
//
CUTLASS_HOST_DEVICE
Arguments():unary_arg() { }
CUTLASS_HOST_DEVICE
Arguments(
typename UnaryOp::Arguments unary_arg,
typename Visitor::Arguments visitor_arg
):
unary_arg(unary_arg),
visitor_arg(visitor_arg)
{ }
};
/// Parameter structure
struct Params {
typename UnaryOp::Params unary_param;
typename Visitor::Params visitor_param; ///< Argument type for visitor
//
// Methods
//
CUTLASS_HOST_DEVICE
Params():unary_param() { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
unary_param(args.unary_arg),
visitor_param(args.visitor_arg)
{ }
};
private:
//
// Data members
//
UnaryOp unary_op;
Visitor visitor_op;
public:
/// Constructs the function object
CUTLASS_HOST_DEVICE
VisitorOpUnary(
Params const &params,
SharedStorage &shared_storage,
int thread_idx,
MatrixCoord threadblock_offset,
MatrixCoord problem_size
):
unary_op(params.unary_param),
visitor_op(params.visitor_param, shared_storage.storage_visitor, thread_idx, threadblock_offset, problem_size)
{ }
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
visitor_op.set_batch_index(batch_idx);
}
CUTLASS_DEVICE
void begin_epilogue() {
if (unary_op.guard()) visitor_op.begin_epilogue();
}
CUTLASS_DEVICE
void begin_step(int step_idx) {
if (unary_op.guard()) visitor_op.begin_step(step_idx);
}
CUTLASS_DEVICE
void begin_row(int row_idx) {
if (unary_op.guard()) visitor_op.begin_row(row_idx);
}
CUTLASS_DEVICE
VisitAccessType visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum
) {
/// Get result from visitor A and visitor B
VisitAccessTypeVisitor result;
if (unary_op.guard()){
result = visitor_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
} else {
result.clear();
}
/// Type conversion
NumericArrayConverter<ElementCompute, ElementVisit, kElementsPerAccess> source_converter;
cutlass::multiplies<VisitAccessType> multiply_op;
return unary_op(source_converter(result));
}
CUTLASS_DEVICE
void end_row(int row_idx) {
if (unary_op.guard()) visitor_op.end_row(row_idx);
}
CUTLASS_DEVICE
void end_step(int step_idx) {
if (unary_op.guard()) visitor_op.end_step(step_idx);
}
CUTLASS_DEVICE
void end_epilogue() {
if (unary_op.guard()) visitor_op.end_epilogue();
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,480 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue visitor type used for partial computation of a layernorm operation
GemmLayernorm example = GEMM0 with partial reduction fused in epilogue (EpilogueVisitorLayerNorm)
+ lightweight full reduction kernel (ApplyFinalReduction)
+ GEMM1 with elementwise operations fused in mainloop (GemmLayernormMainloopFusion)
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass/cutlass.h"
#include "cutlass/arch/memory.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/gemm/kernel/default_gemm.h"
#include "cutlass/gemm/kernel/default_gemm_complex.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename ThreadblockShape_,
int ThreadCount,
typename OutputTileIterator_,
typename AccumulatorTile_,
typename ElementAccumulator_,
typename ElementVariance_,
typename ElementMean_,
typename ElementLayernormCompute_,
typename ElementwiseFunctor_,
bool IsShiftedVariance_ = false
>
class EpilogueVisitorLayerNorm {
public:
using ElementVariance = ElementVariance_;
using ElementMean = ElementMean_;
using ElementLayernormCompute = ElementLayernormCompute_;
using AccumulatorTile = AccumulatorTile_;
using ThreadblockShape = ThreadblockShape_;
static int const kThreadCount = ThreadCount;
using OutputTileIterator = OutputTileIterator_;
using ElementwiseFunctor = ElementwiseFunctor_;
static int const kIterations = OutputTileIterator::kIterations;
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
static int const kRowIterations = OutputTileIterator::ThreadMap::Iterations::kRow;
static int const kThreads = OutputTileIterator::ThreadMap::kThreads;
static bool const kIsShiftedVariance = IsShiftedVariance_;
using ElementOutput = typename OutputTileIterator::Element;
static int const kDeltaRow = OutputTileIterator::ThreadMap::Delta::kRow;
/// Array type used in Shift-K Layernorm
static int const kRowAccessCount = kIterations * kRowIterations;
using ConvertedShiftFragment = Array<ElementLayernormCompute, kRowAccessCount>;
// Conducts manual transpose externally (already supported) for column major
using LayoutOutput = cutlass::layout::RowMajor;
using ElementAccumulator = ElementAccumulator_;
using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
using LayernormFragment = Array<ElementLayernormCompute, kElementsPerAccess>;
using OutputVector = Array<ElementOutput, kElementsPerAccess>;
using TensorRefD = TensorRef<ElementOutput, LayoutOutput>;
static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth;
static int const kThreadsInColumn = kThreads / kThreadsPerRow;
static int const kHalfThreadsPerRow = (kThreadsPerRow >> 1);
/// Argument structure
struct Arguments {
typename ElementwiseFunctor::Params elementwise;
ElementVariance *ptr_Variance;
ElementMean *ptr_Mean;
ElementOutput *ptr_Shifted_K;
MatrixCoord extent;
//
// Methods
//
Arguments():
ptr_Variance(nullptr),
ptr_Mean(nullptr),
ptr_Shifted_K(nullptr)
{
}
Arguments(
typename ElementwiseFunctor::Params elementwise_,
ElementVariance *ptr_Variance,
ElementMean *ptr_Mean_,
ElementOutput *ptr_Shifted_K_ = nullptr,
MatrixCoord extent = MatrixCoord(0, 0)
):
elementwise(elementwise_),
ptr_Variance(ptr_Variance),
ptr_Mean(ptr_Mean_),
ptr_Shifted_K(ptr_Shifted_K_),
extent(extent)
{
}
};
struct Params {
typename ElementwiseFunctor::Params elementwise;
ElementVariance *ptr_Variance;
ElementMean *ptr_Mean;
ElementOutput *ptr_Shifted_K;
MatrixCoord extent;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params():
ptr_Variance(nullptr),
ptr_Mean(nullptr)
{
}
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
elementwise(args.elementwise),
ptr_Variance(args.ptr_Variance),
ptr_Mean(args.ptr_Mean),
ptr_Shifted_K(args.ptr_Shifted_K),
extent(args.extent)
{
}
};
/// Shared storage
struct SharedStorage {
};
private:
Params const & params_;
SharedStorage & shared_storage_;
MatrixCoord extent_;
ElementwiseFunctor elementwise_;
OutputTileIterator iterator_C_;
OutputTileIterator iterator_D_;
typename OutputTileIterator::Fragment fragment_C_;
typename OutputTileIterator::Fragment fragment_D_;
ElementAccumulator alpha_;
ElementAccumulator beta_;
ConvertedShiftFragment shift_k_frag_;
ElementLayernormCompute accum_sum_square_;
ElementLayernormCompute accum_sum_element_;
int thread_idx_;
MatrixCoord thread_offset_;
gemm::GemmCoord threadblock_tile_offset_;
public:
CUTLASS_DEVICE
EpilogueVisitorLayerNorm(
Params const &params, ///< Parameters routed to the epilogue
SharedStorage &shared_storage, ///< Shared storage needed by the functors here
MatrixCoord threadblock_offset,
gemm::GemmCoord threadblock_tile_offset,
int thread_idx,
OutputTileIterator destination_iterator, ///< Tile iterator for destination
OutputTileIterator source_iterator ///< Threadblock tile coordinate in GEMMM
):
params_(params),
shared_storage_(shared_storage),
elementwise_(params.elementwise),
extent_(params.extent),
iterator_C_(source_iterator),
iterator_D_(destination_iterator),
threadblock_tile_offset_(threadblock_tile_offset),
thread_idx_(thread_idx)
{
alpha_ = (params.elementwise.alpha_ptr ? *params.elementwise.alpha_ptr : params.elementwise.alpha);
beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta);
if (beta_ == ElementAccumulator()) {
iterator_C_.clear_mask();
}
}
/// Helper to indicate split-K behavior
CUTLASS_DEVICE
void set_k_partition(
int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
int split_k_slices) { ///< Total number of split-K slices
}
/// Called to set the batch index
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
}
/// Called at the start of the epilogue just before iterating over accumulator slices
CUTLASS_DEVICE
void begin_epilogue() {
// If shift-K feature is enabled, we load shift-k fragment
// at the very beginning of an epilogue
if (kIsShiftedVariance && params_.ptr_Shifted_K != nullptr) {
shift_k_frag_.clear();
int thread_offset_row_base = iterator_D_.thread_start_row();
CUTLASS_PRAGMA_UNROLL
for (int iter_idx = 0; iter_idx < kIterations; ++iter_idx) {
int step_offset = iter_idx * OutputTileIterator::Shape::kRow;
CUTLASS_PRAGMA_UNROLL
for (int rid = 0; rid < kRowIterations; ++rid) {
int row_step_offset = rid * kDeltaRow;
int row_offset = thread_offset_row_base + step_offset + row_step_offset;
bool is_load = (row_offset < extent_.row());
shift_k_frag_[iter_idx * kRowIterations + rid] = load_shift_k_(row_offset, is_load);
}
}
}
}
/// Called at the start of one step before starting accumulator exchange
CUTLASS_DEVICE
void begin_step(int step_idx) {
fragment_D_.clear();
if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
fragment_C_.clear();
iterator_C_.load(fragment_C_);
++iterator_C_;
}
}
/// Called at the start of a row
CUTLASS_DEVICE
void begin_row(int row_idx) {
/// set the accumulator to 0
accum_sum_element_ = ElementLayernormCompute(0);
accum_sum_square_ = ElementLayernormCompute(0);
}
/// Called after accumulators have been exchanged for each accumulator vector
CUTLASS_DEVICE
void visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorFragment const &accum) {
using Mul = cutlass::multiplies<ElementLayernormCompute>;
using Minus = cutlass::minus<ElementLayernormCompute>;
using Exp = cutlass::fast_exp_op<ElementLayernormCompute>;
Minus minus;
Mul mul;
Exp exponential;
LayernormFragment result;
thread_offset_ =
iterator_D_.thread_start() +
OutputTileIterator::ThreadMap::iteration_offset(frag_idx);
NumericArrayConverter<ElementLayernormCompute, ElementOutput, kElementsPerAccess> source_converter;
OutputVector &source_vector = reinterpret_cast<OutputVector *>(&fragment_C_)[frag_idx];
bool column_guard = (thread_offset_.column() < extent_.column());
if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
result = source_converter(elementwise_(accum));
}else{
result = source_converter(elementwise_(accum, source_vector));
}
ElementLayernormCompute inv_scalar = cutlass::constants::one<ElementLayernormCompute>() / ElementLayernormCompute(extent_.column());
// Fragment is cleared for non-reachable columns so no need to check against column guard
ElementLayernormCompute accum_sum_element_tmp = element_sum_accumulator_(result);
// Square sum is different. Non-reachable columns should've been computed for shift-k
// Otherwise we will incorrectly have some extra k^2 added into square sum.
ElementLayernormCompute accum_sum_square_tmp = ElementLayernormCompute(0);
if (column_guard) {
accum_sum_square_tmp = (kIsShiftedVariance) ? \
square_sum_accumulator_(result, shift_k_frag_[iter_idx * kRowIterations + row_idx]) : \
square_sum_accumulator_(result);
}
accum_sum_element_tmp *= inv_scalar;
accum_sum_square_tmp *= inv_scalar;
// After performing the in-thread reduction, we then perform cross-thread / in-warp reduction
CUTLASS_PRAGMA_UNROLL
for (int i = kHalfThreadsPerRow; i > 0; i >>= 1) {
accum_sum_element_tmp += __shfl_xor_sync(0xFFFFFFFF, accum_sum_element_tmp, i);
accum_sum_square_tmp += __shfl_xor_sync(0xFFFFFFFF, accum_sum_square_tmp, i);
}
accum_sum_element_ += accum_sum_element_tmp;
accum_sum_square_ += accum_sum_square_tmp;
// Convert to the output
NumericArrayConverter<ElementOutput, ElementLayernormCompute, kElementsPerAccess> output_converter;
OutputVector &output = reinterpret_cast<OutputVector *>(&fragment_D_)[frag_idx];
output = output_converter(result);
}
/// Called at the start of a row
CUTLASS_DEVICE
void end_row(int row_idx) {
using ConvertVarianceOutput = cutlass::NumericConverter<ElementVariance, ElementLayernormCompute>;
using ConvertMeanOutput = cutlass::NumericConverter<ElementMean, ElementLayernormCompute>;
ConvertVarianceOutput convert_variance_output;
ConvertMeanOutput convert_mean_output;
bool is_write_thread = (thread_offset_.row() < extent_.row() && (threadIdx.x % kThreadsPerRow) == 0);
int row_offset = thread_offset_.row() + threadblock_tile_offset_.n() * extent_.row();
ElementVariance *curr_ptr_sum_square = params_.ptr_Variance + row_offset;
ElementMean *curr_ptr_element_sum = params_.ptr_Mean + row_offset;
arch::global_store<ElementVariance, sizeof(ElementVariance)>(
convert_variance_output(accum_sum_square_),
(void *)curr_ptr_sum_square,
is_write_thread);
arch::global_store<ElementMean, sizeof(ElementMean)>(
convert_mean_output(accum_sum_element_),
(void *)curr_ptr_element_sum,
is_write_thread);
}
/// Called after all accumulator elements have been visited
CUTLASS_DEVICE
void end_step(int step_idx) {
iterator_D_.store(fragment_D_);
++iterator_D_;
}
/// Called after all steps have been completed
CUTLASS_DEVICE
void end_epilogue() {
}
private:
CUTLASS_DEVICE
ElementLayernormCompute load_shift_k_(int row_offset, bool is_load) {
using ConvertShiftK = cutlass::NumericConverter<ElementLayernormCompute, ElementOutput>;
ConvertShiftK convert_shift_k;
ElementOutput shift_k_val;
// Computes the address to load shift_k element
ElementOutput *curr_ptr_shift_k = params_.ptr_Shifted_K + row_offset;
// Conditionally loads from global memory
arch::global_load<ElementOutput, sizeof(ElementOutput)>(shift_k_val, (void *)curr_ptr_shift_k, is_load);
// Converts data type to return
ElementLayernormCompute converted_shift_k_val = convert_shift_k(shift_k_val);
return converted_shift_k_val;
}
CUTLASS_DEVICE
ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum) {
ElementLayernormCompute sum_ = ElementLayernormCompute(0);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < LayernormFragment::kElements; ++i) {
auto accum_ = accum[i];
sum_ += accum_ * accum_;
}
return sum_;
}
CUTLASS_DEVICE
ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum, ElementLayernormCompute shift_k_val) {
ElementLayernormCompute sum_ = ElementLayernormCompute(0);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < LayernormFragment::kElements; ++i) {
auto accum_ = accum[i] - shift_k_val;
sum_ += accum_ * accum_;
}
return sum_;
}
CUTLASS_DEVICE
ElementLayernormCompute element_sum_accumulator_(LayernormFragment const &accum) {
ElementLayernormCompute sum_ = ElementLayernormCompute(0);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < LayernormFragment::kElements; ++i) {
sum_ += accum[i];
}
return sum_;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,77 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind gemm related enum types to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/gemm/gemm.h"
#include "host.h"
namespace py = pybind11;
void bind_gemm(py::module &m) {
//
// Enumerate types
// cutlass/gemm/gemm.h
py::enum_<cutlass::gemm::GemmUniversalMode>(m, "Mode")
.value("Gemm", cutlass::gemm::GemmUniversalMode::kGemm, "Ordinary GEMM & GEMM Split-K serial")
.value("GemmSplitKParallel", cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, "GEMM Split-K parallel")
.value("Batched", cutlass::gemm::GemmUniversalMode::kBatched, "Batched GEMM")
.value("Array", cutlass::gemm::GemmUniversalMode::kArray)
.value("Invalid", cutlass::gemm::GemmUniversalMode::kInvalid);
/// GemmCoord is a structure that specifies a location within the coordinate space of a GEMM problem
py::class_<cutlass::gemm::GemmCoord>(m, "GemmCoord")
.def(py::init<int, int, int>())
.def("m", py::overload_cast<>(&cutlass::gemm::GemmCoord::m))
.def("n", py::overload_cast<>(&cutlass::gemm::GemmCoord::n))
.def("k", py::overload_cast<>(&cutlass::gemm::GemmCoord::k))
// get tensor coords
.def("mk",
[](const cutlass::gemm::GemmCoord & problem_size) {
return cutlass::MatrixCoord(problem_size.mk());
})
.def("kn",
[](const cutlass::gemm::GemmCoord & problem_size) {
return cutlass::MatrixCoord(problem_size.kn());
})
.def("mn",
[](const cutlass::gemm::GemmCoord & problem_size) {
return cutlass::MatrixCoord(problem_size.mn());
});
py::module_ host_submodule = m.def_submodule("host");
bind_gemm_host_helper(host_submodule);
}

View File

@ -1,638 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/params_universal_base.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/complex.h"
#include "cutlass/semaphore.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/trace.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
>
struct GemmUniversalwithEpilogueVisitor {
public:
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueVisitor = typename Epilogue::Visitor;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using ElementC = typename EpilogueVisitor::ElementOutput;
using LayoutC = typename EpilogueVisitor::OutputTileIterator::Layout;
static ComplexTransform const kTransformA = Mma::kTransformA;
static ComplexTransform const kTransformB = Mma::kTransformB;
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
static int const kStages = Mma::kStages;
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
/// Split-K preserves splits that are 128b aligned
static int const kSplitKAlignment = const_max(
128 / sizeof_bits<ElementA>::value,
128 / sizeof_bits<ElementB>::value
);
//
// Structures
//
/// Argument structure
struct Arguments : UniversalArgumentsBase {
//
// Data members
//
typename EpilogueVisitor::Arguments epilogue_visitor;
void const * ptr_A;
void const * ptr_B;
void const * ptr_C;
void * ptr_D;
int64_t batch_stride_A;
int64_t batch_stride_B;
int64_t batch_stride_C;
typename LayoutA::Stride stride_a;
typename LayoutB::Stride stride_b;
typename LayoutC::Stride stride_c;
typename LayoutC::Stride stride_d;
typename LayoutA::Stride::LongIndex lda;
typename LayoutB::Stride::LongIndex ldb;
typename LayoutC::Stride::LongIndex ldc;
typename LayoutC::Stride::LongIndex ldd;
int const * ptr_gather_A_indices;
int const * ptr_gather_B_indices;
int const * ptr_scatter_D_indices;
//
// Methods
//
Arguments():
ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr),
ptr_gather_A_indices(nullptr),
ptr_gather_B_indices(nullptr),
ptr_scatter_D_indices(nullptr) {}
/// constructs an arguments structure
Arguments(
GemmUniversalMode mode,
GemmCoord problem_size,
int batch_count,
typename EpilogueVisitor::Arguments epilogue_visitor,
void const * ptr_A,
void const * ptr_B,
void const * ptr_C,
void * ptr_D,
int64_t batch_stride_A,
int64_t batch_stride_B,
int64_t batch_stride_C,
int64_t batch_stride_D,
typename LayoutA::Stride stride_a,
typename LayoutB::Stride stride_b,
typename LayoutC::Stride stride_c,
typename LayoutC::Stride stride_d,
int const *ptr_gather_A_indices = nullptr,
int const *ptr_gather_B_indices = nullptr,
int const *ptr_scatter_D_indices = nullptr
):
UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D),
epilogue_visitor(epilogue_visitor),
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C),
stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d),
ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),
ptr_scatter_D_indices(ptr_scatter_D_indices) {
lda = 0;
ldb = 0;
ldc = 0;
ldd = 0;
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
}
/// constructs an arguments structure
Arguments(
GemmUniversalMode mode,
GemmCoord problem_size,
int batch_count,
typename EpilogueVisitor::Arguments epilogue_visitor,
void const * ptr_A,
void const * ptr_B,
void const * ptr_C,
void * ptr_D,
int64_t batch_stride_A,
int64_t batch_stride_B,
int64_t batch_stride_C,
int64_t batch_stride_D,
typename LayoutA::Stride::LongIndex lda,
typename LayoutB::Stride::LongIndex ldb,
typename LayoutC::Stride::LongIndex ldc,
typename LayoutC::Stride::LongIndex ldd,
int const *ptr_gather_A_indices = nullptr,
int const *ptr_gather_B_indices = nullptr,
int const *ptr_scatter_D_indices = nullptr
):
UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D),
epilogue_visitor(epilogue_visitor),
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C),
lda(lda), ldb(ldb), ldc(ldc), ldd(ldd),
ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),
ptr_scatter_D_indices(ptr_scatter_D_indices) {
stride_a = make_Coord(lda);
stride_b = make_Coord(ldb);
stride_c = make_Coord(ldc);
stride_d = make_Coord(ldd);
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
}
/// Returns arguments for the transposed problem
Arguments transposed_problem() const {
Arguments args(*this);
std::swap(args.problem_size.m(), args.problem_size.n());
std::swap(args.ptr_A, args.ptr_B);
std::swap(args.lda, args.ldb);
std::swap(args.stride_a, args.stride_b);
std::swap(args.batch_stride_A, args.batch_stride_B);
std::swap(args.ptr_gather_A_indices, args.ptr_gather_B_indices);
return args;
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params : UniversalParamsBase<
ThreadblockSwizzle,
ThreadblockShape,
ElementA,
ElementB,
ElementC> {
using ParamsBase = UniversalParamsBase<
ThreadblockSwizzle,
ThreadblockShape,
ElementA,
ElementB,
ElementC>;
typename Mma::IteratorA::Params params_A;
typename Mma::IteratorB::Params params_B;
typename EpilogueVisitor::OutputTileIterator::Params params_C;
typename EpilogueVisitor::OutputTileIterator::Params params_D;
typename EpilogueVisitor::Params epilogue_visitor;
void * ptr_A;
void * ptr_B;
void * ptr_C;
void * ptr_D;
int64_t batch_stride_A;
int64_t batch_stride_B;
int64_t batch_stride_C;
int * ptr_gather_A_indices;
int * ptr_gather_B_indices;
int * ptr_scatter_D_indices;
int *semaphore;
//
// Methods
//
/// Default constructor
Params() = default;
CUTLASS_HOST_DEVICE
Params(
Arguments const &args,
int device_sms,
int sm_occupancy
):
ParamsBase(args, device_sms, sm_occupancy),
params_A(args.lda ? make_Coord_with_padding<LayoutA::kStrideRank>(args.lda) : args.stride_a),
params_B(args.ldb ? make_Coord_with_padding<LayoutB::kStrideRank>(args.ldb) : args.stride_b),
params_C(args.ldc ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldc) : args.stride_c),
params_D(args.ldd ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldd) : args.stride_d),
epilogue_visitor(args.epilogue_visitor),
ptr_A(const_cast<void *>(args.ptr_A)),
ptr_B(const_cast<void *>(args.ptr_B)),
ptr_C(const_cast<void *>(args.ptr_C)),
ptr_D(args.ptr_D),
batch_stride_A(args.batch_stride_A),
batch_stride_B(args.batch_stride_B),
batch_stride_C(args.batch_stride_C),
ptr_gather_A_indices(const_cast<int *>(args.ptr_gather_A_indices)),
ptr_gather_B_indices(const_cast<int *>(args.ptr_gather_B_indices)),
ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices)) {
}
CUTLASS_HOST_DEVICE
void update(
Arguments const &args,
void *workspace = nullptr) {
ptr_A = const_cast<void *>(args.ptr_A);
ptr_B = const_cast<void *>(args.ptr_B);
ptr_C = const_cast<void *>(args.ptr_C);
ptr_D = args.ptr_D;
ptr_gather_A_indices = const_cast<int *>(args.ptr_gather_A_indices);
ptr_gather_B_indices = const_cast<int *>(args.ptr_gather_B_indices);
ptr_scatter_D_indices = const_cast<int *>(args.ptr_scatter_D_indices);
batch_stride_A = args.batch_stride_A;
batch_stride_B = args.batch_stride_B;
batch_stride_C = args.batch_stride_C;
epilogue_visitor = args.epilogue_visitor;
semaphore = static_cast<int *>(workspace);
CUTLASS_TRACE_HOST("GemmUniversal::Params::update()");
}
};
/// Shared memory storage structure
union SharedStorage {
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
typename EpilogueVisitor::SharedStorage visitor;
};
public:
//
// Methods
//
CUTLASS_DEVICE
GemmUniversalwithEpilogueVisitor() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(
cutlass::gemm::GemmCoord const & problem_size) {
CUTLASS_TRACE_HOST("GemmUniversalwithEpilogueVisitor::can_implement()");
static int const kAlignmentA = (platform::is_same<LayoutA,
layout::ColumnMajorInterleaved<32>>::value)
? 32
: (platform::is_same<LayoutA,
layout::ColumnMajorInterleaved<64>>::value)
? 64
: Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = (platform::is_same<LayoutB,
layout::RowMajorInterleaved<32>>::value)
? 32
: (platform::is_same<LayoutB,
layout::RowMajorInterleaved<64>>::value)
? 64
: Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = (platform::is_same<LayoutC,
layout::ColumnMajorInterleaved<32>>::value)
? 32
: (platform::is_same<LayoutC,
layout::ColumnMajorInterleaved<64>>::value)
? 64
: Epilogue::OutputTileIterator::kElementsPerAccess;
bool isAMisaligned = false;
bool isBMisaligned = false;
bool isCMisaligned = false;
if (platform::is_same<LayoutA, layout::RowMajor>::value) {
isAMisaligned = problem_size.k() % kAlignmentA;
} else if (platform::is_same<LayoutA, layout::ColumnMajor>::value) {
isAMisaligned = problem_size.m() % kAlignmentA;
} else if (platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value
|| platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {
isAMisaligned = problem_size.k() % kAlignmentA;
}
if (platform::is_same<LayoutB, layout::RowMajor>::value) {
isBMisaligned = problem_size.n() % kAlignmentB;
} else if (platform::is_same<LayoutB, layout::ColumnMajor>::value) {
isBMisaligned = problem_size.k() % kAlignmentB;
} else if (platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value
|| platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {
isBMisaligned = problem_size.k() % kAlignmentB;
}
if (platform::is_same<LayoutC, layout::RowMajor>::value) {
isCMisaligned = problem_size.n() % kAlignmentC;
} else if (platform::is_same<LayoutC, layout::ColumnMajor>::value) {
isCMisaligned = problem_size.m() % kAlignmentC;
} else if (platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value
|| platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {
isCMisaligned = problem_size.n() % kAlignmentC;
}
if (isAMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand");
return Status::kErrorMisalignedOperand;
}
if (isBMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand");
return Status::kErrorMisalignedOperand;
}
if (isCMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand");
return Status::kErrorMisalignedOperand;
}
CUTLASS_TRACE_HOST(" returning kSuccess");
return Status::kSuccess;
}
static Status can_implement(Arguments const &args) {
return can_implement(args.problem_size);
}
// Factory invocation
CUTLASS_DEVICE
static void invoke(
Params const &params,
SharedStorage &shared_storage)
{
GemmUniversalwithEpilogueVisitor op;
op(params, shared_storage);
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
return;
}
int offset_k = 0;
int problem_size_k = params.problem_size.k();
ElementA *ptr_A = static_cast<ElementA *>(params.ptr_A);
ElementB *ptr_B = static_cast<ElementB *>(params.ptr_B);
//
// Fetch pointers based on mode.
//
if (params.mode == GemmUniversalMode::kGemm ||
params.mode == GemmUniversalMode::kGemmSplitKParallel) {
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) {
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
}
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
}
else if (params.mode == GemmUniversalMode::kBatched) {
ptr_A += threadblock_tile_offset.k() * params.batch_stride_A;
ptr_B += threadblock_tile_offset.k() * params.batch_stride_B;
}
else if (params.mode == GemmUniversalMode::kArray) {
ptr_A = static_cast<ElementA * const *>(params.ptr_A)[threadblock_tile_offset.k()];
ptr_B = static_cast<ElementB * const *>(params.ptr_B)[threadblock_tile_offset.k()];
}
__syncthreads();
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_tile_offset.m() * Mma::Shape::kM,
offset_k,
};
cutlass::MatrixCoord tb_offset_B{
offset_k,
threadblock_tile_offset.n() * Mma::Shape::kN
};
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
params.params_A,
ptr_A,
{params.problem_size.m(), problem_size_k},
thread_idx,
tb_offset_A,
params.ptr_gather_A_indices);
typename Mma::IteratorB iterator_B(
params.params_B,
ptr_B,
{problem_size_k, params.problem_size.n()},
thread_idx,
tb_offset_B,
params.ptr_gather_B_indices);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
mma(
gemm_k_iterations,
accumulators,
iterator_A,
iterator_B,
accumulators);
//
// Epilogue
//
// EpilogueOutputOp output_op(params.output_op);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
//assume identity swizzle
MatrixCoord threadblock_offset(
threadblock_tile_offset.m() * Mma::Shape::kM,
threadblock_tile_offset.n() * Mma::Shape::kN
);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
ElementC *ptr_C = static_cast<ElementC *>(params.ptr_C);
ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);
//
// Fetch pointers based on mode.
//
// Construct the semaphore.
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
// Tile iterator loading from source tensor.
EpilogueVisitor epilogue_visitor(
params.epilogue_visitor,
shared_storage.visitor,
threadblock_offset,
threadblock_tile_offset,
thread_idx,
params.problem_size.mn()
);
if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) {
epilogue_visitor.set_batch_index(threadblock_tile_offset.k());
}
Epilogue epilogue(
shared_storage.epilogue,
thread_idx,
warp_idx,
lane_idx);
// Wait on the semaphore - this latency may have been covered by iterator construction
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
semaphore.wait(threadblock_tile_offset.k());
}
// Execute the epilogue operator to update the destination tensor.
epilogue(epilogue_visitor, accumulators);
//
// Release the semaphore
//
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
int lock = 0;
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
// The final threadblock resets the semaphore for subsequent grids.
lock = 0;
}
else {
// Otherwise, the semaphore is incremented
lock = threadblock_tile_offset.k() + 1;
}
semaphore.release(lock);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,47 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind gemm host helpers to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/util/host_reorder.h"
#include "cutlass/layout/tensor.h"
namespace py = pybind11;
void bind_gemm_host_helper(py::module &m) {
m.def("reorder_column", &cutlass::reorder_column<32, int8_t, cutlass::layout::RowMajorInterleaved<32>>);
m.def("reorder_column", &cutlass::reorder_column<32, int8_t, cutlass::layout::ColumnMajorInterleaved<32>>);
}

View File

@ -1,47 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind CUTLASS layouts to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "tensor.h"
#include "matrix.h"
namespace py = pybind11;
void bind_layout(py::module &m) {
bind_tensor_layout(m);
bind_matrix_layout(m);
}

View File

@ -1,87 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind Matrix layouts to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/layout/matrix.h"
namespace py = pybind11;
void bind_matrix_layout(py::module &m) {
//
// Matrix layouts
// cutlass/layout/matrix.h
//
py::class_<cutlass::layout::RowMajor>(m, "RowMajor", R"pbdoc(
Mapping function for row-major matrices.
)pbdoc")
.def_static("packed", &cutlass::layout::RowMajor::packed,
py::arg("extent"),
R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc")
.def("stride", [](const cutlass::layout::RowMajor & layout){
return layout.stride().at(0);
}, R"pbdoc(Returns the stride of the layout)pbdoc");
py::class_<cutlass::layout::ColumnMajor>(m, "ColumnMajor", R"pbdoc(
Mapping function for column-major matrices.
)pbdoc")
.def_static("packed", &cutlass::layout::ColumnMajor::packed,
py::arg("extent"),
R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc" )
.def("stride", [](const cutlass::layout::ColumnMajor & layout){
return layout.stride().at(0);
}, R"pbdoc(Returns the stride of the layout)pbdoc");
py::class_<cutlass::layout::RowMajorInterleaved<32>>(m, "RowMajorInterleaved32",
R"pbdoc(Mapping function for interleaved matrices. Matrix is structured
as row-major arrangement of fixed-size columns 32)pbdoc")
.def_static("packed", &cutlass::layout::RowMajorInterleaved<32>::packed,
py::arg("extent"),
R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc")
.def("stride", [](const cutlass::layout::RowMajorInterleaved<32> & layout){
return layout.stride().at(0);
}, R"pbdoc(Returns the stride of the layout)pbdoc");
py::class_<cutlass::layout::ColumnMajorInterleaved<32>>(m, "ColumnMajorInterleaved32",
R"pbdoc(Mapping function for interleaved matrices. Matrix is structured
as column-major arrangement of fixed-size rows 32)pbdoc")
.def_static("packed", &cutlass::layout::ColumnMajorInterleaved<32>::packed,
py::arg("extent"),
R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc")
.def("stride", [](const cutlass::layout::ColumnMajorInterleaved<32> & layout){
return layout.stride().at(0);
}, R"pbdoc(Returns the stride of the layout)pbdoc");
}

View File

@ -1,74 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind Tensor layouts to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/layout/tensor.h"
namespace py = pybind11;
void bind_tensor_layout(py::module &m) {
//
// Tensor layouts
// cutlass/include/cutlass/layout/tensor.h
//
/// Mapping function for 4-D NHWC tensors.
py::class_<cutlass::layout::TensorNHWC>(m, "TensorNHWC",
R"pbdoc(Mapping function for 4-D NHWC tensors)pbdoc")
.def_static("packed", &cutlass::layout::TensorNHWC::packed,
py::arg("extent"),
R"pbdoc(Helper returns a layout to a tightly packed NHWC tensor)pbdoc")
.def("stride", py::overload_cast<>(&cutlass::layout::TensorNHWC::stride),
R"pbdoc(Returns the stride of the layout)pbdoc");
/// Mapping function for 4-D NC/xHWx tensors.
py::class_<cutlass::layout::TensorNCxHWx<32>>(m, "TensorNC32HW32",
R"pbdoc(Mapping function for 4-D NC/32HW32 tensors)pbdoc")
.def_static("packed", &cutlass::layout::TensorNCxHWx<32>::packed,
py::arg("extent"),
R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc")
.def("stride", py::overload_cast<>(&cutlass::layout::TensorNCxHWx<32>::stride),
R"pbdoc(Returns the stride of the layout)pbdoc");
/// Mapping function for 4-D CxRSKx tensors.
py::class_<cutlass::layout::TensorCxRSKx<32>>(m, "TensorC32RSK32",
R"pbdoc(Mapping function for 4-D C32RSK32 tensors)pbdoc")
.def_static("packed", &cutlass::layout::TensorCxRSKx<32>::packed,
py::arg("extent"),
R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc")
.def("stride", py::overload_cast<>(&cutlass::layout::TensorCxRSKx<32>::stride),
R"pbdoc(Returns the stride of the layout)pbdoc");
}

View File

@ -1,169 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind threadblock swizzling to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/conv/threadblock/threadblock_swizzle.h"
#include <cxxabi.h>
#include <cuda_runtime.h>
namespace py = pybind11;
std::string demangle(const char* mangled_name) {
std::size_t len = 0;
int status = 0;
std::unique_ptr<char> ptr(
__cxxabiv1::__cxa_demangle(mangled_name, nullptr, &len, &status));
return ptr.get();
}
template<typename T>
void bind_identity_swizzle(py::module & m, std::string name) {
py::class_<T>(m, name.c_str(),
R"pbdoc(Threadblock swizzling function for GEMMs)pbdoc")
.def(py::init<>())
.def("get_tiled_shape",
py::overload_cast<cutlass::gemm::GemmCoord, cutlass::gemm::GemmCoord, int>(
&T::get_tiled_shape, py::const_
), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"),
R"pbdoc(Returns the shape of the problem in units of logical tiles
:param problem_size: gemm(M, N, K)
:type problem_size: :class:`cutlass.gemm.GemmCoord`
)pbdoc")
.def("get_tiled_shape",
py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&, cutlass::gemm::GemmCoord, int>(
&T::get_tiled_shape, py::const_
), py::arg("conv_operator"), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"),
R"pbdoc(Returns the shape of the problem in units of logical tiles
:param problem_size: Implicit gemm problem size conv_operator(NPQK, NHWC, KRSC)
:type problem_size: :class:`cutlass.gemm.GemmCoord`)
)pbdoc")
.def("get_tiled_shape",
py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv3dProblemSize&, cutlass::gemm::GemmCoord, int>(
&T::get_tiled_shape, py::const_
), py::arg("conv_operator"), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"),
R"pbdoc(Returns the shape of the problem in units of logical tiles
:param problem_size: Implicit gemm problem size conv_operator(NZPQK, NDHWC, KTRSC)
:type problem_size: :class:`cutlass.gemm.GemmCoord`)
)pbdoc")
.def("get_grid_shape", &T::get_grid_shape,
py::arg("tiled_shape"),
R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc")
.def("tag", [](const T & swizzle){
return demangle(typeid(T).name());
}, R"pbdoc(Returns the c++ name of the swizzling for code emission)pbdoc");
}
template<typename T>
void bind_swizzle(py::module & m, std::string name, std::string doc) {
py::class_<T>(m, name.c_str(), doc.c_str())
.def(py::init<>())
.def("get_tiled_shape",
py::overload_cast<cutlass::gemm::GemmCoord, cutlass::gemm::GemmCoord, int>(
&T::get_tiled_shape, py::const_
), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"),
R"pbdoc(Returns the shape of the problem in units of logical tiles
:param problem_size: gemm(M, N, K)
:type problem_size: :class:`cutlass.gemm.GemmCoord`
)pbdoc")
.def("get_grid_shape", &T::get_grid_shape,
py::arg("tiled_shape"),
R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc")
.def("tag", [](const T & swizzle){
return demangle(typeid(T).name());
}, R"pbdoc(Returns the c++ name of the swizzling for code emission)pbdoc");
}
template<typename T>
void bind_swizzle_streamk(py::module & m, std::string name, std::string doc) {
py::class_<T>(m, name.c_str(), doc.c_str())
.def(py::init<>())
.def("tag", [](const T & swizzle){
return demangle(typeid(T).name());
}, R"pbdoc(Returns the c++ name of the swizzling for code emission)pbdoc");
}
template<typename T>
void bind_dgrad_swizzle(py::module & m, std::string name) {
py::class_<T>(m, name.c_str(),
R"pbdoc(Threadblock swizzling function for strided dgrad convolution)pbdoc")
.def(py::init<>())
.def("get_tiled_shape",
py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&, cutlass::gemm::GemmCoord, int>(
&T::get_tiled_shape, py::const_
), py::arg("conv_operator"), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"),
R"pbdoc(Returns the shape of the problem in units of logical tiles
:param problem_size: Implicit gemm problem size conv_operator(NPQK, NHWC, KRSC)
:type problem_size: :class:`cutlass.gemm.GemmCoord`)
)pbdoc")
.def("get_grid_shape", &T::get_grid_shape,
py::arg("tiled_shape"),
R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc")
.def("tag", [](const T & swizzle){
return demangle(typeid(T).name());
}, R"pbdoc(Returns the c++ name of the swizzling for code emission)pbdoc");
}
void bind_threadblock_swizzle(py::module &m) {
py::class_<dim3>(m, "dim3",
R"pbdoc(A int3 type xyz contains three integers)pbdoc")
.def(py::init<int, int, int>(),
py::arg("x"), py::arg("y"), py::arg("z"))
.def_readwrite("x", &dim3::x, R"pbdoc(get value x)pbdoc")
.def_readwrite("y", &dim3::y, R"pbdoc(get value y)pbdoc")
.def_readwrite("z", &dim3::z, R"pbdoc(get value z)pbdoc");
bind_identity_swizzle<cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>>(m, "IdentitySwizzle1");
bind_identity_swizzle<cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>>(m, "IdentitySwizzle2");
bind_identity_swizzle<cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>>(m, "IdentitySwizzle4");
bind_identity_swizzle<cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>>(m, "IdentitySwizzle8");
bind_swizzle<cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle>(m, "HorizontalSwizzle", R"pbdoc(Threadblock swizzling function for GEMMs)pbdoc");
bind_swizzle<cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle>(m, "BatchedIdentitySwizzle", R"pbdoc(Threadblock swizzling function for batched GEMMs)pbdoc");
bind_swizzle_streamk<cutlass::gemm::threadblock::ThreadblockSwizzleStreamK>(m, "ThreadblockSwizzleStreamK", R"pbdoc(Threadblock swizzling function using Stream K feature)pbdoc");
bind_dgrad_swizzle<cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>>(m, "StridedDgradIdentitySwizzle1");
bind_dgrad_swizzle<cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>>(m, "StridedDgradIdentitySwizzle4");
bind_dgrad_swizzle<cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle>(m, "StridedDgradHorizontalSwizzle");
}

View File

@ -1,78 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind Tensor Coord to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/tensor_coord.h"
namespace py = pybind11;
void bind_tensor_coord(py::module &m) {
//
// Tensor Coords
// cutlass/include/cutlass/tensor_coord.h
//
/// Defines a canonical 4D coordinate used by tensor operations.
py::class_<cutlass::Tensor4DCoord>(m, "Tensor4DCoord",
R"pbdoc(Defines a canonical 4D coordinate used by tensor operations)pbdoc")
.def(py::init<int, int, int, int>(),
py::arg("n"), py::arg("h"), py::arg("w"), py::arg("c"),
R"pbdoc(Helper to construct from N, H, W, and C)pbdoc")
.def("at", py::overload_cast<int>(&cutlass::Tensor4DCoord::at),
py::arg("dim"),
R"pbdoc(Gets the index of a given Coord element)pbdoc")
.def("size", [](const cutlass::Tensor4DCoord & coord) {
return coord.at(0) * coord.at(1) * coord.at(2) * coord.at(3);},
R"pbdoc(The size of the tensor coord)pbdoc");
py::class_<cutlass::Coord<3>>(m, "Tensor3DCoord",
R"pbdoc(Defines a canonical 3D coordinate used by tensor operations)pbdoc")
.def("at", py::overload_cast<int>(&cutlass::Coord<3>::at),
py::arg("dim"),
R"pbdoc(Gets the index of a given Coord element)pbdoc");
// Matrix Size
py::class_<cutlass::MatrixCoord>(m, "MatrixCoord",
R"pbdoc(MatrixCoord wraps Coord<2, int> to provide a helper for accessing named dimensions. Classes
expecting a coordinate in the rank=2 index space of a matrix should use MatrixCoord.)pbdoc")
.def(py::init<int, int>(),
py::arg("row"), py::arg("column"), R"pbdoc(Helper to construct from a row and column)pbdoc")
.def("row", py::overload_cast<>(&cutlass::MatrixCoord::row),
R"pbdoc(Returns the row of the coordinate)pbdoc")
.def("column", py::overload_cast<>(&cutlass::MatrixCoord::column),
R"pbdoc(Returns the column of the coordinate)pbdoc");
}

View File

@ -1,102 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind TensorRef and View to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/tensor_ref.h"
#include "cutlass/tensor_view.h"
#include "types.h"
template<typename T, typename L, typename TF>
void bind_tensor_ref_view(py::module &m, std::string name) {
py::class_<cutlass::TensorRef<T, L>>(m, ("TensorRef" + name).c_str())
.def(py::init([](int64_t address, const L& layout_ ) {
T* ptr = reinterpret_cast< T*>(address);
return new cutlass::TensorRef<T, L>(ptr, layout_);
}))
.def("data", [](cutlass::TensorRef<T, L>& tensor_ref) {
T* ptr = tensor_ref.data();
return int64_t(ptr);
})
.def("layout", py::overload_cast<>(&cutlass::TensorRef<T, L>::layout));
m.def("get_tensor_ref", [](int64_t address, TF data, const L& layout_) {
T* ptr = reinterpret_cast<T*>(address);
cutlass::TensorRef<T, L> tensor_ref = cutlass::TensorRef<T, L>(ptr, layout_);
return tensor_ref;
});
py::class_<cutlass::TensorView<T, L>>(m, ("TensorView" + name).c_str())
.def(py::init<const cutlass::TensorRef<T, L>&, const typename L::TensorCoord &>());
}
void bind_tensor_refs_and_views(py::module &m) {
/// float
bind_tensor_ref_view<float, cutlass::layout::RowMajor, cutlass::float32>(m, "F32RowMajor");
bind_tensor_ref_view<float, cutlass::layout::ColumnMajor, cutlass::float32>(m, "F32ColumnMajor");
bind_tensor_ref_view<float, cutlass::layout::TensorNHWC, cutlass::float32>(m, "F32NHWC");
/// double
bind_tensor_ref_view<double, cutlass::layout::RowMajor, cutlass::float64>(m, "F64RowMajor");
bind_tensor_ref_view<double, cutlass::layout::ColumnMajor, cutlass::float64>(m, "F64ColumnMajor");
bind_tensor_ref_view<double, cutlass::layout::TensorNHWC, cutlass::float64>(m, "F64NHWC");
// half_t
bind_tensor_ref_view<cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t>(m, "F16RowMajor");
bind_tensor_ref_view<cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t>(m, "F16ColumnMajor");
bind_tensor_ref_view<cutlass::half_t, cutlass::layout::TensorNHWC, cutlass::half_t>(m, "F16NHWC");
// bfloat16
bind_tensor_ref_view<cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t>(m, "BF16RowMajor");
bind_tensor_ref_view<cutlass::bfloat16_t, cutlass::layout::ColumnMajor, cutlass::bfloat16_t>(m, "BF16ColumnMajor");
bind_tensor_ref_view<cutlass::bfloat16_t, cutlass::layout::TensorNHWC, cutlass::bfloat16_t>(m, "BF16NHWC");
// int8_t
bind_tensor_ref_view<int8_t, cutlass::layout::RowMajorInterleaved<32>, cutlass::int8>(m, "S8RowMajorInterleaved32");
bind_tensor_ref_view<int8_t, cutlass::layout::ColumnMajorInterleaved<32>, cutlass::int8>(m, "S8ColumnMajorInterleaved32");
bind_tensor_ref_view<int8_t, cutlass::layout::RowMajor, cutlass::int8>(m, "S8RowMajor");
bind_tensor_ref_view<int8_t, cutlass::layout::ColumnMajor, cutlass::int8>(m, "S8ColumnMajor");
bind_tensor_ref_view<int8_t, cutlass::layout::TensorNHWC, cutlass::int8>(m, "S8NHWC");
bind_tensor_ref_view<int8_t, cutlass::layout::TensorNCxHWx<32>, cutlass::int8>(m, "S8NC32HW32");
bind_tensor_ref_view<int8_t, cutlass::layout::TensorCxRSKx<32>, cutlass::int8>(m, "S8C32RSK32");
// int32_t
bind_tensor_ref_view<int32_t, cutlass::layout::RowMajor, cutlass::int32>(m, "S32RowMajor");
bind_tensor_ref_view<int32_t, cutlass::layout::ColumnMajor, cutlass::int32>(m, "S32ColumnMajor");
bind_tensor_ref_view<int32_t, cutlass::layout::TensorNHWC, cutlass::int32>(m, "S32NHWC");
}

View File

@ -1,146 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind CUTLASS types to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/half.h"
namespace py = pybind11;
namespace cutlass {
/// IEEE 32-bit signed integer
struct alignas(1) int8 {
int8_t storage;
explicit int8(int x) {
storage = int8_t(x);
}
explicit int8(float x) {
storage = int8_t(x);
}
int8_t c_value(){return storage;}
};
/// IEEE 32-bit signed integer
struct alignas(4) int32 {
int storage;
explicit int32(int x) {
storage = x;
}
explicit int32(float x) {
storage = int(x);
}
int c_value(){return storage;}
};
/// IEEE single-precision floating-point type
struct alignas(4) float32 {
float storage;
explicit float32(float x) {
storage = x;
}
explicit float32(int x) {
storage = float(x);
}
float c_value(){return storage;}
};
/// IEEE double-precision floating-point type
struct alignas(4) float64 {
double storage;
explicit float64(float x) {
storage = double(x);
}
explicit float64(int x) {
storage = double(x);
}
double c_value(){return storage;}
};
}
void bind_cutlass_types(py::module &m) {
// s8
py::class_<cutlass::int8>(m, "int8")
.def(py::init<float>())
.def(py::init<int>())
.def_readwrite("storage", &cutlass::int8::storage)
.def("value", &cutlass::int8::c_value);
// s32
py::class_<cutlass::int32>(m, "int32")
.def(py::init<float>())
.def(py::init<int>())
.def_readwrite("storage", &cutlass::int32::storage)
.def("value", &cutlass::int32::c_value);
// f16
py::class_<cutlass::half_t>(m, "float16")
.def(py::init<float>())
.def(py::init<double>())
.def(py::init<int>())
.def(py::init<unsigned>())
.def_readwrite("storage", &cutlass::half_t::storage)
.def("value", [](const cutlass::half_t& value) {return value;});
// bf16
py::class_<cutlass::bfloat16_t>(m, "bfloat16")
.def(py::init<float>())
.def(py::init<int>())
.def_readwrite("storage", &cutlass::bfloat16_t::storage)
.def("value", [](const cutlass::bfloat16_t& value) {return value;});
// f32
py::class_<cutlass::float32>(m, "float32")
.def(py::init<float>())
.def(py::init<int>())
.def_readwrite("storage", &cutlass::float32::storage)
.def("value", &cutlass::float32::c_value);
// tf32
py::class_<cutlass::tfloat32_t>(m, "tfloat32")
.def(py::init<float>())
.def(py::init<int>())
.def_readwrite("storage", &cutlass::tfloat32_t::storage)
.def("value", [](const cutlass::tfloat32_t& value) {return value;});
// f64
py::class_<cutlass::float64>(m, "float64")
.def(py::init<float>())
.def(py::init<int>())
.def_readwrite("storage", &cutlass::float64::storage)
.def("value", &cutlass::float64::c_value);
}

View File

@ -1,32 +0,0 @@
#include <cutlass/complex.h>
namespace cutlass {
/// ENUM class for datatypes
enum class DataType {
kB1, kU2, kU4, kU8,
kU16, kU32, kU64, kS2,
kS4, kS8, kS16, kS32,
kS64, kF16, kBF16, kF32,
kTF32, kF64, kCF16, kCBF16,
kCF32, kCTF32, kCF64, kCS2,
kCS4, kCS8, kCS16, kCS32,
kCS64, kCU2, kCU4, kCU8,
kCU16, kCU32, kCU64, kInvalid
};
/// ENUM class for LayoutTypes
enum class LayoutType {
kColumnMajor, kRowMajor,
kColumnMajorInterleaved2, kRowMajorInterleaved2,
kColumnMajorInterleaved32, kRowMajorInterleaved32,
kColumnMajorInterleaved64, kRowMajorInterleaved64,
kTensorNHWC, kTensorNDHWC, kTensorNCHW, kTensorNGHWC,
kTensorNC32HW32, kTensorNC64HW64, kTensorC32RSK32,
kTensorC64RSK64
};
/// ENUM class for opcode class
} // namespace cutlass

View File

@ -1,54 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind convolution problems to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "unit/conv/device/conv2d_problems.h"
#include "cutlass/conv/conv2d_problem_size.h"
namespace py = pybind11;
PYBIND11_MAKE_OPAQUE(std::vector<cutlass::conv::Conv2dProblemSize>);
void bind_conv_problem_size_test(py::module &m) {
py::bind_vector<std::vector<cutlass::conv::Conv2dProblemSize>>(m, "Conv2dProblemVector")
.def("size", &std::vector<cutlass::conv::Conv2dProblemSize>::size);
// Get Conv2d problem sizes
py::class_<test::conv::device::TestbedConv2dProblemSizes>(m, "TestbedConv2dProblemSizes")
.def(py::init<int>())
.def_readonly("conv2d_default_sizes", &test::conv::device::TestbedConv2dProblemSizes::conv2d_default_sizes);
}

View File

@ -1,49 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind convolution related types to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "conv_problems.h"
#include "host.h"
namespace py = pybind11;
void bind_convolution_test(py::module &m) {
// Conv problem sizes
bind_conv_problem_size_test(m);
py::module_ host_submodule = m.def_submodule("host");
bind_conv_host_references(host_submodule);
}

View File

@ -1,181 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind Convolution host test helpers to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "unit/conv/device/cache_testbed_output.h"
#include "cutlass/util/reference/host/convolution.h"
#include "cutlass/util/reference/host/tensor_compare.h"
namespace py = pybind11;
template<typename Ta, typename La, typename Tb, typename Lb, typename Tc, typename Lc, typename Tacc, typename Te>
void bind_conv2d_host(py::module &m) {
m.def("conv2d", \
&cutlass::reference::host::Conv2d< \
Ta, La, Tb, Lb, Tc, Lc, Te, Tacc>);
m.def("CreateCachedConv2dTestKey", &test::conv::device::CreateCachedConv2dTestKey<Ta, La, Tb, Lb, Tc, Lc, Tacc, Te>);
}
template<typename Ta, typename La, typename Tb, typename Lb, typename Tc, typename Lc, typename Tacc, typename Te>
void bind_conv2d_host_sat(py::module &m) {
m.def("conv2d", \
&cutlass::reference::host::Conv2d< \
Ta, La, Tb, Lb, Tc, Lc, Te, Tacc>);
m.def("CreateCachedConv2dTestKey", &test::conv::device::CreateCachedConv2dTestKey<Ta, La, Tb, Lb, Tc, Lc, Tacc, Te>);
}
template<typename Ta, typename Tb, typename Tc, typename Tacc, typename Te>
void bind_conv2d_host_nhwc(py::module &m) {
bind_conv2d_host<
Ta, cutlass::layout::TensorNHWC,
Tb, cutlass::layout::TensorNHWC,
Tc, cutlass::layout::TensorNHWC,
Tacc, Te>(m);
}
template<typename Ta, typename Tb, typename Tc, typename Tacc, typename Te>
void bind_conv2d_host_nc32hw32(py::module &m) {
bind_conv2d_host_sat<
Ta, cutlass::layout::TensorNCxHWx<32>,
Tb, cutlass::layout::TensorCxRSKx<32>,
Tc, cutlass::layout::TensorNCxHWx<32>,
Tacc, Te>(m);
}
template<typename T, typename Layout>
void bind_tensor_equals(py::module &m) {
m.def("equals", py::overload_cast<
const cutlass::TensorView<T, Layout>&, const cutlass::TensorView<T, Layout>&>(
&cutlass::reference::host::TensorEquals<T, Layout>
));
}
#define BIND_TENSOR_HASH(Element, Layout) { \
m.def("TensorHash", &test::conv::device::TensorHash<Element, Layout>, py::arg("view"), py::arg("hash") = test::conv::device::CRC32(), py::arg("crc")=uint32_t()); \
}
void bind_conv_host_references(py::module &m) {
//
// Conv2d reference on host
// tools/util/include/cutlass/util/reference/host/convolution.h
/// double
bind_conv2d_host_nhwc<double, double, double, double, double>(m);
/// float
bind_conv2d_host_nhwc<float, float, float, float, float>(m);
/// half
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, cutlass::half_t, cutlass::half_t, cutlass::half_t>(m);
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, cutlass::half_t, float, cutlass::half_t>(m);
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, cutlass::half_t, float, float>(m);
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, cutlass::half_t, cutlass::half_t, float>(m);
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, float, cutlass::half_t, cutlass::half_t>(m);
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, float, float, cutlass::half_t>(m);
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, float, float, float>(m);
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, float, cutlass::half_t, float>(m);
/// bfloat16
bind_conv2d_host_nhwc<cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass::bfloat16_t, float, cutlass::bfloat16_t>(m);
bind_conv2d_host_nhwc<cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass::bfloat16_t, float, float>(m);
bind_conv2d_host_nhwc<cutlass::bfloat16_t, cutlass::bfloat16_t, float, float, cutlass::bfloat16_t>(m);
bind_conv2d_host_nhwc<cutlass::bfloat16_t, cutlass::bfloat16_t, float, float, float>(m);
/// s8
bind_conv2d_host_nhwc<int8_t, int8_t, int8_t, int32_t, int32_t>(m);
bind_conv2d_host_nhwc<int8_t, int8_t, int8_t, int32_t, int8_t>(m);
bind_conv2d_host_nhwc<int8_t, int8_t, int32_t, int32_t, int32_t>(m);
bind_conv2d_host_nhwc<int8_t, int8_t, int32_t, int32_t, int8_t>(m);
bind_conv2d_host_nhwc<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_conv2d_host_nhwc<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_conv2d_host_nhwc<int8_t, int8_t, int32_t, int32_t, float>(m);
bind_conv2d_host_nhwc<int8_t, int8_t, int32_t, int32_t, float>(m);
bind_conv2d_host_nc32hw32<int8_t, int8_t, int8_t, int32_t, int32_t>(m);
bind_conv2d_host_nc32hw32<int8_t, int8_t, int8_t, int32_t, int8_t>(m);
bind_conv2d_host_nc32hw32<int8_t, int8_t, int32_t, int32_t, int32_t>(m);
bind_conv2d_host_nc32hw32<int8_t, int8_t, int32_t, int32_t, int8_t>(m);
bind_conv2d_host_nc32hw32<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_conv2d_host_nc32hw32<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_conv2d_host_nc32hw32<int8_t, int8_t, int32_t, int32_t, float>(m);
bind_conv2d_host_nc32hw32<int8_t, int8_t, int32_t, int32_t, float>(m);
//
// Compare whether two tensors are equal
//
/// double
bind_tensor_equals<double, cutlass::layout::TensorNHWC>(m);
/// float
bind_tensor_equals<float, cutlass::layout::TensorNHWC>(m);
/// half
bind_tensor_equals<cutlass::half_t, cutlass::layout::TensorNHWC>(m);
/// bfloat16
bind_tensor_equals<cutlass::bfloat16_t, cutlass::layout::TensorNHWC>(m);
/// s32
bind_tensor_equals<int32_t, cutlass::layout::TensorNHWC>(m);
bind_tensor_equals<int32_t, cutlass::layout::TensorNCxHWx<32>>(m);
/// s8
bind_tensor_equals<int8_t, cutlass::layout::TensorNHWC>(m);
bind_tensor_equals<int8_t, cutlass::layout::TensorNCxHWx<32>>(m);
/// Cache
py::class_<test::conv::device::CachedTestKey>(m, "CachedTestKey")
.def(py::init<>())
.def(py::init<std::string, std::string, std::string, uint32_t, uint32_t, uint32_t>())
.def_readwrite("problem", &test::conv::device::CachedTestKey::problem);
py::class_<test::conv::device::CachedTestResult>(m, "CachedTestResult")
.def(py::init<>())
.def(py::init<uint32_t>())
.def_readwrite("D", &test::conv::device::CachedTestResult::D);
py::class_<test::conv::device::CachedTestResultListing>(m, "CachedTestResultListing")
.def(py::init<const std::string &>())
.def("find", &test::conv::device::CachedTestResultListing::find)
.def("append", &test::conv::device::CachedTestResultListing::append)
.def("write", &test::conv::device::CachedTestResultListing::write);
py::class_<test::conv::device::CRC32>(m, "CRC32")
.def(py::init<>());
BIND_TENSOR_HASH(double, cutlass::layout::TensorNHWC)
BIND_TENSOR_HASH(float, cutlass::layout::TensorNHWC);
BIND_TENSOR_HASH(cutlass::half_t, cutlass::layout::TensorNHWC);
BIND_TENSOR_HASH(cutlass::bfloat16_t, cutlass::layout::TensorNHWC);
BIND_TENSOR_HASH(int32_t, cutlass::layout::TensorNHWC);
BIND_TENSOR_HASH(int8_t, cutlass::layout::TensorNCxHWx<32>);
}

View File

@ -1,45 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind gemm test to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "host.h"
namespace py = pybind11;
void bind_gemm_test(py::module &m) {
py::module_ host_submodule = m.def_submodule("host");
bind_gemm_host_reference(host_submodule);
}

View File

@ -1,431 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind gemm test host functions to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "cutlass/cutlass.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/host_reorder.h"
#include "cutlass/functional.h"
namespace py = pybind11;
template<
typename ElementA, typename LayoutA,
typename ElementB, typename LayoutB,
typename ElementC, typename LayoutC,
typename AccumulatorType, typename ComputeType,
typename InnerProductOp>
void bind_host_gemm_saturate(py::module &m) {
m.def("gemm_saturate", py::overload_cast<
cutlass::gemm::GemmCoord, ComputeType,
cutlass::TensorRef<ElementA, LayoutA>,
cutlass::TensorRef<ElementB, LayoutB>,
ComputeType,
cutlass::TensorRef<ElementC, LayoutC>,
cutlass::TensorRef<ElementC, LayoutC>,
AccumulatorType>(
&cutlass::reference::host::compute_gemm<
ElementA, LayoutA,
ElementB, LayoutB,
ElementC, LayoutC,
ComputeType,
AccumulatorType,
InnerProductOp,
cutlass::NumericConverterClamp<ElementC, AccumulatorType>>
));
}
template<
typename ElementA, typename LayoutA,
typename ElementB, typename LayoutB,
typename ElementC, typename LayoutC,
typename AccumulatorType, typename ComputeType,
typename InnerProductOp>
void bind_host_gemm(py::module &m) {
m.def("gemm", py::overload_cast<
cutlass::gemm::GemmCoord, ComputeType,
cutlass::TensorRef<ElementA, LayoutA>,
cutlass::TensorRef<ElementB, LayoutB>,
ComputeType,
cutlass::TensorRef<ElementC, LayoutC>,
cutlass::TensorRef<ElementC, LayoutC>,
AccumulatorType>(
&cutlass::reference::host::compute_gemm<
ElementA, LayoutA,
ElementB, LayoutB,
ElementC, LayoutC,
ComputeType,
AccumulatorType,
InnerProductOp,
cutlass::NumericConverter<ElementC, AccumulatorType>>
));
}
template<
typename ElementA, typename ElementB, typename ElementC,
typename AccumulatorType, typename ComputeType>
void bind_host_gemm_multiply_add(py::module &m) {
bind_host_gemm<
ElementA, cutlass::layout::RowMajor,
ElementB, cutlass::layout::RowMajor,
ElementC, cutlass::layout::RowMajor,
ComputeType, AccumulatorType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::ColumnMajor,
ElementB, cutlass::layout::RowMajor,
ElementC, cutlass::layout::RowMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::RowMajor,
ElementB, cutlass::layout::ColumnMajor,
ElementC, cutlass::layout::RowMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::RowMajor,
ElementB, cutlass::layout::RowMajor,
ElementC, cutlass::layout::ColumnMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::RowMajor,
ElementB, cutlass::layout::ColumnMajor,
ElementC, cutlass::layout::ColumnMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::ColumnMajor,
ElementB, cutlass::layout::RowMajor,
ElementC, cutlass::layout::ColumnMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::ColumnMajor,
ElementB, cutlass::layout::ColumnMajor,
ElementC, cutlass::layout::RowMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::ColumnMajor,
ElementB, cutlass::layout::ColumnMajor,
ElementC, cutlass::layout::ColumnMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
}
template<
typename ElementA, typename ElementB, typename ElementC,
typename AccumulatorType, typename ComputeType>
void bind_host_gemm_multiply_add_saturate(py::module &m) {
bind_host_gemm_saturate<
ElementA, cutlass::layout::RowMajor,
ElementB, cutlass::layout::RowMajor,
ElementC, cutlass::layout::RowMajor,
ComputeType, AccumulatorType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::ColumnMajor,
ElementB, cutlass::layout::RowMajor,
ElementC, cutlass::layout::RowMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::RowMajor,
ElementB, cutlass::layout::ColumnMajor,
ElementC, cutlass::layout::RowMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::RowMajor,
ElementB, cutlass::layout::RowMajor,
ElementC, cutlass::layout::ColumnMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::RowMajor,
ElementB, cutlass::layout::ColumnMajor,
ElementC, cutlass::layout::ColumnMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::ColumnMajor,
ElementB, cutlass::layout::RowMajor,
ElementC, cutlass::layout::ColumnMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::ColumnMajor,
ElementB, cutlass::layout::ColumnMajor,
ElementC, cutlass::layout::RowMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::ColumnMajor,
ElementB, cutlass::layout::ColumnMajor,
ElementC, cutlass::layout::ColumnMajor,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
}
template<
typename ElementA, typename ElementB, typename ElementC,
typename AccumulatorType, typename ComputeType>
void bind_host_gemm_multiply_add_interleaved(py::module &m) {
bind_host_gemm<
ElementA, cutlass::layout::RowMajorInterleaved<32>,
ElementB, cutlass::layout::RowMajorInterleaved<32>,
ElementC, cutlass::layout::RowMajorInterleaved<32>,
ComputeType, AccumulatorType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
ElementB, cutlass::layout::RowMajorInterleaved<32>,
ElementC, cutlass::layout::RowMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::RowMajorInterleaved<32>,
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
ElementC, cutlass::layout::RowMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::RowMajorInterleaved<32>,
ElementB, cutlass::layout::RowMajorInterleaved<32>,
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::RowMajorInterleaved<32>,
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
ElementB, cutlass::layout::RowMajorInterleaved<32>,
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
ElementC, cutlass::layout::RowMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm<
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
}
template<
typename ElementA, typename ElementB, typename ElementC,
typename AccumulatorType, typename ComputeType>
void bind_host_gemm_multiply_add_saturate_interleaved(py::module &m) {
bind_host_gemm_saturate<
ElementA, cutlass::layout::RowMajorInterleaved<32>,
ElementB, cutlass::layout::RowMajorInterleaved<32>,
ElementC, cutlass::layout::RowMajorInterleaved<32>,
ComputeType, AccumulatorType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
ElementB, cutlass::layout::RowMajorInterleaved<32>,
ElementC, cutlass::layout::RowMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::RowMajorInterleaved<32>,
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
ElementC, cutlass::layout::RowMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::RowMajorInterleaved<32>,
ElementB, cutlass::layout::RowMajorInterleaved<32>,
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::RowMajorInterleaved<32>,
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
ElementB, cutlass::layout::RowMajorInterleaved<32>,
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
ElementC, cutlass::layout::RowMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
bind_host_gemm_saturate<
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
AccumulatorType, ComputeType,
cutlass::multiply_add<AccumulatorType>>(m);
}
#define BIND_TENSOR_EQUAL(Element, Layout) { \
m.def("equals", py::overload_cast< \
const cutlass::TensorView<Element, Layout>&, const cutlass::TensorView<Element, Layout>&>( \
&cutlass::reference::host::TensorEquals<Element, Layout>)); \
}
void bind_gemm_host_reference(py::module &m) {
/// double
bind_host_gemm_multiply_add<double, double, double, double, double>(m);
/// float
bind_host_gemm_multiply_add<float, float, float, float, float>(m);
/// half_t
bind_host_gemm_multiply_add<cutlass::half_t, cutlass::half_t, cutlass::half_t, cutlass::half_t, cutlass::half_t>(m);
bind_host_gemm_multiply_add<cutlass::half_t, cutlass::half_t, cutlass::half_t, float, float>(m);
bind_host_gemm_multiply_add<cutlass::half_t, cutlass::half_t, float, cutlass::half_t, cutlass::half_t>(m);
bind_host_gemm_multiply_add<cutlass::half_t, cutlass::half_t, float, float, float>(m);
/// bfloat16
bind_host_gemm_multiply_add<cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass::bfloat16_t, float, float>(m);
bind_host_gemm_multiply_add<cutlass::bfloat16_t, cutlass::bfloat16_t, float, float, float>(m);
/// s8
bind_host_gemm_multiply_add<int8_t, int8_t, int8_t, int32_t, int32_t>(m);
bind_host_gemm_multiply_add<int8_t, int8_t, int8_t, int32_t, int8_t>(m);
bind_host_gemm_multiply_add<int8_t, int8_t, int32_t, int32_t, int32_t>(m);
bind_host_gemm_multiply_add<int8_t, int8_t, int32_t, int32_t, int8_t>(m);
bind_host_gemm_multiply_add<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_host_gemm_multiply_add<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_host_gemm_multiply_add<int8_t, int8_t, int32_t, int32_t, float>(m);
bind_host_gemm_multiply_add<int8_t, int8_t, int32_t, int32_t, float>(m);
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int8_t, int32_t, int32_t>(m);
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int8_t, int32_t, int8_t>(m);
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int32_t, int32_t, int32_t>(m);
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int32_t, int32_t, int8_t>(m);
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int32_t, int32_t, float>(m);
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int32_t, int32_t, float>(m);
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int8_t, int32_t, int32_t>(m);
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int8_t, int32_t, int8_t>(m);
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int32_t, int32_t, int32_t>(m);
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int32_t, int32_t, int8_t>(m);
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int32_t, int32_t, float>(m);
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int32_t, int32_t, float>(m);
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int8_t, int32_t, int32_t>(m);
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int8_t, int32_t, int8_t>(m);
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int32_t, int32_t, int32_t>(m);
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int32_t, int32_t, int8_t>(m);
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int8_t, int32_t, float>(m);
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int32_t, int32_t, float>(m);
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int32_t, int32_t, float>(m);
// float
BIND_TENSOR_EQUAL(float, cutlass::layout::RowMajor);
BIND_TENSOR_EQUAL(float, cutlass::layout::ColumnMajor);
// double
BIND_TENSOR_EQUAL(double, cutlass::layout::RowMajor);
BIND_TENSOR_EQUAL(double, cutlass::layout::ColumnMajor);
// half_t
BIND_TENSOR_EQUAL(cutlass::half_t, cutlass::layout::RowMajor);
BIND_TENSOR_EQUAL(cutlass::half_t, cutlass::layout::ColumnMajor);
// bfloat16
BIND_TENSOR_EQUAL(cutlass::bfloat16_t, cutlass::layout::RowMajor);
BIND_TENSOR_EQUAL(cutlass::bfloat16_t, cutlass::layout::ColumnMajor);
// int32_t
BIND_TENSOR_EQUAL(int32_t, cutlass::layout::RowMajor);
BIND_TENSOR_EQUAL(int32_t, cutlass::layout::ColumnMajor);
// int8_t
BIND_TENSOR_EQUAL(int8_t, cutlass::layout::RowMajor);
BIND_TENSOR_EQUAL(int8_t, cutlass::layout::ColumnMajor);
BIND_TENSOR_EQUAL(int8_t, cutlass::layout::RowMajorInterleaved<32>);
BIND_TENSOR_EQUAL(int8_t, cutlass::layout::ColumnMajorInterleaved<32>);
}

View File

@ -214,12 +214,12 @@ cutlass::Status ${name}_kernel_run(cutlass::conv::Conv2dProblemSize* problem_siz
cutlass::Tensor4DCoord tensor_coord_C = cutlass::conv::implicit_gemm_tensor_c_extent(
cutlass::conv::Operator::k${conv_kind_name}, *problem_size
);
TensorRefA tensor_ref_A = get_tensor_ref<TensorRefA, UnderlyingKernel::ElementA>(tensor_coord_A, A);
TensorRefB tensor_ref_B = get_tensor_ref<TensorRefB, UnderlyingKernel::ElementB>(tensor_coord_B, B);
TensorRefC tensor_ref_C = get_tensor_ref<TensorRefC, UnderlyingKernel::ElementC>(tensor_coord_C, C);
TensorRefC tensor_ref_D = get_tensor_ref<TensorRefC, UnderlyingKernel::ElementC>(tensor_coord_C, D);
cutlass::conv::SplitKMode mode;
if (split_k_mode == "serial") {
mode = cutlass::conv::SplitKMode::kSerial;
@ -228,7 +228,7 @@ cutlass::Status ${name}_kernel_run(cutlass::conv::Conv2dProblemSize* problem_siz
} else {
throw std::runtime_error("Invalid split_k_mode: " + split_k_mode);
}
typename DeviceKernel::Arguments arguments{
*problem_size,
tensor_ref_A,
@ -238,18 +238,18 @@ cutlass::Status ${name}_kernel_run(cutlass::conv::Conv2dProblemSize* problem_siz
{alpha, beta},
mode
};
DeviceKernel implicit_gemm_op;
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
void* workspace_ptr = device_memory_allocation(workspace_size, device_id);
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
return status;
}
status = implicit_gemm_op.initialize(arguments, workspace_ptr, stream);
if (status != cutlass::Status::kSuccess) {
return status;
@ -260,6 +260,6 @@ cutlass::Status ${name}_kernel_run(cutlass::conv::Conv2dProblemSize* problem_siz
//
status = implicit_gemm_op(stream);
return status;
return status;
}
"""

View File

@ -81,12 +81,10 @@ The module can later be used in Python via:
import logging
import os
import cutlass_bindings
from cutlass import CUTLASS_PATH, logger, swizzle
from cutlass import CUTLASS_PATH, logger, swizzle, ConvKind, ConvKindNames, DataType
from cutlass.backend.gemm_operation import GemmOperationGrouped, GemmOperationUniversal
from cutlass.backend.conv2d_operation import Conv2dOperation
from cutlass.backend.library import ApiVersion, ConvKindNames
from cutlass.backend.library import ApiVersion
from cutlass.backend.utils.software import CheckPackages, SubstituteTemplate
from cutlass.emit import common
@ -165,26 +163,26 @@ _PYTORCH_CONV2D_FPROP_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
// CUDA forward declarations
at::Tensor ${name}_kernel(
const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
float alpha=1.f, float beta=0.f,
std::string split_k_mode="serial", int split_k_slices=1);
// C++ interface
at::Tensor ${name}(
const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
float alpha=1.f, float beta=0.f,
float alpha=1.f, float beta=0.f,
std::string split_k_mode="serial", int split_k_slices=1) {
return ${name}_kernel(A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices);
return ${name}_kernel(A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("run",
py::overload_cast<
const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>,
std::tuple<int, int>, std::tuple<int, int>, std::tuple<int, int>, float, float, std::string, int>(
&${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr,
&${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr,
py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1),
py::arg("alpha") = 1.f, py::arg("beta") = 0.f,
py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1);
@ -198,26 +196,26 @@ _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
// CUDA forward declarations
at::Tensor ${name}_kernel(
std::tuple<int, int, int, int> result_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
std::tuple<int, int, int, int> result_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
float alpha=1.f, float beta=0.f,
float alpha=1.f, float beta=0.f,
std::string split_k_mode="serial", int split_k_slices=1);
// C++ interface
at::Tensor ${name}(
std::tuple<int, int, int, int> result_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
std::tuple<int, int, int, int> result_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
float alpha=1.f, float beta=0.f,
float alpha=1.f, float beta=0.f,
std::string split_k_mode="serial", int split_k_slices=1) {
return ${name}_kernel(result_size, A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices);
return ${name}_kernel(result_size, A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("run",
py::overload_cast<
std::tuple<int, int, int, int>, const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>,
std::tuple<int, int, int, int>, const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>,
std::tuple<int, int>, std::tuple<int, int>, std::tuple<int, int>, float, float, std::string, int>(
&${name}), py::arg("result_size"), py::arg("A"), py::arg("B"), py::arg("C") = nullptr,
&${name}), py::arg("result_size"), py::arg("A"), py::arg("B"), py::arg("C") = nullptr,
py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1),
py::arg("alpha") = 1.f, py::arg("beta") = 0.f,
py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1);
@ -251,11 +249,11 @@ _PYTORCH_CONV2D_INCLUDES = """
"""
_CUTLASS_TYPE_TO_TORCH_TYPE = {
cutlass_bindings.float16: "torch::kF16",
cutlass_bindings.float32: "torch::kF32",
cutlass_bindings.float64: "torch::kF64",
cutlass_bindings.int8: "torch::I8",
cutlass_bindings.int32: "torch::I32",
DataType.f16: "torch::kF16",
DataType.f32: "torch::kF32",
DataType.f64: "torch::kF64",
DataType.s8: "torch::I8",
DataType.s32: "torch::I32",
}
_PYTORCH_GEMM_IMPL_TEMPLATE_2x = (
@ -446,16 +444,16 @@ std::vector<at::Tensor> ${name}_kernel(const std::vector<at::Tensor>& A, const s
_PYTORCH_CONV2D_IMPL_TEMPLATE_2x = """
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
cutlass::Status status = ${name}_kernel_run(
&problem_size,
&problem_size,
reinterpret_cast<typename UnderlyingKernel::ElementA*>(A.data_ptr()),
reinterpret_cast<typename UnderlyingKernel::ElementB*>(B.data_ptr()),
ptrC,
reinterpret_cast<typename UnderlyingKernel::ElementC*>(D.data_ptr()),
alpha, beta,
split_k_mode, stream, B.device().index());
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
return D;
}
@ -464,19 +462,19 @@ _PYTORCH_CONV2D_IMPL_TEMPLATE_2x = """
_PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x = (
common._CUTLASS_KERNEL_RUN_CONV2D_2x
+ """
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
float alpha=1.f, float beta=0.f, std::string split_k_mode="serial", int split_k_slices=1) {
int N, H, W, C_, K, R, S, P, Q;
N = A.size(0);
C_ = A.size(1);
H = A.size(2);
W = A.size(3);
K = B.size(0);
R = B.size(2);
S = B.size(3);
cutlass::conv::Conv2dProblemSize problem_size(
cutlass::Tensor4DCoord(N, H, W, C_),
cutlass::Tensor4DCoord(K, R, S, C_),
@ -486,14 +484,14 @@ at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional
cutlass::conv::Mode::kCrossCorrelation,
split_k_slices
);
P = problem_size.P;
Q = problem_size.Q;
typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
nullptr :
reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
at::Tensor D = torch::zeros({N, K, P, Q}, options);
""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
@ -503,7 +501,7 @@ at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional
_PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x = (
common._CUTLASS_KERNEL_RUN_CONV2D_2x
+ """
at::Tensor ${name}_kernel(std::tuple<int, int, int, int> input_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
at::Tensor ${name}_kernel(std::tuple<int, int, int, int> input_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1}, float alpha=1.f, float beta=0.f,
std::string split_k_mode="serial", int split_k_slices=1) {
int N, H, W, C_, K, R, S;
@ -511,11 +509,11 @@ at::Tensor ${name}_kernel(std::tuple<int, int, int, int> input_size, const at::T
C_ = std::get<1>(input_size);
H = std::get<2>(input_size);
W = std::get<3>(input_size);
K = B.size(0);
R = B.size(2);
S = B.size(3);
cutlass::conv::Conv2dProblemSize problem_size(
cutlass::Tensor4DCoord(N, H, W, C_),
cutlass::Tensor4DCoord(K, R, S, C_),
@ -525,11 +523,11 @@ at::Tensor ${name}_kernel(std::tuple<int, int, int, int> input_size, const at::T
cutlass::conv::Mode::kCrossCorrelation,
split_k_slices
);
typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
nullptr :
reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
at::Tensor D = torch::empty({N, C_, H, W}, options);
""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
@ -539,19 +537,19 @@ at::Tensor ${name}_kernel(std::tuple<int, int, int, int> input_size, const at::T
_PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x = (
common._CUTLASS_KERNEL_RUN_CONV2D_2x
+ """
at::Tensor ${name}_kernel(std::tuple<int, int, int, int> weight_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1}, float alpha=1.f, float beta=0.f,
at::Tensor ${name}_kernel(std::tuple<int, int, int, int> weight_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1}, float alpha=1.f, float beta=0.f,
std::string split_k_mode="serial", int split_k_slices=1) {
int N, H, W, C_, K, R, S;
K = std::get<0>(weight_size);
C_ = std::get<1>(weight_size);
R = std::get<2>(weight_size);
S = std::get<3>(weight_size);
N = B.size(0);
H = B.size(2);
W = B.size(3);
cutlass::conv::Conv2dProblemSize problem_size(
cutlass::Tensor4DCoord(N, H, W, C_),
cutlass::Tensor4DCoord(K, R, S, C_),
@ -561,11 +559,11 @@ at::Tensor ${name}_kernel(std::tuple<int, int, int, int> weight_size, const at::
cutlass::conv::Mode::kCrossCorrelation,
split_k_slices
);
typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
nullptr :
reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
at::Tensor D = torch::empty({K, C_, R, S}, options);
""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
@ -726,7 +724,7 @@ def _pytorch_gemm(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""
impl_template = _PYTORCH_GEMM_IMPL_TEMPLATE_3x
else:
impl_template = _PYTORCH_GEMM_IMPL_TEMPLATE_2x
if isinstance(op.swizzling_functor, swizzle.ThreadblockSwizzleStreamK):
if op.swizzling_functor == swizzle.ThreadblockSwizzleStreamK:
extra_kw["args"] = common._CUTLASS_KERNEL_ARGS_2x_STREAM_K
else:
extra_kw["args"] = common._CUTLASS_KERNEL_ARGS_2x
@ -837,9 +835,9 @@ def _pytorch_conv2d(op, name: str, cc: int, jit: bool = False, sourcedir: str =
:type jit: bool
:param sourcedir: directory to which generated source files should be written
:type sourcedir: str
Note that the when conv kind is `dgrad` or `wgrad`, the size of the input `(N, C, H, W)` or
weight `(K, C, R, S)` should be provided. This is because there are multiple valid solutions
weight `(K, C, R, S)` should be provided. This is because there are multiple valid solutions
for H/W/R/S given the same P/Q.
:return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise
@ -848,13 +846,13 @@ def _pytorch_conv2d(op, name: str, cc: int, jit: bool = False, sourcedir: str =
os.makedirs(sourcedir)
cuda_file = os.path.join(sourcedir, name + "_kernel.cu")
extra_kw = {}
if op.conv_kind == cutlass_bindings.conv.Operator.fprop:
if op.conv_kind == ConvKind.Fprop:
impl_template = _PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x
cpp_template = _PYTORCH_CONV2D_FPROP_CPP_TEMPLATE
elif op.conv_kind == cutlass_bindings.conv.Operator.dgrad:
elif op.conv_kind == ConvKind.Dgrad:
impl_template = _PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x
cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE
elif op.conv_kind == cutlass_bindings.conv.Operator.wgrad:
elif op.conv_kind == ConvKind.Wgrad:
impl_template = _PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x
cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE
extra_kw["conv_kind_name"] = ConvKindNames[op.conv_kind].capitalize()

View File

@ -0,0 +1,53 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass.epilogue.epilogue import (
get_activations,
get_activation_epilogue,
gelu,
hardswish,
identity,
leaky_relu,
relu,
sigmoid,
silu,
tanh,
trace
)
from cutlass.epilogue.evt_ops import (
max,
multiply_add,
sum,
permute,
reshape
)

View File

@ -45,6 +45,7 @@ code like the following for GEMM:
from cutlass.backend import epilogue
gelu = epilogue.gelu
hardswish = epilogue.hardswish
identity = epilogue.identity
@ -99,9 +100,59 @@ def get_activation_epilogue(
)
else:
return epilogue.LinearCombinationGeneric(
activation(element_compute),
activation,
element_output,
elements_per_access,
element_accumulator,
element_compute,
)
"""
Frontend for EVT that generates epilogue functor through tracing the input function
"""
from cutlass.backend.evt.frontend import PythonASTFrontend
def trace(fn, example_tensors, **kwargs):
"""
Trace `fn(**example_tensors)` and generates epilogue visitor
:param fn: Python callables
:param example_tensors: example inputs for fn
:type example_tensors: dict
.. hightlight:: python
.. code-block:: python
import cutlass.backend.evt
# Define epilogue function as Python callable
def example_fn(accum, C, alpha, beta, gamma):
D = ((accum + C) * alpha - gamma) / beta
return D
# Define the example tensors
example_inputs = {
"accum": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda"),
"C": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda"),
"alpha": 1.5,
"beta": 0.5,
"gamma": 2.5,
"D": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda")
}
# Generate the epilogue functor
epilogue_visitor = cutlass.epilogue.trace(example_fn, example_inputs)
"""
if callable(fn):
class EpilogueFunctor(PythonASTFrontend):
def __init__(self, **kwargs):
super().__init__(**kwargs)
pass
setattr(EpilogueFunctor, "__call__", staticmethod(fn))
epilogue_functor = EpilogueFunctor(**kwargs)
epilogue_functor.trace(example_tensors)
return epilogue_functor
else:
raise NotImplementedError("Expect a callable Python function")

View File

@ -1,6 +1,6 @@
################################################################################
#################################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
@ -28,42 +28,52 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
#################################################################################################
"""
Collection of builtin functions used for host reference in EVT
"""
from cuda import cuda
import cutlass_bindings
import numpy as np
from cutlass.backend.utils.software import CheckPackages
cupy_available = CheckPackages().check_cupy()
if cupy_available:
import cupy as cp
torch_available = CheckPackages().check_torch()
if torch_available:
import torch
class TensorRef:
"""
Python Wrapper for cutlass_bindings.TensorRef
"""
def multiply_add(x, y, z):
return x * y + z
def __init__(self, tensor, dtype, layout) -> None:
if isinstance(tensor, np.ndarray):
ptr = cuda.CUdeviceptr(tensor.__array_interface__["data"][0])
elif torch_available and isinstance(tensor, torch.Tensor):
ptr = cuda.CUdeviceptr(tensor.data_ptr())
elif torch_available and isinstance(tensor, cp.ndarray):
ptr = cuda.CUdeviceptr(int(tensor.data.ptr))
elif isinstance(tensor, cuda.CUdeviceptr):
ptr = tensor
elif isinstance(tensor, int):
ptr = cuda.CUdeviceptr(tensor)
else:
raise NotImplementedError(tensor)
# the dtype(0) is used to overload between different data types
# with the same layout
self.tensor_ref = cutlass_bindings.get_tensor_ref(int(ptr), dtype(0), layout)
def sum(x, dim):
if isinstance(x, np.ndarray):
return x.sum(axis=tuple(dim))
elif torch_available and isinstance(x, torch.Tensor):
return torch.sum(x, dim)
def max(x, dim):
if isinstance(x, np.ndarray):
return x.max(axis=tuple(dim))
elif torch_available and isinstance(x, torch.Tensor):
return torch.amax(x, dim)
##############################################################################
# Layout manipulate nodes
##############################################################################
def permute(x, indices: tuple):
if isinstance(x, np.ndarray):
return np.transpose(x, axes=indices)
elif torch_available and isinstance(x, torch.Tensor):
return x.permute(*indices)
def reshape(x, new_shape: tuple):
if isinstance(x, np.ndarray):
return np.reshape(x, newshape=new_shape)
elif torch_available and isinstance(x, torch.Tensor):
return x.view(new_shape)

View File

@ -35,25 +35,21 @@ Classes containing valid operations for a given compute capability and data type
"""
import logging
from cuda import __version__
# Strip any additional information from the CUDA version
_cuda_version = __version__.split("rc")[0]
# Imports from CUTLASS profiler generator and manifest scripts
import generator as prof_generator
import manifest as prof_manifest
from library import (
ConvKind, IteratorAlgorithm, StrideSupport, GroupMode
)
import cutlass_library
from cutlass_library.library import ConvKind, IteratorAlgorithm, StrideSupport, GroupMode
import cutlass
from cutlass.utils.check import valid_stage_count
from cutlass.utils.datatypes import td_from_profiler_td, td_from_profiler_op, has_binding_type
from cutlass.utils.datatypes import td_from_profiler_td, td_from_profiler_op
_generator_ccs = [50, 60, 61, 70, 75, 80, 90]
# Strip any additional information from the CUDA version
_cuda_version = __version__.split("rc")[0]
class KernelsForDataType:
"""
@ -202,10 +198,10 @@ class ArchOptions:
# Identify the method within CUTLASS generator script that generates kernel
# descriptions for the target CC
generate_function_name = "GenerateSM" + str(kernel_cc)
if not hasattr(prof_generator, generate_function_name):
if not hasattr(cutlass_library.generator, generate_function_name):
cutlass.logger.warning(f"No generator found for architecture {kernel_cc}")
return
generate_function = getattr(prof_generator, generate_function_name)
generate_function = getattr(cutlass_library.generator, generate_function_name)
# Initialize a default manifest and populate it with valid kernel descriptions
# for the target CC
@ -213,8 +209,8 @@ class ArchOptions:
"--kernels=all",
f"--log-level={logging.getLevelName(cutlass.logger.level)}"
]
manifest_args = prof_generator.define_parser().parse_args(args)
manifest = prof_manifest.Manifest(manifest_args)
manifest_args = cutlass_library.generator.define_parser().parse_args(args)
manifest = cutlass_library.manifest.Manifest(manifest_args)
generate_function(manifest, _cuda_version)
if operation_kind not in manifest.operations:
@ -223,9 +219,15 @@ class ArchOptions:
cutlass.logger.warning(f"No operations of type {operation_kind} found for CC {kernel_cc}")
return
# Only one CC should be returned, given the setup above of calling only the generation scripts
# for a given CC
if len(manifest.operations[operation_kind].keys()) != 1 or kernel_cc not in manifest.operations[operation_kind]:
raise Exception(f"Error finding kernels for SM{kernel_cc}. Check that your CUDA toolkit version "
"is sufficient for the architecture in question.")
# Iterate through the available operations for this operation kind and
# find available opclasses and data types
for name, op_list in manifest.operations[operation_kind].items():
for name, op_list in manifest.operations[operation_kind][kernel_cc].items():
for op in op_list:
if operation_kind == cutlass.OperationKind.Gemm:
if op.gemm_kind not in gemm_kinds:
@ -235,15 +237,15 @@ class ArchOptions:
if mi.math_operation not in self.allowed_math_operations:
continue
datatype_comb = (mi.element_a, mi.element_b, mi.element_accumulator)
# Skip any data types that do not currently have conversions via cutlass_bindings
if False in [has_binding_type(elt) for elt in datatype_comb]:
if op.C.element == cutlass.DataType.void:
# The CUTLASS Python interface currently does not support void-C kernels
continue
datatype_comb = (mi.element_a, mi.element_b, mi.element_accumulator)
# Prune operations that don't fit in shared memory
td = td_from_profiler_op(op)
if not valid_stage_count(target_cc, td)[0]:
if not valid_stage_count(target_cc, kernel_cc, td)[0]:
continue
if mi.opcode_class not in self.operations_by_opclass:
@ -337,19 +339,19 @@ class ArchOptions:
[128, 128, 8], 2, [4, 2, 1], math_inst, 50, 1024)
# Prune operations that don't fit in shared memory
if not valid_stage_count(target_cc, td_from_profiler_td(td))[0]:
if not valid_stage_count(target_cc, kernel_cc, td_from_profiler_td(td))[0]:
continue
new_kernels = KernelsForDataType(type_comb, layout_comb)
if operation_kind == cutlass.OperationKind.Gemm:
new_operation = prof_manifest.GemmOperation(
new_operation = cutlass_library.manifest.GemmOperation(
cutlass.GemmKind.Universal, td.minimum_compute_capability,
td, A, B, C, type_comb[2], epilogue_functor, swizzling_functor)
new_kernels.add(new_operation)
elif operation_kind == cutlass.OperationKind.Conv2d:
for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]:
new_operation = prof_manifest.Conv2dOperation(
new_operation = cutlass_library.manifest.Conv2dOperation(
conv_kind, IteratorAlgorithm.Analytic, td.minimum_compute_capability, td,
A, B, C, type_comb[2], StrideSupport.Strided, epilogue_functor, swizzling_functor,
group_mode=GroupMode.SingleGroup

View File

@ -30,7 +30,7 @@
#
#################################################################################################
from cutlass.op.gemm import Gemm
from cutlass.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad
from cutlass.op.gemm import Gemm
from cutlass.op.gemm_grouped import GroupedGemm
from cutlass.op.op import OperationBase

View File

@ -49,7 +49,7 @@
# A, B, C, and D are torch/numpy/cupy tensor objects
plan = cutlass.op.Conv(A, B, C, D)
plan.run(stride=(1, 1), padding=(0, 0), dilation=(1, 1))
One can also use the interface by specifying data types of operands at construction
and using different tensor objects with these data types at runtime:
@ -112,21 +112,29 @@
args.sync()
"""
import cutlass_bindings
import cutlass
from cutlass import epilogue
from cutlass import (
ConvKind,
ConvMode,
IteratorAlgorithm,
SplitKMode,
StrideSupport,
)
from cutlass.backend import compiler
from cutlass.backend.conv2d_operation import Conv2dArguments, Conv2dOperation
from cutlass.backend.reduction_operation import ReductionOperation, ReductionArguments
from cutlass.backend.library import TensorDescription, TileDescription
from cutlass.op.op import OperationBase
from cutlass.shape import Conv2DProblemSize, MatrixCoord
from cutlass.utils import check, datatypes
class Conv2d(OperationBase):
"""
Constructs a ``Conv2d`` object.
Constructs a ``Conv2d`` object.
The convolution kind (fprop, wgrad, degrad), the data types of operands A, B, and C,
The convolution kind (fprop, wgrad, degrad), the data types of operands A, B, and C,
along with the data type of output D and that used for accumulation, are bound to the ``Conv``
object throughout its lifetime -- these are not to be changed after a ``Conv2d`` has been constructed.
@ -142,7 +150,7 @@ class Conv2d(OperationBase):
Conv2d(kind="fprop", element=cutlass.DataType.f32)
# Explicitly specify the data types to use for A, B, C, and D.
Conv2d(kind="fprop", element_A=cutlass.DataType.f32, element_B=cutlass.DataType.f32,
Conv2d(kind="fprop", element_A=cutlass.DataType.f32, element_B=cutlass.DataType.f32,
element_C=cutlass.DataType.f32, element_D=cutlass.DataType.f32)
# Set the data types and elements from existing tensors. Note that one can use different tensors when
@ -151,7 +159,7 @@ class Conv2d(OperationBase):
# A, B, C, and D are torch.Tensor objects of type torch.float32 under the channel-last layout
Conv2d(kind="fprop", A=A, B=B, C=C, D=D)
# Explicitly specify the data type for only some of A, B, C, and D. Unspecified data types will inherit
# Explicitly specify the data type for only some of A, B, C, and D. Unspecified data types will inherit
# those passed in via the generic ``element``
Conv2d(kind="fprop", element_A=cutlass.DataType.f32, element_accumulator=cutlass.DataType.f32,
element=cutlass.DataType.f32)
@ -187,9 +195,9 @@ class Conv2d(OperationBase):
:type kernel_cc: int
"""
def __init__(
self, kind="fprop",
self, kind="fprop",
A=None, B=None, C=None, D=None, alpha=1.0, beta=0.0,
element=None,
element=None,
element_A=None, element_B=None, element_C=None, element_D=None,
element_accumulator=None,
cc: int = None, kernel_cc: int = None
@ -202,18 +210,18 @@ class Conv2d(OperationBase):
cutlass.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
self.specified_kernel_cc = 80
self._reset_options(80)
# The arch is used in testing
self.arch = self.current_cc
self.name = "conv2d" + kind
# The convolution kind. (concept: cutlass_bindings.conv.Operator)
self.conv_kind = getattr(cutlass_bindings.conv.Operator, kind)
# The convolution kind. (concept: cutlass_library.library.ConvKind)
self.conv_kind = datatypes.getattr_enum(ConvKind, kind)
# The element types (concept: cutlass library types) of A, B, C, and D
elements = []
layouts = []
# Complete the data types based on user-provided arguments
for elt, tens, name in zip([element_A, element_B, element_C, element_D],
[A, B, C, D],
@ -222,27 +230,27 @@ class Conv2d(OperationBase):
raise Exception(f'Must not specify both element_{name} and tensor {name}')
if elt is None and tens is None and element is None:
raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.')
elt_to_set = None
lay_to_set = None
if tens is not None:
elt_to_set, _ = datatypes.get_datatype_and_layout(tens)
else:
elt_to_set = elt if elt is not None else element
assert elt_to_set is not None
# Currently we only support layout TensorNHWC
lay_to_set = cutlass.LayoutType.TensorNHWC
lay_to_set = cutlass.LayoutType.TensorNHWC
elements.append(datatypes.library_type(elt_to_set))
layouts.append(lay_to_set)
self._element_a, self._element_b, self._element_c, self._element_d = elements
self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts
self.A, self.B, self.C, self.D, self.alpha, self.beta = A, B, C, D, alpha, beta
if element_accumulator is None:
self._element_accumulator = self._element_c
else:
@ -253,38 +261,38 @@ class Conv2d(OperationBase):
self.B = B
self.C = C
self.D = D
self.alpha = alpha
self.beta = beta
self.beta = beta
# We only specify the stride of the swizzling functor here
# The actual swizzling functor is determined in run based on conv_kind and stride
self._swizzling_stride = 1
# Arguments that will be set to default value in _reset_operations
# The default tile_description and op_class are fetched from manifest of cutlass library
self._tile_description = None
self.op_class = None
# The default identity epilogue will be created
self.epilogue_functor = None
self._reset_operations()
# Arguments that will be determined online based on arguments of "run"
# based on stride, input/output channels, alignment, and conv_kind
self._iterator_algorithm = None
self._stride_support = None
def _reset_operations(self, reset_epilogue: bool = True):
# Set the default op class
datatype_comb = (self._element_a, self._element_b, self._element_accumulator)
layout_comb = (self._layout_a, self._layout_b)
self.possible_op_classes = self.options.supporting_opclasses(
self._element_a, self._element_b, self._element_accumulator,
self._layout_a, self._layout_b
)
if cutlass.OpcodeClass.TensorOp in self.possible_op_classes:
self.opclass = cutlass.OpcodeClass.TensorOp
elif cutlass.OpcodeClass.Simt in self.possible_op_classes:
@ -292,34 +300,34 @@ class Conv2d(OperationBase):
else:
raise Exception(f'No kernel configuration found for supported data type and layout '
f'combination {datatype_comb}x{layout_comb}')
if reset_epilogue:
self._reset_epilogue_functor_activation(epilogue.identity)
self.alignment_pref_A = min(
128 // cutlass.DataTypeSize[self._element_a], max(self.possible_operations.alignments))
self.alignment_pref_B = min(
128 // cutlass.DataTypeSize[self._element_b], max(self.possible_operations.alignments))
self.alignment_pref_C = min(
128 // cutlass.DataTypeSize[self._element_c], max(self.possible_operations.alignments))
128 // cutlass.DataTypeSize[self._element_c], max(self.possible_operations.alignments))
#
# Tile description Related
#
@property
def tile_description(self) -> TileDescription:
"""
Returns the tile description
"""
return self._tile_description
@tile_description.setter
def tile_description(
self, td=None):
"""
Set the tile description
:param td: tile description
:type td: cutlass.backend.TileDescription, or a dict with keys
{
@ -340,7 +348,7 @@ class Conv2d(OperationBase):
if "cluster_shape" in td.keys():
if td["cluster_shape"] != [1, 1, 1]:
cutlass.logger.warning("Conv2d currently only support 'cluster_shape'=[1, 1, 1]'.")
td["cluster_shape"] = [1, 1, 1]
td["cluster_shape"] = [1, 1, 1]
td = self._tile_description.clone_and_update(td)
valid, msg = self._valid_tile_description(td)
@ -348,7 +356,7 @@ class Conv2d(OperationBase):
self._tile_description = td
else:
raise Exception(msg)
def _valid_tile_description(self, td: TileDescription) -> tuple:
"""
Checks whether the provided tile description is valid for the given compute capability. At present,
@ -366,9 +374,7 @@ class Conv2d(OperationBase):
and the second element is a string providing an optional error message.
:rtype: tuple
"""
# Check stage count based on the CC to which we are compiling (self.cc), rather
# than the CC from which we find kernels (self.current_cc)
valid, msg = check.valid_stage_count(self.cc, td)
valid, msg = check.valid_stage_count(self.cc, self.current_cc, td)
if not valid:
return (valid, msg)
@ -393,11 +399,11 @@ class Conv2d(OperationBase):
description_str.append(str(td))
descriptions.append(td)
return descriptions
#
# Swizzling functor Related
#
@property
def swizzling_stride(self):
"""
@ -420,108 +426,110 @@ class Conv2d(OperationBase):
"""
Automatically propose the swizzling functor based on the stride
"""
if self.conv_kind == cutlass_bindings.conv.Operator.dgrad:
if self.conv_kind == ConvKind.Dgrad:
if stride[0] != 1 or stride[1] != 1:
return getattr(cutlass.swizzle, f"StridedDgradIdentitySwizzle{self._swizzling_stride}")
return getattr(cutlass.swizzle, f"IdentitySwizzle{self._swizzling_stride}")
#
# Iterator Algorithm Related
#
@property
def iterator_algorithm(self) -> cutlass_bindings.conv.IteratorAlgorithm:
def iterator_algorithm(self) -> IteratorAlgorithm:
"""
Returns the iterator algorithm
"""
return self._iterator_algorithm
@iterator_algorithm.setter
def iterator_algorithm(self, alg: str):
"""
Sets the iterator algorithm
:param alg: The iterator algorithm
:type td: string, options: "analytic", "optimized", "few_channels", and "fixed_channels"
"""
# Check if the iterator algorithm is valid
if alg in ["few_channels", "fixed_channels"] and self.conv_kind != cutlass_bindings.conv.Operator.fprop:
raise Exception(f"{self.conv_kind} does not support iterator algorithm {alg}.")
self._iterator_algorithm = getattr(cutlass_bindings.conv.IteratorAlgorithm, alg)
iterator_alg = datatypes.getattr_enum(IteratorAlgorithm, alg)
def _propose_iterator_algorithm(self, problem_size, alignment_a, alignment_b) -> cutlass_bindings.conv.IteratorAlgorithm:
# Check if the iterator algorithm is valid
if iterator_alg in [IteratorAlgorithm.FewChannels, IteratorAlgorithm.FixedChannels] and self.conv_kind != ConvKind.Fprop:
raise Exception(f"{self.conv_kind} does not support iterator algorithm {alg}.")
self._iterator_algorithm = iterator_alg
def _propose_iterator_algorithm(self, problem_size, alignment_a, alignment_b) -> IteratorAlgorithm:
"""
Propose a valid iterator algorithm based on problem size and alignment
"""
if self.conv_kind == cutlass_bindings.conv.Operator.fprop:
if self.conv_kind == ConvKind.Fprop:
# Check whether the fixed channel is applicable
if problem_size.C == alignment_a:
return cutlass_bindings.conv.IteratorAlgorithm.fixed_channels
return IteratorAlgorithm.FixedChannels
elif (problem_size.C % alignment_a == 0 and
problem_size.R <= 32 and problem_size.S <= 32):
return cutlass_bindings.conv.IteratorAlgorithm.optimized
return IteratorAlgorithm.Optimized
else:
return cutlass_bindings.conv.IteratorAlgorithm.analytic
elif self.conv_kind == cutlass_bindings.conv.Operator.dgrad:
return IteratorAlgorithm.Analytic
elif self.conv_kind == ConvKind.Dgrad:
if (problem_size.K % alignment_a == 0 and
problem_size.R <= 32 and problem_size.S <= 32 and
problem_size.C % alignment_b == 0):
return cutlass_bindings.conv.IteratorAlgorithm.optimized
return IteratorAlgorithm.Optimized
else:
return cutlass_bindings.conv.IteratorAlgorithm.analytic
elif self.conv_kind == cutlass_bindings.conv.Operator.wgrad:
return IteratorAlgorithm.Analytic
elif self.conv_kind == ConvKind.Wgrad:
if (problem_size.K % alignment_a == 0 and
problem_size.C % alignment_b == 0):
return cutlass_bindings.conv.IteratorAlgorithm.optimized
return IteratorAlgorithm.Optimized
else:
return cutlass_bindings.conv.IteratorAlgorithm.analytic
return IteratorAlgorithm.Analytic
def _validate_iterator_algorithm(self, iterator_algorithm, problem_size, alignment_a, alignment_b) -> bool:
"""
Validate whether the user provide iterator algorithm works for the given problem size
"""
if self.conv_kind == cutlass_bindings.conv.Operator.fprop:
if iterator_algorithm == cutlass_bindings.conv.IteratorAlgorithm.fixed_channels:
if self.conv_kind == ConvKind.Fprop:
if iterator_algorithm == IteratorAlgorithm.FixedChannels:
return problem_size.C == alignment_a
elif iterator_algorithm == cutlass_bindings.conv.IteratorAlgorithm.optimized:
elif iterator_algorithm == IteratorAlgorithm.Optimized:
return (problem_size.C % alignment_a == 0 and
problem_size.R <= 32 and problem_size.S <= 32)
elif iterator_algorithm == cutlass_bindings.conv.IteratorAlgorithm.few_channels:
elif iterator_algorithm == IteratorAlgorithm.FewChannels:
return problem_size.C % alignment_a == 0
elif self.conv_kind == cutlass_bindings.conv.Operator.dgrad:
if iterator_algorithm == cutlass_bindings.conv.IteratorAlgorithm.optimized:
elif self.conv_kind == ConvKind.Dgrad:
if iterator_algorithm == IteratorAlgorithm.Optimized:
return (problem_size.K % alignment_a == 0 and
problem_size.R <= 32 and problem_size.S <= 32 and
problem_size.C % alignment_b == 0)
elif self.conv_kind == cutlass_bindings.conv.Operator.wgrad:
if iterator_algorithm == cutlass_bindings.conv.IteratorAlgorithm.optimized:
elif self.conv_kind == ConvKind.Wgrad:
if iterator_algorithm == IteratorAlgorithm.Optimized:
return (problem_size.K % alignment_a == 0 and
problem_size.C % alignment_b == 0)
return True
#
# Stride Support Related
#
def _propose_stride_support(self, stride):
if self.conv_kind == cutlass_bindings.conv.Operator.dgrad:
if self.conv_kind == ConvKind.Dgrad:
if stride[0] == 1 and stride[1] == 1:
return cutlass.backend.library.StrideSupport.Unity
return cutlass.backend.library.StrideSupport.Strided
return StrideSupport.Unity
return StrideSupport.Strided
#
# Construct and Compilation
#
def construct(
self, tile_description: TileDescription = None,
self, tile_description: TileDescription = None,
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
iterator_algorithm: cutlass_bindings.conv.IteratorAlgorithm = None,
stride_support = None, swizzling_functor: cutlass.swizzle = None,
iterator_algorithm: IteratorAlgorithm = None,
stride_support = None, swizzling_functor: cutlass.swizzle = None,
epilogue_functor=None) -> cutlass.backend.Conv2dOperation:
"""
Constructs a ``cutlass.backend.Conv2dOperation`` based on the input parameters and current
@ -536,9 +544,9 @@ class Conv2d(OperationBase):
:param alignment_C: alignment of operand C
:type alignment_C: int
:param iterator_algorithm: the iterator algorithm used
:type iterator_algorithm: cutlass_bindings.conv.IteratorAlgorithm
:type iterator_algorithm: cutlass_library.library.IteratorAlgorithm
:param stride_support: the stride support of dgrad
:type stride_support: cutlass.backend.library.StrideSupport
:type stride_support: cutlass_library.library.StrideSupport
:param swizzling_functor: the swizzling functor
:type swizzling_functor: cutlass.swizzle
:param epilogue_functor: the epilogue functor
@ -550,66 +558,55 @@ class Conv2d(OperationBase):
alignment_A = check.alignment_or_default(alignment_A, self.alignment_pref_A)
alignment_B = check.alignment_or_default(alignment_B, self.alignment_pref_B)
alignment_C = check.alignment_or_default(alignment_C, self.alignment_pref_C)
tensor_A = TensorDescription(
datatypes.binding_type(self._element_a),
datatypes.binding_layout(self._layout_b),
alignment_A
)
tensor_B = TensorDescription(
datatypes.binding_type(self._element_b),
datatypes.binding_layout(self._layout_b),
alignment_B
)
tensor_C = TensorDescription(
datatypes.binding_type(self._element_c),
datatypes.binding_layout(self._layout_c),
alignment_C
)
tensor_A = TensorDescription(self._element_a, self._layout_b, alignment_A)
tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
if tile_description is None:
if self.tile_description is not None:
tile_description = self.tile_description
else:
op = self.possible_operations.operations(alignment_A)[0]
min_alignment = min([alignment_A, alignment_B, alignment_C])
op = self.possible_operations.operations(min_alignment)[0]
tile_description = datatypes.td_from_profiler_op(op)
else:
valid, err_str = self._valid_tile_description(tile_description)
if not valid:
raise Exception(f"Invalid tile description. {err_str}")
self.tile_description = tile_description
if iterator_algorithm is None:
# If the iterator algorithm is already set
if self.iterator_algorithm is not None:
if self.iterator_algorithm is not None:
iterator_algorithm = self.iterator_algorithm
else:
# Otherwise, we conservatively use the analytic iterator for correctness
iterator_algorithm = cutlass_bindings.conv.IteratorAlgorithm.analytic
iterator_algorithm = IteratorAlgorithm.Analytic
if stride_support is None:
# If the stride support is already set
if self._stride_support is not None:
stride_support = self._stride_support
else:
# Otherwise, we assume strided
stride_support = cutlass.backend.library.StrideSupport.Strided
stride_support = StrideSupport.Strided
if swizzling_functor is None:
# If the swizzling functor is already set
swizzling_functor = self._propose_swizzling_functor(stride=(2, 2))
if epilogue_functor is None:
if self.epilogue_functor is not None:
epilogue_functor = self.epilogue_functor
else:
epilogue_functor = self._create_epilogue_functor_activation(self._activation)
# Reset the alignment of the epilogue functor
epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, epilogue_functor)
operation = Conv2dOperation(
conv_kind=self.conv_kind,
conv_kind=self.conv_kind,
iterator_algorithm=iterator_algorithm,
arch=self.current_cc,
tile_description=tile_description,
@ -618,13 +615,13 @@ class Conv2d(OperationBase):
epilogue_functor=epilogue_functor,
swizzling_functor=swizzling_functor,
)
return operation
def compile(self, tile_description: TileDescription = None,
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
iterator_algorithm: cutlass_bindings.conv.IteratorAlgorithm = None,
stride_support = None, swizzling_functor: cutlass.swizzle = None,
iterator_algorithm: IteratorAlgorithm = None,
stride_support = None, swizzling_functor: cutlass.swizzle = None,
epilogue_functor = None, print_module: bool = False) -> cutlass.backend.Conv2dOperation:
"""
Emits and compiles the kernel currently specified. If ``tile_description`` and any
@ -641,9 +638,9 @@ class Conv2d(OperationBase):
:param alignment_C: alignment of operand C
:type alignment_C: int
:param iterator_algorithm: the iterator algorithm used
:type iterator_algorithm: cutlass_bindings.conv.IteratorAlgorithm
:type iterator_algorithm: cutlass_library.library.IteratorAlgorithm
:param stride_support: the stride support of dgrad
:type stride_support: cutlass.backend.library.StrideSupport
:type stride_support: cutlass_library.library.StrideSupport
:param swizzling_functor: the swizzling functor
:type swizzling_functor: cutlass.swizzle
:param epilogue_functor: the epilogue functor
@ -651,17 +648,17 @@ class Conv2d(OperationBase):
:return: operation that was compiled
:rtype: cutlass.backend.Conv2dOperation
"""
self.operation = self.construct(
tile_description, alignment_A, alignment_B, alignment_C,
tile_description, alignment_A, alignment_B, alignment_C,
iterator_algorithm, stride_support, swizzling_functor, epilogue_functor)
if print_module:
print(self.operation.rt_module.emit())
compiler.add_module([self.operation,])
return self.operation
#
# Run Related
#
@ -681,21 +678,19 @@ class Conv2d(OperationBase):
if dtype != ref_type:
raise Exception(f'Tensor {name} with type and layout {dtype} '
f'does not match the expected type of {ref_type}.')
def _get_and_verify_conv_problem_size(self, A, B, C, stride, padding, dilation):
if self.conv_kind == cutlass_bindings.conv.Operator.fprop:
if self.conv_kind == ConvKind.Fprop:
input = A
weight = B
output = C
output_tensor = "C"
elif self.conv_kind == cutlass_bindings.conv.Operator.dgrad:
elif self.conv_kind == ConvKind.Dgrad:
output = A
weight = B
input = C
output_tensor = "A"
elif self.conv_kind == cutlass_bindings.conv.Operator.wgrad:
elif self.conv_kind == ConvKind.Wgrad:
output = A
input = B
weight = C
@ -703,27 +698,27 @@ class Conv2d(OperationBase):
else:
raise Exception(f"Convolution kind {self.conv_kind} is not supported")
N_, H_, W_, C_ = datatypes.get_tensor_shape(input)
K_, R_, S_, _ = datatypes.get_tensor_shape(weight)
_, P_, Q_, _ = datatypes.get_tensor_shape(output)
problem_size = cutlass_bindings.conv.Conv2dProblemSize(
cutlass_bindings.Tensor4DCoord(N_, H_, W_, C_),
cutlass_bindings.Tensor4DCoord(K_, R_, S_, C_),
cutlass_bindings.Tensor4DCoord(padding[0], padding[0], padding[1], padding[1]),
cutlass_bindings.MatrixCoord(stride[0], stride[1]),
cutlass_bindings.MatrixCoord(dilation[0], dilation[1]),
cutlass_bindings.conv.Mode.cross_correlation,
N_, H_, W_, C_ = datatypes.get_tensor_shape(input, op="CONV")
K_, R_, S_, _ = datatypes.get_tensor_shape(weight, op="CONV")
_, P_, Q_, _ = datatypes.get_tensor_shape(output, op="CONV")
problem_size = Conv2DProblemSize(
N_, H_, W_, C_,
K_, R_, S_, C_,
padding[0], padding[1],
stride[0], stride[1],
dilation[0], dilation[1],
ConvMode.CrossCorrelation,
1, 1
)
if P_ != problem_size.P or Q_ != problem_size.Q:
raise Exception(
f"Tensor {output_tensor} size should be ({N_}, {problem_size.P}, {problem_size.Q}, {K_}), got ({N_}, {P_}, {Q_}, {K_})")
return problem_size
def run(self, A=None, B=None, C=None, D=None,
def run(self, A=None, B=None, C=None, D=None,
stride=(1, 1), padding=(0, 0), dilation=(1, 1),
alpha=None, beta=None,
split_k=("serial", 1), sync: bool = True,
@ -732,9 +727,9 @@ class Conv2d(OperationBase):
Runs the kernel currently specified. If it has not already been, the kernel is emitted and
compiled. Tensors holding operands and outputs of the kernel are sourced either from the
``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta``
parameters provided in the call, or from those
parameters provided in the call, or from those
passed in on the construction of this object -- one of the two must be specified.
By default, this call returns only once the kernel has completed. To launch the kernel
and immediately return, set ``sync=False``. In this case, it is the responsibility of the
caller to syncrhonize the results of the kernel before attempting to access outputs
@ -754,7 +749,7 @@ class Conv2d(OperationBase):
:type sync: bool
:param print_module: whether to print the emitted C++ code
:type print_module: bool
:return: arguments passed in to the kernel
:rtype: cutlass.backend.Conv2dArguments
"""
@ -764,45 +759,49 @@ class Conv2d(OperationBase):
D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D")
alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
# handle the case when there is no C
if C is None:
if beta != 0:
raise Exception(f"With beta {beta} != 0, C has to be provided.")
else:
C = D
# Construct problem size based on input
# It also verifies whether the A, B, C, D, stride, padding, and dilation are matching
problem_size = self._get_and_verify_conv_problem_size(A, B, C, stride, padding, dilation)
# Propose stride support based on input
stride_support = self._propose_stride_support(stride)
# Propose swizzling functor
swizzling_functor = self._propose_swizzling_functor(stride)
shape_a = datatypes.get_tensor_shape(A, op="CONV")
shape_b = datatypes.get_tensor_shape(B, op="CONV")
shape_c = datatypes.get_tensor_shape(C, op="CONV")
# Get the alignment
alignment_a = self.possible_operations.find_alignment(datatypes.get_tensor_shape(A), self._layout_a)
alignment_b = self.possible_operations.find_alignment(datatypes.get_tensor_shape(B), self._layout_b)
alignment_c = self.possible_operations.find_alignment(datatypes.get_tensor_shape(C), self._layout_c)
alignment_a = self.possible_operations.find_alignment(shape_a, self._layout_a)
alignment_b = self.possible_operations.find_alignment(shape_b, self._layout_b)
alignment_c = self.possible_operations.find_alignment(shape_c, self._layout_c)
alignment_a = check.update_alignment(alignment_a, self.alignment_pref_A)
alignment_b = check.update_alignment(alignment_b, self.alignment_pref_B)
alignment_c = check.update_alignment(alignment_c, self.alignment_pref_C)
# Propose iterator algorithm based on input
if self._iterator_algorithm is None:
# Propose a default itertaor algorithm based on the problem size
# Propose a default iterator algorithm based on the problem size
iterator_algorithm = self._propose_iterator_algorithm(problem_size, alignment_a, alignment_b)
else:
if (self._validate_iterator_algorithm(self._iterator_algorithm, problem_size, alignment_a, alignment_b)):
iterator_algorithm = self._iterator_algorithm
else:
raise Exception(f"Iterator algorithm {self._iterator_algorithm} is invalid for current problem.")
epilogue_args = [alpha, beta]
if hasattr(self, "_activation_args"):
if isinstance(self._activation_args, list):
epilogue_args += self._activation_args
@ -813,43 +812,41 @@ class Conv2d(OperationBase):
epilogue_functor = self._create_epilogue_functor_activation(epilogue.identity)
else:
epilogue_functor = self.epilogue_functor
# The alignment is determined by the iterator function (I believe)
self.compile(tile_description=self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
self.compile(tile_description=self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
alignment_C=alignment_c, iterator_algorithm=iterator_algorithm, stride_support=stride_support,
swizzling_functor=swizzling_functor, epilogue_functor=epilogue_functor, print_module=print_module)
# Create reduction operation for parallel split-k
if split_k[0] == "parallel" and split_k[1] > 1:
epilogue_functor_reduction = self._reset_epilogue_functor_alignment(alignment_c, self.epilogue_functor)
self.reduction_operation = ReductionOperation(
shape=cutlass_bindings.MatrixCoord(4, 32 * alignment_c), C=self.operation.C,
element_accumulator=datatypes.binding_type(self._element_accumulator),
element_compute=datatypes.binding_type(self._element_accumulator),
shape=MatrixCoord(4, 32 * alignment_c), C=self.operation.C,
element_accumulator=self._element_accumulator,
element_compute=self._element_accumulator,
epilogue_functor=epilogue_functor_reduction,
count=alignment_c
)
if print_module:
print(self.reduction_operation.rt_module.emit())
compiler.add_module([self.reduction_operation,])
arguments = Conv2dArguments(
operation=self.operation, problem_size=problem_size,
A=A, B=B, C=C, D=D,
output_op=self.operation.epilogue_type(*epilogue_args),
split_k_mode=datatypes.getattr_enum(cutlass_bindings.conv.SplitKMode, split_k[0]),
split_k_mode=datatypes.getattr_enum(SplitKMode, split_k[0]),
split_k_slices=split_k[1]
)
self.operation.run(arguments)
if split_k[0] == "parallel" and split_k[1] > 1:
implicit_gemm_size = cutlass_bindings.conv.implicit_gemm_problem_size(
self.conv_kind, arguments.problem_size
)
implicit_gemm_size = arguments.problem_size.implicit_gemm_size(self.conv_kind)
reduction_arguments = ReductionArguments(
self.reduction_operation,
problem_size=[implicit_gemm_size.m(), implicit_gemm_size.n()],
problem_size=[implicit_gemm_size.m, implicit_gemm_size.n],
partitions=split_k[1],
workspace=arguments.ptr_D,
destination=D,
@ -857,7 +854,7 @@ class Conv2d(OperationBase):
output_op=self.reduction_operation.epilogue_type(*epilogue_args)
)
self.reduction_operation.run(reduction_arguments)
if sync:
if split_k[0] == "parallel" and split_k[1] > 1:
reduction_arguments.sync()
@ -865,23 +862,23 @@ class Conv2d(OperationBase):
arguments.sync()
return arguments
#
# Helper functions
#
@staticmethod
def output_size(input_size, weight_size, padding, stride, dilation):
problem_size = cutlass_bindings.conv.Conv2dProblemSize(
cutlass_bindings.Tensor4DCoord(*input_size),
cutlass_bindings.Tensor4DCoord(*weight_size),
cutlass_bindings.Tensor4DCoord(padding[0], padding[0], padding[1], padding[1]),
cutlass_bindings.MatrixCoord(stride[0], stride[1]),
cutlass_bindings.MatrixCoord(dilation[0], dilation[1]),
cutlass_bindings.conv.Mode.cross_correlation,
problem_size = Conv2DProblemSize(
*input_size,
*weight_size,
padding[0], padding[1],
stride[0], stride[1],
dilation[0], dilation[1],
ConvMode.CrossCorrelation,
1, 1
)
return (problem_size.N, problem_size.P, problem_size.Q, problem_size.K)
#
# Easy to use interfaces for fprop, wgrad, and dgrad
@ -890,23 +887,23 @@ class Conv2d(OperationBase):
class Conv2dFprop(Conv2d):
def __init__(
self,
input=None, weight=None, C=None, output=None, alpha=1, beta=0,
element=None,
element_input=None, element_weight=None, element_C=None, element_output=None,
element_accumulator=None,
input=None, weight=None, C=None, output=None, alpha=1, beta=0,
element=None,
element_input=None, element_weight=None, element_C=None, element_output=None,
element_accumulator=None,
cc: int = None, kernel_cc: int = None):
A, B, D = input, weight, output
element_A, element_B, element_D = element_input, element_weight, element_output
super().__init__(
"fprop", A, B, C, D, alpha, beta, element,
element_A, element_B, element_C, element_D,
"fprop", A, B, C, D, alpha, beta, element,
element_A, element_B, element_C, element_D,
element_accumulator, cc, kernel_cc)
def run(
self, input=None, weight=None, C=None, output=None, alpha=None, beta=None,
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
self, input=None, weight=None, C=None, output=None, alpha=None, beta=None,
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
sync: bool = True, print_module: bool = False) -> Conv2dArguments:
A, B, D = input, weight, output
return super().run(
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module)
@ -915,20 +912,20 @@ class Conv2dFprop(Conv2d):
class Conv2dDgrad(Conv2d):
def __init__(
self,
grad_output=None, weight=None, C=None, grad_input=None, alpha=1, beta=0,
element=None,
element_grad_output=None, element_weight=None, element_C=None, element_grad_input=None,
element_accumulator=None,
grad_output=None, weight=None, C=None, grad_input=None, alpha=1, beta=0,
element=None,
element_grad_output=None, element_weight=None, element_C=None, element_grad_input=None,
element_accumulator=None,
cc: int = None, kernel_cc: int = None):
A, B, D = grad_output, weight, grad_input
element_A, element_B, element_D = element_grad_output, element_weight, element_grad_input
super().__init__(
"dgrad", A, B, C, D, alpha, beta, element,
element_A, element_B, element_C, element_D,
"dgrad", A, B, C, D, alpha, beta, element,
element_A, element_B, element_C, element_D,
element_accumulator, cc, kernel_cc)
def run(self, grad_output=None, weight=None, C=None, grad_input=None, alpha=None, beta=None,
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
def run(self, grad_output=None, weight=None, C=None, grad_input=None, alpha=None, beta=None,
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
sync: bool = True, print_module: bool = False) -> Conv2dArguments:
#
A, B, D = grad_output, weight, grad_input
@ -939,20 +936,20 @@ class Conv2dDgrad(Conv2d):
class Conv2dWgrad(Conv2d):
def __init__(
self,
grad_output=None, input=None, C=None, grad_weight=None, alpha=1, beta=0,
element=None,
element_grad_output=None, element_input=None, element_C=None, element_grad_weight=None,
element_accumulator=None,
grad_output=None, input=None, C=None, grad_weight=None, alpha=1, beta=0,
element=None,
element_grad_output=None, element_input=None, element_C=None, element_grad_weight=None,
element_accumulator=None,
cc: int = None, kernel_cc: int = None):
A, B, D = grad_output, input, grad_weight
element_A, element_B, element_D = element_grad_output, element_input, element_grad_weight
super().__init__(
"wgrad", A, B, C, D, alpha, beta, element,
element_A, element_B, element_C, element_D,
"wgrad", A, B, C, D, alpha, beta, element,
element_A, element_B, element_C, element_D,
element_accumulator, cc, kernel_cc)
def run(self, grad_output=None, input=None, C=None, grad_weight=None, alpha=None, beta=None,
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
def run(self, grad_output=None, input=None, C=None, grad_weight=None, alpha=None, beta=None,
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
sync: bool = True, print_module: bool = False) -> Conv2dArguments:
#
A, B, D = grad_output, input, grad_weight

View File

@ -116,14 +116,18 @@
from math import prod
import cutlass_bindings
import cutlass
from cutlass import epilogue, swizzle
from cutlass import (
epilogue,
swizzle,
GemmUniversalMode,
)
from cutlass.backend import compiler
from cutlass.backend.evt import EpilogueFunctorVisitor
from cutlass.backend.gemm_operation import GemmArguments, GemmOperationUniversal
from cutlass.backend.library import TensorDescription, TileDescription
from cutlass.op.op import OperationBase
from cutlass.shape import GemmCoord
from cutlass.utils import check, datatypes
@ -245,7 +249,7 @@ class Gemm(OperationBase):
lay_to_set = lay if lay is not None else layout
elements.append(datatypes.library_type(elt_to_set))
layouts.append(datatypes.library_layout(lay_to_set))
layouts.append(lay_to_set)
self._element_a, self._element_b, self._element_c, self._element_d = elements
self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts
@ -265,6 +269,7 @@ class Gemm(OperationBase):
self.epilogue_functor = None
self.op_class = None
self._tile_description = None
self._reset_operations()
@ -311,6 +316,48 @@ class Gemm(OperationBase):
raise Exception('ThreadblockSwizzleStreamK is currently unsupported on SM90')
self._swizzling_functor = swizzling_functor
#
# Tile description Related
#
@property
def tile_description(self) -> TileDescription:
"""
Returns the tile description
"""
return self._tile_description
@tile_description.setter
def tile_description(
self, td=None):
"""
Set the tile description
:param td: tile description
:type td: cutlass.backend.TileDescription, or a dict with keys
{
"threadblock_shape": [int, int, int],
"warp_count": [int, int, int],
"stages": int,
"instruction_shape": [int, int, int] (optional),
"cluster_shape": [int, int, int] (optional)
}
"""
if td is None:
return
if isinstance(td, dict):
if self._tile_description is None:
alignment = list(self.possible_operations.kernels_by_alignment.keys())[0]
op = self.possible_operations.operations(alignment)[0]
self._tile_description = datatypes.td_from_profiler_op(op)
td = self._tile_description.clone_and_update(td)
valid, msg = self._valid_tile_description(td)
if valid:
self._tile_description = td
else:
raise Exception(msg)
def _valid_tile_description(self, td: TileDescription) -> tuple:
"""
Checks whether the provided tile description is valid for the given compute capability. At present,
@ -328,9 +375,7 @@ class Gemm(OperationBase):
and the second element is a string providing an optional error message.
:rtype: tuple
"""
# Check stage count based on the CC to which we are compiling (self.cc), rather
# than the CC from which we find kernels (self.current_cc)
valid, msg = check.valid_stage_count(self.cc, td, self._element_c, self._element_d)
valid, msg = check.valid_stage_count(self.cc, self.current_cc, td, self._element_c, self._element_d)
if not valid:
return (valid, msg)
@ -378,30 +423,21 @@ class Gemm(OperationBase):
self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor)
tensor_A = TensorDescription(
datatypes.binding_type(self._element_a),
datatypes.binding_layout(self._layout_a),
alignment_A
)
tensor_B = TensorDescription(
datatypes.binding_type(self._element_b),
datatypes.binding_layout(self._layout_b),
alignment_B
)
tensor_C = TensorDescription(
datatypes.binding_type(self._element_c),
datatypes.binding_layout(self._layout_c),
alignment_C
)
tensor_A = TensorDescription(self._element_a, self._layout_a, alignment_A)
tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
if tile_description is None:
op = self.possible_operations.operations(alignment_A)[0]
tile_description = datatypes.td_from_profiler_op(op)
if self._tile_description is None:
op = self.possible_operations.operations(alignment_A)[0]
tile_description = datatypes.td_from_profiler_op(op)
else:
tile_description = self._tile_description
else:
valid, err_str = self._valid_tile_description(tile_description)
if not valid:
raise Exception(f"Invalid tile description. {err_str}")
self.tile_description = tile_description
self._tile_description = tile_description
operation = GemmOperationUniversal(
arch=self.current_cc,
@ -473,21 +509,13 @@ class Gemm(OperationBase):
:return: tuple of batch count dimensions
:rtype: tuple
"""
A_batch = A.shape[:-2] if len(A.shape) > 2 else tuple()
B_batch = B.shape[:-2] if len(B.shape) > 2 else tuple()
C_batch = C.shape[:-2] if len(C.shape) > 2 else tuple()
D_batch = D.shape[:-2] if len(D.shape) > 2 else tuple()
A_batch = prod(A.shape[:-2]) if len(A.shape) > 2 else 1
B_batch = prod(B.shape[:-2]) if len(B.shape) > 2 else 1
if len(D_batch) > 0 and D_batch not in [A_batch, B_batch, C_batch]:
raise Exception(f"Batch count in D must be present in one of operands A, B, and C. "
f"Batch counts are: A={A_batch}, B={B_batch}, C={C_batch}, D={D_batch}")
for batch_shape in [A_batch, B_batch, C_batch]:
if len(batch_shape) > 0 and batch_shape != D_batch:
raise Exception(f"Batch count for all other operands must either match that of D or be zero."
f"Received batch shape of {batch_shape}, which does not match that of D of {D_batch}.")
return D_batch
if 1 not in [A_batch, B_batch]:
if A_batch != B_batch:
raise Exception(f"Get invalid batch counts: A={A_batch}, B={B_batch}")
return max(A_batch, B_batch)
def _get_batch_stride(self, tensor) -> int:
"""
@ -518,38 +546,38 @@ class Gemm(OperationBase):
:param D: tensor D
:type D: numpy/cupy/torch array/tensor object
:return: tuple containing the problem size (cutlass_bindings.gemm.GemmCoord), the GEMM mode (cutlass_bindings.gemm.Mode), and the batch count (int)
:return: tuple containing the problem size (cutlass.shape.GemmCoord), the GEMM mode (cutlass.GemmUniversalMode), and the batch count (int)
:rtype: tuple
"""
M, K = A.shape[-2:]
N = B.shape[-1]
mode = cutlass_bindings.gemm.Mode.Gemm
mode = GemmUniversalMode.Gemm
batch_count = self._get_batch_count(A, B, C, D)
returned_batch_count = prod(batch_count) if len(batch_count) > 0 else 1
returned_batch_count = batch_count
# If we are running a batched GEMM in which there is a nonzero batch stride
# only for A, then we can fold the batched dimension of A into the M dimension
# (i.e., (b, m, k) x (k, n) -> (m*b, k) x (k, n)). This works only if both A
# and C are row major. A similar operation can be performed if only B has a nonzero
# batch dimension
if len(batch_count) > 0:
if batch_count > 1:
A_row = self._layout_a == cutlass.LayoutType.RowMajor
B_row = self._layout_b == cutlass.LayoutType.RowMajor
C_row = self._layout_c == cutlass.LayoutType.RowMajor
batched = lambda x : len(x.shape) == 2 + len(batch_count)
batched = lambda x : len(x.shape) > 2 and prod(x.shape[:-2]) == batch_count
if batched(A) and not batched(B) and batched(C) and A_row and C_row:
M *= prod(batch_count)
M *= batch_count
returned_batch_count = 1
elif not batched(A) and batched(B) and batched(C) and not B_row and not C_row:
N *= prod(batch_count)
N *= batch_count
returned_batch_count = 1
else:
mode = cutlass_bindings.gemm.Mode.Batched
mode = GemmUniversalMode.Batched
return cutlass_bindings.gemm.GemmCoord(M, N, K), mode, returned_batch_count
return GemmCoord(M, N, K), mode, returned_batch_count
def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name):
"""
@ -570,7 +598,7 @@ class Gemm(OperationBase):
f'layout of ({ref_type}, {ref_layout}).')
def run(self, A=None, B=None, C=None, D=None,
alpha=None, beta=None, sync: bool = True, print_module: bool = False) -> GemmArguments:
alpha=None, beta=None, sync: bool = True, print_module: bool = False, visitor_args: dict = None) -> GemmArguments:
"""
Runs the kernel currently specified. If it has not already been, the kernel is emitted and
compiled. Tensors holding operands and outputs of the kernel are sourced either from the
@ -612,12 +640,12 @@ class Gemm(OperationBase):
alignment_a = self.possible_operations.find_alignment(A.shape, self._layout_a)
alignment_b = self.possible_operations.find_alignment(B.shape, self._layout_b)
alignment_c = self.possible_operations.find_alignment(C.shape, self._layout_c)
self.compile(self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
self.compile(self._tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
alignment_C=alignment_c, print_module=print_module)
problem_size, mode, batch_count = self._get_problem_args(A, B, C, D)
if mode == cutlass_bindings.gemm.Mode.Gemm or batch_count == 1:
if mode == GemmUniversalMode.Gemm or batch_count == 1:
kwargs = {'split_k_slices': 1}
else:
kwargs = {
@ -630,10 +658,15 @@ class Gemm(OperationBase):
}
}
if isinstance(self.epilogue_functor, EpilogueFunctorVisitor):
output_op = self.operation.epilogue_type(visitor_args)
else:
output_op = self.operation.epilogue_type(alpha, beta)
arguments = GemmArguments(
operation=self.operation, problem_size=problem_size,
A=A, B=B, C=C, D=D,
output_op=self.operation.epilogue_type(alpha, beta),
output_op=output_op,
gemm_mode=mode,
**kwargs
)

View File

@ -51,19 +51,18 @@
plan.run([A0, A1], [B0, B1], [C0, C1], [D0, D1])
"""
import cutlass_bindings
from cutlass import DataTypeSize
from cutlass.backend.gemm_operation import (
GemmGroupedArguments,
GemmOperationGrouped,
)
from cutlass.backend.library import (
DataTypeSize,
SchedulerMode,
TensorDescription,
TileDescription,
)
from cutlass.op.gemm import Gemm
from cutlass.shape import GemmCoord
from cutlass.utils import check, datatypes
@ -170,21 +169,9 @@ class GroupedGemm(Gemm):
self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor)
tensor_A = TensorDescription(
datatypes.binding_type(self._element_a),
datatypes.binding_layout(self._layout_a),
alignment_A
)
tensor_B = TensorDescription(
datatypes.binding_type(self._element_b),
datatypes.binding_layout(self._layout_b),
alignment_B
)
tensor_C = TensorDescription(
datatypes.binding_type(self._element_c),
datatypes.binding_layout(self._layout_c),
alignment_C
)
tensor_A = TensorDescription(self._element_a, self._layout_b, alignment_A)
tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
if tile_description is None:
op = self.possible_operations.operations(alignment_A)[0]
@ -244,7 +231,7 @@ class GroupedGemm(Gemm):
Bs[i] = self._verify_tensor(B[i], self.B, self._element_b, self._layout_b, "B")
Cs[i] = self._verify_tensor(C[i], self.C, self._element_c, self._layout_c, "C")
Ds[i] = self._verify_tensor(D[i], self.D, self._element_d, self._layout_d, "D")
problem_sizes.append(cutlass_bindings.gemm.GemmCoord(A[i].shape[0], B[i].shape[1], A[i].shape[1]))
problem_sizes.append(GemmCoord(A[i].shape[0], B[i].shape[1], A[i].shape[1]))
alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")

View File

@ -38,11 +38,12 @@ from bisect import bisect_left
import cutlass
from cutlass import option_registry, epilogue
from cutlass.backend.evt import EpilogueFunctorVisitor
from cutlass.backend.utils.device import device_cc
from cutlass.epilogue import get_activations
from cutlass.library_defaults import _generator_ccs
from cutlass.library_defaults import KernelsForDataType, _generator_ccs
from cutlass.swizzle import get_swizzling_functors
from cutlass.utils import datatypes
from cutlass.utils import datatypes, check
class OperationBase:
@ -67,7 +68,7 @@ class OperationBase:
if self.options is None:
raise Exception(f"Invalid or unsupported compute capability: {self.current_cc}")
# Default activation function: identity
self._activation = epilogue.identity
@ -120,7 +121,7 @@ class OperationBase:
raise Exception(f'Invalid CC for CUTLASS kernels: {cc}.')
self.current_cc = cc
self.options = option_registry.options_for_cc(self.current_cc, self.operation_kind)
def _verify_scalar(self, scalar, ref_scalar, ref_dtype, name):
"""
Verifies the following properties:
@ -153,7 +154,7 @@ class OperationBase:
f"Tensor {name} with type {dtype} does not match expected type {ref_dtype}."
)
return scalar
def _verify_tensor(self, tensor, ref_tensor, ref_dtype, ref_layout, name):
"""
Verifies the following properties:
@ -183,11 +184,11 @@ class OperationBase:
self._verify_type_and_layout(tensor, ref_dtype, ref_layout, name)
return tensor
#
# Opcode Related
#
@property
def opclass(self) -> cutlass.OpcodeClass:
"""
@ -197,7 +198,7 @@ class OperationBase:
:rtype: cutlass.OpcodeClass
"""
return self.op_class
@opclass.setter
def opclass(self, oc: cutlass.OpcodeClass):
if isinstance(oc, str):
@ -223,11 +224,11 @@ class OperationBase:
self.possible_operations = self.options.operations(
self.op_class, self._element_a, self._element_b,
self._element_accumulator, self._layout_a, self._layout_b)
#
# Epilogue
#
def _create_epilogue_functor_activation(self, activation):
"""
Returns the epilogue functor with given activation function
@ -261,44 +262,47 @@ class OperationBase:
return epilogue.get_activation_epilogue(
activation,
datatypes.binding_type(self._element_c),
self._element_c,
elements_per_access,
datatypes.binding_type(self._element_accumulator),
datatypes.binding_type(self._element_accumulator),
self._element_accumulator,
self._element_accumulator,
)
def _reset_epilogue_functor_activation(self, activation):
"""
Set the epilogue functor based on the provided activation function
"""
self.epilogue_functor = self._create_epilogue_functor_activation(activation)
def _reset_epilogue_functor_alignment(self, alignment, epilogue_functor):
"""
Reset the alignment of the current epilogue functor based on alignment C
"""
if isinstance(epilogue_functor, EpilogueFunctorVisitor):
return epilogue_functor
if epilogue_functor is None or not hasattr(epilogue_functor, 'activation_functor'):
# Identity epilogue does not have 'activation_functor'
activation = epilogue.identity
else:
activation = type(epilogue_functor.activation_functor)
activation = epilogue_functor.activation_functor
epilogue_functor = epilogue.get_activation_epilogue(
activation,
datatypes.binding_type(self._element_c),
self._element_c,
alignment,
datatypes.binding_type(self._element_accumulator),
datatypes.binding_type(self._element_accumulator),
self._element_accumulator,
self._element_accumulator,
)
return epilogue_functor
@property
def activation(self):
"""
Returns the type of the current activation function used
"""
if hasattr(self.epilogue_functor, "activation_functor"):
return type(self.epilogue_functor.activation_functor)
return self.epilogue_functor.activation_functor
else:
return epilogue.identity
@ -307,7 +311,7 @@ class OperationBase:
"""
Sets the type of the activation function to use
Activation can come with a set of arguments
:param act: type of activation function to use
:type act: str or tuple. e.g. "relu", ("leaky_relu", 0.01)
@ -325,4 +329,50 @@ class OperationBase:
act = getattr(cutlass.backend.epilogue, act)
self._reset_epilogue_functor_activation(act)
self._activation = act
@property
def epilogue_visitor(self):
"""
Return the epilogue functor
"""
return self.epilogue_functor
@epilogue_visitor.setter
def epilogue_visitor(self, visitor):
"""
Create the epilogue visitor
"""
self.epilogue_functor = EpilogueFunctorVisitor(self.cc, visitor)
# The epilogue_functor may consume too much shared memory
# Reset the possible operations
if self.cc != 90:
# The shared memory is only a concern for sm90 epilogue
# In sm80, the epilogue and mainloop share the shared memory
return
datatype_comb = self.possible_operations.datatype_comb
layout_comb = self.possible_operations.layout_comb
new_possible_operations = KernelsForDataType(datatype_comb, layout_comb)
for operation in self.possible_operations.all_operations:
td = datatypes.td_from_profiler_op(operation)
# Filter invalid epilogue schedules
if td.epilogue_schedule not in [
cutlass.EpilogueScheduleType.TmaWarpSpecialized,
cutlass.EpilogueScheduleType.TmaWarpSpecializedCooperative]:
continue
epilogue_smem_bytes = self.epilogue_functor.get_smem_size(td)
# Verify the maximum number of mainloop stages
mainloop_smem_per_stage = check.calculate_smem_usage_per_stage(td, cutlass.OperationKind.Gemm)
smem_capacity_bytes = cutlass.SharedMemPerCC[self.cc] << 10
mainloop_stages = (smem_capacity_bytes - epilogue_smem_bytes) // mainloop_smem_per_stage
if mainloop_stages < 2:
# Mainloop stages must >= 2
continue
new_possible_operations.add(operation)
if len(new_possible_operations.all_operations) == 0:
raise RuntimeError(
"The epilogue consumes too much shared memory. "
"No valid tile description is found in the generator.")
self.possible_operations = new_possible_operations

Some files were not shown because too many files have changed in this diff Show More