CUTLASS 3.2.1 (#1113)
* Updates for 3.2.1 release. * Minor fix in gemm op profiler for raster order. * Add scheduler mapping for raster order in the kernels.
This commit is contained in:
@ -1,6 +1,12 @@
|
||||

|
||||
|
||||
# CUTLASS Python Interface
|
||||
# Python packages associated with CUTLASS
|
||||
This directory contains Python packages that are associated with CUTLASS:
|
||||
|
||||
* `cutlass`: the CUTLASS Python interface, which enables one to compile and run CUTLASS kernels from within Python
|
||||
* `cutlass_library`: utilities used for enumerating and emitting C++ code for CUTLASS kernels
|
||||
|
||||
## CUTLASS Python Interface
|
||||
The CUTLASS Python interface enables one to compile and run CUTLASS operations from within Python.
|
||||
|
||||
```python
|
||||
@ -15,7 +21,7 @@ plan.run(A, B, C, D)
|
||||
**NOTE:** The CUTLASS Python interface is currently an experimental release. The API may change in the future.
|
||||
We welcome feedback from the community.
|
||||
|
||||
## Overview
|
||||
### Overview
|
||||
The CUTLASS Python interface aims to provide an ease-of-use interface for using CUTLASS via Python. Toward this goal,
|
||||
the CUTLASS Python interface attempts to:
|
||||
|
||||
@ -25,7 +31,7 @@ the CUTLASS Python interface attempts to:
|
||||
* Reduce the occurrence of C++ compile-time errors in favor of descriptive Python exceptions
|
||||
* Make it easy to export CUTLASS kernels to framework extensions (e.g., PyTorch CUDA extensions)
|
||||
|
||||
### Non-goals
|
||||
#### Non-goals
|
||||
The CUTLASS Python interface does not intended to:
|
||||
|
||||
**Select optimal kernel configurations.**
|
||||
@ -43,7 +49,7 @@ one of the CUTLASS emitters for automatically creating a framework extension for
|
||||
The CUTLASS Python interface intends to enable one to use CUTLASS via Python. It can be used by frameworks for JIT compiling
|
||||
Python to CUDA kernels, but does not set out to be such a framework.
|
||||
|
||||
### Comparison to PyCUTLASS
|
||||
#### Comparison to PyCUTLASS
|
||||
The CUTLASS Python interface builds atop CUTLASS's [PyCUTLASS](https://github.com/NVIDIA/cutlass/tree/v3.0.0/tools/library/scripts/pycutlass) library. PyCUTLASS enables
|
||||
one to declare, compile, and run GEMMs, convolutions, and grouped GEMM operators with nearly the same configuration
|
||||
space as CUTLASS's C++ interface. While this flexibility enables one to achieve the similar levels of functionality
|
||||
@ -53,43 +59,14 @@ to operators -- similar to what one must do in specifying template parameters to
|
||||
In contrast, the CUTLASS Python interface aims to provide a higher-level API for declaring, emitting, and compiling
|
||||
kernels that does not require exhaustively defining template parameters.
|
||||
|
||||
#### Transitioning from PyCUTLASS
|
||||
At present, existing PyCUTLASS functionality remains available via the CUTLASS Python interface. One can
|
||||
continue to use PyCUTLASS by replacing references to the PyCUTLASS `cutlass` module with `cutlass_bindings`
|
||||
and the PyCUTLASS `pycutlass` module with `cutlass.backend`.
|
||||
|
||||
For example, the following code using PyCUTLASS:
|
||||
```python
|
||||
import pycutlass
|
||||
import cutlass
|
||||
|
||||
math_inst = pycutlass.MathInstruction(
|
||||
[1, 1, 1], cutlass.float32, cutlass.float32, cutlass.float32,
|
||||
cutlass.OpClass.Simt, pycutlass.MathOperation.multiply_add
|
||||
)
|
||||
```
|
||||
|
||||
can work with the Python interface via:
|
||||
```python
|
||||
import cutlass.backend as pycutlass
|
||||
import cutlass_bindings
|
||||
|
||||
math_inst = pycutlass.MathInstruction(
|
||||
[1, 1, 1], cutlass_bindings.float32, cutlass_bindings.float32, cutlass_bindings.float32,
|
||||
cutlass_bindings.OpClass.Simt, pycutlass.MathOperation.multiply_add
|
||||
)
|
||||
```
|
||||
|
||||
**NOTE:** backwards compatibility of `cutlass.backend` with `pycutlass` will not be maintained moving forward.
|
||||
|
||||
## Current functionality
|
||||
### Current functionality
|
||||
The CUTLASS Python interface currently supports the following operations:
|
||||
* GEMMs
|
||||
* GEMMs with fused elementwise epilogues (e.g., ReLU) (for pre-SM90 kernels)
|
||||
* Stream K swizzling (for pre-SM90 kernels)
|
||||
* Grouped GEMM (for pre-SM90 kernels)
|
||||
|
||||
## Getting started
|
||||
### Getting started
|
||||
We recommend using the CUTLASS Python interface via one of the Docker images located in the [docker](/python/docker) directory.
|
||||
|
||||
```bash
|
||||
@ -99,7 +76,7 @@ docker run --gpus all -it --rm cutlass-cuda12.1:latest
|
||||
|
||||
The CUTLASS Python interface has been tested with CUDA 11.8, 12.0, and 12.1 on Python 3.8.10 and 3.9.7.
|
||||
|
||||
### Optional environment variables
|
||||
#### Optional environment variables
|
||||
Prior to installing the CUTLASS Python interface, one may optionally set the following environment variables:
|
||||
* `CUTLASS_PATH`: the path to the cloned CUTLASS repository
|
||||
* `CUDA_INSTALL_PATH`: the path to the installation of CUDA
|
||||
@ -110,7 +87,7 @@ If these environment variables are not set, the installation process will infer
|
||||
|
||||
**NOTE:** The version of `cuda-python` installed must match the CUDA version in `CUDA_INSTALL_PATH`.
|
||||
|
||||
### Installation
|
||||
#### Installation
|
||||
The CUTLASS Python interface can currently be installed via:
|
||||
```bash
|
||||
python setup.py develop --user
|
||||
@ -119,7 +96,7 @@ This will allow changes to the Python interface source to be reflected when usin
|
||||
|
||||
We plan to add support for installing via `python setup.py install` in a future release.
|
||||
|
||||
## Examples
|
||||
### Examples
|
||||
Jupyter notebook examples of using the CUTLASS Python interface are located in [examples/python](/examples/python).
|
||||
|
||||
To launch these notebooks from this directory, run:
|
||||
@ -127,7 +104,7 @@ To launch these notebooks from this directory, run:
|
||||
jupyter-lab ../examples/python
|
||||
```
|
||||
|
||||
## Building documentation
|
||||
### Building documentation
|
||||
The CUTLASS Python interface uses [Sphinx](https://www.sphinx-doc.org/en/master/) for documentation.
|
||||
|
||||
Building the documentation requires additional packages. These can be installed via:
|
||||
@ -147,6 +124,22 @@ make html
|
||||
mv _build/* ../docs
|
||||
```
|
||||
|
||||
## CUTLASS library package
|
||||
[cutlass_library](/python/cutlass_library) contains utilities for enumerating and emitting CUTLASS C++ kernels.
|
||||
It is used by the CUTLASS CMake system to construct a library of kernels that can be profiled using the CUTLASS profiler.
|
||||
|
||||
To install the `cutlass_library` package, run
|
||||
```bash
|
||||
python setup_library.py develop --user
|
||||
```
|
||||
|
||||
Alternatively, `cutlass_library` will automatically be installed if you install the CUTLASS Python interface package.
|
||||
|
||||
You can also use the [generator.py](/python/cutlass_library/generator.py) script directly without installing the module via:
|
||||
```bash
|
||||
python -m cutlass_library.generator
|
||||
```
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
|
||||
@ -34,6 +34,8 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import cutlass_library
|
||||
|
||||
|
||||
def _cutlass_path_from_dir() -> str:
|
||||
cutlass_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../')
|
||||
@ -62,19 +64,29 @@ CUTLASS_PATH = os.getenv("CUTLASS_PATH", _cutlass_path_from_dir())
|
||||
CUDA_INSTALL_PATH = os.getenv("CUDA_INSTALL_PATH", _cuda_install_path_from_nvcc())
|
||||
CACHE_FILE = "compiled_cache.db"
|
||||
|
||||
# Add the path to the CUTLASS profiler generation/manifest scripts to PYTHONPATH
|
||||
sys.path.insert(0, os.path.join(CUTLASS_PATH, "tools/library/scripts/"))
|
||||
|
||||
# Import types/methods from the CUTLASS utility libraries for profiler generation/emission under
|
||||
from library import (
|
||||
from cutlass_library.library import (
|
||||
ArchitectureNames,
|
||||
ComplexTransform,
|
||||
ComplexTransformTag,
|
||||
ConvKind,
|
||||
ConvKindNames,
|
||||
ConvKindTag,
|
||||
ConvMode,
|
||||
DataType,
|
||||
DataTypeNames,
|
||||
DataTypeSize,
|
||||
DataTypeTag,
|
||||
EpilogueFunctor,
|
||||
EpilogueScheduleSuffixes,
|
||||
EpilogueScheduleTag,
|
||||
EpilogueScheduleType,
|
||||
GemmKind,
|
||||
GemmKindNames,
|
||||
GemmUniversalMode,
|
||||
IteratorAlgorithm,
|
||||
IteratorAlgorithmNames,
|
||||
IteratorAlgorithmTag,
|
||||
LayoutTag,
|
||||
LayoutType,
|
||||
KernelScheduleSuffixes,
|
||||
@ -82,15 +94,27 @@ from library import (
|
||||
KernelScheduleType,
|
||||
MathInstruction,
|
||||
MathOperation,
|
||||
MathOperationTag,
|
||||
OpcodeClass,
|
||||
OpcodeClassNames,
|
||||
OpcodeClassTag,
|
||||
OperationKind,
|
||||
SharedMemPerCC,
|
||||
ShortComplexLayoutNames,
|
||||
ShortDataTypeNames,
|
||||
ShortLayoutTypeNames,
|
||||
SplitKMode,
|
||||
StrideSupport,
|
||||
StrideSupportNames,
|
||||
StrideSupportTag,
|
||||
SwizzlingFunctor,
|
||||
SwizzlingFunctorTag,
|
||||
TensorDescription,
|
||||
TileDescription,
|
||||
TileSchedulerSuffixes,
|
||||
TileSchedulerTag,
|
||||
TileSchedulerType
|
||||
TileSchedulerType,
|
||||
get_complex_from_real,
|
||||
)
|
||||
|
||||
this = sys.modules[__name__]
|
||||
@ -112,7 +136,7 @@ from cutlass.backend.utils.device import device_cc
|
||||
|
||||
this.option_registry = OptionRegistry(device_cc())
|
||||
|
||||
this.__version__ = '3.2.0'
|
||||
this.__version__ = '3.2.1'
|
||||
|
||||
from cutlass.backend import get_memory_pool
|
||||
from cutlass.emit.pytorch import pytorch
|
||||
@ -120,5 +144,6 @@ from cutlass.op.gemm import Gemm
|
||||
from cutlass.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad
|
||||
from cutlass.op.gemm_grouped import GroupedGemm
|
||||
from cutlass.op.op import OperationBase
|
||||
from cutlass.backend.evt.ir.tensor import Tensor
|
||||
|
||||
get_memory_pool(init_pool_size=2 ** 30, max_pool_size=2 ** 32)
|
||||
|
||||
@ -1,6 +1,3 @@
|
||||
# module-wide variables
|
||||
import os
|
||||
|
||||
from cutlass.backend.arguments import *
|
||||
from cutlass.backend.c_types import *
|
||||
from cutlass.backend.compiler import ArtifactManager
|
||||
@ -11,9 +8,7 @@ from cutlass.backend.gemm_operation import *
|
||||
from cutlass.backend.library import *
|
||||
from cutlass.backend.memory_manager import PoolMemoryManager
|
||||
from cutlass.backend.operation import *
|
||||
from cutlass.backend.parser import *
|
||||
from cutlass.backend.reduction_operation import *
|
||||
from cutlass.backend.tensor_ref import *
|
||||
from cutlass.backend.type_hint import *
|
||||
from cutlass.backend.utils import *
|
||||
from cutlass.backend.utils.device import device_cc
|
||||
|
||||
@ -30,6 +30,7 @@
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from math import prod
|
||||
from typing import Union
|
||||
|
||||
from cuda import cuda, cudart
|
||||
@ -67,39 +68,39 @@ class ArgumentBase:
|
||||
# by default, tensor_C is not bias
|
||||
self.bias = False
|
||||
|
||||
# preprocessing input tensors
|
||||
if isinstance(A, np.ndarray):
|
||||
self.host_D = D
|
||||
self.buffer_A = NumpyFrontend.argument(A, False)
|
||||
self.buffer_B = NumpyFrontend.argument(B, False)
|
||||
self.buffer_C = NumpyFrontend.argument(C, False)
|
||||
self.buffer_D = NumpyFrontend.argument(D, True)
|
||||
self.ptr_A = self.buffer_A.ptr
|
||||
self.ptr_B = self.buffer_B.ptr
|
||||
self.ptr_C = self.buffer_C.ptr
|
||||
self.ptr_D = self.buffer_D.ptr
|
||||
# number of elements in C
|
||||
self.tensor_c_numel = C.size
|
||||
elif torch_available and isinstance(A, torch.Tensor):
|
||||
self.ptr_A = TorchFrontend.argument(A)
|
||||
self.ptr_B = TorchFrontend.argument(B)
|
||||
self.ptr_C = TorchFrontend.argument(C)
|
||||
self.ptr_D = TorchFrontend.argument(D)
|
||||
# number of elements in C
|
||||
self.tensor_c_numel = C.numel()
|
||||
elif isinstance(A, cuda.CUdeviceptr):
|
||||
self.ptr_A = A
|
||||
self.ptr_B = B
|
||||
self.ptr_C = C
|
||||
self.ptr_D = D
|
||||
# RMM buffers used to track tensor lifetime
|
||||
self.buffers = {}
|
||||
# Host tensor to copy the computed result back
|
||||
self.host_tensors = {}
|
||||
|
||||
elif cupy_available and isinstance(A, cp.ndarray):
|
||||
self.ptr_A = CupyFrontend.argument(A)
|
||||
self.ptr_B = CupyFrontend.argument(B)
|
||||
self.ptr_C = CupyFrontend.argument(C)
|
||||
self.ptr_D = CupyFrontend.argument(D)
|
||||
# number of elements in C
|
||||
self.tensor_c_numel = C.size
|
||||
self.ptr_A = self.tensor_to_ptr(A, "A")
|
||||
self.ptr_B = self.tensor_to_ptr(B, "B")
|
||||
self.ptr_C = self.tensor_to_ptr(C, "C")
|
||||
self.ptr_D = self.tensor_to_ptr(D, "D", True)
|
||||
if C is not None:
|
||||
if not isinstance(C, cuda.CUdeviceptr):
|
||||
self.tensor_c_numel = prod(C.shape)
|
||||
|
||||
def tensor_to_ptr(self, tensor, name, is_output=False):
|
||||
"""
|
||||
Convert and remember the input tensor to cuda.CUdeviceptr used by cuda python
|
||||
For numpy.ndarray, it also remembers the host buffer for synchronization
|
||||
"""
|
||||
if tensor is None:
|
||||
return cuda.CUdeviceptr(0)
|
||||
if isinstance(tensor, np.ndarray):
|
||||
if is_output:
|
||||
assert name
|
||||
self.buffers[name] = NumpyFrontend.argument(tensor, is_output)
|
||||
if is_output:
|
||||
self.host_tensors[name] = tensor
|
||||
return self.buffers[name].ptr
|
||||
elif torch_available and isinstance(tensor, torch.Tensor):
|
||||
return TorchFrontend.argument(tensor)
|
||||
elif isinstance(tensor, cuda.CUdeviceptr):
|
||||
return tensor
|
||||
elif cupy_available and isinstance(tensor, cp.ndarray):
|
||||
return CupyFrontend.argument(tensor)
|
||||
else:
|
||||
raise TypeError("Unsupported Frontend. Only support numpy and torch")
|
||||
|
||||
@ -109,11 +110,12 @@ class ArgumentBase:
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
|
||||
if hasattr(self, "host_D"):
|
||||
for key in self.host_tensors.keys():
|
||||
host_tensor = self.host_tensors[key]
|
||||
(err,) = cuda.cuMemcpyDtoH(
|
||||
self.host_D,
|
||||
self.ptr_D,
|
||||
self.host_D.size * self.host_D.itemsize,
|
||||
host_tensor,
|
||||
self.buffers[key].ptr,
|
||||
host_tensor.size * host_tensor.itemsize,
|
||||
)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
|
||||
@ -32,7 +32,6 @@
|
||||
|
||||
import ctypes
|
||||
|
||||
import cutlass_bindings
|
||||
from cutlass import (
|
||||
DataType,
|
||||
KernelScheduleType
|
||||
@ -47,9 +46,10 @@ class GemmCoord_(ctypes.Structure):
|
||||
("k", ctypes.c_int)
|
||||
]
|
||||
|
||||
def __init__(self, gemm_coord) -> None:
|
||||
for field_name, _ in self._fields_:
|
||||
setattr(self, field_name, getattr(gemm_coord, field_name)())
|
||||
def __init__(self, m, n, k) -> None:
|
||||
self.m = m
|
||||
self.n = n
|
||||
self.k = k
|
||||
|
||||
|
||||
class GemmCoordBatched_(ctypes.Structure):
|
||||
@ -66,9 +66,10 @@ class GemmCoordBatched_(ctypes.Structure):
|
||||
]
|
||||
|
||||
def __init__(self, gemm_coord, batch_count) -> None:
|
||||
for field_name, _ in self._fields_[:-1]:
|
||||
setattr(self, field_name, getattr(gemm_coord, field_name)())
|
||||
setattr(self, "batch_count", batch_count)
|
||||
self.m = gemm_coord.m
|
||||
self.n = gemm_coord.n
|
||||
self.k = gemm_coord.k
|
||||
self.batch_count = batch_count
|
||||
|
||||
|
||||
class MatrixCoord_(ctypes.Structure):
|
||||
@ -98,14 +99,6 @@ class StrideBatched_(ctypes.Structure):
|
||||
]
|
||||
|
||||
|
||||
dtype2ctype = {
|
||||
cutlass_bindings.float16: ctypes.c_uint16,
|
||||
cutlass_bindings.float32: ctypes.c_float,
|
||||
cutlass_bindings.float64: ctypes.c_double,
|
||||
cutlass_bindings.int32: ctypes.c_int32,
|
||||
}
|
||||
|
||||
|
||||
class GenericMainloopArguments3x_(ctypes.Structure):
|
||||
"""
|
||||
Structure representing the superset of possible mainloop arguments.
|
||||
@ -196,15 +189,28 @@ def get_mainloop_arguments_3x(
|
||||
|
||||
def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor):
|
||||
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
|
||||
if hasattr(epilogue_functor, "visitor"):
|
||||
class _EpilogueArguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("epilogue", _EpilogueOutputOpParams),
|
||||
("arg_C", epilogue_functor.arg_c_type),
|
||||
("arg_D", epilogue_functor.arg_d_type)
|
||||
]
|
||||
|
||||
class _EpilogueArguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("epilogue", _EpilogueOutputOpParams),
|
||||
("ptr_C", ctypes.c_void_p),
|
||||
("stride_C", StrideBatched_),
|
||||
("ptr_D", ctypes.c_void_p),
|
||||
("stride_D", StrideBatched_),
|
||||
]
|
||||
def __init__(self, output_op, ptr_c, stride_c, ptr_d, stride_d) -> None:
|
||||
self.epilogue = output_op
|
||||
self.arg_C = epilogue_functor.arg_c_type(ptr_c)
|
||||
self.arg_D = epilogue_functor.arg_d_type(ptr_d)
|
||||
else:
|
||||
|
||||
class _EpilogueArguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("epilogue", _EpilogueOutputOpParams),
|
||||
("ptr_C", ctypes.c_void_p),
|
||||
("stride_C", StrideBatched_),
|
||||
("ptr_D", ctypes.c_void_p),
|
||||
("stride_D", StrideBatched_),
|
||||
]
|
||||
|
||||
class _HardwareInfo(ctypes.Structure):
|
||||
_fields_ = [
|
||||
@ -324,7 +330,7 @@ def get_gemm_grouped_arguments(epilogue_functor):
|
||||
############################################################################################
|
||||
|
||||
|
||||
class Conv2DProblemSize(ctypes.Structure):
|
||||
class Conv2DProblemSize_(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("N", ctypes.c_int),
|
||||
("H", ctypes.c_int),
|
||||
@ -382,11 +388,13 @@ def get_conv2d_arguments(epilogue_functor):
|
||||
|
||||
class _Conv2dArguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("problem_size", Conv2DProblemSize),
|
||||
("ref_A", TensorRef_),
|
||||
("ref_B", TensorRef_),
|
||||
("ref_C", TensorRef_),
|
||||
("ref_D", TensorRef_),
|
||||
("conv_kind", ctypes.c_int),
|
||||
("problem_size", Conv2DProblemSize_),
|
||||
("ptr_A", ctypes.c_void_p),
|
||||
("ptr_B", ctypes.c_void_p),
|
||||
("ptr_C", ctypes.c_void_p),
|
||||
("ptr_D", ctypes.c_void_p),
|
||||
("tensor_C_numel", ctypes.c_int),
|
||||
("output_op", _EpilogueOutputOpParams),
|
||||
("split_k_mode", ctypes.c_int)
|
||||
]
|
||||
@ -414,3 +422,189 @@ def get_reduction_params(epilogue_functor):
|
||||
]
|
||||
|
||||
return _ReductionParams, _EpilogueOutputParams
|
||||
|
||||
|
||||
###########################################################################################
|
||||
# Epilogue Visitor Type Factory
|
||||
###########################################################################################
|
||||
|
||||
class Empty(ctypes.Structure):
|
||||
_fields_ = []
|
||||
|
||||
def __init__(self, *arg) -> None:
|
||||
pass
|
||||
|
||||
class EmptyByte(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("byte", ctypes.c_byte)
|
||||
]
|
||||
|
||||
def __init__(self, *arg) -> None:
|
||||
pass
|
||||
|
||||
class EBO:
|
||||
def __init__(self, index: int, type) -> None:
|
||||
self.index = index
|
||||
self.type = type
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if isinstance(other, EBO):
|
||||
return self.index == other.index and self.type == other.type
|
||||
return False
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.index, self.type))
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"<{self.index}, {self.type}>"
|
||||
|
||||
|
||||
def tuple_factory_(input_tuple, dtype, constants=[0,1]):
|
||||
"""
|
||||
The factory function generating cute::Tuple with input tuple
|
||||
:param input_tuple: the input tuple
|
||||
:type input_tuple: tuple
|
||||
:param dtype: the data type for non-constant values
|
||||
:type dtype: str, "int32_t", "int", "int64_t"
|
||||
:param constant: the values that will be treated as constants
|
||||
:type constant: list[int]
|
||||
|
||||
:return: ctype structure representing the cute::Tuple
|
||||
:return: the empty base classes of the tuple
|
||||
"""
|
||||
|
||||
# The empty base classes of the current tuple
|
||||
empty_bases = []
|
||||
# The first non empty base class
|
||||
first_non_empty_base = None
|
||||
# The ctype fields of the current tuple
|
||||
ctype_fields = []
|
||||
|
||||
for idx, entry in enumerate(input_tuple):
|
||||
# For nested tuples
|
||||
if isinstance(entry, tuple):
|
||||
sub_tuple_ctype, sub_empty_bases = tuple_factory_(entry, dtype, constants)
|
||||
if ctypes.sizeof(sub_tuple_ctype) == 0:
|
||||
# The empty tuple base class is also an empty EBO
|
||||
empty_bases.append(EBO(idx, entry))
|
||||
else:
|
||||
if first_non_empty_base is None:
|
||||
first_non_empty_base = sub_empty_bases
|
||||
ctype_fields.append((f"entry_{idx}", sub_tuple_ctype))
|
||||
else:
|
||||
if entry in constants:
|
||||
empty_bases.append(EBO(idx, entry))
|
||||
ctype_fields.append((f"entry_{idx}", Empty))
|
||||
else:
|
||||
ctype_fields.append((f"entry_{idx}", dtype))
|
||||
if first_non_empty_base is None:
|
||||
first_non_empty_base = []
|
||||
|
||||
# Determine whether or not add an additional byte for empty base classes
|
||||
additional_byte = False
|
||||
# Special case for constant tuple
|
||||
if first_non_empty_base is None:
|
||||
additional_byte = False
|
||||
else:
|
||||
for base in first_non_empty_base:
|
||||
if base in empty_bases:
|
||||
additional_byte = True
|
||||
break
|
||||
|
||||
if additional_byte:
|
||||
ctype_fields = [("empty_byte", EmptyByte), ] + ctype_fields
|
||||
|
||||
# Create the ctype tuple
|
||||
class TupleType(ctypes.Structure):
|
||||
_fields_ = ctype_fields
|
||||
|
||||
def __init__(self, args) -> None:
|
||||
if additional_byte:
|
||||
fields = self._fields_[1:]
|
||||
else:
|
||||
fields = self._fields_
|
||||
|
||||
assert len(fields) == len(args)
|
||||
for field, arg in zip(fields, args):
|
||||
name = field[0]
|
||||
field_type = field[1]
|
||||
setattr(self, name, field_type(arg))
|
||||
|
||||
return TupleType, empty_bases
|
||||
|
||||
def tuple_factory(input_tuple, dtype: str, constants=[0,1]):
|
||||
"""
|
||||
The factory function generating cute::Tuple with input tuple
|
||||
:param input_tuple: the input tuple
|
||||
:type input_tuple: tuple
|
||||
:param dtype: the data type for non-constant values
|
||||
:type dtype: str, "int32_t", "int", "int64_t"
|
||||
:param constant: the values that will be treated as constants
|
||||
:type constant: list[int]
|
||||
|
||||
:return: ctype structure representing the cute::Tuple
|
||||
:return: the empty base classes of the tuple
|
||||
"""
|
||||
# Step 1: convert the dtype
|
||||
if dtype == "int64_t":
|
||||
dtype = ctypes.c_longlong
|
||||
elif dtype in ["int", "int32_t"]:
|
||||
dtype = ctypes.c_int32
|
||||
else:
|
||||
raise NotImplementedError(f"Type {dtype} is not supported")
|
||||
|
||||
tuple_type, _ = tuple_factory_(input_tuple, dtype, constants)
|
||||
|
||||
if ctypes.sizeof(tuple_type) == 0:
|
||||
return EmptyByte
|
||||
return tuple_type
|
||||
|
||||
|
||||
def visitor_factory(node_types, node_names):
|
||||
"""
|
||||
Creates the argument type of epilogue visitor type
|
||||
|
||||
:param node_types: list of argument types under ctypes
|
||||
:param node_names: list of argument names under str
|
||||
|
||||
:return: tuple type in ctypes.Structure
|
||||
"""
|
||||
ctypes_field = []
|
||||
# Struct is used when number of nodes < 4
|
||||
# Because the Sm90VisitorImplBase has specification up to 4 nodes
|
||||
# in `include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp`
|
||||
if len(node_types) <= 4:
|
||||
for idx, node_type in enumerate(node_types):
|
||||
if ctypes.sizeof(node_type) == 0:
|
||||
# Special case for empty struct
|
||||
# 1 byte placeholder is used for correct alignment
|
||||
ctypes_field.append((node_names[idx], ctypes.c_byte))
|
||||
else:
|
||||
ctypes_field.append((node_names[idx], node_type))
|
||||
|
||||
class VisitorType(ctypes.Structure):
|
||||
_fields_ = ctypes_field
|
||||
|
||||
def __init__(self, kwargs) -> None:
|
||||
for field in self._fields_:
|
||||
fname, ftype = field
|
||||
if ftype != ctypes.c_byte:
|
||||
setattr(self, fname, ftype(kwargs))
|
||||
|
||||
# For cases with more than 4 nodes, tuple is used
|
||||
else:
|
||||
for idx, node_type in enumerate(node_types):
|
||||
ctypes_field.append((node_names[idx], node_type))
|
||||
|
||||
class VisitorType(ctypes.Structure):
|
||||
_fields_ = ctypes_field
|
||||
|
||||
def __init__(self, kwargs) -> None:
|
||||
for field in self._fields_:
|
||||
fname, ftype = field
|
||||
setattr(self, fname, ftype(kwargs))
|
||||
|
||||
return VisitorType
|
||||
|
||||
@ -34,17 +34,16 @@ import ctypes
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
from cuda import cuda, nvrtc
|
||||
import cutlass_bindings
|
||||
|
||||
from cutlass import CACHE_FILE, CUDA_INSTALL_PATH, CUTLASS_PATH, logger
|
||||
from cutlass.backend.gemm_operation import GemmOperationUniversal
|
||||
from cutlass.backend.library import ApiVersion
|
||||
from cutlass.backend.utils.device import device_cc
|
||||
from cutlass.backend.utils.software import SubstituteTemplate
|
||||
import subprocess
|
||||
|
||||
IncludeTemplate = r"""#include "${include}"
|
||||
"""
|
||||
@ -157,8 +156,8 @@ class ArtifactManager:
|
||||
"-Xcudafe --diag_suppress=esa_on_defaulted_function_ignored",
|
||||
]
|
||||
self.nvcc()
|
||||
self.compiled_cache_device = cutlass_bindings.CompileCache()
|
||||
self.compiled_cache_host = cutlass_bindings.CompileCache()
|
||||
self.compiled_cache_device = {}
|
||||
self.compiled_cache_host = {}
|
||||
|
||||
def nvrtc(self):
|
||||
self.backend = "nvrtc"
|
||||
@ -197,7 +196,7 @@ class ArtifactManager:
|
||||
raise RuntimeError("Cuda Error: {}".format(err))
|
||||
|
||||
err, kernel = cuda.cuModuleGetFunction(module, bytes(str.encode(operation_name)))
|
||||
self.compiled_cache_device.insert(key, kernel)
|
||||
self.compiled_cache_device[key] = kernel
|
||||
|
||||
compiled_host_fns = {}
|
||||
host_lib = CDLLBin(host_binary)
|
||||
@ -222,7 +221,7 @@ class ArtifactManager:
|
||||
|
||||
compiled_host_fns[attr] = func
|
||||
|
||||
self.compiled_cache_host.insert(key, compiled_host_fns)
|
||||
self.compiled_cache_host[key] = compiled_host_fns
|
||||
return True
|
||||
|
||||
def emit_compile_(self, operation_list, compilation_options, host_compilation_options):
|
||||
@ -246,11 +245,10 @@ class ArtifactManager:
|
||||
)
|
||||
|
||||
for incl in includes_host:
|
||||
if "/device/" not in incl:
|
||||
source_buffer_host += SubstituteTemplate(
|
||||
IncludeTemplate,
|
||||
{"include": incl},
|
||||
)
|
||||
source_buffer_host += SubstituteTemplate(
|
||||
IncludeTemplate,
|
||||
{"include": incl},
|
||||
)
|
||||
|
||||
# 2. Operations
|
||||
for operation in operation_list:
|
||||
@ -382,16 +380,16 @@ class ArtifactManager:
|
||||
# step 1: get kernel string as key
|
||||
key = operation.rt_module.emit() + operation.procedural_name() + self.backend
|
||||
# step 1: check if the operation is in cache
|
||||
compiled_kernel = self.compiled_cache_device.at(key)
|
||||
compiled_kernel = self.compiled_cache_device.get(key)
|
||||
|
||||
if compiled_kernel is None and not bypass_cache:
|
||||
hit = self.load_operation(key, getattr( operation.rt_module, "extra_funcs", {}))
|
||||
if hit:
|
||||
compiled_kernel = self.compiled_cache_device.at(key)
|
||||
compiled_kernel = self.compiled_cache_device.get(key)
|
||||
assert compiled_kernel is not None
|
||||
if compiled_kernel is not None:
|
||||
operation.rt_module.kernel = compiled_kernel
|
||||
compiled_host_fns = self.compiled_cache_host.at(key)
|
||||
compiled_host_fns = self.compiled_cache_host.get(key)
|
||||
assert compiled_host_fns is not None
|
||||
for key in compiled_host_fns.keys():
|
||||
setattr(operation.rt_module, key, compiled_host_fns[key])
|
||||
@ -417,7 +415,7 @@ class ArtifactManager:
|
||||
bytes(str.encode(operation.name()))
|
||||
)
|
||||
operation_name.append(operation.name())
|
||||
self.compiled_cache_device.insert(key, operation.kernel)
|
||||
self.compiled_cache_device[key] = operation.kernel
|
||||
# get host functions
|
||||
compiled_host_fns = {}
|
||||
op_attr = []
|
||||
@ -456,7 +454,7 @@ class ArtifactManager:
|
||||
op_attr.append(suffix)
|
||||
|
||||
operation_attr.append(op_attr)
|
||||
self.compiled_cache_host.insert(key, compiled_host_fns)
|
||||
self.compiled_cache_host[key] = compiled_host_fns
|
||||
|
||||
for (key, operation_name, operation_attr,) in zip(operation_key, operation_name, operation_attr):
|
||||
self.insert_operation(
|
||||
|
||||
@ -29,19 +29,14 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
# from typeguard import typechecked
|
||||
|
||||
import ctypes
|
||||
from typing import Union
|
||||
|
||||
from cuda import cuda
|
||||
import cutlass_bindings
|
||||
import numpy as np
|
||||
|
||||
from cutlass.backend.arguments import ArgumentBase
|
||||
from cutlass.backend.c_types import Conv2DProblemSize, TensorRef_, get_conv2d_arguments
|
||||
from cutlass.backend.library import (
|
||||
EmissionType,
|
||||
from cutlass import (
|
||||
ConvKindNames,
|
||||
ConvKindTag,
|
||||
DataTypeNames,
|
||||
@ -50,29 +45,40 @@ from cutlass.backend.library import (
|
||||
IteratorAlgorithmNames,
|
||||
IteratorAlgorithmTag,
|
||||
LayoutTag,
|
||||
LayoutType,
|
||||
MathOperation,
|
||||
MathOperationTag,
|
||||
OpcodeClass,
|
||||
OpcodeClassNames,
|
||||
OpcodeClassTag,
|
||||
OperationKind,
|
||||
ShortDataTypeNames,
|
||||
ShortLayoutTypeNames,
|
||||
SplitKMode,
|
||||
StrideSupport,
|
||||
StrideSupportTag,
|
||||
SwizzlingFunctor,
|
||||
SwizzlingFunctorTag,
|
||||
get_complex_from_real,
|
||||
)
|
||||
|
||||
from cutlass.backend.arguments import ArgumentBase
|
||||
from cutlass.backend.c_types import dim3_, get_conv2d_arguments
|
||||
from cutlass.backend.library import (
|
||||
EmissionType,
|
||||
TensorDescription,
|
||||
TileDescription,
|
||||
get_complex_from_real,
|
||||
)
|
||||
from cutlass.backend.memory_manager import device_mem_alloc
|
||||
from cutlass.backend.operation import ExecutableOperation, LaunchConfiguration
|
||||
from cutlass.backend.tensor_ref import TensorRef
|
||||
from cutlass.backend.utils.datatypes import to_device_ptr
|
||||
from cutlass.backend.utils.software import CheckPackages, SubstituteTemplate
|
||||
from cutlass.shape import GemmCoord
|
||||
|
||||
if CheckPackages().check_torch():
|
||||
import torch
|
||||
|
||||
|
||||
# @typechecked
|
||||
class Conv2dArguments(ArgumentBase):
|
||||
"""
|
||||
Argument wrapper for Conv2d. It encodes problem information and
|
||||
@ -81,7 +87,7 @@ class Conv2dArguments(ArgumentBase):
|
||||
:param operation: the Conv2d operation to take the argument
|
||||
:type operation: :class:`cutlass.backend.Conv2dOperation`
|
||||
:param problem_size: the Conv2d problem size
|
||||
:type problem_size: :class:`cutlass_bindings.conv.Conv2dProblemSize`
|
||||
:type problem_size: :class:`cutlass.shape.Conv2dProblemSize`
|
||||
:param A: tensor A
|
||||
:type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
||||
:param B: tensor B
|
||||
@ -90,135 +96,70 @@ class Conv2dArguments(ArgumentBase):
|
||||
:type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
||||
:param D: tensor D
|
||||
:type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
||||
:param split_k_mode: conv2d split K mode, defaults to cutlass_bindings.conv.SplitKMode.Serial
|
||||
:type split_k_mode: cutlass_bindings.conv.SplitKMode, optional
|
||||
:param split_k_mode: conv2d split K mode, defaults to cutlass_library.library.SplitKMode.Serial
|
||||
:type split_k_mode: cutlass_library.library.SplitKMode, optional
|
||||
:param output_op: output operator, optional
|
||||
:type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments`
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
operation: "Conv2dOperation",
|
||||
problem_size: "cutlass_bindings.conv.Conv2dProblemSize",
|
||||
A: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]",
|
||||
B: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]",
|
||||
C: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]",
|
||||
D: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]",
|
||||
split_k_mode: "cutlass_bindings.conv.SplitKMode" = cutlass_bindings.conv.SplitKMode.Serial,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
def __init__(self, operation, problem_size, A, B, C, D,
|
||||
split_k_mode=SplitKMode.Serial, **kwargs, ) -> None:
|
||||
self.operation = operation
|
||||
#: convolution kind
|
||||
self.conv_kind: cutlass_bindings.conv.Operator = operation.conv_kind
|
||||
self.layout_A: cutlass_bindings.layout = operation.A.layout
|
||||
self.layout_B: cutlass_bindings.layout = operation.B.layout
|
||||
self.layout_C: cutlass_bindings.layout = operation.C.layout
|
||||
self.conv_kind = operation.conv_kind
|
||||
self.layout_A = operation.A.layout
|
||||
self.layout_B = operation.B.layout
|
||||
self.layout_C = operation.C.layout
|
||||
|
||||
self.element_A = operation.A.element
|
||||
self.element_B = operation.B.element
|
||||
self.element_C = operation.C.element
|
||||
|
||||
if self.layout_C == cutlass_bindings.TensorNC32HW32:
|
||||
B = self.reorder_tensor_B(B, problem_size)
|
||||
if self.layout_C == LayoutType.TensorNC32HW32:
|
||||
raise Exception("Layout type TensorNC32HW32 is not currently supported")
|
||||
|
||||
super().__init__(A, B, C, D, **kwargs)
|
||||
# preprocessing output ops
|
||||
|
||||
if "split_k_slices" in kwargs.keys() and kwargs["split_k_slices"] > 1:
|
||||
self.split_k_mode = split_k_mode
|
||||
self.split_k_slices = kwargs["split_k_slices"]
|
||||
else:
|
||||
self.split_k_mode = cutlass_bindings.conv.SplitKMode.Serial
|
||||
self.split_k_mode = SplitKMode.Serial
|
||||
self.split_k_slices = 1
|
||||
|
||||
if "output_op" in kwargs.keys() and self.split_k_mode != cutlass_bindings.conv.SplitKMode.Parallel:
|
||||
|
||||
if "output_op" in kwargs.keys() and self.split_k_mode != SplitKMode.Parallel:
|
||||
self.output_op = kwargs["output_op"]
|
||||
else:
|
||||
self.output_op = self.operation.epilogue_type(1.0, 0.0)
|
||||
|
||||
#: problem_size
|
||||
self.problem_size: cutlass_bindings.conv.Conv2dProblemSize = problem_size
|
||||
self.problem_size = problem_size
|
||||
self.problem_size.split_k_slices = self.split_k_slices
|
||||
|
||||
if hasattr(self, "tensor_c_numel"):
|
||||
c_coord = cutlass_bindings.conv.implicit_gemm_tensor_c_extent(
|
||||
self.conv_kind, problem_size)
|
||||
if self.tensor_c_numel == c_coord.at(3) and self.tensor_c_numel < c_coord.size():
|
||||
self.bias = True
|
||||
|
||||
#
|
||||
# initialize the argument
|
||||
#
|
||||
self.initialize()
|
||||
|
||||
# @typechecked
|
||||
def reorder_tensor_B(self, tensor_B: "np.ndarray",
|
||||
problem_size: "cutlass_bindings.conv.Conv2dProblemSize"):
|
||||
"""
|
||||
Reorder tensor_B for interleaved layout
|
||||
|
||||
:param tensor_B: input tensor B
|
||||
:type tensor_B: numpy.ndarray
|
||||
:param problem_size: Conv2d problem size
|
||||
:type problem_size: :class:`cutlass_bindings.conv.Conv2dProblemSize`
|
||||
|
||||
:return: reordered tensor B
|
||||
:rtype: numpy.ndarray
|
||||
"""
|
||||
reordered_tensor_B = np.empty_like(tensor_B)
|
||||
tensor_ref_B = self.get_tensor_ref(
|
||||
tensor_B, self.element_B, self.layout_B, problem_size, "b")
|
||||
reordered_tensor_ref_B = self.get_tensor_ref(
|
||||
reordered_tensor_B, self.element_B, self.layout_B, problem_size, "b")
|
||||
cutlass_bindings.conv.host.reorder_convK(
|
||||
reordered_tensor_ref_B, tensor_ref_B, self.conv_kind, problem_size)
|
||||
|
||||
return reordered_tensor_B
|
||||
|
||||
def get_tensor_ref(
|
||||
self, tensor, dtype, tensor_layout, problem_size, operand):
|
||||
if operand == "a":
|
||||
tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_a_extent(
|
||||
self.conv_kind, problem_size)
|
||||
elif operand == "b":
|
||||
tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_b_extent(
|
||||
self.conv_kind, problem_size)
|
||||
elif operand in ["c", "d"]:
|
||||
tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_c_extent(
|
||||
self.conv_kind, problem_size)
|
||||
else:
|
||||
raise ValueError("unknown operand: " + operand)
|
||||
# Zero stride trick
|
||||
if operand == "c" and self.bias:
|
||||
tensor_coord = cutlass_bindings.Tensor4DCoord(0, 0, 0, 0)
|
||||
|
||||
layout = tensor_layout.packed(tensor_coord)
|
||||
|
||||
return TensorRef(tensor, dtype, layout).tensor_ref
|
||||
|
||||
def get_arguments(self, semaphore):
|
||||
ref_A = TensorRef_(self.get_tensor_ref(
|
||||
self.ptr_A, self.element_A, self.layout_A, self.problem_size, "a"))
|
||||
ref_B = TensorRef_(self.get_tensor_ref(
|
||||
self.ptr_B, self.element_B, self.layout_B, self.problem_size, "b"))
|
||||
ref_C = TensorRef_(self.get_tensor_ref(
|
||||
self.ptr_C, self.element_C, self.layout_C, self.problem_size, "c"))
|
||||
ref_D = TensorRef_(self.get_tensor_ref(
|
||||
self.ptr_D, self.element_C, self.layout_C, self.problem_size, "d"))
|
||||
def get_arguments(self):
|
||||
tc_numel = -1
|
||||
if hasattr(self, "tensor_c_numel"):
|
||||
tc_numel = self.tensor_c_numel
|
||||
|
||||
self.c_arguments = self.operation.argument_type(
|
||||
Conv2DProblemSize(self.problem_size),
|
||||
ref_A, ref_B, ref_C, ref_D, self.output_op, self.split_k_mode)
|
||||
|
||||
self.semaphore = semaphore
|
||||
int(self.conv_kind),
|
||||
self.problem_size.ctype,
|
||||
int(to_device_ptr(self.ptr_A)),
|
||||
int(to_device_ptr(self.ptr_B)),
|
||||
int(to_device_ptr(self.ptr_C)),
|
||||
int(to_device_ptr(self.ptr_D)),
|
||||
tc_numel,
|
||||
self.output_op,
|
||||
int(self.split_k_mode)
|
||||
)
|
||||
|
||||
def initialize(self):
|
||||
# Get launch configuration
|
||||
self.launch_config = self.operation.rt_module.plan(self)
|
||||
|
||||
# Allocate and initialize device workspace
|
||||
device_workspace_size = self.operation.rt_module.get_device_workspace_size(self)
|
||||
self.get_arguments()
|
||||
|
||||
# Allocate and initialize device workspace
|
||||
device_workspace_size = self.operation.rt_module.get_workspace_size(self.c_arguments)
|
||||
if device_workspace_size > 0:
|
||||
self.workspace_buffer = device_mem_alloc(device_workspace_size)
|
||||
workspace_ptr = self.workspace_buffer.ptr
|
||||
@ -227,19 +168,16 @@ class Conv2dArguments(ArgumentBase):
|
||||
else:
|
||||
workspace_ptr = None
|
||||
|
||||
# Get kernel params as a bytearray
|
||||
semaphore = 0
|
||||
if (workspace_ptr is not None
|
||||
and self.split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel):
|
||||
self.semaphore = 0
|
||||
if workspace_ptr is not None and self.split_k_mode == SplitKMode.Parallel:
|
||||
self.ptr_D = workspace_ptr
|
||||
elif (workspace_ptr is not None
|
||||
and self.split_k_mode == cutlass_bindings.conv.SplitKMode.Serial):
|
||||
semaphore = workspace_ptr
|
||||
|
||||
self.get_arguments(semaphore)
|
||||
# Reset arguments now that ptr_D has been updated
|
||||
self.get_arguments()
|
||||
elif workspace_ptr is not None and self.split_k_mode == SplitKMode.Serial:
|
||||
self.semaphore = workspace_ptr
|
||||
|
||||
params_ = self.operation.rt_module.get_args(
|
||||
ctypes.byref(self.c_arguments), ctypes.c_void_p(int(self.semaphore)))
|
||||
self.c_arguments, ctypes.c_void_p(int(self.semaphore)))
|
||||
self.host_workspace = bytearray(params_.contents)
|
||||
self.device_workspace = None
|
||||
|
||||
@ -251,7 +189,6 @@ class Conv2dArguments(ArgumentBase):
|
||||
return super().sync()
|
||||
|
||||
|
||||
# @typechecked
|
||||
class Conv2dRT(ExecutableOperation):
|
||||
"""
|
||||
Conv2dRT manages the CUTLASS runtime components
|
||||
@ -287,24 +224,104 @@ extern "C" {
|
||||
return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
|
||||
}
|
||||
|
||||
using ElementA = typename ${operation_name}_base::ElementA;
|
||||
using ElementB = typename ${operation_name}_base::ElementB;
|
||||
using ElementC = typename ${operation_name}_base::ElementC;
|
||||
using LayoutA = typename ${operation_name}_base::LayoutA;
|
||||
using LayoutB = typename ${operation_name}_base::LayoutB;
|
||||
using LayoutC = typename ${operation_name}_base::LayoutC;
|
||||
using EpilogueOutputOp = typename ${operation_name}_base::EpilogueOutputOp;
|
||||
|
||||
struct ${operation_name}_TemporaryArgs {
|
||||
int conv_kind;
|
||||
cutlass::conv::Conv2dProblemSize problem_size;
|
||||
ElementA* ptr_A;
|
||||
ElementB* ptr_B;
|
||||
ElementC* ptr_C;
|
||||
ElementC* ptr_D;
|
||||
int tensor_c_numel;
|
||||
typename EpilogueOutputOp::Params epilogue_params;
|
||||
int split_k_mode;
|
||||
};
|
||||
|
||||
typename ${operation_name}${operation_suffix}::Arguments
|
||||
construct_arguments(${operation_name}_TemporaryArgs args) {
|
||||
cutlass::conv::Operator conv_operator = static_cast<cutlass::conv::Operator>(args.conv_kind);
|
||||
auto tc_A = cutlass::conv::implicit_gemm_tensor_a_extent(conv_operator, args.problem_size);
|
||||
auto tc_B = cutlass::conv::implicit_gemm_tensor_b_extent(conv_operator, args.problem_size);
|
||||
auto tc_C = cutlass::conv::implicit_gemm_tensor_c_extent(conv_operator, args.problem_size);
|
||||
auto tc_D = cutlass::conv::implicit_gemm_tensor_c_extent(conv_operator, args.problem_size);
|
||||
|
||||
auto size_C = tc_C.at(0) * tc_C.at(1) * tc_C.at(2) * tc_C.at(3);
|
||||
if (args.tensor_c_numel >= 0 && args.tensor_c_numel == tc_C.at(3) && args.tensor_c_numel < size_C) {
|
||||
// C is interpreted as bias
|
||||
tc_C = {0, 0, 0, 0};
|
||||
}
|
||||
|
||||
cutlass::TensorRef<ElementA, LayoutA> tref_A(args.ptr_A, LayoutA::packed(tc_A));
|
||||
cutlass::TensorRef<ElementB, LayoutA> tref_B(args.ptr_B, LayoutB::packed(tc_B));
|
||||
cutlass::TensorRef<ElementC, LayoutA> tref_C(args.ptr_C, LayoutC::packed(tc_C));
|
||||
cutlass::TensorRef<ElementC, LayoutA> tref_D(args.ptr_D, LayoutC::packed(tc_D));
|
||||
|
||||
return {
|
||||
args.problem_size,
|
||||
tref_A,
|
||||
tref_B,
|
||||
tref_C,
|
||||
tref_D,
|
||||
args.epilogue_params,
|
||||
static_cast<cutlass::conv::SplitKMode>(args.split_k_mode)
|
||||
};
|
||||
}
|
||||
|
||||
// Get the params as byte array
|
||||
char* ${operation_name}_get_params(${operation_name}${operation_suffix}::Arguments* arguments, int *semaphore=nullptr){
|
||||
char* ${operation_name}_get_params(${operation_name}_TemporaryArgs args, int *semaphore=nullptr) {
|
||||
auto arguments = construct_arguments(args);
|
||||
typename ${operation_name}${operation_suffix}::Params* params;
|
||||
params = new ${operation_name}${operation_suffix}::Params(*arguments, semaphore);
|
||||
params = new ${operation_name}${operation_suffix}::Params(arguments, semaphore);
|
||||
|
||||
char *bytes = ((char*)(params));
|
||||
char *output = new char[sizeof(${operation_name}${operation_suffix}::Params)];
|
||||
for (unsigned int i = 0; i < sizeof(${operation_name}${operation_suffix}::Params); i ++)
|
||||
output[i] = bytes[i];
|
||||
output[i] = bytes[i];
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
dim3 ${operation_name}_get_grid_shape(
|
||||
int conv_kind,
|
||||
cutlass::conv::Conv2dProblemSize problem_size,
|
||||
cutlass::gemm::GemmCoord tile_size,
|
||||
int split_k_slices
|
||||
) {
|
||||
|
||||
using Swizzle = typename ${operation_name}_base::ThreadblockSwizzle;
|
||||
auto tiled_shape = Swizzle::get_tiled_shape(
|
||||
static_cast<cutlass::conv::Operator>(conv_kind),
|
||||
problem_size,
|
||||
tile_size,
|
||||
split_k_slices);
|
||||
|
||||
return Swizzle::get_grid_shape(tiled_shape);
|
||||
}
|
||||
|
||||
size_t ${operation_name}_get_workspace_size(${operation_name}_TemporaryArgs args) {
|
||||
auto arguments = construct_arguments(args);
|
||||
|
||||
// Temporarily define device::-level Conv2d so that we can call get_workspace_size
|
||||
using DeviceConv = cutlass::conv::device::ImplicitGemmConvolution<${operation_name}_base>;
|
||||
return DeviceConv::get_workspace_size(arguments);
|
||||
}
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, operation: "Conv2dOperation"):
|
||||
super().__init__(operation)
|
||||
self.extra_funcs = {
|
||||
"get_grid_shape": dim3_,
|
||||
"get_workspace_size": ctypes.c_uint64
|
||||
}
|
||||
self.argument_type, self.epilogue_type = get_conv2d_arguments(operation.epilogue_functor)
|
||||
self.argtype = [ctypes.POINTER(self.argument_type), ctypes.c_void_p]
|
||||
self.conv_kind = operation.conv_kind
|
||||
@ -313,47 +330,27 @@ extern "C" {
|
||||
|
||||
self.emitter = EmitConv2dInstance("_type")
|
||||
|
||||
self.threads: int = operation.tile_description.num_threads
|
||||
self.threads = operation.tile_description.num_threads
|
||||
|
||||
self.swizzle_functor = operation.swizzling_functor
|
||||
|
||||
def emit(self):
|
||||
return self.emitter.emit(self.operation)
|
||||
|
||||
def get_device_workspace_size(self, arguments: Conv2dArguments):
|
||||
workspace_bytes = 0
|
||||
|
||||
launch_config = arguments.launch_config
|
||||
|
||||
self.conv_kind = self.operation.conv_kind
|
||||
|
||||
if arguments.split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel:
|
||||
problem_size = arguments.problem_size
|
||||
workspace_bytes = DataTypeSize[self.operation.C.element] \
|
||||
* launch_config.grid[2] * cutlass_bindings.conv.implicit_gemm_tensor_c_size(
|
||||
self.conv_kind, problem_size
|
||||
) // 8
|
||||
|
||||
elif arguments.split_k_mode == cutlass_bindings.conv.SplitKMode.Serial and \
|
||||
arguments.split_k_slices > 1:
|
||||
workspace_bytes = launch_config.grid[0] * launch_config.grid[1] * 4
|
||||
|
||||
return workspace_bytes
|
||||
|
||||
# @typechecked
|
||||
def plan(self, arguments: Conv2dArguments):
|
||||
tile_size = cutlass_bindings.gemm.GemmCoord(
|
||||
tile_size = GemmCoord(
|
||||
self.operation.tile_description.threadblock_shape[0],
|
||||
self.operation.tile_description.threadblock_shape[1],
|
||||
self.operation.tile_description.threadblock_shape[2],
|
||||
)
|
||||
|
||||
grid = self.swizzle_functor.get_grid_shape(
|
||||
self.swizzle_functor.get_tiled_shape(
|
||||
self.conv_kind, arguments.problem_size,
|
||||
tile_size, arguments.split_k_slices
|
||||
)
|
||||
grid = self.get_grid_shape(
|
||||
int(self.conv_kind),
|
||||
arguments.problem_size.ctype,
|
||||
tile_size.ctype,
|
||||
arguments.split_k_slices
|
||||
)
|
||||
|
||||
return LaunchConfiguration(
|
||||
[grid.x, grid.y, grid.z], [self.threads, 1, 1],
|
||||
self.shared_memory_capacity)
|
||||
@ -364,7 +361,7 @@ extern "C" {
|
||||
attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
value=self.shared_memory_capacity)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("Cuda Error: {}".format(err))
|
||||
raise RuntimeError(f"CUDA Error: {err}")
|
||||
|
||||
|
||||
class Conv2dOperation:
|
||||
@ -372,11 +369,11 @@ class Conv2dOperation:
|
||||
CUTLASS Conv2d operation description.
|
||||
|
||||
:param conv_kind: convolution operator
|
||||
:type conv_kind: :class:`cutlass_bindings.conv.Operator`
|
||||
:type conv_kind: :class:`cutlass_library.library.ConvKind`
|
||||
|
||||
:param iterator_algorithm: Selects among several implementation
|
||||
:param iterator_algorithm: Selects among several implementation
|
||||
variants trading off performance with simplicity
|
||||
:type iterator_algorithm: :class:`cutlass_bindings.conv.IteratorAlgorithm`
|
||||
:type iterator_algorithm: :class:`cutlass_library.library.IteratorAlgorithm`
|
||||
|
||||
:param arch: GPU compute capability (sm_xx)
|
||||
:type arch: int
|
||||
@ -397,12 +394,11 @@ class Conv2dOperation:
|
||||
:type D: :class:`cutlass.backend.TensorDescription`
|
||||
|
||||
:param element_epilogue: element type for computation in epilogue \
|
||||
:type element_epilogue: cutlass_bindings.int8 | cutlass_bindings.int32 | cutlass_bindings.float16 | \
|
||||
cutlass_bindings.bfloat16 | cutlass_bindings.float32 | cutlass_bindings.float64
|
||||
:type element_epilogue: cutlass_library.library.DataType
|
||||
|
||||
:param stride_support: distinguish among partial specializations that \
|
||||
accelerate certain problems where convolution stride is unit \
|
||||
:type stride_support: :class:`cutlass_bindings.conv.StrideSupport`
|
||||
:type stride_support: :class:`cutlass_library.library.StrideSupport`
|
||||
|
||||
:param epilogue_functor: convolution epilogue functor
|
||||
:type epilogue_functor: :class:`EpilogueFunctor`
|
||||
@ -411,8 +407,8 @@ class Conv2dOperation:
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
conv_kind: cutlass_bindings.conv.Operator,
|
||||
iterator_algorithm: cutlass_bindings.conv.IteratorAlgorithm,
|
||||
conv_kind,
|
||||
iterator_algorithm,
|
||||
arch: int,
|
||||
tile_description: TileDescription,
|
||||
A: TensorDescription,
|
||||
@ -420,7 +416,7 @@ class Conv2dOperation:
|
||||
C: TensorDescription,
|
||||
stride_support,
|
||||
epilogue_functor,
|
||||
swizzling_functor=cutlass_bindings.IdentitySwizzle1,
|
||||
swizzling_functor=SwizzlingFunctor.Identity1,
|
||||
emission_type=EmissionType.Kernel,
|
||||
**kwargs
|
||||
):
|
||||
@ -434,8 +430,8 @@ class Conv2dOperation:
|
||||
self.epilogue_functor = epilogue_functor
|
||||
self.iterator_algorithm = iterator_algorithm
|
||||
self.stride_support = stride_support
|
||||
self.swizzling_functor = swizzling_functor()
|
||||
|
||||
self.swizzling_functor = swizzling_functor
|
||||
|
||||
self.emission_type = emission_type
|
||||
|
||||
self.rt_module: Conv2dRT = Conv2dRT(self)
|
||||
@ -458,7 +454,7 @@ class Conv2dOperation:
|
||||
)
|
||||
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
raise RuntimeError(f"CUDA Error {err}")
|
||||
|
||||
return err
|
||||
|
||||
@ -470,8 +466,6 @@ class Conv2dOperation:
|
||||
"""The full procedural name indicates architecture, extended name, tile size, and layout."""
|
||||
return self.configuration_name()
|
||||
|
||||
#
|
||||
|
||||
def configuration_name(self):
|
||||
"""The full procedural name indicates architecture, extended name, tile size, and layout."""
|
||||
|
||||
@ -503,7 +497,6 @@ class Conv2dOperation:
|
||||
},
|
||||
)
|
||||
|
||||
#
|
||||
def extended_name(self):
|
||||
"""Append data types if they differ from compute type."""
|
||||
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
|
||||
@ -523,17 +516,15 @@ class Conv2dOperation:
|
||||
|
||||
return extended_name
|
||||
|
||||
#
|
||||
def layout_name(self):
|
||||
return "%s" % (ShortLayoutTypeNames[self.A.layout])
|
||||
|
||||
#
|
||||
def core_name(self):
|
||||
"""The basic operation kind is prefixed with a letter indicating the accumulation type."""
|
||||
|
||||
intermediate_type = ""
|
||||
|
||||
if self.tile_description.math_instruction.opcode_class == cutlass_bindings.OpClass.TensorOp:
|
||||
if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp:
|
||||
inst_shape = "%dx%dx%d" % tuple(
|
||||
self.tile_description.math_instruction.instruction_shape)
|
||||
if self.tile_description.math_instruction.element_a != self.A.element and \
|
||||
@ -550,7 +541,6 @@ class Conv2dOperation:
|
||||
IteratorAlgorithmNames[self.iterator_algorithm]
|
||||
)
|
||||
|
||||
#
|
||||
def is_complex(self):
|
||||
complex_operators = [
|
||||
MathOperation.multiply_add_complex,
|
||||
@ -558,7 +548,6 @@ class Conv2dOperation:
|
||||
]
|
||||
return self.tile_description.math_instruction.math_operation in complex_operators
|
||||
|
||||
#
|
||||
def accumulator_type(self):
|
||||
accum = self.tile_description.math_instruction.element_accumulator
|
||||
|
||||
@ -570,16 +559,17 @@ class Conv2dOperation:
|
||||
def device_op(self):
|
||||
"""
|
||||
Returns a new Conv2dOperation object that is constructed with emission type
|
||||
``EmissionType.Device``.
|
||||
|
||||
``EmissionType.Device``.
|
||||
|
||||
:return: operation ready for device-level code emission
|
||||
:rtype: Conv2dOperation
|
||||
"""
|
||||
return Conv2dOperation(
|
||||
self.conv_kind, self.iterator_algorithm, self.arch, self.tile_description,
|
||||
self.A, self.B, self.C, self.stride_support, self.epilogue_functor, type(self.swizzling_functor),
|
||||
self.A, self.B, self.C, self.stride_support, self.epilogue_functor, self.swizzling_functor,
|
||||
emission_type=EmissionType.Device)
|
||||
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
# Emits single instances of a CUTLASS device-wide operator
|
||||
@ -594,17 +584,18 @@ class EmitConv2dInstance:
|
||||
"cutlass/cutlass.h",
|
||||
"cutlass/conv/kernel/default_conv2d_fprop.h",
|
||||
"cutlass/conv/kernel/default_conv2d_dgrad.h",
|
||||
"cutlass/conv/kernel/default_conv2d_wgrad.h"
|
||||
"cutlass/conv/kernel/default_conv2d_wgrad.h",
|
||||
"cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
]
|
||||
self.template = """
|
||||
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
|
||||
using ${operation_name}_base =
|
||||
using ${operation_name}_base =
|
||||
typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
|
||||
${element_a},
|
||||
${element_a},
|
||||
${layout_a},
|
||||
${element_b},
|
||||
${element_b},
|
||||
${layout_b},
|
||||
${element_c},
|
||||
${element_c},
|
||||
${layout_c},
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
@ -631,11 +622,11 @@ struct ${operation_name}${operation_suffix}:
|
||||
// Conv2d operation ${operation_name}
|
||||
|
||||
using Conv2d${conv_kind_name}Kernel = typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
|
||||
${element_a},
|
||||
${element_a},
|
||||
${layout_a},
|
||||
${element_b},
|
||||
${element_b},
|
||||
${layout_b},
|
||||
${element_c},
|
||||
${element_c},
|
||||
${layout_c},
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
@ -689,7 +680,7 @@ using DeviceKernel =
|
||||
"instruction_shape_k": str(operation.tile_description.math_instruction.instruction_shape[2]),
|
||||
"epilogue_vector_length": str(epilogue_vector_length),
|
||||
"epilogue_functor": operation.epilogue_functor.emit(),
|
||||
"swizzling_functor": operation.swizzling_functor.tag(),
|
||||
"swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
|
||||
"stages": str(operation.tile_description.stages),
|
||||
"iterator_algorithm": IteratorAlgorithmTag[operation.iterator_algorithm],
|
||||
"iterator_algorithm_name": IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(),
|
||||
@ -698,7 +689,7 @@ using DeviceKernel =
|
||||
"align_a": str(operation.A.alignment),
|
||||
"align_b": str(operation.B.alignment),
|
||||
}
|
||||
|
||||
|
||||
if operation.emission_type == EmissionType.Kernel:
|
||||
conv2d_template = self.template
|
||||
else:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,6 @@
|
||||
################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
@ -30,7 +30,5 @@
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from cutlass.backend.test.conv2d_testbed import *
|
||||
from cutlass.backend.test.gemm_grouped_testbed import *
|
||||
from cutlass.backend.test.gemm_testbed import *
|
||||
from cutlass.backend.test.profiler import *
|
||||
from cutlass.backend.evt.epilogue import EpilogueFunctorVisitor
|
||||
from cutlass.backend.evt.frontend import PythonASTFrontend
|
||||
36
python/cutlass/backend/evt/backend/__init__.py
Normal file
36
python/cutlass/backend/evt/backend/__init__.py
Normal file
@ -0,0 +1,36 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass.backend.evt.backend.sm80_emitter import Sm80Emitter
|
||||
import cutlass.backend.evt.backend.sm80_nodes as sm80_nodes
|
||||
from cutlass.backend.evt.backend.sm90_emitter import Sm90Emitter
|
||||
import cutlass.backend.evt.backend.sm90_nodes as sm90_nodes
|
||||
158
python/cutlass/backend/evt/backend/emitter_base.py
Normal file
158
python/cutlass/backend/evt/backend/emitter_base.py
Normal file
@ -0,0 +1,158 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Base class for Epilogue Visitor Emitter
|
||||
"""
|
||||
|
||||
from cutlass import DataTypeTag
|
||||
from cutlass.backend.evt.ir import TopoVisitorNode, DAGIR
|
||||
|
||||
|
||||
class FusionCallbacks:
|
||||
def __init__(self, dag_ir: DAGIR, cc: int, emit_CD=True) -> None:
|
||||
"""
|
||||
Emit the EVT fusion callbacks
|
||||
:param dag_ir: the DAG IR holding the epilogue visitor
|
||||
:param cc: compute capability
|
||||
:param emit_CD: whether to emit nodes C & D as a part of the fusion callbacks
|
||||
For Sm90, set emit_CD=False, as Tensor C & D are hardcoded in the collective API
|
||||
so that their shared memory can be explicitly reused
|
||||
For Sm89, set emit_CD=True as they are treated as normal AuxLoad & AuxStore nodes.
|
||||
"""
|
||||
self.dag_ir = dag_ir
|
||||
self.emit_CD = emit_CD
|
||||
self.cc = cc
|
||||
if self.cc < 90:
|
||||
self.namespace = "threadblock"
|
||||
else:
|
||||
self.namespace = "fusion"
|
||||
|
||||
#
|
||||
# Helper functions
|
||||
#
|
||||
|
||||
def get_visitor_name(self, node: str):
|
||||
"""
|
||||
Get the visitor name
|
||||
"""
|
||||
meta = self.dag_ir.get_node_meta(node)
|
||||
if not isinstance(meta, TopoVisitorNode) and self.dag_ir.in_degree(node) > 0:
|
||||
return f"EVT{meta.name_camel}"
|
||||
else:
|
||||
return meta.name_camel
|
||||
|
||||
def emit(self):
|
||||
node_metas = self.dag_ir.node_metas_topological_order()
|
||||
epilogue_str = ""
|
||||
# Step 1: emit individual node type decl
|
||||
# emit the EVT & DAG connector
|
||||
for meta in node_metas:
|
||||
if not meta.disabled:
|
||||
epilogue_str += self.emit_node(meta)
|
||||
if not self.emit_CD and meta.name == "D":
|
||||
continue
|
||||
if isinstance(meta, TopoVisitorNode):
|
||||
epilogue_str += self.emit_dag(meta)
|
||||
else:
|
||||
epilogue_str += self.emit_evt(meta)
|
||||
|
||||
# Step 2: post-processing & get callback name
|
||||
if not self.emit_CD:
|
||||
if not self.dag_ir.has_node("C"):
|
||||
epilogue_str += "using ElementC = void;\nusing StrideC = StrideD;\n"
|
||||
output_node = self.dag_ir.get_all_inputs("D")[0]
|
||||
# The callback is the src of node D
|
||||
callback_name = self.get_visitor_name(output_node)
|
||||
else:
|
||||
# The callback is the last node in the topological order
|
||||
callback_name = self.get_visitor_name(node_metas[-1].name)
|
||||
return epilogue_str, callback_name
|
||||
|
||||
def emit_evt(self, node):
|
||||
if self.dag_ir.in_degree(node.name) == 0:
|
||||
return ""
|
||||
|
||||
evt_tmp = f"""
|
||||
using EVT{node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.cc}EVT<
|
||||
{node.name_camel},
|
||||
"""
|
||||
sorted_children = self.dag_ir.get_all_inputs(node.name)
|
||||
evt_node_strs = [f" {self.get_visitor_name(child_name)}" for child_name in sorted_children]
|
||||
evt_tmp += ",\n".join(evt_node_strs) + ">;\n"
|
||||
|
||||
return evt_tmp
|
||||
|
||||
def emit_dag(self, node):
|
||||
subgraph = node.subgraph
|
||||
subgraph_nodes = subgraph.nodes_topological_order()
|
||||
# Emit the Edge Tuple
|
||||
edge_tuples = "cute::tuple<\n"
|
||||
for n in subgraph_nodes[:-1]:
|
||||
in_edges = subgraph.in_edges(n)
|
||||
edge_weights = [subgraph.get_edge_weight(edge[0], edge[1]) for edge in in_edges]
|
||||
sorted_children = [edge[0] for _, edge in sorted(zip(edge_weights, in_edges))]
|
||||
edge_tuple = " cute::seq<"
|
||||
edge_str = [str(subgraph_nodes.index(child)) for child in sorted_children]
|
||||
edge_tuple += ", ".join(edge_str) + ">,\n"
|
||||
|
||||
edge_tuples += edge_tuple
|
||||
edge_tuples += " >"
|
||||
|
||||
# Emit the node list
|
||||
dag_nodes = ""
|
||||
dag_node_strs = []
|
||||
for n in subgraph_nodes[:-1]:
|
||||
n_meta = subgraph.get_node_meta(n)
|
||||
if n_meta.disabled:
|
||||
dag_node_strs.append(f" {self.get_visitor_name(n)}")
|
||||
else:
|
||||
dag_node_strs.append(f" {n_meta.name_camel}")
|
||||
dag_nodes = ",\n".join(dag_node_strs)
|
||||
|
||||
return f"""
|
||||
using {node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.cc}TopologicalVisitor<
|
||||
{DataTypeTag[node.subgraph.element_compute]},
|
||||
{edge_tuples},
|
||||
{dag_nodes}
|
||||
>;
|
||||
"""
|
||||
|
||||
def emit_node(self, node):
|
||||
if isinstance(node, TopoVisitorNode):
|
||||
emission = ""
|
||||
for node in node.subgraph.node_metas_topological_order():
|
||||
if not node.disabled:
|
||||
emission += self.emit_node(node)
|
||||
return emission
|
||||
else:
|
||||
return node.underlying_impl.type_decl
|
||||
47
python/cutlass/backend/evt/backend/sm80_emitter.py
Normal file
47
python/cutlass/backend/evt/backend/sm80_emitter.py
Normal file
@ -0,0 +1,47 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Emitter for Sm80 Epilogue Visitor
|
||||
"""
|
||||
|
||||
from cutlass.backend.evt.backend.emitter_base import FusionCallbacks
|
||||
from cutlass.backend import GemmOperationUniversal
|
||||
|
||||
|
||||
class Sm80Emitter:
|
||||
def __init__(self, operation: GemmOperationUniversal, graph) -> None:
|
||||
self.fusion_callbacks = FusionCallbacks(graph, cc=80)
|
||||
|
||||
def emit(self):
|
||||
callback_decl, callback_name = self.fusion_callbacks.emit()
|
||||
return callback_name, callback_decl
|
||||
258
python/cutlass/backend/evt/backend/sm80_nodes.py
Normal file
258
python/cutlass/backend/evt/backend/sm80_nodes.py
Normal file
@ -0,0 +1,258 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass import DataTypeTag
|
||||
|
||||
from cutlass.backend.evt.ir import (
|
||||
# Load Node
|
||||
AccumulatorImpl,
|
||||
AuxLoadImpl,
|
||||
ColumnBroadcastImpl,
|
||||
LoadNode,
|
||||
LoadSrcImpl,
|
||||
RowBroadcastImpl,
|
||||
ScalarBroadcastImpl,
|
||||
# Compute Node
|
||||
ComputeImpl,
|
||||
# Store Node
|
||||
AuxStoreImpl,
|
||||
ColumnReductionImpl,
|
||||
RowReductionImpl,
|
||||
ScalarReductionImpl
|
||||
)
|
||||
|
||||
from cutlass.backend.library import (
|
||||
FloatRoundStyleTag,
|
||||
FunctionalOp,
|
||||
op_tag,
|
||||
)
|
||||
|
||||
|
||||
class Sm80AccumulatorImpl(AccumulatorImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""\nusing {self.name_camel} = cutlass::epilogue::threadblock::VisitorAccFetch;\n"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80AuxLoadImpl(AuxLoadImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorAuxLoad<
|
||||
OutputTileThreadMap, {DataTypeTag[self.element]}, {self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80LoadSrcImpl(Sm80AuxLoadImpl):
|
||||
pass
|
||||
|
||||
|
||||
class Sm80ScalarBroadcastImpl(ScalarBroadcastImpl):
|
||||
def __init__(self, node: LoadNode) -> None:
|
||||
super().__init__(node)
|
||||
self.broadcast_count = 1
|
||||
self.reduction_fn = FunctionalOp.Multiplies
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorScalarBroadcast<
|
||||
{DataTypeTag[self.element]}, {self.stride_mnl}, {self.broadcast_count}, {op_tag(self.reduction_fn)}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80RowBroadcastImpl(RowBroadcastImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorRowBroadcast<
|
||||
OutputTileThreadMap, {DataTypeTag[self.element]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80ColumnBroadcastImpl(ColumnBroadcastImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorColBroadcast<
|
||||
OutputTileThreadMap, {DataTypeTag[self.element]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80ComputeImpl(ComputeImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
{op_tag(self.fn)}, {DataTypeTag[self.element_output]}, {DataTypeTag[self.element_compute]},
|
||||
{FloatRoundStyleTag[self.round_style]}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80AuxStoreImpl(AuxStoreImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorAuxStore<
|
||||
OutputTileThreadMap, {DataTypeTag[self.element]}, {FloatRoundStyleTag[self.round_style]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80StoreDImpl(Sm80AuxStoreImpl):
|
||||
pass
|
||||
|
||||
|
||||
class Sm80ColumnReductionImpl(ColumnReductionImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorColReduction<
|
||||
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
|
||||
OutputTileThreadMap, {DataTypeTag[self.element]},
|
||||
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80RowReductionImpl(RowReductionImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorRowReduction<
|
||||
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
|
||||
OutputTileThreadMap, {DataTypeTag[self.element]},
|
||||
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80ScalarReductionImpl(ScalarReductionImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorScalarReduction<
|
||||
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
|
||||
OutputTileThreadMap, {DataTypeTag[self.element]},
|
||||
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
98
python/cutlass/backend/evt/backend/sm90_emitter.py
Normal file
98
python/cutlass/backend/evt/backend/sm90_emitter.py
Normal file
@ -0,0 +1,98 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Emitter for Sm90 Epilogue Visitor
|
||||
"""
|
||||
|
||||
from cutlass import DataTypeTag, EpilogueScheduleTag
|
||||
from cutlass.backend import GemmOperationUniversal
|
||||
from cutlass.backend.evt.backend.emitter_base import FusionCallbacks
|
||||
|
||||
|
||||
class CollectiveEpilogue:
|
||||
def __init__(self, tile_description,
|
||||
schedule,
|
||||
element_c,
|
||||
element_d,
|
||||
fusion_callbacks) -> None:
|
||||
|
||||
self.cta_tile_mnk = tile_description.threadblock_shape
|
||||
self.element_c = element_c
|
||||
self.element_d = element_d
|
||||
self.schedule = schedule
|
||||
self.fusion_callbacks = fusion_callbacks
|
||||
|
||||
@property
|
||||
def CtaTileMNK(self) -> str:
|
||||
"""
|
||||
The threadblock shape
|
||||
"""
|
||||
return f"cute::Shape<_{self.cta_tile_mnk[0]}, _{self.cta_tile_mnk[1]}, _{self.cta_tile_mnk[2]}>"
|
||||
|
||||
@property
|
||||
def EpilogueTileType(self) -> str:
|
||||
"""
|
||||
The epilogue tile type
|
||||
"""
|
||||
return "cutlass::epilogue::collective::EpilogueTileAuto"
|
||||
|
||||
@property
|
||||
def Schedule(self) -> str:
|
||||
return EpilogueScheduleTag[self.schedule]
|
||||
|
||||
def emit(self):
|
||||
callback_decl, callback_name = self.fusion_callbacks.emit()
|
||||
return callback_name, f"""
|
||||
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
||||
{self.CtaTileMNK}, {self.EpilogueTileType},
|
||||
{DataTypeTag[self.element_c]}, {DataTypeTag[self.element_d]},
|
||||
{self.Schedule}
|
||||
>;
|
||||
{callback_decl}
|
||||
"""
|
||||
|
||||
|
||||
class Sm90Emitter:
|
||||
def __init__(self, operation: GemmOperationUniversal, graph) -> None:
|
||||
fusion_callbacks = FusionCallbacks(graph, cc=90, emit_CD=False)
|
||||
|
||||
self.collective_epilogue = CollectiveEpilogue(
|
||||
tile_description=operation.tile_description,
|
||||
schedule=operation.tile_description.epilogue_schedule,
|
||||
element_c=operation.C.element,
|
||||
element_d=operation.C.element,
|
||||
fusion_callbacks=fusion_callbacks
|
||||
)
|
||||
|
||||
def emit(self):
|
||||
return self.collective_epilogue.emit()
|
||||
351
python/cutlass/backend/evt/backend/sm90_nodes.py
Normal file
351
python/cutlass/backend/evt/backend/sm90_nodes.py
Normal file
@ -0,0 +1,351 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from pycute import product
|
||||
|
||||
from cutlass import DataTypeSize, DataTypeTag
|
||||
from cutlass.backend.evt.ir import (
|
||||
# Load Node
|
||||
AccumulatorImpl,
|
||||
AuxLoadImpl,
|
||||
ColumnBroadcastImpl,
|
||||
LoadNode,
|
||||
LoadSrcImpl,
|
||||
RowBroadcastImpl,
|
||||
ScalarBroadcastImpl,
|
||||
# Compute Node
|
||||
ComputeImpl,
|
||||
ComputeNode,
|
||||
# Store Node
|
||||
AuxStoreImpl,
|
||||
ColumnReductionImpl,
|
||||
RowReductionImpl,
|
||||
ScalarReductionImpl,
|
||||
StoreNode,
|
||||
StoreDImpl,
|
||||
)
|
||||
from cutlass.backend.library import (
|
||||
FloatRoundStyleTag,
|
||||
FunctionalOp,
|
||||
op_tag,
|
||||
)
|
||||
|
||||
|
||||
class Sm90AccumulatorImpl(AccumulatorImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""\nusing {self.name_camel} = cutlass::epilogue::fusion::Sm90AccFetch;\n"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90LoadSrcImpl(LoadSrcImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using ElementC = {DataTypeTag[self.element]};
|
||||
using StrideC = {self.stride_mnl};
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90SrcFetch;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90AuxLoadImpl(AuxLoadImpl):
|
||||
|
||||
@property
|
||||
def descriptor(self) -> str:
|
||||
"""
|
||||
Descriptor for Aux Load
|
||||
"""
|
||||
return f"{self.name_camel}Descriptor"
|
||||
|
||||
def decl_descriptor(self) -> str:
|
||||
"""
|
||||
Declare the descriptor type
|
||||
"""
|
||||
return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::AuxLoadDescriptor<EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}>;\n"
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = self.decl_descriptor()
|
||||
self._type_decl += f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxLoad<
|
||||
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
|
||||
{self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, typename {self.descriptor}::CopyOpS2R
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
|
||||
"""
|
||||
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
|
||||
"""
|
||||
return (DataTypeSize[self.element] * stages_c * product(epilogue_tile_mn) // 8, 128)
|
||||
|
||||
|
||||
class Sm90ScalarBroadcastImpl(ScalarBroadcastImpl):
|
||||
def __init__(self, node: LoadNode) -> None:
|
||||
super().__init__(node)
|
||||
self.broadcast_count = 1
|
||||
self.reduction_fn = FunctionalOp.Multiplies
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ScalarBroadcast<
|
||||
{DataTypeTag[self.element]}, {self.stride_mnl}, {self.broadcast_count}, {op_tag(self.reduction_fn)}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90RowBroadcastImpl(RowBroadcastImpl):
|
||||
|
||||
@property
|
||||
def descriptor(self) -> str:
|
||||
"""
|
||||
Descriptor for Aux Load
|
||||
"""
|
||||
return f"{self.name_camel}Descriptor"
|
||||
|
||||
def decl_descriptor(self) -> str:
|
||||
"""
|
||||
Declare the descriptor type
|
||||
"""
|
||||
return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::RowBroadcastDescriptor<EpilogueDescriptor, {DataTypeTag[self.element]}>;\n"
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = self.decl_descriptor()
|
||||
self._type_decl += f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
{self.descriptor}::Stages, typename EpilogueDescriptor::TileShape,
|
||||
typename {self.descriptor}::Element, {self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
|
||||
"""
|
||||
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
|
||||
"""
|
||||
stages = (stages_c + epi_tiles - 1) // epi_tiles + 1
|
||||
return (DataTypeSize[self.element] * cta_tile_mnk[1] * stages // 8, 16)
|
||||
|
||||
|
||||
class Sm90ColumnBroadcastImpl(ColumnBroadcastImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90ComputeImpl(ComputeImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90Compute<
|
||||
{op_tag(self.fn)}, {DataTypeTag[self.element_output]}, {DataTypeTag[self.element_compute]},
|
||||
{FloatRoundStyleTag[self.round_style]}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90AuxStoreImpl(AuxStoreImpl):
|
||||
|
||||
@property
|
||||
def descriptor(self) -> str:
|
||||
"""
|
||||
Descriptor for Aux Load
|
||||
"""
|
||||
return f"{self.name_camel}Descriptor"
|
||||
|
||||
def decl_descriptor(self) -> str:
|
||||
"""
|
||||
Declare the descriptor type
|
||||
"""
|
||||
return f"""
|
||||
using {self.descriptor} = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
|
||||
EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}
|
||||
>;
|
||||
"""
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = self.decl_descriptor()
|
||||
self._type_decl += f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxStore<
|
||||
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
|
||||
{FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom,
|
||||
typename {self.descriptor}::CopyOpR2S
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
|
||||
"""
|
||||
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
|
||||
"""
|
||||
return (DataTypeSize[self.element] * stages_d * product(epilogue_tile_mn) // 8, 128)
|
||||
|
||||
|
||||
class Sm90StoreDImpl(StoreDImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
return f"""
|
||||
using ElementD = {DataTypeTag[self.element]};
|
||||
using StrideD = {self.stride_mnl};
|
||||
"""
|
||||
|
||||
|
||||
class Sm90ColumnReductionImpl(ColumnReductionImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ColReduction<
|
||||
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, 0,
|
||||
typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]},
|
||||
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90RowReductionImpl(RowReductionImpl):
|
||||
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowReduction<
|
||||
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, 0 /* Stages */,
|
||||
typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]},
|
||||
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90ScalarReductionImpl(ScalarReductionImpl):
|
||||
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ScalarReduction<
|
||||
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
|
||||
{DataTypeTag[self.element]}, {DataTypeTag[self.element_compute]},
|
||||
{FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
165
python/cutlass/backend/evt/epilogue.py
Normal file
165
python/cutlass/backend/evt/epilogue.py
Normal file
@ -0,0 +1,165 @@
|
||||
################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
"""
|
||||
Epilogue Visitor interface for compiling, and running visitor-based epilogue.
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
|
||||
from cuda import cuda
|
||||
import numpy as np
|
||||
|
||||
from cutlass import DataType
|
||||
from cutlass.backend.epilogue import EpilogueFunctorBase
|
||||
import cutlass.backend.evt.backend
|
||||
from cutlass.backend.frontend import TensorFrontend
|
||||
|
||||
|
||||
class EpilogueFunctorVisitor(EpilogueFunctorBase):
|
||||
"""
|
||||
Apply an epilogue functor described by the epilogue EVT
|
||||
|
||||
:param cc: compute capability
|
||||
:param visitor_frontend: user-provide visitor frontend
|
||||
|
||||
"""
|
||||
def __init__(self, cc: int, visitor, element_compute=DataType.f32) -> None:
|
||||
# Type of Emitter based on CC
|
||||
self.emit_cls = getattr(cutlass.backend.evt.backend, f"Sm{cc}Emitter")
|
||||
|
||||
# Visitor Types
|
||||
self.visitor = visitor
|
||||
self.graph = visitor.dag_ir
|
||||
|
||||
# Data types
|
||||
self.element_epilogue = element_compute # element compute
|
||||
self.element_output = self.graph.get_node_meta('D').underlying_impl.element
|
||||
|
||||
# Epilogue Thread Type
|
||||
epilogue_thread_type = self.visitor.epilogue_thread_type
|
||||
if cc == 90:
|
||||
self.arg_c_type = self.visitor.arg_c_type
|
||||
self.arg_d_type = self.visitor.arg_d_type
|
||||
output_names = self.visitor.return_names
|
||||
reduction_names = self.visitor.reduction_names
|
||||
|
||||
# Epilogue stages specialized for sm80 kernel
|
||||
if cc == 80:
|
||||
if hasattr(self.visitor, "epilogue_stages"):
|
||||
self.epilogue_stages = self.visitor.epilogue_stages
|
||||
assert self.epilogue_stages <= 2, "Only supports Stages <=2 in SM80 Epilogue"
|
||||
|
||||
# Epilogue Argument Type
|
||||
class _Arguments(ctypes.Structure):
|
||||
"""
|
||||
Concepts:
|
||||
class _EpilogueArguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("epilogue", _Arguments), <- this class
|
||||
("ptr_C", ctypes.c_void_p),
|
||||
("stride_C", StrideBatched_),
|
||||
("ptr_D", ctypes.c_void_p),
|
||||
("stride_D", StrideBatched_)
|
||||
]
|
||||
"""
|
||||
_fields_ = [
|
||||
("output_op", epilogue_thread_type)
|
||||
]
|
||||
|
||||
def __init__(self, kwargs: dict) -> None:
|
||||
# The user-input kwargs is a dict of (name: tensors)
|
||||
# We first convert all of them to device pointers
|
||||
ptr_kwargs = {}
|
||||
for key in kwargs.keys():
|
||||
is_output = key in output_names and key not in reduction_names
|
||||
ptr_kwargs[key] = self.get_tensor_ptr(key, kwargs, is_output)
|
||||
# Initialize the thread arguments
|
||||
self.output_op = epilogue_thread_type(ptr_kwargs)
|
||||
|
||||
def get_tensor_ptr(self, tensor_name, kwargs, is_output=False):
|
||||
"""
|
||||
Helper function for extracting device pointer
|
||||
"""
|
||||
# Skip the special tensors
|
||||
if cc == 90:
|
||||
if tensor_name in ["C", "D"]:
|
||||
return 0
|
||||
if tensor_name not in kwargs.keys():
|
||||
raise ValueError(f"Tensor {tensor_name} is not provided.")
|
||||
tensor = kwargs[tensor_name]
|
||||
|
||||
# For float scalar constant, directly return the value
|
||||
if isinstance(tensor, float):
|
||||
return tensor
|
||||
|
||||
# The tensor frontend returns a device buffer for np.ndarray
|
||||
# and device ptr for other frontends
|
||||
buffer_or_ptr = TensorFrontend.argument(tensor, is_output)
|
||||
if isinstance(tensor, np.ndarray):
|
||||
# Remember the host tensor for later synchronization
|
||||
setattr(self, f"{tensor_name}_buffer", buffer_or_ptr)
|
||||
setattr(self, f"{tensor_name}_host", tensor)
|
||||
return int(buffer_or_ptr.ptr)
|
||||
else:
|
||||
return int(buffer_or_ptr)
|
||||
|
||||
def sync(self):
|
||||
"""
|
||||
Synchronize the results from device to host
|
||||
"""
|
||||
for name in output_names:
|
||||
if hasattr(self, f"{name}_host"):
|
||||
host_tensor = getattr(self, f"{name}_host")
|
||||
tensor_ptr = getattr(self, f"{name}_buffer").ptr
|
||||
(err,) = cuda.cuMemcpyDtoH(
|
||||
host_tensor,
|
||||
tensor_ptr,
|
||||
host_tensor.size * host_tensor.itemsize,
|
||||
)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
|
||||
self.epilogue_type = _Arguments
|
||||
|
||||
def emit(self, operation):
|
||||
"""
|
||||
Emit the C++ code
|
||||
"""
|
||||
emitter = self.emit_cls(operation, self.graph)
|
||||
return emitter.emit()
|
||||
|
||||
def get_smem_size(self, tile_description):
|
||||
"""
|
||||
Get the shared memory size in bytes
|
||||
"""
|
||||
return self.visitor.get_smem_size(tile_description)
|
||||
33
python/cutlass/backend/evt/frontend/__init__.py
Normal file
33
python/cutlass/backend/evt/frontend/__init__.py
Normal file
@ -0,0 +1,33 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass.backend.evt.frontend.python_ast import PythonASTFrontend
|
||||
262
python/cutlass/backend/evt/frontend/frontend_base.py
Normal file
262
python/cutlass/backend/evt/frontend/frontend_base.py
Normal file
@ -0,0 +1,262 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Base class for Python EVT Frontend
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
|
||||
from cutlass import DataType
|
||||
from cutlass.backend.evt.ir import (
|
||||
ComputeNode,
|
||||
DAGIR,
|
||||
LayoutNode,
|
||||
LoadNode,
|
||||
StoreNode,
|
||||
)
|
||||
from cutlass.backend.evt.passes import (
|
||||
EVTGraphDrawer,
|
||||
EVTPassManager,
|
||||
GetSmemSize,
|
||||
PassDAG2Tree,
|
||||
PassGetArgumentType,
|
||||
PassGetImpl,
|
||||
PassFixElementD,
|
||||
PassLayoutManipulateElimination,
|
||||
PassPreprocessRed,
|
||||
PassShapeTypePropagation,
|
||||
)
|
||||
from cutlass.backend.utils import device_cc
|
||||
from cutlass.epilogue.evt_ops import permute, reshape
|
||||
from cutlass.utils.datatypes import library_type
|
||||
|
||||
|
||||
class EVTFrontendBase:
|
||||
layout_fns = {
|
||||
"permute": permute,
|
||||
"reshape": reshape
|
||||
}
|
||||
|
||||
def __init__(self, element_compute=DataType.f32, cc=None, additional_passes=[], **kwargs) -> None:
|
||||
self.cc = cc if cc else device_cc()
|
||||
self.element_compute = library_type(element_compute)
|
||||
self.dag_ir = DAGIR(self.element_compute, self.cc)
|
||||
self.compute_cnt = 0
|
||||
self.layout_cnt = 0
|
||||
|
||||
self.pass_manager = EVTPassManager(
|
||||
self.dag_ir,
|
||||
[
|
||||
PassPreprocessRed,
|
||||
PassGetArgumentType,
|
||||
PassShapeTypePropagation,
|
||||
PassLayoutManipulateElimination,
|
||||
PassGetImpl,
|
||||
PassDAG2Tree,
|
||||
PassFixElementD
|
||||
] + additional_passes)
|
||||
|
||||
if self.cc == 80:
|
||||
self._epilogue_stages = 1
|
||||
else:
|
||||
self._epilogue_stages = None
|
||||
|
||||
@property
|
||||
def epilogue_stages(self):
|
||||
return self._epilogue_stages
|
||||
|
||||
@epilogue_stages.setter
|
||||
def epilogue_stages(self, stages):
|
||||
self._epilogue_stages = stages
|
||||
|
||||
|
||||
def parse(self, *args, **kwargs):
|
||||
raise NotImplementedError(f"The 'parse' function must be overloaded in frontend class")
|
||||
|
||||
def trace(self, *args, **kwargs):
|
||||
# Parse the input
|
||||
self.parse(*args, **kwargs)
|
||||
|
||||
# Run the passes
|
||||
self.pass_manager()
|
||||
# Set the epilogue type
|
||||
self.epilogue_thread_type = self.dag_ir.epilogue_thread_type
|
||||
if self.cc == 90:
|
||||
self.arg_c_type = self.dag_ir.arg_c_type
|
||||
self.arg_d_type = self.dag_ir.arg_d_type
|
||||
self.reduction_names = self.dag_ir.reduction_names
|
||||
|
||||
#
|
||||
# Helper functions for DAG IR manipulation
|
||||
#
|
||||
|
||||
def add_node(self, node):
|
||||
self.dag_ir.add_node(node)
|
||||
|
||||
def add_edge(self, src, tgt, weight=0):
|
||||
self.dag_ir.add_edge(src, tgt, weight=weight)
|
||||
|
||||
def set_tensor(self, node_name, example):
|
||||
"""
|
||||
Add an example tensor to node {node_name} in the DAG IR
|
||||
"""
|
||||
meta = self.dag_ir.get_node_meta(node_name)
|
||||
meta.tensor = {"tensor": example}
|
||||
|
||||
def set_store_tensor(self, node_name, example):
|
||||
"""
|
||||
Add an example tensor to node {node_name} in the DAG IR
|
||||
"""
|
||||
meta = self.dag_ir.get_node_meta(node_name)
|
||||
meta.store_tensor = {"tensor": example}
|
||||
|
||||
def mark_output(self, node_name):
|
||||
"""
|
||||
Mark a store node as output
|
||||
"""
|
||||
meta = self.dag_ir.get_node_meta(node_name)
|
||||
if not isinstance(meta, StoreNode):
|
||||
raise ValueError(
|
||||
f"Only StoreNodes can be marked as output. "
|
||||
f"Got {type(meta).__name__}: {node_name}")
|
||||
meta.is_output = True
|
||||
|
||||
# Add node with specific type
|
||||
|
||||
def add_load_node(self, name, example):
|
||||
"""
|
||||
Add a Load node to DAG IR
|
||||
:param name: name of the loaded variable
|
||||
:type name: str
|
||||
:param example: example input
|
||||
:type example: np.ndarray|torch.Tensor|cupy.ndarray|float
|
||||
"""
|
||||
if name is None:
|
||||
raise ValueError(f"Name is not provided.")
|
||||
if example is None:
|
||||
raise ValueError(f"Example input for {name} is not provided.")
|
||||
load_node = LoadNode(name)
|
||||
load_node.tensor = {"tensor": example}
|
||||
# Special logics for accumulator
|
||||
if name == "accum":
|
||||
if load_node.tensor.rank == 2:
|
||||
new_shape = tuple([1, ] + list(load_node.tensor.shape))
|
||||
load_node.tensor.broadcast(new_shape)
|
||||
elif load_node.tensor.rank < 2 or load_node.tensor.rank > 3:
|
||||
raise ValueError(f"Expect example inputs for 'accum' be a rank-2 or rank-3 tensor. Got {load_node.tensor.shape}.")
|
||||
self.add_node(load_node)
|
||||
|
||||
def add_imm(self, value: Union[float,int]):
|
||||
"""
|
||||
Add an immediate scalar value to DAG IR
|
||||
:param value: the value of the immediate scalar
|
||||
:type value: float
|
||||
"""
|
||||
try:
|
||||
value = float(value)
|
||||
except:
|
||||
raise ValueError(f"{type(value).__name__} cannot be converted to float.")
|
||||
|
||||
name = f"imm_{value}".replace('.', '_')
|
||||
load_node = LoadNode(name)
|
||||
load_node.tensor = {"tensor": value, "is_constant": True}
|
||||
self.add_node(load_node)
|
||||
return name
|
||||
|
||||
def add_compute_node(self, op, name=None):
|
||||
"""
|
||||
Add a compute node.
|
||||
:param op: the computation op
|
||||
:param name: the node name (optional)
|
||||
:type name: str
|
||||
:return: the name of the compute node
|
||||
"""
|
||||
if name is None:
|
||||
name = f"compute_{self.compute_cnt}"
|
||||
self.compute_cnt += 1
|
||||
compute_node = ComputeNode(
|
||||
name=name, fn=op,
|
||||
element_output=self.element_compute,
|
||||
element_compute=self.element_compute)
|
||||
self.add_node(compute_node)
|
||||
return compute_node.name
|
||||
|
||||
def add_layout_node(self, op, kwargs, name=None):
|
||||
"""
|
||||
Add a layout node.
|
||||
:param op: the layout op
|
||||
:type op: evt_ops
|
||||
:param name: the node name (optional)
|
||||
:type name: str
|
||||
:return: the name of the layout node
|
||||
"""
|
||||
if name is None:
|
||||
name = f"layout_{self.layout_cnt}"
|
||||
self.layout_cnt += 1
|
||||
layout_node = LayoutNode(name=name, fn=op, kwargs=kwargs)
|
||||
self.add_node(layout_node)
|
||||
return layout_node.name
|
||||
|
||||
def add_store_node(self, name):
|
||||
store_node = StoreNode(name)
|
||||
self.add_node(store_node)
|
||||
|
||||
#
|
||||
# Visualization The DAG IR
|
||||
#
|
||||
|
||||
def visualize(self, name="dag_ir"):
|
||||
"""
|
||||
Visualize the dag ir with svg file
|
||||
:param name: the name of the graph
|
||||
"""
|
||||
drawer = EVTGraphDrawer(self.dag_ir, name)
|
||||
if drawer.dot_available:
|
||||
for name, graph in drawer.get_dot_graph():
|
||||
graph.write_svg(f"./{name}.svg")
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"'dot' is not found in path. GraphDrawer is disabled. "
|
||||
"Please install it with 'sudo apt-get install graphviz'."
|
||||
)
|
||||
|
||||
#
|
||||
# Get shared memory size
|
||||
#
|
||||
|
||||
def get_smem_size(self, tile_description):
|
||||
"""
|
||||
Get the shared memory size of the epilogue
|
||||
"""
|
||||
smem_size = GetSmemSize(self.dag_ir)(tile_description)
|
||||
return smem_size
|
||||
184
python/cutlass/backend/evt/frontend/python_ast.py
Normal file
184
python/cutlass/backend/evt/frontend/python_ast.py
Normal file
@ -0,0 +1,184 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Python AST frontend that parses input into DAG IR
|
||||
"""
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
|
||||
import cutlass
|
||||
from cutlass import DataType
|
||||
from cutlass.backend.evt.frontend.frontend_base import EVTFrontendBase
|
||||
from cutlass.backend.epilogue import relu
|
||||
from cutlass.backend.library import FunctionalOp
|
||||
|
||||
|
||||
class PythonASTFrontend(EVTFrontendBase, ast.NodeVisitor):
|
||||
def __init__(self, element_compute=DataType.f32, **kwargs):
|
||||
super().__init__(element_compute, **kwargs)
|
||||
# Flags
|
||||
# If this state is True, visit_Constant returns values without creating imm node
|
||||
self.no_imm = False
|
||||
self.visiting_return = False
|
||||
|
||||
def parse(self, example_inputs):
|
||||
self.example_inputs = example_inputs
|
||||
self.source = textwrap.dedent(inspect.getsource(self.__call__))
|
||||
self.ast = ast.parse(self.source)
|
||||
self.visit(self.ast)
|
||||
|
||||
#
|
||||
# Helper functions
|
||||
#
|
||||
@staticmethod
|
||||
def ast_op_to_bindings(op):
|
||||
mapping = {
|
||||
ast.Add: FunctionalOp.Plus,
|
||||
ast.Sub: FunctionalOp.Minus,
|
||||
ast.Mult: FunctionalOp.Multiplies,
|
||||
ast.Div: FunctionalOp.Divides,
|
||||
"relu": relu.binding_type,
|
||||
"multiply_add": FunctionalOp.MultiplyAdd,
|
||||
"sum": (FunctionalOp.Plus, FunctionalOp.AtomicAdd),
|
||||
"max": (FunctionalOp.Maximum, FunctionalOp.AtomicMaximum)
|
||||
}
|
||||
return mapping[op]
|
||||
|
||||
#
|
||||
# Visiting different node types
|
||||
#
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef):
|
||||
# Visit args and register load nodes
|
||||
for arg in node.args.args:
|
||||
self.visit(arg)
|
||||
for expr in node.body:
|
||||
self.visit(expr)
|
||||
|
||||
def visit_arg(self, node: ast.arg):
|
||||
# Name of the argument
|
||||
name = node.arg
|
||||
try:
|
||||
example_tensor = self.example_inputs[name]
|
||||
except:
|
||||
raise RuntimeError(f"Example input for {name} is not provided.")
|
||||
|
||||
self.add_load_node(name, example_tensor)
|
||||
|
||||
def visit_Name(self, node: ast.Name):
|
||||
return node.id
|
||||
|
||||
def visit_Constant(self, node: ast.Constant):
|
||||
if self.no_imm:
|
||||
return node.value
|
||||
else:
|
||||
name = self.add_imm(node.value)
|
||||
return name
|
||||
|
||||
def visit_Tuple(self, node: ast.Tuple):
|
||||
results = []
|
||||
for elt in node.elts:
|
||||
results.append(self.visit(elt))
|
||||
return tuple(results)
|
||||
|
||||
def visit_keyword(self, node: ast.keyword):
|
||||
return {node.arg: self.visit(node.value)}
|
||||
|
||||
def visit_BinOp(self, node: ast.BinOp):
|
||||
if self.visiting_return:
|
||||
raise SyntaxError("Return value cannot be an expression")
|
||||
lhs = self.visit(node.left)
|
||||
rhs = self.visit(node.right)
|
||||
op = self.ast_op_to_bindings(type(node.op))
|
||||
name = self.add_compute_node(op)
|
||||
|
||||
# Add edges
|
||||
# The edge weights are used to sort the input args
|
||||
self.add_edge(lhs, name, weight=0)
|
||||
self.add_edge(rhs, name, weight=1)
|
||||
return name
|
||||
|
||||
def visit_Assign(self, node: ast.BinOp):
|
||||
target = self.visit(node.targets[0])
|
||||
value = self.visit(node.value)
|
||||
# Create the assign node
|
||||
self.add_store_node(target)
|
||||
|
||||
# Add edges
|
||||
self.add_edge(value, target)
|
||||
return target
|
||||
|
||||
def visit_Call(self, node: ast.Call):
|
||||
if self.visiting_return:
|
||||
raise SyntaxError("Return value cannot be an expression")
|
||||
func = self.visit(node.func)
|
||||
args = [self.visit(arg) for arg in node.args]
|
||||
|
||||
if func in self.layout_fns.keys():
|
||||
# Parse kwargs
|
||||
# By default, visiting imm automatically creates a load node
|
||||
# However, in function call, keyword args are used to set
|
||||
# specific function attributes such as indices for permute
|
||||
# So no_imm is set to True temporarily
|
||||
self.no_imm = True
|
||||
kwargs = {}
|
||||
for kw in node.keywords:
|
||||
kwargs.update(self.visit(kw))
|
||||
self.no_imm = False
|
||||
op = self.layout_fns[func]
|
||||
name = self.add_layout_node(op, kwargs)
|
||||
else:
|
||||
op = self.ast_op_to_bindings(func)
|
||||
name = self.add_compute_node(op)
|
||||
|
||||
# Add edges
|
||||
for idx, arg in enumerate(args):
|
||||
self.add_edge(arg, name, weight=idx)
|
||||
return name
|
||||
|
||||
def visit_Return(self, node: ast.Return):
|
||||
self.visiting_return = True
|
||||
results = self.visit(node.value)
|
||||
self.visiting_return = False
|
||||
self.return_names = results
|
||||
if not isinstance(results, tuple):
|
||||
results = (results,)
|
||||
for rst in results:
|
||||
try:
|
||||
example_tensor = self.example_inputs[rst]
|
||||
except:
|
||||
raise RuntimeError(f"Example input for {rst} is not provided.")
|
||||
self.set_store_tensor(rst, example_tensor)
|
||||
self.mark_output(rst)
|
||||
53
python/cutlass/backend/evt/ir/__init__.py
Normal file
53
python/cutlass/backend/evt/ir/__init__.py
Normal file
@ -0,0 +1,53 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass.backend.evt.ir.compute_nodes import ComputeNode, ComputeImpl
|
||||
from cutlass.backend.evt.ir.dag_ir import DAGIR
|
||||
from cutlass.backend.evt.ir.layout_nodes import LayoutNode
|
||||
from cutlass.backend.evt.ir.load_nodes import (
|
||||
LoadNode,
|
||||
AccumulatorImpl,
|
||||
LoadSrcImpl,
|
||||
AuxLoadImpl,
|
||||
RowBroadcastImpl,
|
||||
ColumnBroadcastImpl,
|
||||
ScalarBroadcastImpl
|
||||
)
|
||||
from cutlass.backend.evt.ir.node import TopoVisitorNode, NoOpImpl
|
||||
from cutlass.backend.evt.ir.store_nodes import (
|
||||
StoreNode,
|
||||
StoreDImpl,
|
||||
AuxStoreImpl,
|
||||
ColumnReductionImpl,
|
||||
RowReductionImpl,
|
||||
ScalarReductionImpl
|
||||
)
|
||||
91
python/cutlass/backend/evt/ir/compute_nodes.py
Normal file
91
python/cutlass/backend/evt/ir/compute_nodes.py
Normal file
@ -0,0 +1,91 @@
|
||||
################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
"""
|
||||
Python registration for compute nodes in EVT
|
||||
"""
|
||||
|
||||
from cutlass.backend.evt.ir.node import NodeBase, ImplBase
|
||||
from cutlass.backend.library import FloatRoundStyle
|
||||
|
||||
|
||||
class ComputeImplBase(ImplBase):
|
||||
"""
|
||||
Base class for compute implementation
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
|
||||
|
||||
class ComputeImpl(ComputeImplBase):
|
||||
"""
|
||||
Implementation for Compute Node
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
|
||||
self.fn = node.fn
|
||||
self.element_output = node.element_output
|
||||
self.element_compute = node.element_compute
|
||||
self.round_style = node.round_style
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
return True
|
||||
|
||||
|
||||
class ComputeNode(NodeBase):
|
||||
"""
|
||||
Compute Node in DAG IR
|
||||
"""
|
||||
possible_impls = [
|
||||
ComputeImpl
|
||||
]
|
||||
def __init__(
|
||||
self, name: str, fn, element_output,
|
||||
element_compute,
|
||||
round_style=FloatRoundStyle.ToNearest) -> None:
|
||||
super().__init__(name)
|
||||
self.op = "compute"
|
||||
self.fn = fn
|
||||
self.element_compute = element_compute
|
||||
self.round_style = round_style
|
||||
|
||||
def type_propagation(self, *args, **kwargs):
|
||||
"""
|
||||
Load node loads tensor under type `tensor.element` and returns an array of type `tensor.element`.
|
||||
"""
|
||||
self.element = self.element_compute
|
||||
# In general, the compute nodes have element_output = element_compute
|
||||
# In certain cases like producer of D it is overwritten by other passes
|
||||
if not hasattr(self, "element_output"):
|
||||
self.element_output = self.element
|
||||
235
python/cutlass/backend/evt/ir/dag_ir.py
Normal file
235
python/cutlass/backend/evt/ir/dag_ir.py
Normal file
@ -0,0 +1,235 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
DAG IR used by Python EVT
|
||||
"""
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from cutlass import DataType
|
||||
from cutlass.backend.evt.ir.node import NodeBase
|
||||
from cutlass.backend.utils import device_cc
|
||||
|
||||
|
||||
class DAGIR:
|
||||
"""
|
||||
``DAGIR`` is the main data structure used in the EVT Intermediate Representation.
|
||||
It consists of a series of ``Node`` s, each representing epilogue visitor nodes.
|
||||
|
||||
In the DAGIR, ``node`` is an string of its name. ``node_meta`` is the underlying class of the node
|
||||
"""
|
||||
def __init__(self, element_compute=DataType.f32, cc: int=None) -> None:
|
||||
# The EVT DAGIR is managed through the nextworkX Digraph class
|
||||
self._graph = nx.DiGraph()
|
||||
|
||||
self.element_compute = element_compute
|
||||
|
||||
self.reduction_names = []
|
||||
|
||||
self.cc = cc if cc else device_cc()
|
||||
|
||||
#
|
||||
# IR manipulator
|
||||
#
|
||||
|
||||
def add_node(self, meta: NodeBase):
|
||||
"""
|
||||
Add a node to dag ir
|
||||
"""
|
||||
if self.has_node(meta.name):
|
||||
raise SyntaxError(f"Variable '{meta.name}' cannot be defined twice.")
|
||||
self._graph.add_node(meta.name, meta=meta)
|
||||
|
||||
def add_edge(self, src: str, dst: str, weight: int=0):
|
||||
"""
|
||||
Add an edge src -> dst to dag ir with weight
|
||||
"""
|
||||
if not self.has_node(src):
|
||||
raise SyntaxError(f"Variable '{src}' is undefined.")
|
||||
if not self.has_node(dst):
|
||||
raise SyntaxError(f"Variable '{dst}' is undefined.")
|
||||
self._graph.add_edge(src, dst, weight=weight)
|
||||
|
||||
def remove_node(self, node: str):
|
||||
"""
|
||||
Remove node from dag ir
|
||||
"""
|
||||
self._graph.remove_node(node)
|
||||
|
||||
def remove_edge(self, src: str, dst: str):
|
||||
"""
|
||||
Remove edge src -> dst
|
||||
"""
|
||||
self._graph.remove_edge(src, dst)
|
||||
|
||||
#
|
||||
# Helper functions for getting attrs
|
||||
#
|
||||
|
||||
def has_node(self, node: str) -> bool:
|
||||
"""
|
||||
Check if the node is in the graph
|
||||
"""
|
||||
return self._graph.has_node(node)
|
||||
|
||||
def in_degree(self, node: str):
|
||||
"""
|
||||
Get the input degree of node
|
||||
"""
|
||||
return self._graph.in_degree(node)
|
||||
|
||||
def in_edges(self, node: str):
|
||||
"""
|
||||
Get the input edges of node
|
||||
"""
|
||||
return [edge for edge in self._graph.in_edges(node)]
|
||||
|
||||
def out_degree(self, node: str):
|
||||
"""
|
||||
Get the output degree of node
|
||||
"""
|
||||
return self._graph.out_degree(node)
|
||||
|
||||
def out_edges(self, node: str):
|
||||
"""
|
||||
Get the output edges of node
|
||||
"""
|
||||
return [edge for edge in self._graph.out_edges(node)]
|
||||
|
||||
def get_node_meta(self, node: str):
|
||||
"""
|
||||
Get the meta data of the node
|
||||
"""
|
||||
return self._graph.nodes[node]["meta"]
|
||||
|
||||
def get_edge_weight(self, src, dst):
|
||||
"""
|
||||
Get the edge weight of edge src->dst
|
||||
"""
|
||||
return self._graph.get_edge_data(src, dst)["weight"]
|
||||
|
||||
#
|
||||
# High-level helper functions
|
||||
#
|
||||
|
||||
def all_reachable_nodes(self, node: str):
|
||||
"""
|
||||
Get all the nodes reachable from the current node (exclude)
|
||||
"""
|
||||
return list(nx.dfs_preorder_nodes(self._graph, source=node))
|
||||
|
||||
def get_users(self, node: str):
|
||||
"""
|
||||
Get all users of the current node
|
||||
"""
|
||||
return [edge[1] for edge in self.out_edges(node)]
|
||||
|
||||
def get_all_inputs(self, node: str):
|
||||
"""
|
||||
Get all the input nodes sorted by edge weight
|
||||
"""
|
||||
in_edges = self.in_edges(node)
|
||||
edge_weights = [self.get_edge_weight(*edge) for edge in in_edges]
|
||||
return [edge[0] for _, edge in sorted(zip(edge_weights, in_edges))]
|
||||
|
||||
def get_all_inputs_meta(self, node: str):
|
||||
"""
|
||||
Get all the input node metas sorted by edge weight
|
||||
"""
|
||||
return [self.get_node_meta(input_node) for input_node in self.get_all_inputs(node)]
|
||||
|
||||
def replace_all_uses_with(self, node1, node2):
|
||||
"""
|
||||
Replace all uses of node1 with node2
|
||||
"""
|
||||
for edge in self.out_edges(node1):
|
||||
weight = self.get_edge_weight(*edge)
|
||||
user = edge[1]
|
||||
self.add_edge(node2, user, weight)
|
||||
self.remove_edge(node1, user)
|
||||
self.remove_node(node1)
|
||||
|
||||
#
|
||||
# Node accessor
|
||||
#
|
||||
def nodes_topological_order(self):
|
||||
"""
|
||||
Get the nodes in the unique lexicographical topological order
|
||||
It generates a unique ordering of nodes by first sorting topologically
|
||||
and then additionally by sorting lexicographically.
|
||||
|
||||
Although topological_sort alone also works, this generates a unique key
|
||||
for each epilogue visitor pattern and ensures the compilation cache can be reused.
|
||||
:return: list[str]
|
||||
"""
|
||||
return list(nx.lexicographical_topological_sort(self._graph))
|
||||
|
||||
def node_metas_topological_order(self):
|
||||
"""
|
||||
Get the node metas in topological order
|
||||
:return: list[NodeBase]
|
||||
"""
|
||||
return [self.get_node_meta(node) for node in self.nodes_topological_order()]
|
||||
|
||||
@property
|
||||
def nodes(self):
|
||||
"""
|
||||
Get all nodes
|
||||
:return: list[str]
|
||||
"""
|
||||
return list(self._graph.nodes)
|
||||
|
||||
@property
|
||||
def nodes_meta(self):
|
||||
"""
|
||||
Get all node metas
|
||||
:return: list[NodeBase]
|
||||
"""
|
||||
return [data[1]['meta'] for data in self._graph.nodes.data()]
|
||||
|
||||
@property
|
||||
def edges(self):
|
||||
"""
|
||||
Get all edges
|
||||
:return: list[(str, str)]
|
||||
"""
|
||||
return list(self._graph.edges)
|
||||
|
||||
#
|
||||
# Path
|
||||
#
|
||||
def has_path(self, src: str, target: str) -> bool:
|
||||
"""
|
||||
Return True is a path exists from src to target
|
||||
"""
|
||||
return nx.has_path(self._graph, src, target)
|
||||
324
python/cutlass/backend/evt/ir/layout_algorithm.py
Normal file
324
python/cutlass/backend/evt/ir/layout_algorithm.py
Normal file
@ -0,0 +1,324 @@
|
||||
################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
"""
|
||||
Layout algebras
|
||||
"""
|
||||
|
||||
from pycute import Layout, composition, make_layout, flatten, product
|
||||
|
||||
|
||||
def _infer_split(old_shape, new_shape):
|
||||
old_shape = _tuple_to_list(old_shape)
|
||||
new_shape = _tuple_to_list(new_shape)
|
||||
if len(old_shape) == 0 and len(new_shape) == 0:
|
||||
return []
|
||||
if len(old_shape) == 0:
|
||||
if product(tuple(new_shape)) != 1:
|
||||
raise ValueError("Invalid reshape size")
|
||||
else:
|
||||
return new_shape
|
||||
if len(new_shape) == 0:
|
||||
if product(tuple(old_shape)) != 1:
|
||||
raise ValueError("Invalid reshape size")
|
||||
else:
|
||||
return old_shape
|
||||
# This is done recursively by only process the last dimension at each time
|
||||
old_dim = old_shape[-1]
|
||||
new_dim = new_shape[-1]
|
||||
# Exact match
|
||||
if old_dim == new_dim:
|
||||
return _infer_split(old_shape[:-1], new_shape[:-1]) + [new_dim,]
|
||||
# Needs split
|
||||
if old_dim > new_dim and old_dim % new_dim == 0:
|
||||
residual = old_dim // new_dim
|
||||
return _infer_split(old_shape[:-1] + [residual,], new_shape[:-1]) + [new_dim,]
|
||||
# Needs merge
|
||||
if old_dim < new_dim and new_dim % old_dim == 0:
|
||||
residual = new_dim // old_dim
|
||||
return _infer_split(old_shape[:-1], new_shape[:-1] + [residual,]) + [old_dim,]
|
||||
|
||||
raise NotImplementedError(f"Unsupported split: {old_shape} -> {new_shape}")
|
||||
|
||||
def _infer_merge(flatten_shape, shape):
|
||||
flatten_shape = _tuple_to_list(flatten_shape)
|
||||
shape = _tuple_to_list(shape)
|
||||
idx_flat = 0
|
||||
merged_shape = []
|
||||
for dim in shape:
|
||||
# Exact match
|
||||
if dim == flatten_shape[idx_flat]:
|
||||
merged_shape.append(dim)
|
||||
idx_flat += 1
|
||||
# Need group
|
||||
elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0:
|
||||
residual = dim
|
||||
group = []
|
||||
while(residual > 1):
|
||||
group.append(flatten_shape[idx_flat])
|
||||
residual = residual // flatten_shape[idx_flat]
|
||||
idx_flat += 1
|
||||
merged_shape.append(group)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}")
|
||||
|
||||
return merged_shape
|
||||
|
||||
def _list_to_tuple(nested_list):
|
||||
if isinstance(nested_list, list) or isinstance(nested_list, tuple):
|
||||
return tuple(_list_to_tuple(item) for item in nested_list)
|
||||
return nested_list
|
||||
|
||||
def _tuple_to_list(nested_tuple):
|
||||
if isinstance(nested_tuple, list) or isinstance(nested_tuple, tuple):
|
||||
return list(_tuple_to_list(item) for item in nested_tuple)
|
||||
return nested_tuple
|
||||
|
||||
def _reverse_tuple(nested_tuple: tuple):
|
||||
if isinstance(nested_tuple, tuple):
|
||||
return tuple([_reverse_tuple(item) for item in nested_tuple][::-1])
|
||||
return nested_tuple
|
||||
|
||||
def _get_first_lhs_nonzero_stride(stride_list, idx):
|
||||
for i in reversed(range(idx)):
|
||||
if stride_list[i] != 0:
|
||||
return i
|
||||
else:
|
||||
return None
|
||||
|
||||
def _get_first_rhs_nonzero_stride(stride_list, idx):
|
||||
for i in range(idx+1, len(stride_list)):
|
||||
if stride_list[i] != 0:
|
||||
return i
|
||||
else:
|
||||
return None
|
||||
|
||||
def reshape(layout, new_shape):
|
||||
"""
|
||||
General reshape of input layout.
|
||||
It takes two steps:
|
||||
1. split the dimensions of the old layout
|
||||
2. merge the splitted dimensions according to the new shape
|
||||
"""
|
||||
#
|
||||
# Step 1: Split the dimensions of the old layout
|
||||
#
|
||||
# 1.1 Flat old and new shape
|
||||
old_flatten_shape = list(flatten(layout.shape))
|
||||
new_flatten_shape = list(flatten(new_shape))
|
||||
|
||||
# 1.2 Infer the flatten splitted shape
|
||||
splitted_flatten_shape = _infer_split(old_flatten_shape, new_flatten_shape)
|
||||
|
||||
# 1.3 Unflat the splitted shape based on the old shape
|
||||
splited_shape = _infer_merge(splitted_flatten_shape, old_flatten_shape)
|
||||
|
||||
# 1.4 Infer the type of each split
|
||||
# If the split type is in row-major (R), the dimension list is reversed because
|
||||
# the cute::composition only support column-major split
|
||||
split_type = [] # the type of each split (ColumnMajor or RowMajor)
|
||||
permuted_splitted_shape = []
|
||||
old_flatten_stride = list(flatten(layout.stride))
|
||||
for idx, dim in enumerate(splited_shape):
|
||||
if not isinstance(dim, list):
|
||||
permuted_splitted_shape.append(dim)
|
||||
split_type.append("C")
|
||||
else:
|
||||
lhs_stride = _get_first_lhs_nonzero_stride(old_flatten_stride, idx)
|
||||
rhs_stride = _get_first_rhs_nonzero_stride(old_flatten_stride, idx)
|
||||
# Special case for single tuple
|
||||
# Use column-major by default
|
||||
if lhs_stride is None and rhs_stride is None:
|
||||
permuted_splitted_shape.append(dim)
|
||||
split_type.append("C")
|
||||
else:
|
||||
if lhs_stride is not None and rhs_stride is not None:
|
||||
# We consider shape[idx]:stride[idx]
|
||||
# Case 1: stride[idx - 1] <= stride[idx] <= stride[idx + 1]: column major
|
||||
if lhs_stride <= old_flatten_stride[idx] and old_flatten_stride[idx] <= rhs_stride:
|
||||
permuted_splitted_shape.append(dim)
|
||||
split_type.append("C")
|
||||
# Case 2: stride[idx - 1] > stride[idx] > stride[idx + 1]: row major
|
||||
elif lhs_stride > old_flatten_stride[idx] and old_flatten_stride[idx] > rhs_stride:
|
||||
permuted_splitted_shape.append([d for d in reversed(dim)])
|
||||
split_type.append("R")
|
||||
# Case 3: stride[idx - 1] <= stride[idx] > stride[idx + 1]: concave
|
||||
elif lhs_stride <= old_flatten_stride[idx] and old_flatten_stride[idx] > rhs_stride:
|
||||
if lhs_stride >= rhs_stride:
|
||||
permuted_splitted_shape.append(dim)
|
||||
split_type.append("C")
|
||||
else:
|
||||
permuted_splitted_shape.append([d for d in reversed(dim)])
|
||||
split_type.append("R")
|
||||
# Case 4: stride[idx - 1] > stride[idx] <= stride[idx + 1]: concave
|
||||
elif lhs_stride > old_flatten_stride[idx] and old_flatten_stride[idx] <= rhs_stride:
|
||||
if lhs_stride >= rhs_stride:
|
||||
permuted_splitted_shape.append(dim)
|
||||
split_type.append("C")
|
||||
else:
|
||||
permuted_splitted_shape.append([d for d in reversed(dim)])
|
||||
split_type.append("R")
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
elif lhs_stride is None:
|
||||
# Case 1: dim's stride < dim+1's stride, expand in column major
|
||||
if old_flatten_stride[idx] > rhs_stride:
|
||||
permuted_splitted_shape.append([d for d in reversed(dim)])
|
||||
split_type.append("R")
|
||||
else:
|
||||
permuted_splitted_shape.append(dim)
|
||||
split_type.append("C")
|
||||
else:
|
||||
# Case 1: dim's stride > dim-1's stride
|
||||
if old_flatten_stride[idx] < lhs_stride:
|
||||
permuted_splitted_shape.append([d for d in reversed(dim)])
|
||||
split_type.append("R")
|
||||
else:
|
||||
permuted_splitted_shape.append(dim)
|
||||
split_type.append("C")
|
||||
|
||||
# 1.4 Generate the splitted layout
|
||||
permuted_splitted_layout = composition(layout, Layout(_list_to_tuple(permuted_splitted_shape)))
|
||||
|
||||
# 1.5 Reverse the permutation in 1.4 before merge
|
||||
splitted_shape = []
|
||||
splitted_stride = []
|
||||
for shape_dim, stride_dim, type in zip(
|
||||
permuted_splitted_layout.shape,
|
||||
permuted_splitted_layout.stride,
|
||||
split_type):
|
||||
if type == "C":
|
||||
splitted_shape.append(shape_dim)
|
||||
splitted_stride.append(stride_dim)
|
||||
else:
|
||||
splitted_shape.append(tuple([d for d in reversed(shape_dim)]))
|
||||
splitted_stride.append(tuple([d for d in reversed(stride_dim)]))
|
||||
splitted_layout = Layout(tuple(splitted_shape), tuple(splitted_stride))
|
||||
|
||||
|
||||
#
|
||||
# Step 2: Merge the splitted dimensions according to the new shape
|
||||
#
|
||||
# 2.1 Merge layout
|
||||
merged_layout = composition(splitted_layout, Layout(new_shape))
|
||||
|
||||
# 2.2 Cleaning up
|
||||
output_layout = composition(merged_layout, Layout(new_shape))
|
||||
return output_layout
|
||||
|
||||
|
||||
def permutation(layout, permutation):
|
||||
"""
|
||||
Permute the layout
|
||||
"""
|
||||
new_shape = tuple([layout.shape[idx] for idx in permutation])
|
||||
new_stride = tuple([layout.stride[idx] for idx in permutation])
|
||||
return Layout(new_shape, new_stride)
|
||||
|
||||
|
||||
def _broadcast(layout, new_shape):
|
||||
if len(layout) == 1 and isinstance(new_shape, int):
|
||||
old_dim = layout.shape
|
||||
old_stride = layout.stride
|
||||
new_dim = new_shape
|
||||
if old_dim == new_dim:
|
||||
return Layout(old_dim, old_stride)
|
||||
elif old_dim == 1:
|
||||
return Layout(new_dim, 0)
|
||||
else:
|
||||
raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {new_dim}")
|
||||
|
||||
# Align the dimensions
|
||||
old_shape = layout.shape
|
||||
if isinstance(old_shape, int):
|
||||
old_shape = (old_shape,)
|
||||
sub_layouts = [layout,]
|
||||
else:
|
||||
sub_layouts = [sub_layout for sub_layout in layout]
|
||||
rhs_broadcast_layouts = [Layout(1, 0)] * (len(new_shape) - len(old_shape))
|
||||
# Get the broadcasted layout
|
||||
broadcast_layouts = []
|
||||
try:
|
||||
layout = make_layout(*sub_layouts, *rhs_broadcast_layouts)
|
||||
broadcast_layouts = []
|
||||
for idx, sub_layout in enumerate(layout):
|
||||
broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx]))
|
||||
except NotImplementedError:
|
||||
layout = make_layout(*rhs_broadcast_layouts, *sub_layouts)
|
||||
for idx, sub_layout in enumerate(layout):
|
||||
broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx]))
|
||||
return make_layout(*broadcast_layouts)
|
||||
|
||||
|
||||
def broadcast(layout, new_shape):
|
||||
"""
|
||||
Broadcast the new layout based on the input shape
|
||||
The broadcasted shape equals to the new shape
|
||||
The stride of broadcasted dimensions are 0
|
||||
"""
|
||||
return _broadcast(layout, new_shape)
|
||||
|
||||
|
||||
def debroadcast(layout, dims):
|
||||
"""
|
||||
Squeeze the 0-stride
|
||||
"""
|
||||
for dim in dims:
|
||||
if layout.stride[dim] != 0:
|
||||
raise ValueError(f"Dim{dim} cannot be debroadcasted as it has stride {layout.stride[dim]}")
|
||||
new_shape = tuple([s for idx, s in enumerate(layout.shape) if idx not in dims])
|
||||
new_stride = tuple([s for idx, s in enumerate(layout.stride) if idx not in dims])
|
||||
return Layout(new_shape, new_stride)
|
||||
|
||||
|
||||
def canonicalization_(shapes, strides):
|
||||
if isinstance(shapes, tuple):
|
||||
c_shapes = []
|
||||
c_strides = []
|
||||
for shape, stride in zip(shapes, strides):
|
||||
c_shape, c_stride = canonicalization_(shape, stride)
|
||||
c_shapes.append(c_shape)
|
||||
c_strides.append(c_stride)
|
||||
return tuple(c_shapes), tuple(c_strides)
|
||||
else:
|
||||
if shapes == 1:
|
||||
return 1, 0
|
||||
else:
|
||||
return shapes, strides
|
||||
|
||||
def canonicalization(layout):
|
||||
"""
|
||||
Canonicalize the input layout
|
||||
1. set the stride of shape "1" to 0
|
||||
"""
|
||||
new_shape, new_stride = canonicalization_(layout.shape, layout.stride)
|
||||
return Layout(new_shape, new_stride)
|
||||
336
python/cutlass/backend/evt/ir/layout_nodes.py
Normal file
336
python/cutlass/backend/evt/ir/layout_nodes.py
Normal file
@ -0,0 +1,336 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Layout manipulation nodes and implementations
|
||||
|
||||
The layout Nodes change the layout of intermediate nodes in epilogue visitor graph
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
from pycute import product, flatten
|
||||
|
||||
import cutlass
|
||||
from cutlass import LayoutType
|
||||
from cutlass.backend.evt.ir.layout_algorithm import _list_to_tuple, _tuple_to_list
|
||||
from cutlass.backend.evt.ir.node import NodeBase
|
||||
from cutlass.backend.evt.ir.tensor import Tensor
|
||||
|
||||
|
||||
class PermutationImpl:
|
||||
"""
|
||||
Detailed implementation and helper functions for permutation
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
assert "indices" in node.kwargs.keys()
|
||||
self.indices = list(node.kwargs["indices"])
|
||||
self.inverse_indices = self.get_inverse_indices(self.indices)
|
||||
|
||||
def get_inverse_impl(self):
|
||||
inverse_impl = deepcopy(self)
|
||||
inverse_impl.indices = self.inverse_indices
|
||||
inverse_impl.inverse_indices = self.indices
|
||||
return inverse_impl
|
||||
|
||||
def update(self, shape):
|
||||
num_dim = len(shape)
|
||||
indices = self.indices
|
||||
num_old_dim = len(indices)
|
||||
# Add offset
|
||||
for i, idx in enumerate(indices):
|
||||
indices[i] = idx + num_dim - num_old_dim
|
||||
# Add broadcast dims
|
||||
for i in range(num_dim - num_old_dim):
|
||||
indices = [i,] + indices
|
||||
|
||||
self.indices = indices
|
||||
self.inverse_indices = self.get_inverse_indices(self.indices)
|
||||
|
||||
def get_inverse_indices(self, indices):
|
||||
"""
|
||||
Get the indices for inverse permutation
|
||||
"""
|
||||
num_dim = len(indices)
|
||||
inverse_indices = [0] * num_dim
|
||||
for i in range(num_dim):
|
||||
inverse_indices[indices[i]] = i
|
||||
return inverse_indices
|
||||
|
||||
def shape_propagation(self, input_node_meta):
|
||||
input_shape = input_node_meta.tensor.shape
|
||||
output_shape = tuple([input_shape[idx] for idx in self.indices])
|
||||
return output_shape
|
||||
|
||||
def broadcast(self, shape, node_meta: NodeBase):
|
||||
"""
|
||||
Broadcast the inputs based on current shape
|
||||
"""
|
||||
self.update(shape)
|
||||
inverse_shape = tuple([shape[idx] for idx in self.inverse_indices])
|
||||
node_meta.tensor.broadcast(inverse_shape)
|
||||
|
||||
def apply_to_user(self, usr_meta: NodeBase):
|
||||
"""
|
||||
Propagate the permutation to the users of the current nodes
|
||||
"""
|
||||
usr_meta.tensor.permute(self.inverse_indices)
|
||||
if hasattr(usr_meta, "store_tensor"):
|
||||
if usr_meta.store_tensor is not None:
|
||||
usr_meta.store_tensor.permute(self.inverse_indices)
|
||||
|
||||
def apply_to_input(self, input_meta: NodeBase):
|
||||
"""
|
||||
Propagate the permutation to inputs of the current nodes
|
||||
"""
|
||||
input_meta.tensor.permute(self.indices)
|
||||
if hasattr(input_meta, "store_tensor"):
|
||||
if input_meta.store_tensor is not None:
|
||||
input_meta.store_tensor.permute(self.indices)
|
||||
|
||||
|
||||
class ReshapeImpl:
|
||||
"""
|
||||
Detailed implementation and helper functions for reshape
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
self.node = node
|
||||
assert "new_shape" in node.kwargs.keys()
|
||||
self.output_shape = _list_to_tuple(node.kwargs["new_shape"])
|
||||
|
||||
def get_inverse_impl(self):
|
||||
inverse_impl = deepcopy(self)
|
||||
inverse_impl.output_shape = self.input_shape
|
||||
inverse_impl.input_shape = self.output_shape
|
||||
return inverse_impl
|
||||
|
||||
def shape_propagation(self, input_node_meta):
|
||||
self.input_shape = input_node_meta.tensor.shape
|
||||
return _list_to_tuple(self.output_shape)
|
||||
|
||||
def broadcast(self, shape, node_meta: NodeBase):
|
||||
"""
|
||||
Broadcast the inputs based on current shape.
|
||||
"""
|
||||
# Step 1: infer split
|
||||
flatten_split_shape = self.infer_split(flatten(self.input_shape), flatten(self.output_shape))
|
||||
split_input_shape = self.infer_merge(flatten_split_shape, self.input_shape)
|
||||
split_output_shape = self.infer_merge(flatten_split_shape, self.output_shape)
|
||||
|
||||
# broadcast shape -> split_output_shape -> flatten_split_shape
|
||||
if len(shape) - len(split_output_shape) > 0:
|
||||
for _ in range(len(shape) - len(split_output_shape)):
|
||||
split_output_shape = [1,] + split_output_shape
|
||||
flatten_split_shape = [1,] + flatten_split_shape
|
||||
split_input_shape = [1,] + split_input_shape
|
||||
broadcast_factor = []
|
||||
for dim, old_dim in zip(shape, split_output_shape):
|
||||
if not isinstance(dim, list):
|
||||
dim = [dim,]
|
||||
if not isinstance(old_dim, list):
|
||||
old_dim = [old_dim,]
|
||||
if product(tuple(dim)) == product(tuple(old_dim)):
|
||||
broadcast_factor += [1] * len(old_dim)
|
||||
elif product(tuple(old_dim)) == 1:
|
||||
assert len(dim) == 1
|
||||
broadcast_factor.append(dim[0])
|
||||
else:
|
||||
raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {dim}")
|
||||
|
||||
# flatten_split_shape -> split_input_shape
|
||||
factor_idx = 0
|
||||
broadcast_split_input_shape = []
|
||||
for dim in split_input_shape:
|
||||
if isinstance(dim, list):
|
||||
new_dim = []
|
||||
for d in dim:
|
||||
new_dim.append(d * broadcast_factor[factor_idx])
|
||||
factor_idx += 1
|
||||
broadcast_split_input_shape.append(new_dim)
|
||||
else:
|
||||
broadcast_split_input_shape.append(dim * broadcast_factor[factor_idx])
|
||||
factor_idx += 1
|
||||
broadcast_split_input_shape = _list_to_tuple(broadcast_split_input_shape)
|
||||
node_meta.tensor.reshape(_list_to_tuple(split_input_shape))
|
||||
node_meta.tensor.broadcast(broadcast_split_input_shape)
|
||||
# Last reshape op to clean up
|
||||
broadcast_input_shape = tuple([product(dim) for dim in broadcast_split_input_shape])
|
||||
node_meta.tensor.reshape(broadcast_input_shape)
|
||||
# Update the input shape and output shape
|
||||
self.input_shape = _list_to_tuple(node_meta.tensor.shape)
|
||||
self.output_shape = _list_to_tuple(shape)
|
||||
|
||||
def apply_to_user(self, user_meta: NodeBase):
|
||||
"""
|
||||
Propagate the reshape to user nodes
|
||||
"""
|
||||
user_meta.tensor.reshape(tuple(self.input_shape))
|
||||
if hasattr(user_meta, "store_tensor"):
|
||||
if user_meta.store_tensor is not None:
|
||||
user_meta.store_tensor.reshape(tuple(self.input_shape))
|
||||
|
||||
def apply_to_input(self, input_meta: NodeBase):
|
||||
"""
|
||||
Propagate the reshape to input nodes
|
||||
"""
|
||||
input_meta.tensor.reshape(tuple(self.output_shape))
|
||||
if hasattr(input_meta, "store_tensor"):
|
||||
if input_meta.store_tensor is not None:
|
||||
input_meta.store_tensor.reshape(tuple(self.output_shape))
|
||||
|
||||
#
|
||||
# Helper functions
|
||||
#
|
||||
|
||||
def infer_split(self, input_shape, output_shape):
|
||||
"""
|
||||
Infer the flatten splitted shape that can be merged to both input_shape and output_shape
|
||||
"""
|
||||
input_shape = _tuple_to_list(input_shape)
|
||||
output_shape = _tuple_to_list(output_shape)
|
||||
if len(input_shape) == 0 and len(output_shape) == 0:
|
||||
return []
|
||||
if len(input_shape) == 0:
|
||||
if product(tuple(output_shape)) != 1:
|
||||
raise ValueError("Invalid reshape size")
|
||||
else:
|
||||
return output_shape
|
||||
if len(output_shape) == 0:
|
||||
if product(tuple(input_shape)) != 1:
|
||||
raise ValueError("Invalid reshape size")
|
||||
else:
|
||||
return input_shape
|
||||
# This is done recursively by only process the last dimension at each time
|
||||
old_dim = input_shape[-1]
|
||||
new_dim = output_shape[-1]
|
||||
# Exact match
|
||||
if old_dim == new_dim:
|
||||
return self.infer_split(input_shape[:-1], output_shape[:-1]) + [new_dim,]
|
||||
# Needs split
|
||||
if old_dim > new_dim and old_dim % new_dim == 0:
|
||||
residual = old_dim // new_dim
|
||||
return self.infer_split(input_shape[:-1] + [residual,], output_shape[:-1]) + [new_dim,]
|
||||
# Needs merge
|
||||
if old_dim < new_dim and new_dim % old_dim == 0:
|
||||
residual = new_dim // old_dim
|
||||
return self.infer_split(input_shape[:-1], output_shape[:-1] + [residual,]) + [old_dim,]
|
||||
|
||||
raise NotImplementedError(f"Unsupported split: {input_shape} -> {output_shape}")
|
||||
|
||||
def infer_merge(self, flatten_shape, shape):
|
||||
flatten_shape = _tuple_to_list(flatten_shape)
|
||||
shape = _tuple_to_list(shape)
|
||||
idx_flat = len(flatten_shape) - 1
|
||||
merged_shape = []
|
||||
for dim in reversed(shape):
|
||||
# Exact match
|
||||
if dim == flatten_shape[idx_flat]:
|
||||
merged_shape.append(dim)
|
||||
idx_flat -= 1
|
||||
# need group
|
||||
elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0:
|
||||
residual = dim
|
||||
group = []
|
||||
while(residual > 1):
|
||||
group.append(flatten_shape[idx_flat])
|
||||
residual = residual // flatten_shape[idx_flat]
|
||||
idx_flat -= 1
|
||||
merged_shape.append(group[::-1])
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}")
|
||||
|
||||
return merged_shape[::-1]
|
||||
|
||||
|
||||
class LayoutNode(NodeBase):
|
||||
"""
|
||||
Layout manipulation nodes
|
||||
"""
|
||||
fn_to_impl = {
|
||||
"permute": PermutationImpl,
|
||||
"reshape": ReshapeImpl
|
||||
}
|
||||
def __init__(self, name: str, fn, kwargs: dict) -> None:
|
||||
super().__init__(name)
|
||||
self.op = "layout"
|
||||
self.fn = fn
|
||||
self.kwargs = kwargs
|
||||
self.underlying_impl = self.fn_to_impl[self.fn.__name__](self)
|
||||
|
||||
def get_inverse_node(self):
|
||||
inverse_node = deepcopy(self)
|
||||
inverse_node.underlying_impl = self.underlying_impl.get_inverse_impl()
|
||||
return inverse_node
|
||||
|
||||
def shape_propagation(self, input_node_metas):
|
||||
if self._tensor is not None:
|
||||
return
|
||||
assert len(input_node_metas) == 1, "Layout node can only have one input node"
|
||||
|
||||
output_shape = self.underlying_impl.shape_propagation(input_node_metas[0])
|
||||
|
||||
self._tensor = Tensor(
|
||||
element=self.element_output,
|
||||
shape=output_shape, layout_tag=LayoutType.RowMajor
|
||||
)
|
||||
|
||||
return super().shape_propagation(input_node_metas)
|
||||
|
||||
def type_propagation(self, input_node_metas: 'list[NodeBase]'):
|
||||
"""
|
||||
The store nodes has element_output = element_input
|
||||
"""
|
||||
assert len(input_node_metas) == 1, "Layout node can only have one input node"
|
||||
self.element_output = input_node_metas[0].element_output
|
||||
|
||||
def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
|
||||
"""
|
||||
Propagate the broadcast in the reversed topological order
|
||||
"""
|
||||
if self.tensor is None:
|
||||
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
|
||||
shape = self.tensor.shape
|
||||
|
||||
for child in input_node_metas:
|
||||
self.underlying_impl.broadcast(shape, child)
|
||||
|
||||
def apply_to_user(self, usr_meta: NodeBase):
|
||||
"""
|
||||
Propagate the permutation to user nodes
|
||||
"""
|
||||
self.underlying_impl.apply_to_user(usr_meta)
|
||||
|
||||
def apply_to_input(self, input_meta: NodeBase):
|
||||
"""
|
||||
Propagate the permutation to input nodes
|
||||
"""
|
||||
self.underlying_impl.apply_to_input(input_meta)
|
||||
294
python/cutlass/backend/evt/ir/load_nodes.py
Normal file
294
python/cutlass/backend/evt/ir/load_nodes.py
Normal file
@ -0,0 +1,294 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Load nodes and implementations
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
|
||||
from cutlass.backend.c_types import tuple_factory
|
||||
from cutlass.backend.epilogue import dtype2ctype, to_ctype_value
|
||||
from cutlass.backend.evt.ir.node import NodeBase, ImplBase
|
||||
|
||||
|
||||
class LoadImplBase(ImplBase):
|
||||
"""
|
||||
Base class for load node implementations
|
||||
"""
|
||||
reserved_names = ["accum", "C"]
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.element = node.element
|
||||
self.element_output = node.element_output
|
||||
self.stride = node.tensor.stride
|
||||
|
||||
|
||||
class AccumulatorImpl(LoadImplBase):
|
||||
"""
|
||||
Accumulator node implementation
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
return node.name == "accum" and node.tensor.shape == problem_size
|
||||
|
||||
|
||||
class LoadSrcImpl(LoadImplBase):
|
||||
"""
|
||||
Load C implementation
|
||||
"""
|
||||
@property
|
||||
def name_camel(self) -> str:
|
||||
return "TensorC"
|
||||
|
||||
@property
|
||||
def argument_type_c(self):
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ptr_C", ctypes.c_void_p),
|
||||
("stride_C", tuple_type)
|
||||
]
|
||||
def __init__(self, ptr) -> None:
|
||||
self.ptr_C = ptr
|
||||
self.stride_C = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
return node.name == "C" and node.tensor.shape == problem_size
|
||||
|
||||
|
||||
class AuxLoadImpl(LoadImplBase):
|
||||
"""
|
||||
Load arbitrary tensor
|
||||
"""
|
||||
@property
|
||||
def argument_type(self):
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
name = self.name
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
element_type = self.element
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ptr_aux", ctypes.c_void_p),
|
||||
("null_default", dtype2ctype[element_type]),
|
||||
("dAux", tuple_type)
|
||||
]
|
||||
def __init__(self, kwargs) -> None:
|
||||
ptr = kwargs[name]
|
||||
self.ptr_aux = ptr
|
||||
self.null_default = to_ctype_value(0, element_type)
|
||||
self.dAux = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if node.name in LoadImplBase.reserved_names:
|
||||
return False
|
||||
strideMN = node.tensor.stride[-2:]
|
||||
if (strideMN[0] == 1 and strideMN[1] != 0 or
|
||||
strideMN[0] != 0 and strideMN[1] == 1 ):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class RowBroadcastImpl(LoadImplBase):
|
||||
"""
|
||||
Broadcast a row vector
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.stride_dtype = "int"
|
||||
|
||||
@property
|
||||
def argument_type(self):
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
name = self.name
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
element_type = self.element
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ptr_row", ctypes.c_void_p),
|
||||
("null_default", dtype2ctype[element_type]),
|
||||
("dRow", tuple_type)
|
||||
]
|
||||
def __init__(self, kwargs) -> None:
|
||||
ptr = kwargs[name]
|
||||
self.ptr_row = ptr
|
||||
self.null_default = to_ctype_value(0, element_type)
|
||||
self.dRow = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if node.name in LoadImplBase.reserved_names:
|
||||
return False
|
||||
|
||||
strideMN = node.tensor.stride[-2:]
|
||||
if strideMN == (0, 1):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class ColumnBroadcastImpl(LoadImplBase):
|
||||
"""
|
||||
Broadcast a column vector
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.stride_dtype = "int"
|
||||
|
||||
@property
|
||||
def argument_type(self):
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
name = self.name
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
element_type = self.element
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ptr_col", ctypes.c_void_p),
|
||||
("null_default", dtype2ctype[element_type]),
|
||||
("dCol", tuple_type)
|
||||
]
|
||||
def __init__(self, kwargs) -> None:
|
||||
ptr = kwargs[name]
|
||||
self.ptr_col = int(ptr)
|
||||
self.null_default = to_ctype_value(0, element_type)
|
||||
self.dCol = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if node.name in LoadImplBase.reserved_names:
|
||||
return False
|
||||
|
||||
strideMN = node.tensor.stride[-2:]
|
||||
if strideMN == (1, 0):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class ScalarBroadcastImpl(LoadImplBase):
|
||||
"""
|
||||
Broadcast a scalar
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.stride_dtype = "int"
|
||||
|
||||
@property
|
||||
def argument_type(self):
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
name = self.name
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
element_type = self.element
|
||||
|
||||
if self.tensor.is_constant:
|
||||
value = self.tensor.value
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("scalars", dtype2ctype[element_type]),
|
||||
("scalar_ptrs", ctypes.c_void_p),
|
||||
("dScalar", tuple_type)
|
||||
]
|
||||
def __init__(self, kwargs) -> None:
|
||||
self.scalars = to_ctype_value(value, element_type)
|
||||
self.scalar_ptrs = 0
|
||||
self.dScalar = tuple_type(stride_mnl)
|
||||
|
||||
else:
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("scalars", dtype2ctype[element_type]),
|
||||
("scalar_ptrs", ctypes.c_void_p),
|
||||
("dScalar", tuple_type)
|
||||
]
|
||||
def __init__(self, kwargs) -> None:
|
||||
scalar_or_ptr = kwargs[name]
|
||||
if isinstance(scalar_or_ptr, float):
|
||||
self.scalars = to_ctype_value(scalar_or_ptr, element_type)
|
||||
self.scalar_ptrs = 0
|
||||
else:
|
||||
self.scalar_ptrs = int(scalar_or_ptr)
|
||||
|
||||
self.dScalar = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if node.name in LoadImplBase.reserved_names:
|
||||
return False
|
||||
|
||||
strideMN = node.tensor.stride[-2:]
|
||||
if strideMN == (0, 0):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class LoadNode(NodeBase):
|
||||
"""
|
||||
Load Node
|
||||
"""
|
||||
cnt = 0
|
||||
possible_impls = [
|
||||
AccumulatorImpl, LoadSrcImpl, AuxLoadImpl,
|
||||
RowBroadcastImpl, ColumnBroadcastImpl,
|
||||
ScalarBroadcastImpl
|
||||
]
|
||||
def __init__(self, name: str) -> None:
|
||||
if name is None:
|
||||
name = f"load{LoadNode.cnt}"
|
||||
LoadNode.cnt += 1
|
||||
super().__init__(name)
|
||||
self.op = "load"
|
||||
|
||||
def type_propagation(self, *args, **kwargs):
|
||||
"""
|
||||
Load node loads tensor under type `tensor.element` and returns an array of type `tensor.element`.
|
||||
"""
|
||||
if self.tensor is None:
|
||||
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
|
||||
|
||||
self.element = self.tensor.element
|
||||
self.element_output = self.tensor.element
|
||||
292
python/cutlass/backend/evt/ir/node.py
Normal file
292
python/cutlass/backend/evt/ir/node.py
Normal file
@ -0,0 +1,292 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Base & visitor classes of DAGIR Nodes
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
from re import sub
|
||||
|
||||
from cutlass import LayoutType
|
||||
from cutlass.backend.evt.ir.layout_algorithm import _list_to_tuple, _reverse_tuple
|
||||
from cutlass.backend.evt.ir.tensor import Tensor
|
||||
|
||||
|
||||
class ImplBase:
|
||||
"""
|
||||
Base class for Node Implementation
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
self.node = node
|
||||
self.name = node.name
|
||||
self.tensor = node.tensor
|
||||
self._type_decl = None
|
||||
self.stride_dtype = "int64_t"
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
"""
|
||||
Match function used in get_underlying_impl
|
||||
"""
|
||||
raise NotImplementedError(f"The `match` function is not defined.")
|
||||
|
||||
@property
|
||||
def argument_type(self):
|
||||
"""
|
||||
Default class for Argument Type
|
||||
"""
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = []
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
return _Argument
|
||||
|
||||
@property
|
||||
def name_camel(self) -> str:
|
||||
"""
|
||||
Return the CamelCase name.
|
||||
"""
|
||||
return sub(r"(_|-)+", " ", self.name).title().replace(" ", "")
|
||||
|
||||
def _emit_cute_tuple(self, py_tuple):
|
||||
"""
|
||||
Emit the cute tuple to C++ code
|
||||
"""
|
||||
if isinstance(py_tuple, int):
|
||||
if py_tuple in [0, 1]:
|
||||
return f"cute::Int<{py_tuple}>"
|
||||
else:
|
||||
return f"{self.stride_dtype}"
|
||||
elif isinstance(py_tuple, tuple):
|
||||
decl = "cute::Stride<"
|
||||
for item in py_tuple:
|
||||
decl += self._emit_cute_tuple(item) + ", "
|
||||
return decl[:-2] + ">"
|
||||
else:
|
||||
raise ValueError(f"_emit_cute_tuple only accepts tuple or int, got {type(py_tuple).__name__}")
|
||||
|
||||
@property
|
||||
def stride_mnl(self):
|
||||
"""
|
||||
Typename StrideMNL
|
||||
"""
|
||||
stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2]))))
|
||||
return self._emit_cute_tuple(stride)
|
||||
|
||||
def get_non_constant_stride(self, py_tuple):
|
||||
if isinstance(py_tuple, int):
|
||||
if py_tuple not in [0, 1]:
|
||||
return py_tuple
|
||||
else:
|
||||
return None
|
||||
non_constant_stride = []
|
||||
for item in py_tuple:
|
||||
item_out = self.get_non_constant_stride(item)
|
||||
if item_out:
|
||||
non_constant_stride.append(item_out)
|
||||
return tuple(non_constant_stride)
|
||||
|
||||
def get_stride_mnl(self):
|
||||
"""
|
||||
Get the non-zero stride mnl. This is used in argument construction
|
||||
"""
|
||||
stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2]))))
|
||||
return stride
|
||||
|
||||
def get_smem_size(self, *args, **kwargs):
|
||||
"""
|
||||
Get the shared memory size and alignment of current node
|
||||
"""
|
||||
return (0, 1)
|
||||
|
||||
|
||||
class NoOpImpl(ImplBase):
|
||||
"""
|
||||
The NoOpImpl does nothing but forward its input to users
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if node.op == "store":
|
||||
# Store that is not output is a No OP
|
||||
return not node.is_output
|
||||
|
||||
|
||||
class NodeBase:
|
||||
"""
|
||||
Base class of DAG Node
|
||||
"""
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
self.underlying_impl = None
|
||||
|
||||
self._tensor = None
|
||||
|
||||
# Whether the node is disabled for emit
|
||||
self.disabled = False
|
||||
|
||||
@property
|
||||
def name_camel(self) -> str:
|
||||
"""
|
||||
Return the CamelCase name.
|
||||
"""
|
||||
return self.underlying_impl.name_camel
|
||||
|
||||
@property
|
||||
def tensor(self) -> Tensor:
|
||||
"""
|
||||
Return the output tensor (concept: cutlass.backend.evt.ir.tensor)
|
||||
"""
|
||||
return self._tensor
|
||||
|
||||
@tensor.setter
|
||||
def tensor(self, kwargs):
|
||||
"""
|
||||
Setting the tensor
|
||||
"""
|
||||
self._tensor = Tensor(**kwargs)
|
||||
|
||||
#
|
||||
# Helper functions for type/shape propagation
|
||||
#
|
||||
|
||||
def shape_propagation(self, input_node_metas):
|
||||
"""
|
||||
Infer shape from input nodes
|
||||
General Broadcasting Rules from NumPy
|
||||
When operating on two arrays, we compare their shapes element-wise.
|
||||
It starts with the trailing (i.e. rightmost) dimension and works its
|
||||
way left. Two dimensions are compatible when
|
||||
1. they are equal
|
||||
2. one of them is 1
|
||||
"""
|
||||
if self._tensor is not None:
|
||||
return
|
||||
|
||||
shape = None
|
||||
for src in input_node_metas:
|
||||
src_shape = src.tensor.shape
|
||||
if shape is None:
|
||||
shape = src_shape
|
||||
else:
|
||||
len_difference = len(shape) - len(src_shape)
|
||||
if len_difference > 0:
|
||||
for _ in range(len_difference):
|
||||
src_shape = [1, ] + list(src_shape)
|
||||
elif len_difference < 0:
|
||||
for _ in range(-len_difference):
|
||||
shape = [1, ] + list(shape)
|
||||
broadcasted_shape = []
|
||||
# Infer broadcast shape
|
||||
for shape_dim, src_dim in zip(reversed(shape), reversed(src_shape)):
|
||||
if shape_dim == 1:
|
||||
broadcasted_shape = [src_dim, ] + list(broadcasted_shape)
|
||||
elif src_dim == 1:
|
||||
broadcasted_shape = [shape_dim, ] + list(broadcasted_shape)
|
||||
elif shape_dim == src_dim:
|
||||
broadcasted_shape = [shape_dim, ] + list(broadcasted_shape)
|
||||
else:
|
||||
error_msg = "Dimension mismatch between "
|
||||
for src_ in input_node_metas:
|
||||
error_msg += f"{src_.name}{src_.tensor.shape}, "
|
||||
error_msg = error_msg[:-2] + "."
|
||||
raise RuntimeError(error_msg)
|
||||
shape = tuple(broadcasted_shape)
|
||||
|
||||
self._tensor = Tensor(element=self.element_output, shape=shape, layout_tag=LayoutType.RowMajor)
|
||||
|
||||
def type_propagation(self, *args, **kwargs):
|
||||
"""
|
||||
Each node is associated with two data types: `element` and `element_output`.
|
||||
The `element_output` is the type of return array of the node. The `element`
|
||||
has specific meaning for different node types.
|
||||
* Load Node: data type of tensor in gmem
|
||||
* Compute Node: element compute
|
||||
* Store Node: data type of tensor in gmem
|
||||
This function must be overloaded in the derived classes
|
||||
"""
|
||||
raise NotImplementedError(f"Function `type_propagation` is not overloaded in {self.__class__.__name__}")
|
||||
|
||||
def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
|
||||
"""
|
||||
Propagate the broadcast in the reversed topological order.
|
||||
For example:
|
||||
C[l, m, n] = A[m, 1] + B[l, m, n]
|
||||
After the broadcast propagation, it will be come
|
||||
C[l, m, n] = A[l, m, n] + B[l, m, n]
|
||||
and each tensor will have a proper stride accessing the underlying tensor
|
||||
"""
|
||||
if self.tensor is None:
|
||||
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
|
||||
for child in input_node_metas:
|
||||
child.tensor.broadcast(self.tensor.shape)
|
||||
|
||||
def get_underlying_impl(self, problem_size: tuple):
|
||||
"""
|
||||
Get the underlying implementation of the current node.
|
||||
"""
|
||||
if self.tensor is None:
|
||||
raise RuntimeError(f"The Layout of node {self.name} is unknown. Please call PassShapeTypePropagation first.")
|
||||
|
||||
for impl in self.possible_impls:
|
||||
if impl.match(self, problem_size):
|
||||
self.underlying_impl = impl(self)
|
||||
break
|
||||
|
||||
if self.underlying_impl is None:
|
||||
raise NotImplementedError(f"No matching op for node {self.name} with stride {self.tensor.stride}.")
|
||||
|
||||
#
|
||||
# Visitor Nodes & Impls
|
||||
#
|
||||
|
||||
class TopoVisitorImpl(ImplBase):
|
||||
"""
|
||||
Impl for topological visitor
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node.output_node)
|
||||
self.name = node.name
|
||||
self.element_output = node.output_node.element_output
|
||||
|
||||
class TopoVisitorNode(NodeBase):
|
||||
def __init__(self, name: str, subgraph, output_node) -> None:
|
||||
super().__init__(name)
|
||||
self.subgraph = subgraph
|
||||
self.output_node = output_node
|
||||
self.op = "dag"
|
||||
self.underlying_impl = TopoVisitorImpl(self)
|
||||
276
python/cutlass/backend/evt/ir/store_nodes.py
Normal file
276
python/cutlass/backend/evt/ir/store_nodes.py
Normal file
@ -0,0 +1,276 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Store node and implementations
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
|
||||
from cutlass import DataType
|
||||
from cutlass.backend.c_types import tuple_factory
|
||||
from cutlass.backend.epilogue import dtype2ctype, to_ctype_value
|
||||
from cutlass.backend.evt.ir.node import NodeBase, ImplBase, NoOpImpl
|
||||
from cutlass.backend.evt.ir.tensor import Tensor
|
||||
from cutlass.backend.library import FloatRoundStyle, FunctionalOp
|
||||
|
||||
|
||||
class StoreImplBase(ImplBase):
|
||||
"""
|
||||
Base class for store node implementation
|
||||
"""
|
||||
reserved_names = ["D"]
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.element = node.element
|
||||
self.element_output = node.element_output
|
||||
self.stride = node.store_tensor.stride
|
||||
|
||||
|
||||
class StoreDImpl(StoreImplBase):
|
||||
"""
|
||||
Store D implementation
|
||||
"""
|
||||
|
||||
@property
|
||||
def argument_type_d(self):
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ptr_D", ctypes.c_void_p),
|
||||
("stride_D", tuple_type)
|
||||
]
|
||||
def __init__(self, ptr: int) -> None:
|
||||
self.ptr_D = ptr
|
||||
self.stride_D = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if node.name == "D" and node.store_tensor.shape == problem_size:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class AuxStoreImpl(StoreImplBase):
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.round_style = FloatRoundStyle.ToNearest
|
||||
|
||||
@property
|
||||
def argument_type(self):
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
name = self.name
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ptr_aux", ctypes.c_void_p),
|
||||
("dAux", tuple_type)
|
||||
]
|
||||
def __init__(self, kwargs) -> None:
|
||||
ptr = kwargs[name]
|
||||
self.ptr_aux = ptr
|
||||
self.dAux = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if not node.is_output:
|
||||
return False
|
||||
if node.name in StoreImplBase.reserved_names:
|
||||
return False
|
||||
|
||||
strideMN = node.store_tensor.stride[-2:]
|
||||
if (strideMN[0] == 1 and strideMN[1] != 0 or
|
||||
strideMN[0] != 0 and strideMN[1] == 1 ):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class ReductionImplBase(StoreImplBase):
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.element = node.store_tensor.element
|
||||
self.element_compute = node.element_compute
|
||||
self.reg_reduce_fn = self.node.reg_reduce_fn
|
||||
self.gmem_reduce_fn = self.node.gmem_reduce_fn
|
||||
self.round_style = node.round_style
|
||||
self.stride_dtype = "int"
|
||||
|
||||
def get_reduce_identity(self):
|
||||
"""
|
||||
Return the reduction identity of the current reduce_fn
|
||||
"""
|
||||
maxes = {
|
||||
DataType.f32: (2 ** 31) - 1,
|
||||
DataType.f16: (2 ** 15),
|
||||
DataType.s32: (2 ** 31) - 1,
|
||||
DataType.s8: (2 ** 7) - 1
|
||||
}
|
||||
mins = {
|
||||
DataType.f32: -maxes[DataType.f32],
|
||||
DataType.f16: -maxes[DataType.f16],
|
||||
DataType.s32: -maxes[DataType.s32],
|
||||
DataType.s8: -maxes[DataType.s8]
|
||||
}
|
||||
if self.reg_reduce_fn == FunctionalOp.Maximum:
|
||||
if self.element_compute not in mins:
|
||||
raise Exception(f"No min entry for data type {self.element_compute}")
|
||||
return to_ctype_value(mins[self.element_compute], self.element_compute)
|
||||
elif self.reg_reduce_fn == FunctionalOp.Multiplies:
|
||||
return to_ctype_value(1., self.element_compute)
|
||||
elif self.reg_reduce_fn == FunctionalOp.Minimum:
|
||||
if self.element_compute not in maxes:
|
||||
raise Exception(f"No max entry for data type {self.element_compute}")
|
||||
return to_ctype_value(maxes[self.element_compute], self.element_compute)
|
||||
else:
|
||||
return to_ctype_value(0., self.element_compute)
|
||||
|
||||
@property
|
||||
def argument_type(self):
|
||||
self.get_reduce_identity()
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
name = self.name
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
element_compute = self.element_compute
|
||||
reduce_identity = self.get_reduce_identity()
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ptr", ctypes.c_void_p),
|
||||
("reduce_identity", dtype2ctype[element_compute]),
|
||||
("dMNL", tuple_type)
|
||||
]
|
||||
def __init__(self, kwargs) -> None:
|
||||
ptr = kwargs[name]
|
||||
self.ptr = ptr
|
||||
self.reduce_identity = reduce_identity
|
||||
self.dMNL = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
|
||||
class ColumnReductionImpl(ReductionImplBase):
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if not node.is_output:
|
||||
return False
|
||||
if node.name in StoreImplBase.reserved_names:
|
||||
return False
|
||||
|
||||
strideMN = node.store_tensor.stride[-2:]
|
||||
if strideMN == (1, 0):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class RowReductionImpl(ReductionImplBase):
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if not node.is_output:
|
||||
return False
|
||||
if node.name in StoreImplBase.reserved_names:
|
||||
return False
|
||||
|
||||
strideMN = node.store_tensor.stride[-2:]
|
||||
if strideMN == (0, 1):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class ScalarReductionImpl(ReductionImplBase):
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if not node.is_output:
|
||||
return False
|
||||
if node.name in StoreImplBase.reserved_names:
|
||||
return False
|
||||
|
||||
strideMN = node.store_tensor.stride[-2:]
|
||||
if strideMN == (0, 0):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class StoreNode(NodeBase):
|
||||
"""
|
||||
Store node
|
||||
"""
|
||||
possible_impls = [
|
||||
AuxStoreImpl, RowReductionImpl,
|
||||
ColumnReductionImpl, ScalarReductionImpl,
|
||||
NoOpImpl, StoreDImpl
|
||||
]
|
||||
def __init__(self, name: str) -> None:
|
||||
super().__init__(name)
|
||||
self.op = "store"
|
||||
self.is_output = False
|
||||
self._store_tensor = None
|
||||
|
||||
@property
|
||||
def store_tensor(self) -> Tensor:
|
||||
"""
|
||||
Return the output tensor (concept: cutlass.backend.evt.ir.tensor)
|
||||
"""
|
||||
return self._store_tensor
|
||||
|
||||
@store_tensor.setter
|
||||
def store_tensor(self, kwargs):
|
||||
"""
|
||||
Setting the tensor
|
||||
"""
|
||||
self._store_tensor = Tensor(**kwargs)
|
||||
|
||||
def type_propagation(self, input_node_metas: 'list[NodeBase]'):
|
||||
"""
|
||||
The store nodes has element_output = element_input
|
||||
"""
|
||||
if self.is_output:
|
||||
if self.store_tensor is None:
|
||||
raise RuntimeError(f"The store tensor of node {self.name} is unknown.")
|
||||
self.element = self.store_tensor.element
|
||||
assert len(input_node_metas) == 1, "Store node can only have one input node"
|
||||
self.element_output = input_node_metas[0].element_output
|
||||
|
||||
def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
|
||||
super().broadcast_propagation(input_node_metas)
|
||||
if self.is_output:
|
||||
self._store_tensor.broadcast(self.tensor.shape)
|
||||
130
python/cutlass/backend/evt/ir/tensor.py
Normal file
130
python/cutlass/backend/evt/ir/tensor.py
Normal file
@ -0,0 +1,130 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
High-level class for tensor
|
||||
"""
|
||||
|
||||
from cutlass import LayoutType
|
||||
|
||||
from cutlass.backend.evt.ir.layout_algorithm import (
|
||||
Layout,
|
||||
broadcast,
|
||||
canonicalization,
|
||||
permutation,
|
||||
reshape,
|
||||
_reverse_tuple
|
||||
)
|
||||
from cutlass.utils.datatypes import get_datatype_and_layout, get_tensor_shape, library_type
|
||||
|
||||
|
||||
class Tensor:
|
||||
"""
|
||||
The tensor abstracts the data type
|
||||
"""
|
||||
def __init__(self, tensor=None, element=None, shape=None, layout_tag=None, is_constant=False) -> None:
|
||||
if element is not None and tensor is not None:
|
||||
raise Exception(f"Must not specify both element and tensor")
|
||||
elif shape is not None and tensor is not None:
|
||||
raise Exception(f"Must not specify both shape and tensor")
|
||||
elif layout_tag is not None and tensor is not None:
|
||||
raise Exception(f"Must not specify both layout_tag and tensor")
|
||||
elif (element is None or layout_tag is None or shape is None) and (tensor is None) :
|
||||
raise Exception(f"Must specify one of (element, shape, layout) or (tensor)")
|
||||
|
||||
if isinstance(tensor, Tensor):
|
||||
# Directly copy all the attributes
|
||||
self.__dict__.update(vars(tensor))
|
||||
else:
|
||||
if tensor is None:
|
||||
self.element = library_type(element)
|
||||
else:
|
||||
self.element, layout_tag = get_datatype_and_layout(tensor)
|
||||
shape = get_tensor_shape(tensor)
|
||||
if layout_tag == LayoutType.RowMajor:
|
||||
self.layout = Layout(shape[::-1])
|
||||
elif layout_tag == LayoutType.ColumnMajor:
|
||||
self.layout = permutation(Layout(shape), [idx for idx in reversed(range(len(shape)))])
|
||||
self.layout = canonicalization(self.layout)
|
||||
|
||||
self.is_constant = is_constant
|
||||
# Save the tensor value if it is constant
|
||||
if is_constant and tensor is not None:
|
||||
self.value = tensor
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
"""
|
||||
Returns the RowMajor layout shape
|
||||
"""
|
||||
return _reverse_tuple(self.layout.shape)
|
||||
|
||||
@property
|
||||
def stride(self):
|
||||
"""
|
||||
Returns the RowMajor layout stride
|
||||
"""
|
||||
return _reverse_tuple(self.layout.stride)
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
"""
|
||||
Returns the rank of the tensor
|
||||
"""
|
||||
return len(self.shape)
|
||||
|
||||
#
|
||||
# Layout Algorithms
|
||||
#
|
||||
|
||||
def broadcast(self, shape):
|
||||
"""
|
||||
Broadcast self.layout to shape
|
||||
"""
|
||||
assert isinstance(shape, tuple)
|
||||
self.layout = broadcast(self.layout, _reverse_tuple(shape))
|
||||
|
||||
def reshape(self, shape):
|
||||
"""
|
||||
Reshape self.layout to shape
|
||||
"""
|
||||
assert isinstance(shape, tuple)
|
||||
reverse_shape = _reverse_tuple(shape)
|
||||
self.layout = reshape(self.layout, reverse_shape)
|
||||
|
||||
def permute(self, indices):
|
||||
"""
|
||||
Permute self.layout according to indices
|
||||
"""
|
||||
length = len(indices)
|
||||
indices = [length - idx - 1 for idx in indices]
|
||||
self.layout = permutation(self.layout, indices[::-1])
|
||||
42
python/cutlass/backend/evt/passes/__init__.py
Normal file
42
python/cutlass/backend/evt/passes/__init__.py
Normal file
@ -0,0 +1,42 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass.backend.evt.passes.graph_drawer import EVTGraphDrawer
|
||||
from cutlass.backend.evt.passes.pass_argument_type import PassGetArgumentType
|
||||
from cutlass.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree
|
||||
from cutlass.backend.evt.passes.pass_get_impl import PassGetImpl
|
||||
from cutlass.backend.evt.passes.pass_fix_element_d import PassFixElementD
|
||||
from cutlass.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination
|
||||
from cutlass.backend.evt.passes.pass_manager import EVTPassManager
|
||||
from cutlass.backend.evt.passes.pass_preprocess_red import PassPreprocessRed
|
||||
from cutlass.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
||||
from cutlass.backend.evt.passes.smem_size_calculator import GetSmemSize
|
||||
158
python/cutlass/backend/evt/passes/graph_drawer.py
Normal file
158
python/cutlass/backend/evt/passes/graph_drawer.py
Normal file
@ -0,0 +1,158 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
import subprocess
|
||||
|
||||
import pydot
|
||||
|
||||
from cutlass import DataTypeTag
|
||||
from cutlass.backend.evt.ir.dag_ir import DAGIR
|
||||
|
||||
|
||||
_COLOR_MAP = {
|
||||
"load": '"AliceBlue"',
|
||||
"compute": "LemonChiffon1",
|
||||
"accumulator": "LightGrey",
|
||||
"store": "PowderBlue",
|
||||
"layout": "lightseagreen",
|
||||
"dag": "darkorange"
|
||||
}
|
||||
|
||||
|
||||
class EVTGraphDrawer:
|
||||
"""
|
||||
Visualize a EVT DAGIR with graphviz
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
graph: DAGIR,
|
||||
name: str
|
||||
):
|
||||
self._name = name
|
||||
self._dot_graphs = {}
|
||||
|
||||
self._dot_graphs[name] = self._to_dot(graph, name)
|
||||
self.dot_available = self._check_dot_availability()
|
||||
|
||||
def _check_dot_availability(self):
|
||||
"""
|
||||
Check if graphviz is installed
|
||||
"""
|
||||
try:
|
||||
# Run the 'dot' command and capture its output
|
||||
result = subprocess.run(
|
||||
["dot", "-V"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
# Check if the command was successful and the output contains version information
|
||||
if result.returncode == 0 and "dot - graphviz" in result.stderr:
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _get_node_style(self, node):
|
||||
template = {
|
||||
"shape": "record",
|
||||
"fillcolor": "#CAFFE3",
|
||||
"style": '"filled,rounded"',
|
||||
"fontcolor": "#000000",
|
||||
}
|
||||
if node.op in _COLOR_MAP:
|
||||
template["fillcolor"] = _COLOR_MAP[node.op]
|
||||
else:
|
||||
raise NotImplementedError("unknown node op")
|
||||
if node.disabled:
|
||||
template["fontcolor"] = "grey"
|
||||
template["fillcolor"] = "white"
|
||||
return template
|
||||
|
||||
def _get_node_label(self, node):
|
||||
label = "{" + f"name={node.name}|op={node.op}"
|
||||
if node.op == "layout":
|
||||
label += f"|fn={node.fn.__name__}"
|
||||
for key in node.kwargs:
|
||||
label += f"|{key}={node.kwargs[key]}"
|
||||
if node.underlying_impl is not None:
|
||||
label += f"|impl={type(node.underlying_impl).__name__}"
|
||||
if node.op == "load":
|
||||
label += f"|element_output={DataTypeTag[node.underlying_impl.element]}"
|
||||
elif node.op == "compute":
|
||||
label += f"|element_compute={DataTypeTag[node.underlying_impl.element_compute]}|element_output={DataTypeTag[node.underlying_impl.element_output]}"
|
||||
elif node.op == "store":
|
||||
label += f"|element_store={DataTypeTag[node.underlying_impl.element]}|element_output={DataTypeTag[node.underlying_impl.element_output]}"
|
||||
elif node.op == "dag":
|
||||
label += f"|element_output={DataTypeTag[node.underlying_impl.element_output]}"
|
||||
if node.tensor is not None:
|
||||
shape = node.tensor.shape
|
||||
stride = node.tensor.stride
|
||||
label += f"|shape={shape}|stride={stride}"
|
||||
|
||||
if hasattr(node, "store_tensor"):
|
||||
if node.store_tensor is not None:
|
||||
store_shape = node.store_tensor.shape
|
||||
store_stride = node.store_tensor.stride
|
||||
label += f"|store_shape={store_shape}|stride_stride={store_stride}"
|
||||
|
||||
label += "}"
|
||||
return label
|
||||
|
||||
def _to_dot(
|
||||
self,
|
||||
graph: DAGIR,
|
||||
name: str
|
||||
):
|
||||
dot_graph = pydot.Dot(name, randir="TB")
|
||||
for node in graph.nodes_meta:
|
||||
style = self._get_node_style(node)
|
||||
label = self._get_node_label(node)
|
||||
dot_node = pydot.Node(
|
||||
node.name, label=label, **style
|
||||
)
|
||||
dot_graph.add_node(dot_node)
|
||||
if node.op == "dag":
|
||||
dot_subgraph = self._to_dot(node.subgraph, name=node.name)
|
||||
self._dot_graphs[node.name] = dot_subgraph
|
||||
|
||||
# Add edges
|
||||
for src, dst in graph.edges:
|
||||
weight = graph.get_edge_weight(src, dst)
|
||||
dot_graph.add_edge(pydot.Edge(src, dst, label=weight))
|
||||
|
||||
return dot_graph
|
||||
|
||||
def get_dot_graph(self) -> pydot.Dot:
|
||||
return [(key, self.get_dot_graph_by_name(key)) for key in self._dot_graphs.keys()]
|
||||
|
||||
def get_dot_graph_by_name(self, name) -> pydot.Dot:
|
||||
return self._dot_graphs[name]
|
||||
|
||||
def get_main_dot_graph(self) -> pydot.Dot:
|
||||
return self._dot_graphs[self._name]
|
||||
116
python/cutlass/backend/evt/passes/pass_argument_type.py
Normal file
116
python/cutlass/backend/evt/passes/pass_argument_type.py
Normal file
@ -0,0 +1,116 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Construct the epilogue visitor argument type
|
||||
"""
|
||||
|
||||
from cutlass.backend.c_types import visitor_factory
|
||||
from cutlass.backend.evt.ir import TopoVisitorNode
|
||||
from cutlass.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree
|
||||
from cutlass.backend.evt.passes.pass_get_impl import PassGetImpl
|
||||
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
|
||||
from cutlass.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
||||
|
||||
|
||||
class PassGetArgumentType(EVTPassBase):
|
||||
"""
|
||||
Construct the epilogue visitor argument type
|
||||
"""
|
||||
dependencies = [
|
||||
PassShapeTypePropagation, # The Layout of all nodes must be set
|
||||
PassDAG2Tree, # The type of each node must be set
|
||||
PassGetImpl # The DAG subgraphs must be set
|
||||
]
|
||||
|
||||
def requires(self) -> None:
|
||||
# Check "D" is in the node list
|
||||
if self.cc == 90 and (not self.dag_ir.has_node("D")):
|
||||
raise SyntaxError(
|
||||
"Sm90 EVT requires the epilogue to have a returned tensor D, "
|
||||
"but the variable 'D' is not found in the return values.")
|
||||
|
||||
def call(self):
|
||||
nodes = self.dag_ir.nodes_topological_order()
|
||||
self.argument_types = {}
|
||||
for node in nodes:
|
||||
meta = self.dag_ir.get_node_meta(node)
|
||||
if not meta.disabled:
|
||||
self.argument_types[node] = meta.underlying_impl.argument_type
|
||||
if node == "D" and self.cc == 90:
|
||||
continue
|
||||
if isinstance(meta, TopoVisitorNode):
|
||||
self.get_dag_argument_type(node)
|
||||
else:
|
||||
self.get_evt_argument_type(node)
|
||||
|
||||
self.cc_specific_method(self.set_argument_type)()
|
||||
|
||||
def get_evt_argument_type(self, node):
|
||||
# Sort the input nodes by edge weight
|
||||
input_types = [self.argument_types[child] for child in self.dag_ir.get_all_inputs(node)]
|
||||
if len(input_types) > 0:
|
||||
self.argument_types[node] = visitor_factory(
|
||||
input_types + [self.argument_types[node],], self.dag_ir.get_all_inputs(node) + [node,])
|
||||
|
||||
def get_dag_argument_type(self, node):
|
||||
meta = self.dag_ir.get_node_meta(node)
|
||||
subgraph = meta.subgraph
|
||||
subgraph_nodes = subgraph.nodes_topological_order()
|
||||
# Visit the unvisited nodes in subgraph
|
||||
for n in subgraph_nodes:
|
||||
m = subgraph.get_node_meta(n)
|
||||
if m.disabled:
|
||||
continue
|
||||
else:
|
||||
self.argument_types[n] = m.underlying_impl.argument_type
|
||||
input_types = [self.argument_types[child] for child in subgraph_nodes[:-1]]
|
||||
if len(input_types) > 0:
|
||||
self.argument_types[node] = visitor_factory(input_types, subgraph_nodes[:-1])
|
||||
|
||||
def set_argument_type(self):
|
||||
pass
|
||||
|
||||
def sm90_set_argument_type(self):
|
||||
self.dag_ir.epilogue_thread_type = self.argument_types[self.dag_ir.get_all_inputs("D")[0]]
|
||||
# Get the tensorD argument type
|
||||
self.dag_ir.arg_d_type = self.dag_ir.get_node_meta("D").underlying_impl.argument_type_d
|
||||
|
||||
# Get the tensorC argument type
|
||||
if self.dag_ir.has_node("C"):
|
||||
self.dag_ir.arg_c_type = self.dag_ir.get_node_meta("C").underlying_impl.argument_type_c
|
||||
else:
|
||||
self.dag_ir.arg_c_type = self.dag_ir.arg_d_type
|
||||
|
||||
def sm80_set_argument_type(self):
|
||||
nodes = self.dag_ir.nodes_topological_order()
|
||||
self.dag_ir.epilogue_thread_type = self.argument_types[nodes[-1]]
|
||||
147
python/cutlass/backend/evt/passes/pass_dag_2_tree.py
Normal file
147
python/cutlass/backend/evt/passes/pass_dag_2_tree.py
Normal file
@ -0,0 +1,147 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Merge non-tree sub-graphs of the DAG IR into a single DAG. The fused DAG will be implemented
|
||||
by the topological visitor, while the rest of the graph will be implemented with the tree visitor.
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
from cutlass.backend.evt.ir import DAGIR, TopoVisitorNode
|
||||
from cutlass.backend.evt.passes.pass_get_impl import PassGetImpl
|
||||
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
|
||||
from cutlass.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
||||
|
||||
|
||||
class PassDAG2Tree(EVTPassBase):
|
||||
"""
|
||||
Convert the DAG IR to Tree by fusing subgraphs
|
||||
"""
|
||||
dependencies = [
|
||||
PassShapeTypePropagation,
|
||||
PassGetImpl
|
||||
]
|
||||
|
||||
def call(self):
|
||||
# Step 1: find the nodes that have multiple parents
|
||||
multi_parent_nodes = []
|
||||
|
||||
for node in self.dag_ir.nodes_topological_order():
|
||||
if self.dag_ir.out_degree(node) > 1:
|
||||
multi_parent_nodes.append(node)
|
||||
# Step 2: find the lowest common ancestor (LCA) of all its parents
|
||||
for node in multi_parent_nodes:
|
||||
# A multi-parent node could be already fused by the previous node
|
||||
if not self.dag_ir.has_node(node):
|
||||
continue
|
||||
# A node uncovered by the previous fusions can have out degree change
|
||||
# Case 1: it has <= 1 edges to the previously fused subgraph, no degree change
|
||||
# Case 2: it has more than one edges to the previously fused subgraph, degree drops
|
||||
if self.dag_ir.out_degree(node) <= 1:
|
||||
continue
|
||||
|
||||
# Otherwise, the node still
|
||||
reachable_nodes = []
|
||||
# Complexity: O(Dout*N)
|
||||
for parent in self.dag_ir.get_users(node):
|
||||
reachable_nodes.append(set(self.dag_ir.all_reachable_nodes(parent)))
|
||||
# get the common reachable objects
|
||||
common_items = set.intersection(*reachable_nodes)
|
||||
|
||||
# If common ancestor exists, find the lowest one
|
||||
if len(common_items) > 0:
|
||||
topo_order = self.dag_ir.nodes_topological_order()
|
||||
lca = None
|
||||
topo_idx = -1
|
||||
for item in common_items:
|
||||
if lca is None:
|
||||
lca = item
|
||||
topo_idx = topo_order.index(item)
|
||||
else:
|
||||
if topo_idx > topo_order.index(item):
|
||||
lca = item
|
||||
topo_idx = topo_order.index(item)
|
||||
# The lca is the output node of the DAG node
|
||||
# Get the nodes to be fused
|
||||
node_to_fuse = set.union(*reachable_nodes).difference(common_items)
|
||||
node_to_fuse.add(lca)
|
||||
# Get all the input nodes
|
||||
all_input_nodes = []
|
||||
all_output_nodes = []
|
||||
for node in node_to_fuse:
|
||||
all_input_nodes.append(set(self.dag_ir.get_all_inputs(node)))
|
||||
all_output_nodes.append(set(self.dag_ir.get_users(node)))
|
||||
all_input_nodes = set.union(*all_input_nodes)
|
||||
all_output_nodes = set.union(*all_output_nodes)
|
||||
|
||||
new_subgraph_nodes = set.union(node_to_fuse, all_input_nodes, all_output_nodes)
|
||||
|
||||
# Create the subgraph
|
||||
subgraph_ = self.dag_ir._graph.subgraph(new_subgraph_nodes)
|
||||
subgraph = DAGIR()
|
||||
for node in subgraph_.nodes:
|
||||
meta = deepcopy(self.dag_ir.get_node_meta(node))
|
||||
if node not in node_to_fuse:
|
||||
meta.disabled = True
|
||||
subgraph.add_node(meta)
|
||||
for edge in subgraph_.edges:
|
||||
subgraph.add_edge(edge[0], edge[1], self.dag_ir.get_edge_weight(edge[0], edge[1]))
|
||||
|
||||
|
||||
# Create the fused node
|
||||
dag_node = TopoVisitorNode(
|
||||
name=f"dag_{lca}", subgraph=subgraph,
|
||||
output_node=self.dag_ir.get_node_meta(lca))
|
||||
self.dag_ir.add_node(dag_node)
|
||||
|
||||
# Add input edges
|
||||
for idx, node in enumerate(all_input_nodes):
|
||||
self.dag_ir.add_edge(node, dag_node.name, weight=idx)
|
||||
|
||||
# Replace all uses with DAG node (only 1 output node)
|
||||
self.dag_ir.replace_all_uses_with(lca, dag_node.name)
|
||||
|
||||
# Remove all fused nodes
|
||||
node_to_fuse.remove(lca)
|
||||
for node in node_to_fuse:
|
||||
self.dag_ir.remove_node(node)
|
||||
|
||||
else:
|
||||
raise NotImplementedError("No LCA found. Consider SplitTreeVisitor.")
|
||||
|
||||
def ensures(self) -> None:
|
||||
# Ensure that after the pass, the resulting DAG becomes a tree
|
||||
for node in self.dag_ir.nodes:
|
||||
out_degree = self.dag_ir.out_degree(node)
|
||||
if out_degree > 1:
|
||||
raise RuntimeError(f"PassDAG2Tree failed. Node {node} still have outdegree = {out_degree}")
|
||||
@ -1,6 +1,6 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
@ -30,40 +30,35 @@
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cuda import cuda, cudart
|
||||
"""
|
||||
Fix the element_output of producer of D.
|
||||
|
||||
In Sm90 epilogue visitor, the node writing D to gmem does not have internal
|
||||
element converter, so the compute node producing D must have element_output = type(D).
|
||||
"""
|
||||
|
||||
from cutlass.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination
|
||||
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
|
||||
|
||||
|
||||
class GpuTimer:
|
||||
def __init__(self) -> None:
|
||||
self.events = [
|
||||
cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1],
|
||||
cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1],
|
||||
]
|
||||
class PassFixElementD(EVTPassBase):
|
||||
"""
|
||||
In Sm90 epilogue visitor, the node writing D to gmem does not have internal
|
||||
element converter, so the compute node producing D must have
|
||||
element_output = type(D)
|
||||
"""
|
||||
dependencies = [
|
||||
PassLayoutManipulateElimination
|
||||
]
|
||||
def get_producer(self, node, element_D):
|
||||
node_meta = self.dag_ir.get_node_meta(node)
|
||||
if node_meta.op == "compute":
|
||||
node_meta.element_output = element_D
|
||||
elif node_meta.op == "store":
|
||||
self.get_producer(self.dag_ir.get_all_inputs(node)[0], element_D)
|
||||
|
||||
def start(self, stream=cuda.CUstream(0)):
|
||||
(err,) = cuda.cuEventRecord(self.events[0], stream)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
|
||||
def stop(self, stream=cuda.CUstream(0)):
|
||||
(err,) = cuda.cuEventRecord(self.events[1], stream)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
pass
|
||||
|
||||
def stop_and_wait(self, stream=cuda.CUstream(0)):
|
||||
self.stop(stream)
|
||||
if stream:
|
||||
(err,) = cuda.cuStreamSynchronize(stream)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
else:
|
||||
(err,) = cudart.cudaDeviceSynchronize()
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
|
||||
def duration(self, iterations=1):
|
||||
err, duration = cuda.cuEventElapsedTime(self.events[0], self.events[1])
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
return duration / float(iterations)
|
||||
def call(self):
|
||||
if self.dag_ir.has_node("D"):
|
||||
node_d_meta = self.dag_ir.get_node_meta("D")
|
||||
element_D = node_d_meta.store_tensor.element
|
||||
self.get_producer("D", element_D)
|
||||
89
python/cutlass/backend/evt/passes/pass_get_impl.py
Normal file
89
python/cutlass/backend/evt/passes/pass_get_impl.py
Normal file
@ -0,0 +1,89 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Infer the underlying implement of each node.
|
||||
|
||||
While the frontend only distinguish between Load/Store/Compute Node,
|
||||
each of these nodes can have different underlying implementation based
|
||||
on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadcast, etc.
|
||||
This pass infers the underlying impl of each node
|
||||
"""
|
||||
|
||||
import cutlass.backend.evt.backend as evt_backend
|
||||
from cutlass.backend.evt.ir import DAGIR, LoadNode
|
||||
from cutlass.backend.evt.passes.pass_fix_element_d import PassFixElementD
|
||||
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
|
||||
from cutlass.backend.evt.passes.pass_no_op_elimination import PassNoOpElimination
|
||||
from cutlass.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
||||
|
||||
|
||||
class PassGetImpl(EVTPassBase):
|
||||
"""
|
||||
While the frontend only distinguish between Load/Store/Compute Node,
|
||||
each of these nodes can have different underlying implementation based
|
||||
on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadcast, etc.
|
||||
This pass infers the underlying impl of each node
|
||||
"""
|
||||
dependencies = [
|
||||
PassShapeTypePropagation, # The shape and type info are required for inference
|
||||
PassFixElementD
|
||||
]
|
||||
|
||||
def __init__(self, dag_ir: DAGIR) -> None:
|
||||
super().__init__(dag_ir)
|
||||
self.no_op_elimination = PassNoOpElimination(dag_ir)
|
||||
|
||||
def requires(self) -> None:
|
||||
# Verify "accum" is in the arg list
|
||||
if not self.dag_ir.has_node("accum"):
|
||||
raise SyntaxError("Cannot find 'accum' in the argument list.")
|
||||
|
||||
def call(self):
|
||||
# The loop structure of the epilogue is determined by the
|
||||
# accumulator shape
|
||||
accumulator: LoadNode = self.dag_ir.get_node_meta("accum")
|
||||
problem_size = accumulator.tensor.shape
|
||||
|
||||
for node_meta in self.dag_ir.node_metas_topological_order():
|
||||
node_meta.get_underlying_impl(problem_size)
|
||||
|
||||
def ensures(self) -> None:
|
||||
# Some nodes will be lowered to NoOp, eliminate them
|
||||
self.no_op_elimination()
|
||||
# Lower to cc-specific impl
|
||||
for node_meta in self.dag_ir.nodes_meta:
|
||||
node_impl_ccs = getattr(evt_backend, f"sm{self.cc}_nodes")
|
||||
node_meta.underlying_impl = getattr(
|
||||
node_impl_ccs,
|
||||
f"Sm{self.cc}" + node_meta.underlying_impl.__class__.__name__
|
||||
)(node_meta)
|
||||
217
python/cutlass/backend/evt/passes/pass_layout_elimination.py
Normal file
217
python/cutlass/backend/evt/passes/pass_layout_elimination.py
Normal file
@ -0,0 +1,217 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Eliminate layout manipulation nodes
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
from cutlass.backend.evt.ir import DAGIR, LayoutNode
|
||||
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
|
||||
from cutlass.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
||||
|
||||
|
||||
class PassLayoutManipulateElimination(EVTPassBase):
|
||||
"""
|
||||
Eliminate layout manipulation nodes
|
||||
"""
|
||||
dependencies = [PassShapeTypePropagation]
|
||||
|
||||
def __init__(self, dag_ir: DAGIR) -> None:
|
||||
super().__init__(dag_ir)
|
||||
self.copy_cnt = 0
|
||||
|
||||
def call(self):
|
||||
self.layout_nodes_worklist = self.get_all_layout_nodes()
|
||||
# Run while loop utill all layout nodes are eliminated
|
||||
while(len(self.layout_nodes_worklist) > 0):
|
||||
node = self.layout_nodes_worklist.pop(0)
|
||||
# for node in layout_nodes:
|
||||
# Step 1: get the propagation direction
|
||||
direction = self.get_propagation_direction(node)
|
||||
self.visited = []
|
||||
getattr(self, f"propagate_to_{direction}")(self.dag_ir.get_node_meta(node), node)
|
||||
# Eliminate the current node
|
||||
input_node = self.dag_ir.get_all_inputs(node)[0]
|
||||
self.dag_ir.replace_all_uses_with(node, input_node)
|
||||
# layout_nodes = self.get_all_layout_nodes()
|
||||
|
||||
def get_all_layout_nodes(self):
|
||||
layout_nodes = []
|
||||
for node_meta in reversed(self.dag_ir.node_metas_topological_order()):
|
||||
if isinstance(node_meta, LayoutNode):
|
||||
layout_nodes.append(node_meta.name)
|
||||
return layout_nodes
|
||||
|
||||
def get_propagation_direction(self, node: str):
|
||||
"""
|
||||
The logic is propagating all layout nodes away from the accumulator node.
|
||||
"""
|
||||
self.visited = []
|
||||
self.get_influenced_users(node)
|
||||
nodes_influenced_dir_users = self.visited
|
||||
self.visited = []
|
||||
self.get_influenced_inputs(node)
|
||||
nodes_influenced_dir_inputs = self.visited
|
||||
|
||||
if "accum" in nodes_influenced_dir_users and "accum" not in nodes_influenced_dir_inputs:
|
||||
return "inputs"
|
||||
elif "accum" not in nodes_influenced_dir_users and "accum" in nodes_influenced_dir_inputs:
|
||||
return "users"
|
||||
else:
|
||||
raise RuntimeError("Unsolved propagation direction")
|
||||
|
||||
# Get all influenced nodes if we propagate along the user direction
|
||||
def get_influenced_users(self, node: str):
|
||||
if node in self.visited:
|
||||
return
|
||||
self.visited.append(node)
|
||||
|
||||
users = self.dag_ir.get_users(node)
|
||||
for user in users:
|
||||
self.get_influenced_users(user)
|
||||
user_inputs = []
|
||||
for user in users:
|
||||
user_inputs.append(set(self.dag_ir.get_all_inputs(user)))
|
||||
if len(user_inputs) > 0:
|
||||
user_inputs = set.union(*user_inputs)
|
||||
user_inputs.remove(node)
|
||||
for input in user_inputs:
|
||||
self.get_influenced_inputs(input)
|
||||
|
||||
# Get all influenced nodes if we propagate along the input direction
|
||||
def get_influenced_inputs(self, node: str):
|
||||
if node in self.visited:
|
||||
return
|
||||
self.visited.append(node)
|
||||
|
||||
inputs = self.dag_ir.get_all_inputs(node)
|
||||
for input in inputs:
|
||||
self.get_influenced_inputs(input)
|
||||
input_users = []
|
||||
for input in inputs:
|
||||
input_users.append(set(self.dag_ir.get_users(input)))
|
||||
if len(input_users) > 0:
|
||||
input_users = set.union(*input_users)
|
||||
input_users.remove(node)
|
||||
for user in input_users:
|
||||
self.get_influenced_users(user)
|
||||
|
||||
def add_copy_before(self, layout_node_meta: LayoutNode, target: str):
|
||||
copied_node_meta = deepcopy(layout_node_meta)
|
||||
copied_node = f"{copied_node_meta.name}_copy{self.copy_cnt}"
|
||||
self.copy_cnt += 1
|
||||
copied_node_meta.name = copied_node
|
||||
self.dag_ir.add_node(copied_node_meta)
|
||||
# Add edges
|
||||
target_inputs = self.dag_ir.get_all_inputs(target)
|
||||
for src in target_inputs:
|
||||
self.dag_ir.remove_edge(src, target)
|
||||
self.dag_ir.add_edge(src, copied_node)
|
||||
self.dag_ir.add_edge(copied_node, target)
|
||||
self.layout_nodes_worklist.append(copied_node)
|
||||
|
||||
def add_copy_after(self, layout_node_meta: LayoutNode, target: str):
|
||||
copied_node_meta = deepcopy(layout_node_meta)
|
||||
copied_node = f"{copied_node_meta.name}_copy{self.copy_cnt}"
|
||||
self.copy_cnt += 1
|
||||
copied_node_meta.name = copied_node
|
||||
self.dag_ir.add_node(copied_node_meta)
|
||||
# Add edges
|
||||
users = self.dag_ir.get_users(target)
|
||||
for user in users:
|
||||
self.dag_ir.remove_edge(target, user)
|
||||
self.dag_ir.add_edge(copied_node, user)
|
||||
self.dag_ir.add_edge(target, copied_node)
|
||||
self.layout_nodes_worklist.append(copied_node)
|
||||
|
||||
# Propagate the layout `node` along the user direction
|
||||
def propagate_to_users(self, layout_node_meta: LayoutNode, node: str):
|
||||
"""
|
||||
Propagate layout node to users
|
||||
"""
|
||||
if node in self.visited:
|
||||
# Avoid applying twice
|
||||
return
|
||||
self.visited.append(node)
|
||||
|
||||
node_meta = self.dag_ir.get_node_meta(node)
|
||||
if layout_node_meta.name != node:
|
||||
if isinstance(node_meta, LayoutNode):
|
||||
# Layout node is not transparent with layout node
|
||||
self.add_copy_before(layout_node_meta, node)
|
||||
return
|
||||
else:
|
||||
layout_node_meta.apply_to_user(node_meta)
|
||||
|
||||
users = self.dag_ir.get_users(node)
|
||||
user_inputs = []
|
||||
for user in users:
|
||||
user_inputs.append(set(self.dag_ir.get_all_inputs(user)))
|
||||
for user in users:
|
||||
self.propagate_to_users(layout_node_meta, user)
|
||||
if len(user_inputs) > 0:
|
||||
user_inputs = set.union(*user_inputs)
|
||||
user_inputs.remove(node)
|
||||
for input in user_inputs:
|
||||
self.propagate_to_inputs(layout_node_meta.get_inverse_node(), input)
|
||||
|
||||
# Propagate the layout `node` along the input direction
|
||||
def propagate_to_inputs(self, layout_node_meta: LayoutNode, node: str):
|
||||
"""
|
||||
Propagate layout node to inputs
|
||||
"""
|
||||
if node in self.visited:
|
||||
# Avoid applying twice
|
||||
return
|
||||
self.visited.append(node)
|
||||
|
||||
node_meta = self.dag_ir.get_node_meta(node)
|
||||
if layout_node_meta.name != node:
|
||||
if isinstance(node_meta, LayoutNode):
|
||||
# Layout node is not transparent with layout node
|
||||
self.add_copy_after(layout_node_meta, node)
|
||||
return
|
||||
else:
|
||||
layout_node_meta.apply_to_input(node_meta)
|
||||
inputs = self.dag_ir.get_all_inputs(node)
|
||||
input_users = []
|
||||
for input in inputs:
|
||||
input_users.append(set(self.dag_ir.get_users(input)))
|
||||
for input in inputs:
|
||||
self.propagate_to_inputs(layout_node_meta, input)
|
||||
if len(input_users) > 0:
|
||||
input_users = set.union(*input_users)
|
||||
input_users.remove(node)
|
||||
for user in input_users:
|
||||
self.propagate_to_users(layout_node_meta.get_inverse_node(), user)
|
||||
163
python/cutlass/backend/evt/passes/pass_manager.py
Normal file
163
python/cutlass/backend/evt/passes/pass_manager.py
Normal file
@ -0,0 +1,163 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Pass manager for DAG IR.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from cutlass.backend.evt.ir import DAGIR
|
||||
|
||||
|
||||
class EVTPassBase:
|
||||
"""
|
||||
Base class for EVT Passes
|
||||
"""
|
||||
dependencies = []
|
||||
def __init__(self, dag_ir: DAGIR) -> None:
|
||||
self.dag_ir = dag_ir
|
||||
self.cc = self.dag_ir.cc
|
||||
|
||||
def requires(self) -> None:
|
||||
"""
|
||||
This function will be called before the pass is run.
|
||||
"""
|
||||
pass
|
||||
|
||||
def call(self) -> None:
|
||||
"""
|
||||
The pass that is run through the self.dag_ir
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"__call__ is not overwritten in Pass {self.__class__.__name__}")
|
||||
|
||||
def ensures(self) -> None:
|
||||
"""
|
||||
This function will be called after the pass is run.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __call__(self) -> Any:
|
||||
self.requires()
|
||||
self.call()
|
||||
self.ensures()
|
||||
|
||||
def cc_specific_method(self, func):
|
||||
"""
|
||||
This enables defining function that behaves differently under different cc
|
||||
The simplest example of using this function is the following
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
class ExamplePass(EVTPassBase):
|
||||
|
||||
def call(sekf):
|
||||
# This automatically select the smXX_func based on current cc
|
||||
self.cc_specific_method(self.func)()
|
||||
|
||||
# Interface func, can be empty
|
||||
def func(self):
|
||||
pass
|
||||
|
||||
# Sm90 specific func
|
||||
def sm90_func(self):
|
||||
// sm90 specific method
|
||||
return
|
||||
|
||||
# Sm80 specific func
|
||||
def sm80_func(self):
|
||||
// sm80 specific method
|
||||
return
|
||||
"""
|
||||
func_name = f"sm{self.cc}_{func.__name__}"
|
||||
if hasattr(self, func_name):
|
||||
return getattr(self, func_name)
|
||||
else:
|
||||
raise NotImplementedError(f"func {func.__name__} is not overwritten for Sm{self.cc}")
|
||||
|
||||
|
||||
class EVTPassManager(nx.DiGraph):
|
||||
"""
|
||||
Topological-based Pass Manager.
|
||||
Each registered pass has a list of dependencies. The pass manager organizes
|
||||
the passes as a DAG and launch the compiler passes under topological order.
|
||||
"""
|
||||
def __init__(self, dag_ir: DAGIR, pass_list):
|
||||
super().__init__()
|
||||
self.dag_ir = dag_ir
|
||||
for pass_cls in pass_list:
|
||||
self.add_pass(pass_cls)
|
||||
|
||||
self.sorted_passes = self.schedule()
|
||||
|
||||
def get_callable(self, pass_name):
|
||||
"""
|
||||
Return the callable of the pass
|
||||
"""
|
||||
return self.nodes[pass_name]["callable"]
|
||||
|
||||
def add_pass(self, pass_cls):
|
||||
"""
|
||||
Add a pass to the pass manager
|
||||
:param pass_cls: the class of pass
|
||||
:type pass_cls: derived class of EVTPassBase
|
||||
"""
|
||||
name = pass_cls.__name__
|
||||
pass_callable = pass_cls(self.dag_ir)
|
||||
self.add_node(name, callable=pass_callable)
|
||||
|
||||
def schedule(self):
|
||||
"""
|
||||
Schedule the added passes under topological order
|
||||
"""
|
||||
# Add edges
|
||||
for pass_name in self.nodes:
|
||||
callable = self.get_callable(pass_name)
|
||||
for dependency_cls in callable.dependencies:
|
||||
self.add_edge(
|
||||
dependency_cls.__name__,
|
||||
type(callable).__name__)
|
||||
|
||||
# Topological sort
|
||||
return list(nx.topological_sort(self))
|
||||
|
||||
def __call__(self) -> Any:
|
||||
"""
|
||||
Launch the registered passes
|
||||
"""
|
||||
for pass_name in self.sorted_passes:
|
||||
callable = self.get_callable(pass_name)
|
||||
callable()
|
||||
53
python/cutlass/backend/evt/passes/pass_no_op_elimination.py
Normal file
53
python/cutlass/backend/evt/passes/pass_no_op_elimination.py
Normal file
@ -0,0 +1,53 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
No op elimination node
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from cutlass.backend.evt.ir import NoOpImpl
|
||||
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
|
||||
|
||||
|
||||
class PassNoOpElimination(EVTPassBase):
|
||||
"""
|
||||
The dead node elimination pass removes nodes with NoOpImpl in DAG IR
|
||||
"""
|
||||
dependencies = []
|
||||
|
||||
def call(self) -> Any:
|
||||
for node in self.dag_ir.nodes_topological_order():
|
||||
node_meta = self.dag_ir.get_node_meta(node)
|
||||
if isinstance(node_meta.underlying_impl, NoOpImpl):
|
||||
self.dag_ir.replace_all_uses_with(node, self.dag_ir.get_all_inputs(node)[0])
|
||||
98
python/cutlass/backend/evt/passes/pass_preprocess_red.py
Normal file
98
python/cutlass/backend/evt/passes/pass_preprocess_red.py
Normal file
@ -0,0 +1,98 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Preprocess the reduction nodes.
|
||||
|
||||
The parser treats reduction as Compute(op=(reg_reduce_fn, gmem_reduce_fn)) - Store()
|
||||
This pass fuses these into a single store node, and then replaces all uses of the
|
||||
current node with the new store node.
|
||||
"""
|
||||
|
||||
from cutlass.backend.evt.ir import ComputeNode, StoreNode
|
||||
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
|
||||
|
||||
|
||||
|
||||
class PassPreprocessRed(EVTPassBase):
|
||||
"""
|
||||
Preprocess red nodes
|
||||
"""
|
||||
|
||||
def call(self):
|
||||
# Step 1: find the compute nodes with op=red
|
||||
red_compute_nodes = []
|
||||
for node_meta in self.dag_ir.nodes_meta:
|
||||
if isinstance(node_meta, ComputeNode):
|
||||
if type(node_meta.fn) == tuple:
|
||||
# To keep the frontend simple, the reduction nodes
|
||||
# are parsed into compute nodes by default
|
||||
# The simple heuristic to distinguish between compute
|
||||
# and reduction node is that compute node is a single function,
|
||||
# while the reduction node is a tuple of functions for
|
||||
# in-register reduction and atomic global memory reduction
|
||||
red_compute_nodes.append(node_meta.name)
|
||||
|
||||
# Step 2: for each compute, merge it with the succeeding store
|
||||
for node in red_compute_nodes:
|
||||
# Verify
|
||||
users = self.dag_ir.get_users(node)
|
||||
inputs = self.dag_ir.get_all_inputs(node)
|
||||
# Has a single user
|
||||
assert len(users) == 1
|
||||
assert len(inputs) == 1
|
||||
user = users[0]
|
||||
input = inputs[0]
|
||||
|
||||
user_meta = self.dag_ir.get_node_meta(user)
|
||||
# Must be a store node
|
||||
assert isinstance(user_meta, StoreNode)
|
||||
# With output degree == 0
|
||||
assert self.dag_ir.out_degree(user) == 0
|
||||
# Register the reduce op
|
||||
node_meta = self.dag_ir.get_node_meta(node)
|
||||
user_meta.reg_reduce_fn, user_meta.gmem_reduce_fn = node_meta.fn
|
||||
user_meta.element_compute = node_meta.element_compute
|
||||
user_meta.round_style = node_meta.round_style
|
||||
|
||||
# Replace all uses
|
||||
self.dag_ir.remove_edge(input, node)
|
||||
input_users = self.dag_ir.get_users(input)
|
||||
for iu in input_users:
|
||||
weight = self.dag_ir.get_edge_weight(input, iu)
|
||||
self.dag_ir.add_edge(user, iu, weight)
|
||||
self.dag_ir.remove_edge(input, iu)
|
||||
self.dag_ir.add_edge(input, user)
|
||||
self.dag_ir.remove_node(node)
|
||||
|
||||
# Register the reduction name
|
||||
self.dag_ir.reduction_names.append(user)
|
||||
@ -0,0 +1,59 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Shape and type propagation pass
|
||||
"""
|
||||
|
||||
from cutlass.backend.evt.ir.node import NodeBase
|
||||
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
|
||||
from cutlass.backend.evt.passes.pass_preprocess_red import PassPreprocessRed
|
||||
|
||||
|
||||
class PassShapeTypePropagation(EVTPassBase):
|
||||
"""
|
||||
Propagate the shape and type of all nodes
|
||||
"""
|
||||
dependencies = [PassPreprocessRed]
|
||||
|
||||
def call(self):
|
||||
# Propagate the node shape and type
|
||||
for node in self.dag_ir.nodes_topological_order():
|
||||
node_meta: NodeBase = self.dag_ir.get_node_meta(node)
|
||||
input_node_metas = self.dag_ir.get_all_inputs_meta(node)
|
||||
node_meta.type_propagation(input_node_metas)
|
||||
node_meta.shape_propagation(input_node_metas)
|
||||
|
||||
for node in reversed(self.dag_ir.nodes_topological_order()):
|
||||
node_meta: NodeBase = self.dag_ir.get_node_meta(node)
|
||||
input_node_metas = self.dag_ir.get_all_inputs_meta(node)
|
||||
node_meta.broadcast_propagation(input_node_metas)
|
||||
200
python/cutlass/backend/evt/passes/smem_size_calculator.py
Normal file
200
python/cutlass/backend/evt/passes/smem_size_calculator.py
Normal file
@ -0,0 +1,200 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Compute the shared memory size in bytes
|
||||
"""
|
||||
|
||||
from pycute import shape_div, product
|
||||
|
||||
import cutlass
|
||||
from cutlass.backend.evt.ir import TopoVisitorNode, DAGIR
|
||||
from cutlass.backend.library import DataTypeSize
|
||||
|
||||
|
||||
class GetSmemSize:
|
||||
"""
|
||||
Get the size in byte of shared memory used by the kernel
|
||||
"""
|
||||
def __init__(self, dag_ir: DAGIR) -> None:
|
||||
self.dag_ir = dag_ir
|
||||
self.cc = self.dag_ir.cc
|
||||
|
||||
#
|
||||
# Sm90 epilogue specific
|
||||
#
|
||||
|
||||
def sm90_epilogue_tile(self, tile_description):
|
||||
# Get the epilogue tile size
|
||||
schedule = tile_description.epilogue_schedule
|
||||
if schedule == cutlass.EpilogueScheduleType.TmaWarpSpecialized:
|
||||
epilogue_tile_mn = (64, 32)
|
||||
elif schedule == cutlass.EpilogueScheduleType.TmaWarpSpecializedCooperative:
|
||||
epilogue_tile_mn = (128, 32)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported schedule: {schedule}")
|
||||
|
||||
# Get the pipeline stages
|
||||
stages_d = 2
|
||||
epi_tiles = product(shape_div(tuple(tile_description.threadblock_shape)[:2], epilogue_tile_mn))
|
||||
if self.dag_ir.has_node("C"):
|
||||
element_c = self.dag_ir.get_node_meta("C").element
|
||||
else:
|
||||
element_c = None
|
||||
|
||||
element_d = self.dag_ir.get_node_meta("D").element
|
||||
if element_c == element_d:
|
||||
reuse_smem_c = True
|
||||
else:
|
||||
reuse_smem_c = False
|
||||
stages_c = max(epi_tiles, stages_d + 1) if reuse_smem_c else epi_tiles
|
||||
|
||||
# Record the epilogue tile
|
||||
self.cta_tile_mnk = tuple(tile_description.threadblock_shape)
|
||||
self.epilogue_tile_mn = epilogue_tile_mn
|
||||
self.epi_tiles = epi_tiles
|
||||
self.stages_c = stages_c
|
||||
self.stages_d = stages_d
|
||||
self.reuse_smem_c = reuse_smem_c
|
||||
self.element_c = element_c
|
||||
self.element_d = element_d
|
||||
self.is_source_supported = element_c is not None
|
||||
|
||||
def sm90_epilogue_smem_size(self, tile_description):
|
||||
"""
|
||||
Compute the shared memory size of sm90 collective epilogue
|
||||
"""
|
||||
self.sm90_epilogue_tile(tile_description)
|
||||
# Get the Fusion Storage
|
||||
nodes = self.dag_ir.nodes_topological_order()
|
||||
self.smem_types = {}
|
||||
for node in nodes:
|
||||
meta = self.dag_ir.get_node_meta(node)
|
||||
if not meta.disabled:
|
||||
self.smem_types[node] = meta.underlying_impl.get_smem_size(
|
||||
self.cta_tile_mnk, self.epilogue_tile_mn,
|
||||
self.stages_c, self.stages_d, self.epi_tiles)
|
||||
if node == "D":
|
||||
continue
|
||||
if isinstance(meta, TopoVisitorNode):
|
||||
self.get_dag_smem_type(node)
|
||||
else:
|
||||
self.get_evt_smem_type(node)
|
||||
|
||||
thread_smem_size = self.smem_types[self.dag_ir.get_all_inputs("D")[0]][0]
|
||||
# Get the Tensor Storage
|
||||
tensors = []
|
||||
if self.is_source_supported:
|
||||
smem_C = DataTypeSize[self.element_c] * product(self.epilogue_tile_mn) * self.stages_c // 8
|
||||
tensors.append((smem_C, 128))
|
||||
else:
|
||||
tensors.append((0, 1))
|
||||
if self.reuse_smem_c:
|
||||
tensors.append((0, 128))
|
||||
else:
|
||||
smem_D = DataTypeSize[self.element_d] * product(self.epilogue_tile_mn) * self.stages_d // 8
|
||||
tensors.append((smem_D, 128))
|
||||
tensors.append((thread_smem_size, 128))
|
||||
|
||||
tensor_smem_size = self.get_struct_size(tensors)
|
||||
# Get pipeline storage size
|
||||
# sizeof(uint64_t * stages_c * 2), alignment of uint64_t
|
||||
# 2 is for FullBarrier and EmptyBarrier
|
||||
pipeline_smem_size = (8 * self.stages_c * 2, 8)
|
||||
|
||||
# get SharedStorage size
|
||||
smem_size = self.get_struct_size([tensor_smem_size, pipeline_smem_size])
|
||||
return smem_size[0]
|
||||
|
||||
def __call__(self, tile_description):
|
||||
return getattr(self, f"sm{self.cc}_epilogue_smem_size")(tile_description)
|
||||
|
||||
#
|
||||
# Helper functions
|
||||
#
|
||||
|
||||
@staticmethod
|
||||
def get_visitor_size(members: list, ebo: bool):
|
||||
"""
|
||||
Get the size of struct in bytes
|
||||
"""
|
||||
offset = 0
|
||||
max_alignment = 1
|
||||
if len(members) > 0:
|
||||
# Get alignment
|
||||
for _, alignment in members:
|
||||
max_alignment = max(max_alignment, alignment)
|
||||
|
||||
for type_size, _ in members:
|
||||
if type_size != 0:
|
||||
offset = ((offset + max_alignment - 1) // max_alignment) * max_alignment
|
||||
if type_size == 0 and not ebo:
|
||||
offset += 1
|
||||
else:
|
||||
offset += type_size
|
||||
offset = ((offset + max_alignment - 1) // max_alignment) * max_alignment
|
||||
return (offset, max_alignment)
|
||||
else:
|
||||
# Struct size is at least 1
|
||||
return (1, 1)
|
||||
|
||||
def get_struct_size(self, members: list):
|
||||
"""
|
||||
Get the size of struct in bytes
|
||||
"""
|
||||
return self.get_visitor_size(members, False)
|
||||
|
||||
def get_evt_smem_type(self, node):
|
||||
# Sort the input nodes by edge weight
|
||||
input_types = [self.smem_types[child] for child in self.dag_ir.get_all_inputs(node)]
|
||||
input_types.append(self.smem_types[node])
|
||||
if len(input_types) > 1:
|
||||
ebo = len(input_types) > 4
|
||||
self.smem_types[node] = self.get_visitor_size(input_types, ebo)
|
||||
|
||||
def get_dag_smem_type(self, node):
|
||||
meta = self.dag_ir.get_node_meta(node)
|
||||
subgraph = meta.subgraph
|
||||
subgraph_nodes = subgraph.nodes_topological_order()
|
||||
# Visit the unvisited nodes in subgraph
|
||||
for n in subgraph_nodes:
|
||||
m = subgraph.get_node_meta(n)
|
||||
if m.disabled:
|
||||
continue
|
||||
else:
|
||||
self.smem_types[n] = m.underlying_impl.get_smem_size(
|
||||
self.cta_tile_mnk, self.epilogue_tile_mn,
|
||||
self.stages_c, self.stages_d, self.epi_tiles)
|
||||
input_types = [self.smem_types[child] for child in subgraph_nodes[:-1]]
|
||||
if len(input_types) > 0:
|
||||
ebo = len(input_types) > 4
|
||||
self.smem_types[node] = self.get_visitor_size(input_types, ebo)
|
||||
@ -36,10 +36,12 @@ import numpy as np
|
||||
from cutlass.backend.memory_manager import device_mem_alloc, todevice
|
||||
from cutlass.backend.utils.software import CheckPackages
|
||||
|
||||
if CheckPackages().check_torch():
|
||||
torch_available = CheckPackages().check_torch()
|
||||
if torch_available:
|
||||
import torch
|
||||
|
||||
if CheckPackages().check_cupy():
|
||||
cupy_available = CheckPackages().check_cupy()
|
||||
if cupy_available:
|
||||
import cupy as cp
|
||||
|
||||
|
||||
@ -94,3 +96,19 @@ class CupyFrontend:
|
||||
@staticmethod
|
||||
def argument(cupy_ndarray: "cp.ndarray"):
|
||||
return cuda.CUdeviceptr(int(cupy_ndarray.data.ptr))
|
||||
|
||||
class TensorFrontend:
|
||||
"""
|
||||
Universal Frontend for client-provide tensors
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def argument(tensor, is_output=False):
|
||||
if isinstance(tensor, np.ndarray):
|
||||
return NumpyFrontend.argument(tensor, is_output)
|
||||
elif torch_available and isinstance(tensor, torch.Tensor):
|
||||
return TorchFrontend.argument(tensor)
|
||||
elif cupy_available and isinstance(tensor, cp.ndarray):
|
||||
return CupyFrontend.argument(tensor)
|
||||
else:
|
||||
raise NotImplementedError("Unknown Tensor Type")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -31,14 +31,21 @@
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Common data types and string names for them. This file is similar to /tools/library/scripts/library.py,
|
||||
but uses the Pybind-bound CUTLASS data types as many keys to the dictionary.
|
||||
Common data types and string names/tags for them
|
||||
"""
|
||||
|
||||
import enum
|
||||
|
||||
import cutlass_bindings
|
||||
from cutlass import EpilogueScheduleType, KernelScheduleType, TileSchedulerType
|
||||
from cutlass import (
|
||||
ComplexTransform,
|
||||
DataType,
|
||||
DataTypeSize,
|
||||
EpilogueScheduleType,
|
||||
KernelScheduleType,
|
||||
MathOperation,
|
||||
OpcodeClass,
|
||||
TileSchedulerType
|
||||
)
|
||||
|
||||
|
||||
# The following block implements enum.auto() for Python 3.5 variants that don't include it such
|
||||
@ -58,121 +65,6 @@ except ImportError:
|
||||
return i
|
||||
|
||||
|
||||
ShortDataTypeNames = {
|
||||
cutlass_bindings.int32: "i",
|
||||
cutlass_bindings.float16: "h",
|
||||
cutlass_bindings.float32: "s",
|
||||
cutlass_bindings.float64: "d",
|
||||
cutlass_bindings.dtype.cf32: "c",
|
||||
cutlass_bindings.dtype.cf64: "z",
|
||||
}
|
||||
|
||||
|
||||
DataTypeNames = {
|
||||
cutlass_bindings.dtype.b1: "b1",
|
||||
cutlass_bindings.dtype.u4: "u4",
|
||||
cutlass_bindings.dtype.u8: "u8",
|
||||
cutlass_bindings.dtype.u16: "u16",
|
||||
cutlass_bindings.dtype.u32: "u32",
|
||||
cutlass_bindings.dtype.u64: "u64",
|
||||
cutlass_bindings.dtype.s4: "s4",
|
||||
cutlass_bindings.int8: "s8",
|
||||
cutlass_bindings.dtype.s16: "s16",
|
||||
cutlass_bindings.int32: "s32",
|
||||
cutlass_bindings.dtype.s64: "s64",
|
||||
cutlass_bindings.float16: "f16",
|
||||
cutlass_bindings.bfloat16: "bf16",
|
||||
cutlass_bindings.float32: "f32",
|
||||
cutlass_bindings.tfloat32: "tf32",
|
||||
cutlass_bindings.float64: "f64",
|
||||
cutlass_bindings.dtype.cf16: "cf16",
|
||||
cutlass_bindings.dtype.cbf16: "cbf16",
|
||||
cutlass_bindings.dtype.cf32: "cf32",
|
||||
cutlass_bindings.dtype.ctf32: "ctf32",
|
||||
cutlass_bindings.dtype.cf64: "cf64",
|
||||
cutlass_bindings.dtype.cu4: "cu4",
|
||||
cutlass_bindings.dtype.cu8: "cu8",
|
||||
cutlass_bindings.dtype.cu16: "cu16",
|
||||
cutlass_bindings.dtype.cu32: "cu32",
|
||||
cutlass_bindings.dtype.cu64: "cu64",
|
||||
cutlass_bindings.dtype.cs4: "cs4",
|
||||
cutlass_bindings.dtype.cs8: "cs8",
|
||||
cutlass_bindings.dtype.cs16: "cs16",
|
||||
cutlass_bindings.dtype.cs32: "cs32",
|
||||
cutlass_bindings.dtype.cs64: "cs64",
|
||||
}
|
||||
|
||||
|
||||
DataTypeTag = {
|
||||
cutlass_bindings.dtype.b1: "cutlass::uint1b_t",
|
||||
cutlass_bindings.dtype.u4: "cutlass::uint4b_t",
|
||||
cutlass_bindings.dtype.u8: "uint8_t",
|
||||
cutlass_bindings.dtype.u16: "uint16_t",
|
||||
cutlass_bindings.dtype.u32: "uint32_t",
|
||||
cutlass_bindings.dtype.u64: "uint64_t",
|
||||
cutlass_bindings.dtype.s4: "cutlass::int4b_t",
|
||||
cutlass_bindings.int8: "int8_t",
|
||||
cutlass_bindings.dtype.s16: "int16_t",
|
||||
cutlass_bindings.int32: "int32_t",
|
||||
cutlass_bindings.dtype.s64: "int64_t",
|
||||
cutlass_bindings.float16: "cutlass::half_t",
|
||||
cutlass_bindings.bfloat16: "cutlass::bfloat16_t",
|
||||
cutlass_bindings.float32: "float",
|
||||
cutlass_bindings.tfloat32: "cutlass::tfloat32_t",
|
||||
cutlass_bindings.float64: "double",
|
||||
cutlass_bindings.dtype.cf16: "cutlass::complex<cutlass::half_t>",
|
||||
cutlass_bindings.dtype.cbf16: "cutlass::complex<cutlass::bfloat16_t>",
|
||||
cutlass_bindings.dtype.cf32: "cutlass::complex<float>",
|
||||
cutlass_bindings.dtype.ctf32: "cutlass::complex<cutlass::tfloat32_t>",
|
||||
cutlass_bindings.dtype.cf64: "cutlass::complex<double>",
|
||||
cutlass_bindings.dtype.cu4: "cutlass::complex<cutlass::uint4b_t>",
|
||||
cutlass_bindings.dtype.cu8: "cutlass::complex<cutlass::uint8_t>",
|
||||
cutlass_bindings.dtype.cu16: "cutlass::complex<cutlass::uint16_t>",
|
||||
cutlass_bindings.dtype.cu32: "cutlass::complex<cutlass::uint32_t>",
|
||||
cutlass_bindings.dtype.cu64: "cutlass::complex<cutlass::uint64_t>",
|
||||
cutlass_bindings.dtype.cs4: "cutlass::complex<cutlass::int4b_t>",
|
||||
cutlass_bindings.dtype.cs8: "cutlass::complex<cutlass::int8_t>",
|
||||
cutlass_bindings.dtype.cs16: "cutlass::complex<cutlass::int16_t>",
|
||||
cutlass_bindings.dtype.cs32: "cutlass::complex<cutlass::int32_t>",
|
||||
cutlass_bindings.dtype.cs64: "cutlass::complex<cutlass::int64_t>",
|
||||
}
|
||||
|
||||
|
||||
DataTypeSize = {
|
||||
cutlass_bindings.dtype.b1: 1,
|
||||
cutlass_bindings.dtype.u4: 4,
|
||||
cutlass_bindings.dtype.u8: 8,
|
||||
cutlass_bindings.dtype.u16: 16,
|
||||
cutlass_bindings.dtype.u32: 32,
|
||||
cutlass_bindings.dtype.u64: 64,
|
||||
cutlass_bindings.dtype.s4: 4,
|
||||
cutlass_bindings.int8: 8,
|
||||
cutlass_bindings.dtype.s16: 16,
|
||||
cutlass_bindings.int32: 32,
|
||||
cutlass_bindings.dtype.s64: 64,
|
||||
cutlass_bindings.float16: 16,
|
||||
cutlass_bindings.bfloat16: 16,
|
||||
cutlass_bindings.float32: 32,
|
||||
cutlass_bindings.tfloat32: 32,
|
||||
cutlass_bindings.float64: 64,
|
||||
cutlass_bindings.dtype.cf16: 32,
|
||||
cutlass_bindings.dtype.cbf16: 32,
|
||||
cutlass_bindings.dtype.cf32: 64,
|
||||
cutlass_bindings.dtype.ctf32: 32,
|
||||
cutlass_bindings.dtype.cf64: 128,
|
||||
cutlass_bindings.dtype.cu4: 8,
|
||||
cutlass_bindings.dtype.cu8: 16,
|
||||
cutlass_bindings.dtype.cu16: 32,
|
||||
cutlass_bindings.dtype.cu32: 64,
|
||||
cutlass_bindings.dtype.cu64: 128,
|
||||
cutlass_bindings.dtype.cs4: 8,
|
||||
cutlass_bindings.dtype.cs8: 16,
|
||||
cutlass_bindings.dtype.cs16: 32,
|
||||
cutlass_bindings.dtype.cs32: 64,
|
||||
cutlass_bindings.dtype.cs64: 128,
|
||||
}
|
||||
|
||||
|
||||
class DataTypeSizeBytes:
|
||||
"""
|
||||
Static class to mimic the `DataTypeSize` dictionary, but with checks for whether the
|
||||
@ -193,193 +85,15 @@ class DataTypeSizeBytes:
|
||||
bits = DataTypeSize[datatype]
|
||||
if bits < 8:
|
||||
raise Exception(
|
||||
"Data type {} is less than one byte in size.".format(datatype)
|
||||
f"Data type {datatype} is less than one byte in size."
|
||||
)
|
||||
elif bits % 8 != 0:
|
||||
raise Exception(
|
||||
"Data type {} is not an integer number of bytes.".format(datatype)
|
||||
f"Data type datatype is not an integer number of bytes."
|
||||
)
|
||||
return bits // 8
|
||||
|
||||
|
||||
ComplexTransformTag = {
|
||||
cutlass_bindings.complex_transform.none: "cutlass::ComplexTransform::kNone",
|
||||
cutlass_bindings.complex_transform.conj: "cutlass::ComplexTransform::kConjugate",
|
||||
}
|
||||
|
||||
|
||||
RealComplexBijection = [
|
||||
(cutlass_bindings.float16, cutlass_bindings.dtype.cf16),
|
||||
(cutlass_bindings.float32, cutlass_bindings.dtype.cf32),
|
||||
(cutlass_bindings.float64, cutlass_bindings.dtype.cf64),
|
||||
]
|
||||
|
||||
|
||||
def is_complex(data_type):
|
||||
for r, c in RealComplexBijection:
|
||||
if data_type == c:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_complex_from_real(real_type):
|
||||
for r, c in RealComplexBijection:
|
||||
if real_type == r:
|
||||
return c
|
||||
return cutlass_bindings.dtype.invalid
|
||||
|
||||
|
||||
def get_real_from_complex(complex_type):
|
||||
for r, c in RealComplexBijection:
|
||||
if complex_type == c:
|
||||
return r
|
||||
return cutlass_bindings.dtype.invalid
|
||||
|
||||
|
||||
class ComplexMultiplyOp(enum.Enum):
|
||||
multiply_add = enum_auto()
|
||||
gaussian = enum_auto()
|
||||
|
||||
|
||||
class MathOperation(enum.Enum):
|
||||
multiply_add = enum_auto()
|
||||
multiply_add_saturate = enum_auto()
|
||||
xor_popc = enum_auto()
|
||||
multiply_add_fast_bf16 = enum_auto()
|
||||
multiply_add_fast_f16 = enum_auto()
|
||||
multiply_add_fast_f32 = enum_auto()
|
||||
multiply_add_complex_fast_f32 = enum_auto()
|
||||
multiply_add_complex = enum_auto()
|
||||
multiply_add_complex_gaussian = enum_auto()
|
||||
|
||||
|
||||
MathOperationNames = {
|
||||
MathOperation.multiply_add: "multiply_add",
|
||||
MathOperation.multiply_add_saturate: "multiply_add_saturate",
|
||||
MathOperation.xor_popc: "xor_popc",
|
||||
MathOperation.multiply_add_fast_bf16: "multiply_add_fast_bf16",
|
||||
MathOperation.multiply_add_fast_f16: "multiply_add_fast_f16",
|
||||
MathOperation.multiply_add_fast_f32: "multiply_add_fast_f32",
|
||||
MathOperation.multiply_add_complex_fast_f32: "multiply_add_complex_fast_f32",
|
||||
MathOperation.multiply_add_complex: "multiply_add_complex",
|
||||
MathOperation.multiply_add_complex_gaussian: "multiply_add_complex_gaussian",
|
||||
}
|
||||
|
||||
|
||||
MathOperationTag = {
|
||||
MathOperation.multiply_add: "cutlass::arch::OpMultiplyAdd",
|
||||
MathOperation.multiply_add_saturate: "cutlass::arch::OpMultiplyAddSaturate",
|
||||
MathOperation.xor_popc: "cutlass::arch::OpXorPopc",
|
||||
MathOperation.multiply_add_fast_bf16: "cutlass::arch::OpMultiplyAddFastBF16",
|
||||
MathOperation.multiply_add_fast_f16: "cutlass::arch::OpMultiplyAddFastF16",
|
||||
MathOperation.multiply_add_fast_f32: "cutlass::arch::OpMultiplyAddFastF32",
|
||||
MathOperation.multiply_add_complex_fast_f32: "cutlass::arch::OpMultiplyAddComplexFastF32",
|
||||
MathOperation.multiply_add_complex: "cutlass::arch::OpMultiplyAddComplex",
|
||||
MathOperation.multiply_add_complex_gaussian: "cutlass::arch::OpMultiplyAddGaussianComplex",
|
||||
}
|
||||
|
||||
|
||||
LayoutTag = {
|
||||
cutlass_bindings.ColumnMajor: "cutlass::layout::ColumnMajor",
|
||||
cutlass_bindings.RowMajor: "cutlass::layout::RowMajor",
|
||||
cutlass_bindings.layout.ColumnMajorInterleaved2: "cutlass::layout::ColumnMajorInterleaved<2>",
|
||||
cutlass_bindings.layout.RowMajorInterleaved2: "cutlass::layout::RowMajorInterleaved<2>",
|
||||
cutlass_bindings.ColumnMajorInterleaved32: "cutlass::layout::ColumnMajorInterleaved<32>",
|
||||
cutlass_bindings.RowMajorInterleaved32: "cutlass::layout::RowMajorInterleaved<32>",
|
||||
cutlass_bindings.layout.ColumnMajorInterleaved64: "cutlass::layout::ColumnMajorInterleaved<64>",
|
||||
cutlass_bindings.layout.RowMajorInterleaved64: "cutlass::layout::RowMajorInterleaved<64>",
|
||||
cutlass_bindings.TensorNHWC: "cutlass::layout::TensorNHWC",
|
||||
cutlass_bindings.layout.TensorNDHWC: "cutlass::layout::TensorNDHWC",
|
||||
cutlass_bindings.layout.TensorNCHW: "cutlass::layout::TensorNCHW",
|
||||
cutlass_bindings.layout.TensorNGHWC: "cutlass::layout::TensorNGHWC",
|
||||
cutlass_bindings.TensorNC32HW32: "cutlass::layout::TensorNCxHWx<32>",
|
||||
cutlass_bindings.TensorC32RSK32: "cutlass::layout::TensorCxRSKx<32>",
|
||||
cutlass_bindings.layout.TensorNC64HW64: "cutlass::layout::TensorNCxHWx<64>",
|
||||
cutlass_bindings.layout.TensorC64RSK64: "cutlass::layout::TensorCxRSKx<64>",
|
||||
}
|
||||
|
||||
|
||||
TransposedLayout = {
|
||||
cutlass_bindings.ColumnMajor: cutlass_bindings.RowMajor,
|
||||
cutlass_bindings.RowMajor: cutlass_bindings.ColumnMajor,
|
||||
cutlass_bindings.layout.ColumnMajorInterleaved2: cutlass_bindings.layout.RowMajorInterleaved2,
|
||||
cutlass_bindings.layout.RowMajorInterleaved2: cutlass_bindings.layout.ColumnMajorInterleaved2,
|
||||
cutlass_bindings.ColumnMajorInterleaved32: cutlass_bindings.RowMajorInterleaved32,
|
||||
cutlass_bindings.RowMajorInterleaved32: cutlass_bindings.ColumnMajorInterleaved32,
|
||||
cutlass_bindings.layout.ColumnMajorInterleaved64: cutlass_bindings.layout.RowMajorInterleaved64,
|
||||
cutlass_bindings.layout.RowMajorInterleaved64: cutlass_bindings.layout.ColumnMajorInterleaved64,
|
||||
cutlass_bindings.TensorNHWC: cutlass_bindings.TensorNHWC,
|
||||
}
|
||||
|
||||
|
||||
ShortLayoutTypeNames = {
|
||||
cutlass_bindings.ColumnMajor: "n",
|
||||
cutlass_bindings.layout.ColumnMajorInterleaved2: "n2",
|
||||
cutlass_bindings.ColumnMajorInterleaved32: "n32",
|
||||
cutlass_bindings.layout.ColumnMajorInterleaved64: "n64",
|
||||
cutlass_bindings.RowMajor: "t",
|
||||
cutlass_bindings.layout.RowMajorInterleaved2: "t2",
|
||||
cutlass_bindings.RowMajorInterleaved32: "t32",
|
||||
cutlass_bindings.layout.RowMajorInterleaved64: "t64",
|
||||
cutlass_bindings.TensorNHWC: "nhwc",
|
||||
cutlass_bindings.layout.TensorNDHWC: "ndhwc",
|
||||
cutlass_bindings.layout.TensorNCHW: "nchw",
|
||||
cutlass_bindings.layout.TensorNGHWC: "nghwc",
|
||||
cutlass_bindings.TensorNC32HW32: "nc32hw32",
|
||||
cutlass_bindings.layout.TensorNC64HW64: "nc64hw64",
|
||||
cutlass_bindings.TensorC32RSK32: "c32rsk32",
|
||||
cutlass_bindings.layout.TensorC64RSK64: "c64rsk64",
|
||||
}
|
||||
|
||||
|
||||
ShortComplexLayoutNames = {
|
||||
(cutlass_bindings.ColumnMajor, cutlass_bindings.complex_transform.none): "n",
|
||||
(cutlass_bindings.ColumnMajor, cutlass_bindings.complex_transform.conj): "c",
|
||||
(cutlass_bindings.RowMajor, cutlass_bindings.complex_transform.none): "t",
|
||||
(cutlass_bindings.RowMajor, cutlass_bindings.complex_transform.conj): "h",
|
||||
}
|
||||
|
||||
|
||||
OpcodeClassNames = {
|
||||
cutlass_bindings.OpClass.Simt: "simt",
|
||||
cutlass_bindings.OpClass.TensorOp: "tensorop",
|
||||
cutlass_bindings.OpClass.WmmaTensorOp: "wmma_tensorop",
|
||||
cutlass_bindings.OpClass.SparseTensorOp: "sptensorop",
|
||||
}
|
||||
|
||||
|
||||
OpcodeClassTag = {
|
||||
cutlass_bindings.OpClass.Simt: "cutlass::arch::OpClassSimt",
|
||||
cutlass_bindings.OpClass.TensorOp: "cutlass::arch::OpClassTensorOp",
|
||||
cutlass_bindings.OpClass.WmmaTensorOp: "cutlass::arch::OpClassWmmaTensorOp",
|
||||
cutlass_bindings.OpClass.SparseTensorOp: "cutlass::arch::OpClassSparseTensorOp",
|
||||
}
|
||||
|
||||
|
||||
class OperationKind(enum.Enum):
|
||||
Gemm = enum_auto()
|
||||
Conv2d = enum_auto()
|
||||
Conv3d = enum_auto()
|
||||
|
||||
|
||||
OperationKindNames = {
|
||||
OperationKind.Gemm: "gemm",
|
||||
OperationKind.Conv2d: "conv2d",
|
||||
OperationKind.Conv3d: "conv3d",
|
||||
}
|
||||
|
||||
|
||||
ArchitectureNames = {
|
||||
50: "maxwell",
|
||||
60: "pascal",
|
||||
61: "pascal",
|
||||
70: "volta",
|
||||
75: "turing",
|
||||
80: "ampere",
|
||||
90: "hopper",
|
||||
}
|
||||
|
||||
|
||||
SharedMemPerCC = {
|
||||
70: 96 << 10, # 96KB of SMEM
|
||||
72: 96 << 10, # 96KB of SMEM
|
||||
@ -392,52 +106,8 @@ SharedMemPerCC = {
|
||||
}
|
||||
|
||||
|
||||
class GemmKind(enum.Enum):
|
||||
Gemm = enum_auto()
|
||||
Sparse = enum_auto()
|
||||
Universal = enum_auto()
|
||||
PlanarComplex = enum_auto()
|
||||
PlanarComplexArray = enum_auto()
|
||||
Grouped = enum_auto()
|
||||
|
||||
|
||||
GemmKindNames = {
|
||||
GemmKind.Gemm: "gemm",
|
||||
GemmKind.Sparse: "spgemm",
|
||||
GemmKind.Universal: "gemm",
|
||||
GemmKind.PlanarComplex: "gemm_planar_complex",
|
||||
GemmKind.PlanarComplexArray: "gemm_planar_complex_array",
|
||||
GemmKind.Grouped: "gemm_grouped",
|
||||
}
|
||||
|
||||
|
||||
class SwizzlingFunctor(enum.Enum):
|
||||
Identity1 = enum_auto()
|
||||
Identity2 = enum_auto()
|
||||
Identity4 = enum_auto()
|
||||
Identity8 = enum_auto()
|
||||
Horizontal = enum_auto()
|
||||
BatchedIdentity1 = enum_auto()
|
||||
StridedDgradIdentity1 = enum_auto()
|
||||
StridedDgradIdentity4 = enum_auto()
|
||||
StridedDgradHorizontal = enum_auto()
|
||||
|
||||
|
||||
SwizzlingFunctorTag = {
|
||||
cutlass_bindings.IdentitySwizzle1: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>",
|
||||
SwizzlingFunctor.Identity2: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>",
|
||||
SwizzlingFunctor.Identity4: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>",
|
||||
SwizzlingFunctor.Identity8: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>",
|
||||
SwizzlingFunctor.Horizontal: "cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle",
|
||||
SwizzlingFunctor.BatchedIdentity1: "cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle",
|
||||
SwizzlingFunctor.StridedDgradIdentity1: "cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>",
|
||||
SwizzlingFunctor.StridedDgradIdentity4: "cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>",
|
||||
SwizzlingFunctor.StridedDgradHorizontal: "cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle",
|
||||
}
|
||||
|
||||
|
||||
class SchedulerMode(enum.Enum):
|
||||
Device = (enum_auto(),)
|
||||
Device = enum_auto()
|
||||
Host = enum_auto()
|
||||
|
||||
|
||||
@ -450,61 +120,98 @@ SchedulerModeTag = {
|
||||
ShortSchedulerModeNames = {SchedulerMode.Device: "Device", SchedulerMode.Host: "Host"}
|
||||
|
||||
|
||||
ConvKindTag = {
|
||||
cutlass_bindings.conv.Operator.fprop: "cutlass::conv::Operator::kFprop",
|
||||
cutlass_bindings.conv.Operator.dgrad: "cutlass::conv::Operator::kDgrad",
|
||||
cutlass_bindings.conv.Operator.wgrad: "cutlass::conv::Operator::kWgrad",
|
||||
class FunctionalOp(enum.Enum):
|
||||
AtomicAdd = enum_auto()
|
||||
AtomicMaximum = enum_auto()
|
||||
Divides = enum_auto()
|
||||
Maximum = enum_auto()
|
||||
Minimum = enum_auto()
|
||||
Minus = enum_auto()
|
||||
Multiplies = enum_auto()
|
||||
MultiplyAdd = enum_auto()
|
||||
Plus = enum_auto()
|
||||
|
||||
|
||||
FunctionalOpTag = {
|
||||
FunctionalOp.AtomicAdd: "cutlass::atomic_add",
|
||||
FunctionalOp.AtomicMaximum: "cutlass::atomic_maximum",
|
||||
FunctionalOp.Divides: "cutlass::divides",
|
||||
FunctionalOp.Maximum: "cutlass::maximum",
|
||||
FunctionalOp.Minimum: "cutlass::minimum",
|
||||
FunctionalOp.Minus: "cutlass::minus",
|
||||
FunctionalOp.Multiplies: "cutlass::multiplies",
|
||||
FunctionalOp.MultiplyAdd: "cutlass::multiply_add",
|
||||
FunctionalOp.Plus: "cutlass::plus",
|
||||
}
|
||||
|
||||
|
||||
ConvKindNames = {
|
||||
cutlass_bindings.conv.Operator.fprop: "fprop",
|
||||
cutlass_bindings.conv.Operator.dgrad: "dgrad",
|
||||
cutlass_bindings.conv.Operator.wgrad: "wgrad",
|
||||
class ActivationOp(enum.Enum):
|
||||
DGelu = enum_auto()
|
||||
Gelu = enum_auto()
|
||||
GeluTaylor = enum_auto()
|
||||
HardSwish = enum_auto()
|
||||
Identity = enum_auto()
|
||||
LeakyReLU = enum_auto()
|
||||
ReLU = enum_auto()
|
||||
Sigmoid = enum_auto()
|
||||
SiLU = enum_auto()
|
||||
Tanh = enum_auto()
|
||||
|
||||
|
||||
ActivationOpTag = {
|
||||
ActivationOp.DGelu: "cutlass::epilogue::thread::dGELU",
|
||||
ActivationOp.Gelu: "cutlass::epilogue::thread::GELU",
|
||||
ActivationOp.GeluTaylor: "cutlass::epilogue::thread::GELU_taylor",
|
||||
ActivationOp.HardSwish: "cutlass::epilogue::thread::HardSwish",
|
||||
ActivationOp.Identity: "cutlass::epilogue::thread::Identity",
|
||||
ActivationOp.LeakyReLU: "cutlass::epilogue::thread::LeakyReLU",
|
||||
ActivationOp.ReLU: "cutlass::epilogue::thread::ReLu",
|
||||
ActivationOp.Sigmoid: "cutlass::epilogue::thread::Sigmoid",
|
||||
ActivationOp.SiLU: "cutlass::epilogue::thread::SiLu",
|
||||
ActivationOp.Tanh: "cutlass::epilogue::thread::Tanh",
|
||||
}
|
||||
|
||||
|
||||
IteratorAlgorithmTag = {
|
||||
cutlass_bindings.conv.IteratorAlgorithm.analytic: "cutlass::conv::IteratorAlgorithm::kAnalytic",
|
||||
cutlass_bindings.conv.IteratorAlgorithm.optimized: "cutlass::conv::IteratorAlgorithm::kOptimized",
|
||||
cutlass_bindings.conv.IteratorAlgorithm.fixed_channels: "cutlass::conv::IteratorAlgorithm::kFixedChannels",
|
||||
cutlass_bindings.conv.IteratorAlgorithm.few_channels: "cutlass::conv::IteratorAlgorithm::kFewChannels",
|
||||
}
|
||||
def op_tag(op) -> str:
|
||||
"""
|
||||
Dispatches `op` to the appropriate *Tag dictionary depending on whether
|
||||
`op` is an ActivationOp or FunctionalOp. This is useful for cases in which
|
||||
either type can be used.
|
||||
|
||||
:param op: operation to emit a tag for
|
||||
:type op: ActivationOp | FunctionalOp
|
||||
|
||||
:return: tag corresponding to op
|
||||
:rtype: str
|
||||
"""
|
||||
if isinstance(op, ActivationOp):
|
||||
return ActivationOpTag[op]
|
||||
elif isinstance(op, FunctionalOp):
|
||||
return FunctionalOpTag[op]
|
||||
else:
|
||||
raise Exception(f"Unexpected op type {op}. Must be one of ActivationOp or FunctionalOp.")
|
||||
|
||||
|
||||
IteratorAlgorithmNames = {
|
||||
cutlass_bindings.conv.IteratorAlgorithm.analytic: "analytic",
|
||||
cutlass_bindings.conv.IteratorAlgorithm.optimized: "optimized",
|
||||
cutlass_bindings.conv.IteratorAlgorithm.fixed_channels: "fixed_channels",
|
||||
cutlass_bindings.conv.IteratorAlgorithm.few_channels: "few_channels",
|
||||
}
|
||||
class FloatRoundStyle(enum.Enum):
|
||||
ToNearest = enum_auto()
|
||||
ToNearestSatfinite = enum_auto()
|
||||
Indeterminate = enum_auto()
|
||||
TowardZero = enum_auto()
|
||||
TowardInfinity = enum_auto()
|
||||
TowardNegInfinity = enum_auto()
|
||||
HalfUlpTruncDntz = enum_auto()
|
||||
HalfUlpTruncate = enum_auto()
|
||||
|
||||
|
||||
class StrideSupport(enum.Enum):
|
||||
Strided = enum_auto()
|
||||
Unity = enum_auto()
|
||||
|
||||
|
||||
StrideSupportTag = {
|
||||
StrideSupport.Strided: "cutlass::conv::StrideSupport::kStrided",
|
||||
StrideSupport.Unity: "cutlass::conv::StrideSupport::kUnity",
|
||||
}
|
||||
|
||||
|
||||
StrideSupportNames = {
|
||||
StrideSupport.Strided: "",
|
||||
StrideSupport.Unity: "unity_stride",
|
||||
}
|
||||
|
||||
|
||||
class ConvMode(enum.Enum):
|
||||
CrossCorrelation = enum_auto()
|
||||
Convolution = enum_auto()
|
||||
|
||||
|
||||
ConvModeTag = {
|
||||
ConvMode.CrossCorrelation: "cutlass::conv::Mode::kCrossCorrelation",
|
||||
ConvMode.Convolution: "cutlass::conv::Mode::kConvolution",
|
||||
FloatRoundStyleTag = {
|
||||
FloatRoundStyle.ToNearest: "cutlass::FloatRoundStyle::round_to_nearest",
|
||||
FloatRoundStyle.ToNearestSatfinite: "cutlass::FloatRoundStyle::round_to_nearest_satfinite",
|
||||
FloatRoundStyle.Indeterminate: "cutlass::FloatRoundStyle::round_indeterminate",
|
||||
FloatRoundStyle.TowardZero: "cutlass::FloatRoundStyle::round_toward_zero",
|
||||
FloatRoundStyle.TowardInfinity: "cutlass::FloatRoundStyle::round_toward_infinity",
|
||||
FloatRoundStyle.TowardNegInfinity: "cutlass::FloatRoundStyle::round_toward_neg_infinity",
|
||||
FloatRoundStyle.HalfUlpTruncDntz: "cutlass::FloatRoundStyle::round_half_ulp_trunc_dntz",
|
||||
FloatRoundStyle.HalfUlpTruncate: "cutlass::FloatRoundStyle::round_half_ulp_truncate",
|
||||
}
|
||||
|
||||
|
||||
@ -519,7 +226,7 @@ class MathInstruction:
|
||||
element_a,
|
||||
element_b,
|
||||
element_accumulator,
|
||||
opcode_class=cutlass_bindings.OpClass.Simt,
|
||||
opcode_class=OpcodeClass.Simt,
|
||||
math_operation=MathOperation.multiply_add,
|
||||
):
|
||||
"""
|
||||
@ -529,7 +236,7 @@ class MathInstruction:
|
||||
:param element_b: data type of operand B
|
||||
:param element_accumulator: data type used in accumulation
|
||||
:param opcode_class: higher-level class of the instruction (e.g., SIMT or Tensor Core)
|
||||
:type opcode_class: cutlass_bindings.OpClass
|
||||
:type opcode_class: cutlass_library.library.OpcodeClass
|
||||
:param math_operation: the type of low-level operation to be performed (e.g., multiply accumulate)
|
||||
:type math_operation: MathOperation
|
||||
"""
|
||||
@ -556,7 +263,7 @@ class TileDescription:
|
||||
cluster_shape=[1, 1, 1],
|
||||
kernel_schedule: KernelScheduleType = None,
|
||||
epilogue_schedule: EpilogueScheduleType = None,
|
||||
tile_scheduler: TileSchedulerType = None,
|
||||
tile_scheduler: TileSchedulerType = None
|
||||
):
|
||||
"""
|
||||
:param threadblock_shape: shape of a threadblock tyle
|
||||
@ -610,7 +317,7 @@ class TileDescription:
|
||||
else:
|
||||
attrs[key] = getattr(self, key)
|
||||
|
||||
mi = MathInstruction(
|
||||
attrs["math_instruction"] = MathInstruction(
|
||||
attrs["instruction_shape"],
|
||||
self.math_instruction.element_a,
|
||||
self.math_instruction.element_b,
|
||||
@ -619,11 +326,10 @@ class TileDescription:
|
||||
self.math_instruction.math_operation
|
||||
)
|
||||
|
||||
return TileDescription(
|
||||
attrs["threadblock_shape"], attrs["stages"],
|
||||
attrs["warp_count"], mi, attrs["cluster_shape"],
|
||||
attrs["kernel_schedule"], attrs["epilogue_schedule"]
|
||||
)
|
||||
# Remove the instruction shape
|
||||
del attrs["instruction_shape"]
|
||||
|
||||
return TileDescription(**attrs)
|
||||
|
||||
@property
|
||||
def num_threads(self):
|
||||
@ -660,6 +366,15 @@ class TileDescription:
|
||||
|
||||
return name
|
||||
|
||||
def procedural_name_2x(self):
|
||||
"""
|
||||
Returns a name identifying the tile description
|
||||
|
||||
:return: name identifying the tile description
|
||||
:rtype: int
|
||||
"""
|
||||
return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages)
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
Returns a string with containing each of the tile description's values
|
||||
@ -695,8 +410,7 @@ class TileDescription:
|
||||
|
||||
|
||||
class TensorDescription:
|
||||
def __init__(self, element, layout, alignment=1,
|
||||
complex_transform=cutlass_bindings.complex_transform.none):
|
||||
def __init__(self, element, layout, alignment=1, complex_transform=ComplexTransform.none):
|
||||
self.element = element
|
||||
self.layout = layout
|
||||
self.alignment = min(128 // DataTypeSize[self.element], alignment)
|
||||
@ -751,7 +465,7 @@ class ApiVersion(enum.Enum):
|
||||
v3x = enum_auto()
|
||||
|
||||
|
||||
def api_version(arch, opclass, datatype):
|
||||
def api_version(arch, opclass, dtype):
|
||||
"""
|
||||
Returns whether the architecture, opcode class, and datatype in question require using CUTLASS 2.x
|
||||
or 3.x for code emission.
|
||||
@ -759,15 +473,16 @@ def api_version(arch, opclass, datatype):
|
||||
:param arch: compute capability of device on which to run
|
||||
:type arch: int
|
||||
:param opclass: class of the operation being performed
|
||||
:type opclass: cutlass_bindings.OpClass
|
||||
:param datatype: data type to be used in operation (assumes that ElementA and ElementB are the same)
|
||||
:type opclass: cutlass.OpcodeClass
|
||||
:param dtype: data type to be used in operation (assumes that ElementA and ElementB are the same)
|
||||
:type dtype: cutlass.DataType
|
||||
|
||||
:return: API version to be used in code emission
|
||||
:rtype: ApiVersion
|
||||
"""
|
||||
if (arch >= 90 and
|
||||
opclass == cutlass_bindings.OpClass.TensorOp and
|
||||
(datatype != cutlass_bindings.float64)):
|
||||
opclass == OpcodeClass.TensorOp and
|
||||
(dtype != DataType.f64)):
|
||||
return ApiVersion.v3x
|
||||
else:
|
||||
return ApiVersion.v2x
|
||||
|
||||
@ -1,877 +0,0 @@
|
||||
################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import ast
|
||||
import ctypes
|
||||
import inspect
|
||||
import textwrap
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from cuda import cuda, cudart
|
||||
import numpy as np
|
||||
from treelib import Tree
|
||||
|
||||
from cutlass.backend.epilogue import (
|
||||
AccumulatorOp,
|
||||
BinaryOp,
|
||||
ColumnBroadcastOp,
|
||||
ColumnReductionOp,
|
||||
RowBroadcastOp,
|
||||
RowReductionOp,
|
||||
TensorInputOp,
|
||||
TensorOutputOp,
|
||||
UnaryOp,
|
||||
)
|
||||
from cutlass.backend.frontend import NumpyFrontend
|
||||
from cutlass.backend.utils.software import SubstituteTemplate
|
||||
import cutlass.backend as backend
|
||||
|
||||
################################################################################
|
||||
# Type annotation for input arguments
|
||||
################################################################################
|
||||
|
||||
Ttype = TypeVar("Ttype")
|
||||
Dtype = TypeVar("Dtype")
|
||||
|
||||
|
||||
class NDArray(np.ndarray, Generic[Ttype, Dtype]):
|
||||
pass
|
||||
|
||||
|
||||
################################################################################
|
||||
# Operations
|
||||
################################################################################
|
||||
|
||||
operators = {
|
||||
ast.Add: "Add",
|
||||
ast.Div: "Div",
|
||||
ast.Eq: "Equal",
|
||||
ast.Mult: "Mult",
|
||||
}
|
||||
|
||||
|
||||
################################################################################
|
||||
# AST Node abstractions
|
||||
################################################################################
|
||||
class UnaryNode:
|
||||
cnt = 0
|
||||
|
||||
# Concept: this is created by the BinOp Node in python ast
|
||||
def __init__(
|
||||
self,
|
||||
element_accumulator,
|
||||
element_compute,
|
||||
elements_per_access,
|
||||
node,
|
||||
args,
|
||||
) -> None:
|
||||
if isinstance(node, BinOpNode):
|
||||
self.op = node.op
|
||||
elif isinstance(node, ast.Call):
|
||||
if isinstance(node.func, ast.Name):
|
||||
self.op = node.func.id
|
||||
elif isinstance(node.func, ast.Attribute):
|
||||
self.op = node.func.value.id
|
||||
else:
|
||||
raise TypeError
|
||||
else:
|
||||
raise TypeError
|
||||
self.tag = "Unary" + self.op + str(UnaryNode.cnt)
|
||||
self.id = self.op + str(UnaryNode.cnt)
|
||||
self.args = args
|
||||
UnaryNode.cnt += 1
|
||||
|
||||
self.type = "tensor"
|
||||
|
||||
self.epilogue_op = getattr(backend, self.op)(element_compute)
|
||||
|
||||
# data types
|
||||
self.element_accumulator = element_accumulator
|
||||
self.element_compute = element_compute
|
||||
self.elements_per_access = elements_per_access
|
||||
|
||||
def get_epilogue_node(self, visitors):
|
||||
self.epilogue_node = UnaryOp(
|
||||
self.element_accumulator,
|
||||
self.element_compute,
|
||||
self.elements_per_access,
|
||||
*visitors,
|
||||
self.epilogue_op,
|
||||
)
|
||||
|
||||
def get_argument(self, visitor_args, kwargs):
|
||||
epilogue_ops = []
|
||||
for arg in self.args:
|
||||
try:
|
||||
epilogue_ops.append(kwargs[arg])
|
||||
except:
|
||||
epilogue_ops.append(arg) # direct arguments like constant
|
||||
self.argument = self.epilogue_node.argument_type(
|
||||
self.epilogue_op.argument_type(*epilogue_ops),
|
||||
*visitor_args,
|
||||
)
|
||||
|
||||
|
||||
class BinOpNode:
|
||||
cnt = 0
|
||||
|
||||
# Concept: this is created by the BinOp Node in python ast
|
||||
def __init__(
|
||||
self,
|
||||
element_accumulator,
|
||||
element_compute,
|
||||
elements_per_access,
|
||||
node,
|
||||
) -> None:
|
||||
self.op = operators[type(node.op)]
|
||||
self.tag = "Binary" + self.op + str(BinOpNode.cnt)
|
||||
self.id = self.op + str(BinOpNode.cnt)
|
||||
self.args = None
|
||||
BinOpNode.cnt += 1
|
||||
|
||||
self.type = "tensor"
|
||||
|
||||
self.epilogue_op = getattr(backend, "Vector" + self.op)(element_compute)
|
||||
|
||||
# data types
|
||||
self.element_accumulator = element_accumulator
|
||||
self.element_compute = element_compute
|
||||
self.elements_per_access = elements_per_access
|
||||
|
||||
def get_epilogue_node(self, visitors):
|
||||
self.epilogue_node = BinaryOp(
|
||||
self.element_accumulator,
|
||||
self.element_compute,
|
||||
self.elements_per_access,
|
||||
*visitors,
|
||||
self.epilogue_op,
|
||||
)
|
||||
|
||||
def get_argument(self, visitor_args, kwargs):
|
||||
self.argument = self.epilogue_node.argument_type(
|
||||
self.epilogue_op.argument_type(self.args),
|
||||
*visitor_args,
|
||||
)
|
||||
|
||||
|
||||
class NameNode:
|
||||
# Concept: this is created by the Name Node in python ast
|
||||
def __init__(self, node) -> None:
|
||||
try:
|
||||
self.id = node.id
|
||||
except:
|
||||
self.id = node.targets[0].id
|
||||
self.tag = self.id
|
||||
|
||||
|
||||
class ScalarInputNode(NameNode):
|
||||
# Concept: scalar
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.tag = "Scalar:" + self.tag
|
||||
self.type = "scalar"
|
||||
|
||||
|
||||
class AccumulatorNode(NameNode):
|
||||
# Concept: VisitorOpAccumulator
|
||||
def __init__(
|
||||
self,
|
||||
element_accumulator,
|
||||
elements_per_access,
|
||||
node,
|
||||
) -> None:
|
||||
super().__init__(node)
|
||||
self.tag = "Accum:" + self.tag
|
||||
self.type = "tensor"
|
||||
|
||||
self.element_accumulator = element_accumulator
|
||||
self.elements_per_access = elements_per_access
|
||||
|
||||
def get_epilogue_node(self, visitors):
|
||||
self.epilogue_node = AccumulatorOp(
|
||||
self.element_accumulator,
|
||||
self.elements_per_access,
|
||||
)
|
||||
|
||||
def get_argument(self, visitor_args, kwargs):
|
||||
self.argument = self.epilogue_node.argument_type()
|
||||
|
||||
|
||||
class TensorInputNode(NameNode):
|
||||
# Concept: VisitorOpTensorInput
|
||||
def __init__(self, element_accumulator, node) -> None:
|
||||
super().__init__(node)
|
||||
self.tag = "TensorInput:" + self.tag
|
||||
self.type = "tensor"
|
||||
self.element_accumulator = element_accumulator
|
||||
|
||||
def get_epilogue_node(self, *args):
|
||||
self.epilogue_node = TensorInputOp(self.element_accumulator)
|
||||
|
||||
def get_argument(self, visitor_args, kwargs):
|
||||
self.argument = self.epilogue_node.argument_type(
|
||||
kwargs[self.id + "_ptr"],
|
||||
kwargs["problem_size"][1],
|
||||
kwargs["problem_size"][0] * kwargs["problem_size"][1],
|
||||
)
|
||||
|
||||
|
||||
class RowBroadcastNode(NameNode):
|
||||
# Concept: VisitorOpRowBroadcast
|
||||
def __init__(
|
||||
self,
|
||||
element_accumulator,
|
||||
element_fragment,
|
||||
node,
|
||||
) -> None:
|
||||
super().__init__(node)
|
||||
#
|
||||
self.tag = "RowBroadcast:" + self.tag
|
||||
self.type = "tensor"
|
||||
self.element_accumulator = element_accumulator
|
||||
self.element_fragment = element_fragment
|
||||
|
||||
def get_epilogue_node(self, *args):
|
||||
self.epilogue_node = RowBroadcastOp(
|
||||
self.element_accumulator,
|
||||
self.element_fragment,
|
||||
)
|
||||
|
||||
def get_argument(self, visitor_args, kwargs):
|
||||
self.argument = self.epilogue_node.argument_type(
|
||||
kwargs[self.id + "_ptr"],
|
||||
kwargs["problem_size"][1],
|
||||
)
|
||||
|
||||
|
||||
class ColumnBroadcastNode(NameNode):
|
||||
# Concept: VisitorOpColumnBroadcast
|
||||
def __init__(
|
||||
self,
|
||||
element_accumulator,
|
||||
element_fragment,
|
||||
node,
|
||||
) -> None:
|
||||
super().__init__(node)
|
||||
self.tag = "ColumnBroadcast:" + self.tag
|
||||
self.type = "tensor"
|
||||
self.element_accumulator = element_accumulator
|
||||
self.element_fragment = element_fragment
|
||||
|
||||
def get_epilogue_node(self, *args):
|
||||
self.epilogue_node = ColumnBroadcastOp(
|
||||
self.element_accumulator,
|
||||
self.element_fragment,
|
||||
)
|
||||
|
||||
def get_argument(self, visitor_args, kwargs):
|
||||
self.argument = self.epilogue_node.argument_type(
|
||||
kwargs[self.id + "_ptr"],
|
||||
kwargs["problem_size"][0],
|
||||
)
|
||||
|
||||
|
||||
class TensorOutputNode(NameNode):
|
||||
# Concept: VisitorOpTensorOutput
|
||||
def __init__(self, element_accumulator, node) -> None:
|
||||
super().__init__(node)
|
||||
self.tag = "TensorOutput:" + self.tag
|
||||
self.type = "tensor"
|
||||
self.element_accumulator = element_accumulator
|
||||
|
||||
def get_epilogue_node(self, visitors):
|
||||
self.epilogue_node = TensorOutputOp(self.element_accumulator, *visitors)
|
||||
|
||||
def get_argument(self, visitor_args, kwargs):
|
||||
self.argument = self.epilogue_node.argument_type(
|
||||
kwargs[self.id + "_ptr"],
|
||||
kwargs["problem_size"][1],
|
||||
*visitor_args,
|
||||
kwargs["problem_size"][0] * kwargs["problem_size"][1],
|
||||
)
|
||||
|
||||
|
||||
class RowReductionNode:
|
||||
# Concept: RowReductionOp
|
||||
def __init__(
|
||||
self,
|
||||
element_accumulator,
|
||||
element_reduction,
|
||||
element_reduction_accumulator,
|
||||
id,
|
||||
factor,
|
||||
) -> None:
|
||||
#
|
||||
self.id = id
|
||||
self.tag = "RowReduction:" + self.id
|
||||
self.type = "tensor"
|
||||
self.element_accumulator = element_accumulator
|
||||
self.element_reduction = element_reduction
|
||||
self.element_reduction_accumulator = element_reduction_accumulator
|
||||
self.factor = factor
|
||||
|
||||
def get_epilogue_node(self, visitors):
|
||||
self.epilogue_node = RowReductionOp(
|
||||
self.element_accumulator,
|
||||
self.element_reduction,
|
||||
self.element_reduction_accumulator,
|
||||
*visitors,
|
||||
)
|
||||
|
||||
def get_batch_stride(self, problem_size):
|
||||
return problem_size[0] * ((problem_size[1] + self.factor - 1) // self.factor)
|
||||
|
||||
def get_argument(self, visitor_args, kwargs):
|
||||
self.argument = self.epilogue_node.argument_type(
|
||||
kwargs[self.id + "_ptr"],
|
||||
*visitor_args,
|
||||
self.get_batch_stride(kwargs["problem_size"]),
|
||||
)
|
||||
|
||||
|
||||
class ColumnReductionNode:
|
||||
# Concept: ColumnReductionOp
|
||||
def __init__(
|
||||
self,
|
||||
element_accumulator,
|
||||
element_reduction,
|
||||
element_reduction_accumulator,
|
||||
id,
|
||||
factor,
|
||||
) -> None:
|
||||
#
|
||||
self.id = id
|
||||
self.tag = "ColumnReduction:" + self.id
|
||||
self.type = "tensor"
|
||||
self.element_accumulator = element_accumulator
|
||||
self.element_reduction = element_reduction
|
||||
self.element_reduction_accumulator = element_reduction_accumulator
|
||||
self.factor = factor
|
||||
|
||||
def get_epilogue_node(self, visitors):
|
||||
self.epilogue_node = ColumnReductionOp(
|
||||
self.element_accumulator,
|
||||
self.element_reduction,
|
||||
self.element_reduction_accumulator,
|
||||
*visitors,
|
||||
)
|
||||
|
||||
def get_batch_stride(self, problem_size):
|
||||
return problem_size[1] * ((problem_size[0] + self.factor - 1) // self.factor)
|
||||
|
||||
def get_argument(self, visitor_args, kwargs):
|
||||
self.argument = self.epilogue_node.argument_type(
|
||||
kwargs[self.id + "_ptr"],
|
||||
*visitor_args,
|
||||
self.get_batch_stride(kwargs["problem_size"]),
|
||||
)
|
||||
|
||||
|
||||
################################################################################
|
||||
# Epilogue parser function
|
||||
################################################################################
|
||||
class EpilogueAST(ast.NodeVisitor):
|
||||
def __init__(
|
||||
self,
|
||||
epilogue,
|
||||
tile_description,
|
||||
element_accumulator,
|
||||
elements_per_access,
|
||||
element_compute,
|
||||
element_output,
|
||||
) -> None:
|
||||
#
|
||||
|
||||
self.tile_description = tile_description
|
||||
self.element_accumulator = element_accumulator
|
||||
self.elements_per_access = elements_per_access
|
||||
self.element_compute = element_compute
|
||||
self.element_output = element_output
|
||||
self.epilogue = epilogue
|
||||
|
||||
self.source = textwrap.dedent(inspect.getsource(epilogue.__call__))
|
||||
self.ast_tree = ast.parse(self.source)
|
||||
self.epilogue_tree = Tree()
|
||||
|
||||
# print(ast.dump(self.ast_tree, indent=4)) # For Debug purpose
|
||||
|
||||
# input arguments
|
||||
self.input_args = {}
|
||||
# return nodes
|
||||
self.returns = []
|
||||
# reduction source nodes
|
||||
self.reduction_source = {}
|
||||
|
||||
# stack used to keep the parent node id
|
||||
self.stack = []
|
||||
|
||||
# visit the AST
|
||||
self.visit(self.ast_tree)
|
||||
|
||||
# visit the name node
|
||||
def visit_Name(self, node):
|
||||
# append the return ids into self.returns
|
||||
if self.stack[-1] == "return":
|
||||
self.returns.append(node.id)
|
||||
else:
|
||||
# accum is produced from accumulator node
|
||||
if node.id == "accum":
|
||||
name_node = AccumulatorNode(
|
||||
self.element_accumulator,
|
||||
self.elements_per_access,
|
||||
node,
|
||||
)
|
||||
else:
|
||||
# for input nodes
|
||||
if node.id in self.input_args.keys():
|
||||
type = self.input_args[node.id][0]
|
||||
if type == "tensor":
|
||||
name_node = TensorInputNode(
|
||||
self.element_accumulator,
|
||||
node,
|
||||
)
|
||||
elif type == "row":
|
||||
name_node = RowBroadcastNode(
|
||||
self.element_accumulator,
|
||||
self.element_compute,
|
||||
node,
|
||||
)
|
||||
elif type == "column":
|
||||
name_node = ColumnBroadcastNode(
|
||||
self.element_accumulator,
|
||||
self.element_compute,
|
||||
node,
|
||||
)
|
||||
elif type == "scalar":
|
||||
name_node = ScalarInputNode(node)
|
||||
else:
|
||||
raise ValueError(type)
|
||||
# for output nodes
|
||||
else:
|
||||
name_node = TensorOutputNode(
|
||||
self.element_accumulator,
|
||||
node,
|
||||
)
|
||||
self.epilogue_tree.create_node(
|
||||
name_node.tag,
|
||||
name_node.id,
|
||||
data=name_node,
|
||||
parent=self.stack[-1],
|
||||
)
|
||||
|
||||
def visit_Assign(self, node):
|
||||
pre_assign_node = self.epilogue_tree.get_node(node.targets[0].id)
|
||||
if pre_assign_node is None:
|
||||
# The assign is to a root node
|
||||
# skip the reduction nodes
|
||||
if isinstance(node.value, ast.Call):
|
||||
if isinstance(node.value.func, ast.Name):
|
||||
func_type = node.value.func.id
|
||||
elif isinstance(node.value.func, ast.Attribute):
|
||||
func_type = node.value.func.value.id
|
||||
else:
|
||||
raise TypeError
|
||||
if func_type == "reduction_op":
|
||||
self.reduction_source[node.value.args[0].id] = [
|
||||
node.value.args[1].value,
|
||||
node.value.args[2].value,
|
||||
node.targets[0].id,
|
||||
]
|
||||
return
|
||||
name_node = TensorOutputNode(self.element_accumulator, node)
|
||||
self.epilogue_tree.create_node(
|
||||
name_node.tag,
|
||||
name_node.id,
|
||||
data=name_node,
|
||||
)
|
||||
self.stack.append(name_node.id)
|
||||
else:
|
||||
if (
|
||||
node.targets[0].id in self.returns
|
||||
or node.targets[0].id in self.reduction_source.keys()
|
||||
):
|
||||
self.stack.append(node.targets[0].id)
|
||||
else:
|
||||
self.stack.append(
|
||||
pre_assign_node.predecessor(self.epilogue_tree.identifier)
|
||||
)
|
||||
self.epilogue_tree.remove_node(node.targets[0].id)
|
||||
|
||||
# get child tag
|
||||
self.visit(node.value)
|
||||
self.stack.pop()
|
||||
|
||||
def visit_Call(self, node):
|
||||
if isinstance(node.func, ast.Name):
|
||||
func_type = node.func.id
|
||||
elif isinstance(node.func, ast.Attribute):
|
||||
func_type = node.func.value.id
|
||||
else:
|
||||
raise TypeError
|
||||
if func_type == "reduction_op":
|
||||
self.visit(node.args[0])
|
||||
else:
|
||||
arg_list = []
|
||||
for idx, arg in enumerate(node.args):
|
||||
if idx == 0:
|
||||
continue
|
||||
if isinstance(arg, ast.Constant):
|
||||
arg_list.append(arg.value)
|
||||
elif isinstance(arg, ast.Name):
|
||||
arg_list.append(arg.id)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
unary_node = UnaryNode(
|
||||
self.element_accumulator,
|
||||
self.element_compute,
|
||||
self.elements_per_access,
|
||||
node,
|
||||
arg_list,
|
||||
)
|
||||
self.epilogue_tree.create_node(
|
||||
unary_node.tag,
|
||||
unary_node.id,
|
||||
parent=self.stack[-1],
|
||||
data=unary_node,
|
||||
)
|
||||
self.stack.append(unary_node.id)
|
||||
self.visit(node.args[0])
|
||||
self.stack.pop()
|
||||
|
||||
def visit_BinOp(self, node):
|
||||
binop = BinOpNode(
|
||||
self.element_accumulator,
|
||||
self.element_compute,
|
||||
self.elements_per_access,
|
||||
node,
|
||||
)
|
||||
self.epilogue_tree.create_node(
|
||||
binop.tag,
|
||||
binop.id,
|
||||
data=binop,
|
||||
parent=self.stack[-1],
|
||||
)
|
||||
self.stack.append(binop.id)
|
||||
self.visit(node.left)
|
||||
self.visit(node.right)
|
||||
self.stack.pop()
|
||||
|
||||
def visit_Return(self, node):
|
||||
self.stack.append("return")
|
||||
self.visit(node.value)
|
||||
self.stack.pop()
|
||||
|
||||
# # A function definition
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef):
|
||||
# visit args
|
||||
for arg in node.args.args:
|
||||
if arg.arg == "self":
|
||||
continue
|
||||
if isinstance(arg.annotation, ast.Constant):
|
||||
self.input_args[arg.arg] = [
|
||||
arg.annotation.value,
|
||||
]
|
||||
# visit the assign in the reverse order
|
||||
for idx in range(len(node.body)):
|
||||
self.visit(node.body[-1 - idx])
|
||||
|
||||
#
|
||||
# Tree optimization pass
|
||||
#
|
||||
|
||||
# pass 1: lower Binary to Unary
|
||||
def pass_binary_2_unary(self, tree, nid):
|
||||
node = tree.get_node(nid)
|
||||
if isinstance(node.data, BinOpNode):
|
||||
lhs_node = tree.get_node(node.successors(tree.identifier)[0])
|
||||
left_type = lhs_node.data.type
|
||||
rhs_node = tree.get_node(node.successors(tree.identifier)[1])
|
||||
right_type = rhs_node.data.type
|
||||
|
||||
if left_type == "scalar" and right_type == "tensor":
|
||||
node.data = UnaryNode(
|
||||
self.element_accumulator,
|
||||
self.element_compute,
|
||||
self.elements_per_access,
|
||||
node.data,
|
||||
[
|
||||
lhs_node.data.id,
|
||||
],
|
||||
)
|
||||
node.tag = node.data.tag
|
||||
tree.remove_node(lhs_node.data.id)
|
||||
self.pass_binary_2_unary(tree, rhs_node.data.id)
|
||||
|
||||
elif left_type == "tensor" and right_type == "scalar":
|
||||
node.data = UnaryNode(
|
||||
self.element_accumulator,
|
||||
self.element_compute,
|
||||
self.elements_per_access,
|
||||
node.data,
|
||||
[
|
||||
rhs_node.id,
|
||||
],
|
||||
)
|
||||
node.tag = node.data.tag
|
||||
tree.remove_node(rhs_node.data.id)
|
||||
self.pass_binary_2_unary(tree, lhs_node.data.id)
|
||||
|
||||
else:
|
||||
self.pass_binary_2_unary(tree, lhs_node.data.id)
|
||||
self.pass_binary_2_unary(tree, rhs_node.data.id)
|
||||
else:
|
||||
for child in node.successors(tree.identifier):
|
||||
self.pass_binary_2_unary(tree, child)
|
||||
|
||||
# pass 2: inject reduction nodes
|
||||
def pass_inject_reduction(self, tree, nid):
|
||||
node = tree.get_node(nid)
|
||||
if isinstance(node.data, TensorOutputNode):
|
||||
if node.data.id in self.reduction_source.keys():
|
||||
direction = self.reduction_source[node.data.id][0]
|
||||
target = self.reduction_source[node.data.id][-1]
|
||||
if direction == "row":
|
||||
reduction_node = RowReductionNode(
|
||||
self.element_accumulator,
|
||||
self.element_output,
|
||||
self.element_accumulator,
|
||||
target,
|
||||
self.tile_description.threadblock_shape[1],
|
||||
)
|
||||
elif direction == "column":
|
||||
reduction_node = ColumnReductionNode(
|
||||
self.element_accumulator,
|
||||
self.element_output,
|
||||
self.element_accumulator,
|
||||
target,
|
||||
self.tile_description.threadblock_shape[0],
|
||||
)
|
||||
else:
|
||||
raise ValueError(direction)
|
||||
child_nid = node.successors(tree.identifier)[0]
|
||||
# if this output node is injected only for reduction
|
||||
if node.data.id not in self.returns:
|
||||
# get reduction config from disc
|
||||
node.data = reduction_node
|
||||
node.tag = reduction_node.tag
|
||||
self.pass_inject_reduction(tree, child_nid)
|
||||
# if this output node is also a tensor output, inject reduction as its children
|
||||
else:
|
||||
# get child node
|
||||
tree.create_node(
|
||||
reduction_node.tag,
|
||||
reduction_node.id,
|
||||
data=reduction_node,
|
||||
parent=node.data.id,
|
||||
)
|
||||
tree.move_node(
|
||||
child_nid,
|
||||
reduction_node.id,
|
||||
)
|
||||
child = tree.get_node(child_nid)
|
||||
for grand_child in child.successors(tree.identifier):
|
||||
self.pass_inject_reduction(tree, grand_child)
|
||||
else:
|
||||
for child in node.successors(tree.identifier):
|
||||
self.pass_inject_reduction(tree, child)
|
||||
else:
|
||||
for child in node.successors(tree.identifier):
|
||||
self.pass_inject_reduction(tree, child)
|
||||
|
||||
def pass_inject_epilogue_op(self, tree, nid):
|
||||
node = tree.get_node(nid)
|
||||
visitors = []
|
||||
for child in node.successors(tree.identifier):
|
||||
visitors.append(self.pass_inject_epilogue_op(tree, child))
|
||||
|
||||
node.data.get_epilogue_node(visitors)
|
||||
return node.data.epilogue_node
|
||||
|
||||
def get_arguments(self, tree, nid, kwargs):
|
||||
node = tree.get_node(nid)
|
||||
visitor_args = []
|
||||
for child in node.successors(tree.identifier):
|
||||
visitor_args.append(self.get_arguments(tree, child, kwargs))
|
||||
|
||||
node.data.get_argument(visitor_args, kwargs)
|
||||
return node.data.argument
|
||||
|
||||
|
||||
class EpilogueVisitTree:
|
||||
KernelTemplate = """
|
||||
${visitor}
|
||||
|
||||
using ${operation_name}_EpilogueVisitor = cutlass::epilogue::threadblock::EpilogueVisitorGeneric<${visitor_name}>;
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
elementwise_functor,
|
||||
tile_description,
|
||||
element_accumulator,
|
||||
elements_per_access,
|
||||
element_compute,
|
||||
element_output,
|
||||
) -> None:
|
||||
#
|
||||
# data types
|
||||
self.tile_description = tile_description
|
||||
self.element_accumulator = element_accumulator
|
||||
self.elements_per_access = elements_per_access
|
||||
self.element_compute = element_compute
|
||||
self.element_output = element_output
|
||||
self.elementwise_functor = elementwise_functor
|
||||
pass
|
||||
|
||||
def initialize(self):
|
||||
function = EpilogueAST(
|
||||
self,
|
||||
self.tile_description,
|
||||
self.element_accumulator,
|
||||
self.elements_per_access,
|
||||
self.element_compute,
|
||||
self.element_output,
|
||||
)
|
||||
#
|
||||
tree = function.epilogue_tree
|
||||
self.tree = tree
|
||||
function.pass_binary_2_unary(self.tree, self.tree.root)
|
||||
function.pass_inject_reduction(self.tree, self.tree.root)
|
||||
function.pass_inject_epilogue_op(self.tree, self.tree.root)
|
||||
|
||||
visitor = self.tree.get_node(self.tree.root).data.epilogue_node
|
||||
self.visitor = visitor
|
||||
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
(
|
||||
"visitor_arg",
|
||||
visitor.argument_type,
|
||||
)
|
||||
]
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
# process input args
|
||||
_kwargs = {}
|
||||
for input_key in function.input_args.keys():
|
||||
if input_key == "accum":
|
||||
continue
|
||||
if function.input_args[input_key][0] == "scalar":
|
||||
continue
|
||||
# tensor input
|
||||
else:
|
||||
setattr(
|
||||
self,
|
||||
"buffer_tensor_" + input_key,
|
||||
NumpyFrontend.argument(
|
||||
kwargs[input_key],
|
||||
False,
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
self,
|
||||
input_key + "_ptr",
|
||||
int(
|
||||
getattr(
|
||||
self,
|
||||
"buffer_tensor_" + input_key,
|
||||
).ptr
|
||||
),
|
||||
)
|
||||
_kwargs[input_key + "_ptr"] = getattr(
|
||||
self,
|
||||
input_key + "_ptr",
|
||||
)
|
||||
# process the return args
|
||||
for ret in function.returns:
|
||||
setattr(
|
||||
self,
|
||||
"buffer_tensor_" + ret,
|
||||
NumpyFrontend.argument(kwargs[ret], True),
|
||||
)
|
||||
setattr(
|
||||
self,
|
||||
ret + "_ptr",
|
||||
int(
|
||||
getattr(
|
||||
self,
|
||||
"buffer_tensor_" + ret,
|
||||
).ptr
|
||||
),
|
||||
)
|
||||
_kwargs[ret + "_ptr"] = getattr(self, ret + "_ptr")
|
||||
setattr(
|
||||
self,
|
||||
"host_tensor_" + ret,
|
||||
kwargs[ret],
|
||||
)
|
||||
|
||||
_kwargs.update(kwargs)
|
||||
function.get_arguments(tree, tree.root, _kwargs)
|
||||
self.visitor_arg = tree.get_node(tree.root).data.argument
|
||||
|
||||
def sync(self, stream_sync=True):
|
||||
if stream_sync:
|
||||
(err,) = cudart.cudaDeviceSynchronize()
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
|
||||
for ret in function.returns:
|
||||
(err,) = cuda.cuMemcpyDtoH(
|
||||
getattr(
|
||||
self,
|
||||
"host_tensor_" + ret,
|
||||
),
|
||||
cuda.CUdeviceptr(getattr(self, ret + "_ptr")),
|
||||
getattr(
|
||||
self,
|
||||
"host_tensor_" + ret,
|
||||
).size
|
||||
* getattr(
|
||||
self,
|
||||
"host_tensor_" + ret,
|
||||
).itemsize,
|
||||
)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
pass
|
||||
|
||||
self.epilogue_type = _Argument
|
||||
|
||||
def emit(self, operation):
|
||||
values = {
|
||||
"visitor": self.visitor.emit(operation),
|
||||
"operation_name": operation.procedural_name(),
|
||||
"visitor_name": self.visitor.instance_name,
|
||||
}
|
||||
return SubstituteTemplate(self.KernelTemplate, values)
|
||||
@ -30,24 +30,24 @@
|
||||
#
|
||||
################################################################################
|
||||
|
||||
|
||||
import ctypes
|
||||
from typing import Union
|
||||
|
||||
import ctypes
|
||||
from cuda import cuda, cudart
|
||||
import cutlass_bindings
|
||||
import numpy as np
|
||||
|
||||
from cutlass.backend.c_types import MatrixCoord_, TensorRef2D_, get_reduction_params
|
||||
from cutlass.backend.frontend import NumpyFrontend, TorchFrontend
|
||||
from cutlass.backend.library import (
|
||||
from cutlass import (
|
||||
DataTypeNames,
|
||||
DataTypeSize,
|
||||
DataTypeTag,
|
||||
TensorDescription,
|
||||
LayoutType
|
||||
)
|
||||
from cutlass.backend.c_types import MatrixCoord_, TensorRef2D_, get_reduction_params
|
||||
from cutlass.backend.frontend import NumpyFrontend, TorchFrontend
|
||||
from cutlass.backend.library import TensorDescription
|
||||
from cutlass.backend.operation import ExecutableOperation, LaunchConfiguration
|
||||
from cutlass.backend.utils.software import CheckPackages, SubstituteTemplate
|
||||
from cutlass.shape import MatrixCoord
|
||||
|
||||
if CheckPackages().check_torch():
|
||||
import torch
|
||||
@ -80,10 +80,9 @@ class ReductionArguments:
|
||||
self.bias = False
|
||||
|
||||
self.operation = operation
|
||||
#: pointer to the workspace
|
||||
self.ptr_workspace = workspace
|
||||
|
||||
#: number of split-k partitions
|
||||
# number of split-k partitions
|
||||
self.partitions = partitions
|
||||
|
||||
if isinstance(destination, np.ndarray):
|
||||
@ -112,19 +111,18 @@ class ReductionArguments:
|
||||
else:
|
||||
self.output_op = self.operation.epilogue_type(1.0, 0.0)
|
||||
|
||||
# get arguments
|
||||
self.get_arguments()
|
||||
|
||||
@staticmethod
|
||||
def get_tensor_ref(
|
||||
extent: "tuple[int]",
|
||||
device_ptr: cuda.CUdeviceptr,
|
||||
layout: cutlass_bindings.layout,
|
||||
layout: LayoutType,
|
||||
):
|
||||
if layout == cutlass_bindings.RowMajor:
|
||||
if layout == LayoutType.RowMajor:
|
||||
return TensorRef2D_(int(device_ptr), extent[1])
|
||||
else:
|
||||
raise ValueError("unknown layout type")
|
||||
raise ValueError(f"Unknown layout type {layout}")
|
||||
|
||||
def get_arguments(self):
|
||||
ref_workspace = ReductionArguments.get_tensor_ref(
|
||||
@ -133,13 +131,13 @@ class ReductionArguments:
|
||||
self.problem_size.column,
|
||||
],
|
||||
device_ptr=self.ptr_workspace,
|
||||
layout=cutlass_bindings.RowMajor,
|
||||
layout=LayoutType.RowMajor,
|
||||
)
|
||||
if self.bias:
|
||||
ref_source = ReductionArguments.get_tensor_ref(
|
||||
extent=[0, 0],
|
||||
device_ptr=self.ptr_source,
|
||||
layout=cutlass_bindings.RowMajor,
|
||||
layout=LayoutType.RowMajor,
|
||||
)
|
||||
else:
|
||||
ref_source = ReductionArguments.get_tensor_ref(
|
||||
@ -148,7 +146,7 @@ class ReductionArguments:
|
||||
self.problem_size.column,
|
||||
],
|
||||
device_ptr=self.ptr_source,
|
||||
layout=cutlass_bindings.RowMajor,
|
||||
layout=LayoutType.RowMajor,
|
||||
)
|
||||
|
||||
ref_destination = ReductionArguments.get_tensor_ref(
|
||||
@ -157,7 +155,7 @@ class ReductionArguments:
|
||||
self.problem_size.column,
|
||||
],
|
||||
device_ptr=self.ptr_destination,
|
||||
layout=cutlass_bindings.RowMajor,
|
||||
layout=LayoutType.RowMajor,
|
||||
)
|
||||
|
||||
self.c_arguments = self.operation.argument_type(
|
||||
@ -176,7 +174,7 @@ class ReductionArguments:
|
||||
def sync(self):
|
||||
(err,) = cudart.cudaDeviceSynchronize()
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
raise RuntimeError(f"CUDA Error {str(err)}")
|
||||
|
||||
if hasattr(self, "host_D"):
|
||||
(err,) = cuda.cuMemcpyDtoH(
|
||||
@ -258,15 +256,15 @@ extern "C" {
|
||||
|
||||
def plan(self, arguments: ReductionArguments):
|
||||
block_shape = [
|
||||
self.operation.shape.column() // self.elements_per_access,
|
||||
self.operation.shape.row(),
|
||||
self.operation.shape.column // self.elements_per_access,
|
||||
self.operation.shape.row,
|
||||
1,
|
||||
]
|
||||
grid_shape = [
|
||||
(arguments.problem_size.row + self.operation.shape.row() - 1)
|
||||
// self.operation.shape.row(),
|
||||
(arguments.problem_size.column + self.operation.shape.column() - 1)
|
||||
// self.operation.shape.column(),
|
||||
(arguments.problem_size.row + self.operation.shape.row - 1)
|
||||
// self.operation.shape.row,
|
||||
(arguments.problem_size.column + self.operation.shape.column - 1)
|
||||
// self.operation.shape.column,
|
||||
1,
|
||||
]
|
||||
return LaunchConfiguration(
|
||||
@ -282,20 +280,17 @@ extern "C" {
|
||||
value=self.shared_memory_capacity,
|
||||
)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("Cuda Error: {}".format(err))
|
||||
raise RuntimeError(f"CUDA Error: {err}")
|
||||
|
||||
|
||||
class ReductionOperation:
|
||||
"""
|
||||
CUTLASS Reduction Operation
|
||||
shape: shape of CTA
|
||||
outputop: output operator
|
||||
r
|
||||
CUTLASS reduction Operation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shape: cutlass_bindings.MatrixCoord,
|
||||
shape: MatrixCoord,
|
||||
C: TensorDescription,
|
||||
element_accumulator,
|
||||
element_workspace=None,
|
||||
@ -304,45 +299,33 @@ class ReductionOperation:
|
||||
count: int = 1,
|
||||
partitions_per_stage: int = 4,
|
||||
) -> None:
|
||||
"""Constructor"""
|
||||
|
||||
self.shape = shape
|
||||
#: epilogue functor (default: LinearCombination)
|
||||
self.epilogue_functor = epilogue_functor
|
||||
#: datatype of accumulator
|
||||
self.element_accumulator = element_accumulator
|
||||
|
||||
if element_workspace is None:
|
||||
#: datatype of workspace
|
||||
self.element_workspace = element_accumulator
|
||||
else:
|
||||
#: datatype of workspace
|
||||
self.element_workspace = element_workspace
|
||||
|
||||
if element_compute is None:
|
||||
#: datatype of workspace
|
||||
self.element_compute = element_accumulator
|
||||
else:
|
||||
#: datatype of workspace
|
||||
self.element_compute = element_compute
|
||||
|
||||
#: datatype of output
|
||||
self.element_output = C.element
|
||||
|
||||
#: operand C
|
||||
self.C: TensorDescription = C
|
||||
|
||||
#: reduce op processing size
|
||||
# Reduce op processing size
|
||||
self.count: int = count
|
||||
|
||||
#: number of partitions to reduce per stage
|
||||
# Number of partitions to reduce per stage
|
||||
self.partitions_per_stage: int = partitions_per_stage
|
||||
|
||||
self.rt_module: ReductionRT = ReductionRT(self)
|
||||
self.argument_type = self.rt_module.argument_type
|
||||
self.epilogue_type = self.rt_module.epilogue_type
|
||||
|
||||
#
|
||||
def extended_name(self):
|
||||
extend_name = "${element_workspace}_${element_accumulator}_${element_compute}_${element_output}"
|
||||
|
||||
@ -356,15 +339,14 @@ class ReductionOperation:
|
||||
},
|
||||
)
|
||||
|
||||
#
|
||||
def configuration_name(self):
|
||||
"""The full procedural name indicates architecture, extended name, tile size"""
|
||||
|
||||
configuration_name = "cutlass_reduce_split_k_${extended_name}_${threadblock}"
|
||||
|
||||
threadblock = "%dx%d" % (
|
||||
self.shape.row(),
|
||||
self.shape.column(),
|
||||
self.shape.row,
|
||||
self.shape.column,
|
||||
)
|
||||
|
||||
return SubstituteTemplate(
|
||||
@ -375,7 +357,6 @@ class ReductionOperation:
|
||||
},
|
||||
)
|
||||
|
||||
#
|
||||
def procedural_name(self):
|
||||
"""The full procedural name indicates architeture, extended name, tile size"""
|
||||
return self.configuration_name()
|
||||
@ -384,14 +365,11 @@ class ReductionOperation:
|
||||
"""
|
||||
Configure and launch the cuda kernel with input arguments
|
||||
"""
|
||||
# get launch configuration
|
||||
launch_config = self.rt_module.plan(arguments)
|
||||
|
||||
# get the host and device workspace
|
||||
host_workspace = arguments.host_workspace
|
||||
device_workspace = None
|
||||
|
||||
# launch the kernel
|
||||
err = self.rt_module.run(
|
||||
host_workspace,
|
||||
device_workspace,
|
||||
@ -399,7 +377,7 @@ class ReductionOperation:
|
||||
)
|
||||
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
raise RuntimeError(f"CUDA Error {str(err)}")
|
||||
|
||||
return err
|
||||
|
||||
@ -421,7 +399,7 @@ class EmitReductionInstance:
|
||||
]
|
||||
self.template = """
|
||||
// Reduction kernel instance
|
||||
using ${operation_name}_base =
|
||||
using ${operation_name}_base =
|
||||
typename cutlass::reduction::kernel::ReduceSplitK<
|
||||
cutlass::MatrixShape<${shape_row}, ${shape_column}>,
|
||||
${epilogue_functor},
|
||||
@ -436,19 +414,14 @@ struct ${operation_name}${operation_suffix}:
|
||||
"""
|
||||
|
||||
def emit(self, operation: ReductionOperation):
|
||||
epilogue_vector_length = int(
|
||||
min(
|
||||
operation.C.alignment * DataTypeSize[operation.C.element],
|
||||
128,
|
||||
)
|
||||
/ DataTypeSize[operation.C.element]
|
||||
)
|
||||
vector_length_bits = min(operation.C.alignment * DataTypeSize[operation.C.element], 128)
|
||||
epilogue_vector_length = vector_length_bits // DataTypeSize[operation.C.element]
|
||||
|
||||
values = {
|
||||
"operation_name": operation.configuration_name(),
|
||||
"operation_suffix": self.operation_suffix,
|
||||
"shape_row": str(operation.shape.row()),
|
||||
"shape_column": str(operation.shape.column()),
|
||||
"shape_row": str(operation.shape.row),
|
||||
"shape_column": str(operation.shape.column),
|
||||
"epilogue_functor": operation.epilogue_functor.emit(),
|
||||
"element_output": DataTypeTag[operation.element_output],
|
||||
"epilogue_vector_length": str(epilogue_vector_length),
|
||||
|
||||
@ -1,807 +0,0 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
import re
|
||||
import subprocess
|
||||
from time import sleep
|
||||
|
||||
from bfloat16 import bfloat16
|
||||
import cutlass_bindings
|
||||
import numpy as np
|
||||
|
||||
from cutlass.backend import compiler
|
||||
from cutlass.backend.conv2d_operation import Conv2dArguments, Conv2dOperation
|
||||
from cutlass.backend.library import DataTypeSize, ShortDataTypeNames, StrideSupport
|
||||
from cutlass.backend.memory_manager import get_allocated_size
|
||||
from cutlass.backend.reduction_operation import ReductionArguments, ReductionOperation
|
||||
from cutlass.backend.test.profiler import GpuTimer
|
||||
from cutlass.backend.utils.software import SubstituteTemplate
|
||||
|
||||
|
||||
def getTensorRef(tensor, tensor_layout, conv_kind, problem_size, operand):
|
||||
ptr = tensor.__array_interface__["data"][0]
|
||||
if operand == "a":
|
||||
tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_a_extent(
|
||||
conv_kind, problem_size
|
||||
)
|
||||
elif operand == "b":
|
||||
tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_b_extent(
|
||||
conv_kind, problem_size
|
||||
)
|
||||
elif operand in ["c", "d"]:
|
||||
tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_c_extent(
|
||||
conv_kind, problem_size
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown operand: " + operand)
|
||||
|
||||
layout = tensor_layout.packed(tensor_coord)
|
||||
|
||||
if tensor.dtype == np.float64:
|
||||
return cutlass_bindings.TensorRefF64NHWC(ptr, layout)
|
||||
elif tensor.dtype == np.float32:
|
||||
return cutlass_bindings.TensorRefF32NHWC(ptr, layout)
|
||||
elif tensor.dtype == np.float16:
|
||||
return cutlass_bindings.TensorRefF16NHWC(ptr, layout)
|
||||
if tensor.dtype == bfloat16:
|
||||
return cutlass_bindings.TensorRefBF16NHWC(ptr, layout)
|
||||
elif tensor.dtype == np.int32:
|
||||
return cutlass_bindings.TensorRefS32NHWC(ptr, layout)
|
||||
elif tensor.dtype == np.int8:
|
||||
if tensor_layout == cutlass_bindings.TensorNC32HW32:
|
||||
return cutlass_bindings.TensorRefS8NC32HW32(ptr, layout)
|
||||
elif tensor_layout == cutlass_bindings.TensorC32RSK32:
|
||||
return cutlass_bindings.TensorRefS8C32RSK32(ptr, layout)
|
||||
else:
|
||||
return cutlass_bindings.TensorRefS8NHWC(ptr, layout)
|
||||
else:
|
||||
raise ValueError("unsupported data type")
|
||||
|
||||
|
||||
def getTensorView(tensor, tensor_layout, conv_kind, problem_size, operand):
|
||||
tensor_ref = getTensorRef(tensor, tensor_layout, conv_kind, problem_size, operand)
|
||||
|
||||
if operand == "a":
|
||||
tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_a_extent(
|
||||
conv_kind, problem_size
|
||||
)
|
||||
elif operand == "b":
|
||||
tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_b_extent(
|
||||
conv_kind, problem_size
|
||||
)
|
||||
elif operand in ["c", "d"]:
|
||||
tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_c_extent(
|
||||
conv_kind, problem_size
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown operand: " + operand)
|
||||
|
||||
if tensor.dtype == np.float64:
|
||||
return cutlass_bindings.TensorViewF64NHWC(tensor_ref, tensor_coord)
|
||||
elif tensor.dtype == np.float32:
|
||||
return cutlass_bindings.TensorViewF32NHWC(tensor_ref, tensor_coord)
|
||||
elif tensor.dtype == np.float16:
|
||||
return cutlass_bindings.TensorViewF16NHWC(tensor_ref, tensor_coord)
|
||||
elif tensor.dtype == bfloat16:
|
||||
return cutlass_bindings.TensorViewBF16NHWC(tensor_ref, tensor_coord)
|
||||
elif tensor.dtype == np.int32:
|
||||
return cutlass_bindings.TensorViewS32NHWC(tensor_ref, tensor_coord)
|
||||
elif tensor.dtype == np.int8:
|
||||
if tensor_layout == cutlass_bindings.TensorNC32HW32:
|
||||
return cutlass_bindings.TensorViewS8NC32HW32(tensor_ref, tensor_coord)
|
||||
elif tensor_layout == cutlass_bindings.TensorC32RSK32:
|
||||
return cutlass_bindings.TensorViewS8C32RSK32(tensor_ref, tensor_coord)
|
||||
else:
|
||||
return cutlass_bindings.TensorViewS8NHWC(tensor_ref, tensor_coord)
|
||||
|
||||
else:
|
||||
raise ValueError("unsupported data type")
|
||||
|
||||
|
||||
class Conv2dLauncher:
|
||||
"""
|
||||
Launcher that runs the operation on given problem size
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
operation: "Conv2dOperation",
|
||||
seed: int = 2080,
|
||||
interleaved=False,
|
||||
verification=True,
|
||||
profiling=False,
|
||||
warmup_iterations=500,
|
||||
iterations=500,
|
||||
compilation_mode="nvcc",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.enable_cached_results = True
|
||||
self.interleaved = interleaved
|
||||
|
||||
# create the reduction kernel
|
||||
self.reduction_operation = ReductionOperation(
|
||||
shape=cutlass_bindings.MatrixCoord(4, 32 * operation.C.alignment),
|
||||
C=operation.C,
|
||||
element_accumulator=operation.tile_description.math_instruction.element_accumulator,
|
||||
element_compute=operation.epilogue_functor.element_epilogue,
|
||||
epilogue_functor=operation.epilogue_functor,
|
||||
count=operation.C.alignment,
|
||||
)
|
||||
|
||||
#: verify the output result
|
||||
self.verification = verification
|
||||
#: profile the kernel's runtime
|
||||
self.profiling = profiling
|
||||
|
||||
self.timer = GpuTimer()
|
||||
|
||||
self.warmup_iterations = warmup_iterations
|
||||
self.iterations = iterations
|
||||
|
||||
if "sleep" in kwargs.keys():
|
||||
self.sleep_time = kwargs["sleep"]
|
||||
else:
|
||||
self.sleep_time = 0
|
||||
|
||||
#
|
||||
# Compile the operator
|
||||
#
|
||||
|
||||
if compilation_mode == "nvcc":
|
||||
compiler.nvcc()
|
||||
elif compilation_mode == "nvrtc":
|
||||
compiler.nvrtc()
|
||||
else:
|
||||
raise Exception(f"Unexpected compilation mode {compilation_mode}")
|
||||
|
||||
compiler.add_module([operation, self.reduction_operation])
|
||||
|
||||
self.operation = operation
|
||||
|
||||
self.dtype_A = Conv2dLauncher.numpy_type(operation.A.element)
|
||||
self.layout_A = operation.A.layout
|
||||
self.dtype_B = Conv2dLauncher.numpy_type(operation.B.element)
|
||||
self.layout_B = operation.B.layout
|
||||
self.dtype_C = Conv2dLauncher.numpy_type(operation.C.element)
|
||||
self.layout_C = operation.C.layout
|
||||
self.dtype_D = Conv2dLauncher.numpy_type(operation.C.element)
|
||||
self.layout_D = operation.C.layout
|
||||
|
||||
accumulator_size = DataTypeSize[
|
||||
operation.tile_description.math_instruction.element_accumulator
|
||||
]
|
||||
element_size = DataTypeSize[operation.A.element]
|
||||
|
||||
if element_size <= 8:
|
||||
self.randomization_max = 1
|
||||
elif element_size == 16:
|
||||
if accumulator_size <= 16:
|
||||
self.randomization_max = 2
|
||||
else:
|
||||
self.randomization_max = 4
|
||||
else:
|
||||
self.randomization_max = 7
|
||||
|
||||
# Seed
|
||||
self.seed = seed
|
||||
|
||||
self.conv_kind = operation.conv_kind
|
||||
|
||||
#
|
||||
# Get the host reference function
|
||||
#
|
||||
|
||||
self.element_compute = operation.epilogue_functor.element_epilogue
|
||||
|
||||
self.host_conv2d = cutlass_bindings.test.conv.host.conv2d
|
||||
|
||||
self.timer = GpuTimer()
|
||||
|
||||
@staticmethod
|
||||
def numpy_type(type):
|
||||
if type == cutlass_bindings.float64:
|
||||
return np.float64
|
||||
elif type == cutlass_bindings.float32:
|
||||
return np.float32
|
||||
elif type == cutlass_bindings.float16:
|
||||
return np.float16
|
||||
elif type == cutlass_bindings.bfloat16:
|
||||
return bfloat16
|
||||
elif type == cutlass_bindings.int32:
|
||||
return np.int32
|
||||
elif type == cutlass_bindings.int8:
|
||||
return np.int8
|
||||
else:
|
||||
raise ValueError("unsupported type: %s" % ShortDataTypeNames[type])
|
||||
|
||||
def print_problem_size(self, p, split_k_mode=1):
|
||||
print(
|
||||
"nhwc_%dx%dx%dx%d_krsc_%dx%dx%dx%d_padding_%dx%d_stride_%dx%d_dilation_%dx%d_splitkslices_%d_splitkmode_%d"
|
||||
% (
|
||||
p.N,
|
||||
p.H,
|
||||
p.W,
|
||||
p.C,
|
||||
p.K,
|
||||
p.R,
|
||||
p.S,
|
||||
p.C,
|
||||
p.pad_h,
|
||||
p.pad_w,
|
||||
p.stride_h,
|
||||
p.stride_w,
|
||||
p.dilation_h,
|
||||
p.dilation_w,
|
||||
p.split_k_slices,
|
||||
split_k_mode,
|
||||
)
|
||||
)
|
||||
|
||||
def uniform_init(self, size, dtype):
|
||||
if dtype in [np.float32, np.float16, bfloat16, np.float64]:
|
||||
return np.ceil(
|
||||
np.random.uniform(
|
||||
low=-self.randomization_max - 0.5, high=self.randomization_max - 0.5, size=size
|
||||
).astype(dtype)
|
||||
)
|
||||
else:
|
||||
return np.random.uniform(
|
||||
low=-self.randomization_max - 1, high=self.randomization_max + 1, size=size
|
||||
).astype(dtype)
|
||||
|
||||
def eq_gemm_size(self, problem_size):
|
||||
n = problem_size.N
|
||||
p = problem_size.P
|
||||
q = problem_size.Q
|
||||
k = problem_size.K
|
||||
r = problem_size.R
|
||||
s = problem_size.S
|
||||
c = problem_size.C
|
||||
h = problem_size.H
|
||||
w = problem_size.W
|
||||
if self.conv_kind == cutlass_bindings.conv.Operator.fprop:
|
||||
return cutlass_bindings.gemm.GemmCoord(n * p * q, k, r * s * c)
|
||||
elif self.conv_kind == cutlass_bindings.conv.Operator.dgrad:
|
||||
return cutlass_bindings.gemm.GemmCoord(n * h * w, c, k * r * s)
|
||||
else:
|
||||
return cutlass_bindings.gemm.GemmCoord(k, r * s * c, n * p * q)
|
||||
|
||||
def bytes(self, problem_size, alpha, beta):
|
||||
mnk = self.eq_gemm_size(problem_size)
|
||||
|
||||
bytes_ = (
|
||||
(DataTypeSize[self.operation.A.element] * mnk.m() // 8) * mnk.k()
|
||||
+ (DataTypeSize[self.operation.B.element] * mnk.n() // 8) * mnk.k()
|
||||
+ (DataTypeSize[self.operation.C.element] * mnk.m() // 8) * mnk.n()
|
||||
)
|
||||
|
||||
if beta != 0:
|
||||
bytes_ += (DataTypeSize[self.operation.C.element] * mnk.m() // 8) * mnk.n()
|
||||
|
||||
return bytes_
|
||||
|
||||
def flops(self, problem_size):
|
||||
mnk = self.eq_gemm_size(problem_size)
|
||||
|
||||
flops_mainloop_ = mnk.m() * mnk.n() * mnk.k() * 2
|
||||
flops_epilogue_ = mnk.m() * mnk.n() * 2
|
||||
|
||||
# Adjust mainloop flop for dgrad stride
|
||||
if self.conv_kind == cutlass_bindings.conv.Operator.dgrad:
|
||||
flops_mainloop_ = flops_mainloop_ // (
|
||||
problem_size.stride_h * problem_size.stride_w
|
||||
)
|
||||
|
||||
flops_total_ = flops_mainloop_ + flops_epilogue_
|
||||
|
||||
return flops_total_
|
||||
|
||||
def host_reference(self, problem_size, tensor_A, tensor_B, tensor_C, alpha, beta):
|
||||
if self.element_compute == cutlass_bindings.float16:
|
||||
alpha = cutlass_bindings.float16(alpha)
|
||||
beta = cutlass_bindings.float16(beta)
|
||||
elif self.element_compute == cutlass_bindings.int32:
|
||||
alpha = int(alpha)
|
||||
beta = int(beta)
|
||||
else:
|
||||
alpha = alpha
|
||||
beta = beta
|
||||
|
||||
# if cached result is loaded
|
||||
cached_result_loaded = False
|
||||
|
||||
if self.enable_cached_results:
|
||||
# get problem key
|
||||
cached_test_key = cutlass_bindings.test.conv.host.CreateCachedConv2dTestKey(
|
||||
self.conv_kind,
|
||||
problem_size,
|
||||
alpha,
|
||||
beta,
|
||||
getTensorView(
|
||||
tensor_A, self.layout_A, self.conv_kind, problem_size, "a"
|
||||
),
|
||||
getTensorView(
|
||||
tensor_B, self.layout_B, self.conv_kind, problem_size, "b"
|
||||
),
|
||||
getTensorView(
|
||||
tensor_C, self.layout_C, self.conv_kind, problem_size, "c"
|
||||
),
|
||||
)
|
||||
|
||||
cached_test_result = cutlass_bindings.test.conv.host.CachedTestResult()
|
||||
|
||||
conv2d_result_cache_name = "cached_results_SM%d_%d.txt" % (
|
||||
self.operation.arch,
|
||||
self.seed,
|
||||
)
|
||||
|
||||
cached_results = cutlass_bindings.test.conv.host.CachedTestResultListing(
|
||||
conv2d_result_cache_name
|
||||
)
|
||||
# CachedTestResultListing cached_results(conv2d_result_cache_name);
|
||||
cached = cached_results.find(cached_test_key)
|
||||
cached_result_loaded = cached[0]
|
||||
if cached_result_loaded:
|
||||
cached_test_result = cached[1]
|
||||
|
||||
if not cached_result_loaded:
|
||||
# compute the conv2d on host
|
||||
tensor_D_ref = np.ones_like(tensor_C)
|
||||
tensor_ref_A = getTensorRef(
|
||||
tensor_A, self.layout_A, self.conv_kind, problem_size, "a"
|
||||
)
|
||||
tensor_ref_B = getTensorRef(
|
||||
tensor_B, self.layout_B, self.conv_kind, problem_size, "b"
|
||||
)
|
||||
tensor_ref_C = getTensorRef(
|
||||
tensor_C, self.layout_C, self.conv_kind, problem_size, "c"
|
||||
)
|
||||
tensor_ref_D_ref = getTensorRef(
|
||||
tensor_D_ref, self.layout_D, self.conv_kind, problem_size, "d"
|
||||
)
|
||||
|
||||
self.host_conv2d(
|
||||
self.conv_kind,
|
||||
problem_size,
|
||||
tensor_ref_A,
|
||||
tensor_ref_B,
|
||||
tensor_ref_C,
|
||||
tensor_ref_D_ref,
|
||||
alpha,
|
||||
beta,
|
||||
)
|
||||
|
||||
tensor_view_D_ref = getTensorView(
|
||||
tensor_D_ref, self.layout_D, self.conv_kind, problem_size, "d"
|
||||
)
|
||||
|
||||
if self.enable_cached_results:
|
||||
cached_test_result.D = cutlass_bindings.test.conv.host.TensorHash(
|
||||
tensor_view_D_ref
|
||||
)
|
||||
cached_results = (
|
||||
cutlass_bindings.test.conv.host.CachedTestResultListing(
|
||||
conv2d_result_cache_name
|
||||
)
|
||||
)
|
||||
cached_results.append(cached_test_key, cached_test_result)
|
||||
cached_results.write(conv2d_result_cache_name)
|
||||
else:
|
||||
return tensor_D_ref
|
||||
|
||||
return cached_test_result.D
|
||||
|
||||
def equal(self, tensor_D, tensor_D_ref, problem_size):
|
||||
if self.enable_cached_results:
|
||||
tensor_view_D = getTensorView(
|
||||
tensor_D, self.layout_D, self.conv_kind, problem_size, "d"
|
||||
)
|
||||
tensor_D_hash = cutlass_bindings.test.conv.host.TensorHash(tensor_view_D)
|
||||
|
||||
return tensor_D_hash == tensor_D_ref
|
||||
else:
|
||||
tensor_view_D = getTensorView(
|
||||
tensor_D, self.layout_D, self.conv_kind, problem_size, "d"
|
||||
)
|
||||
tensor_view_D_ref = getTensorView(
|
||||
tensor_D_ref, self.layout_D, self.conv_kind, problem_size, "d"
|
||||
)
|
||||
return cutlass_bindings.test.conv.host.equals(
|
||||
tensor_view_D, tensor_view_D_ref
|
||||
)
|
||||
|
||||
def run_cutlass_profiler(
|
||||
self,
|
||||
problem_size,
|
||||
split_k_mode=cutlass_bindings.conv.SplitKMode.Serial,
|
||||
alpha=1.0,
|
||||
beta=0.0,
|
||||
):
|
||||
if split_k_mode == cutlass_bindings.conv.SplitKMode.Serial:
|
||||
split_k_mode_ = "serial"
|
||||
else:
|
||||
split_k_mode_ = "parallel"
|
||||
|
||||
cutlass_path = os.getenv("CUTLASS_PATH")
|
||||
assert (
|
||||
cutlass_path is not None
|
||||
), "Environment variable 'CUTLASS_PATH' is not defined."
|
||||
|
||||
values = {
|
||||
"profiler_path": cutlass_path + "/build/tools/profiler/cutlass_profiler",
|
||||
"kernel_name": self.operation.procedural_name(),
|
||||
"verification_providers": "device",
|
||||
"provider": "cutlass",
|
||||
"n": str(problem_size.N),
|
||||
"h": str(problem_size.H),
|
||||
"w": str(problem_size.W),
|
||||
"c": str(problem_size.C),
|
||||
"k": str(problem_size.K),
|
||||
"r": str(problem_size.R),
|
||||
"s": str(problem_size.S),
|
||||
"p": str(problem_size.P),
|
||||
"q": str(problem_size.Q),
|
||||
"pad_h": str(problem_size.pad_h),
|
||||
"pad_w": str(problem_size.pad_w),
|
||||
"stride_h": str(problem_size.stride_h),
|
||||
"stride_w": str(problem_size.stride_w),
|
||||
"dilation_h": str(problem_size.dilation_h),
|
||||
"dilation_w": str(problem_size.dilation_w),
|
||||
"split_k_slices": str(problem_size.split_k_slices),
|
||||
"split_k_mode": split_k_mode_,
|
||||
"alpha": str(alpha),
|
||||
"beta": str(beta),
|
||||
"warmup": str(self.warmup_iterations),
|
||||
"profile": str(self.iterations),
|
||||
}
|
||||
|
||||
cmd_template = (
|
||||
"${profiler_path} --kernels=${kernel_name} --verification-providers=${verification_providers}"
|
||||
" --providers=${provider} --n=${n} --h=${h} --w=${w} --c=${c} --k=${k} --r=${r} --s=${s} --p=${p}"
|
||||
" --q=${q} --pad_h=${pad_h} --pad_w=${pad_w} --stride_h={stride_h} --stride_w=${stride_w}"
|
||||
" --dilation_h=${dilation_h} --dilation_w=${dilation_w} --warmup-iterations=${warmup} --profiling-iterations=${profile}"
|
||||
" --split_k_slices=${split_k_slices} --alpha=${alpha} --beta=${beta} --split_k_mode=${split_k_mode}"
|
||||
)
|
||||
|
||||
cmd = SubstituteTemplate(cmd_template, values)
|
||||
result = subprocess.getoutput(cmd)
|
||||
|
||||
m = re.search(r"Runtime:\s+(?P<runtime>\d+.\d+)", result)
|
||||
runtime = float(m.group("runtime"))
|
||||
|
||||
m = re.search(r"Bytes:\s+(?P<bytes>\d+)", result)
|
||||
bytes = int(m.group("bytes"))
|
||||
|
||||
m = re.search(r"FLOPs:\s+(?P<flops>\d+)", result)
|
||||
flops = int(m.group("flops"))
|
||||
|
||||
# check if the problem size matches
|
||||
assert bytes == self.bytes(problem_size, alpha, beta)
|
||||
assert flops == self.flops(problem_size)
|
||||
|
||||
return runtime
|
||||
|
||||
def run(
|
||||
self,
|
||||
problem_size,
|
||||
split_k_mode=cutlass_bindings.conv.SplitKMode.Serial,
|
||||
alpha=1.0,
|
||||
beta=0.0,
|
||||
):
|
||||
assert get_allocated_size() == 0, (
|
||||
"%d byte of pool memory is not released in previous run"
|
||||
% get_allocated_size()
|
||||
)
|
||||
|
||||
#
|
||||
# Initialize input and output tensors
|
||||
#
|
||||
tensor_A_size = cutlass_bindings.conv.implicit_gemm_tensor_a_size(
|
||||
self.conv_kind, problem_size
|
||||
)
|
||||
tensor_B_size = cutlass_bindings.conv.implicit_gemm_tensor_b_size(
|
||||
self.conv_kind, problem_size
|
||||
)
|
||||
tensor_C_size = cutlass_bindings.conv.implicit_gemm_tensor_c_size(
|
||||
self.conv_kind, problem_size
|
||||
)
|
||||
|
||||
np.random.seed(self.seed)
|
||||
|
||||
tensor_A = self.uniform_init(size=(tensor_A_size,), dtype=self.dtype_A)
|
||||
tensor_B = self.uniform_init(size=(tensor_B_size,), dtype=self.dtype_B)
|
||||
tensor_C = self.uniform_init(size=(tensor_C_size,), dtype=self.dtype_C)
|
||||
tensor_D = np.zeros(shape=(tensor_C_size,), dtype=self.dtype_D)
|
||||
|
||||
#
|
||||
# Launch kernel
|
||||
#
|
||||
|
||||
arguments = Conv2dArguments(
|
||||
operation=self.operation,
|
||||
problem_size=problem_size,
|
||||
A=tensor_A,
|
||||
B=tensor_B,
|
||||
C=tensor_C,
|
||||
D=tensor_D,
|
||||
output_op=self.operation.epilogue_type(alpha, beta),
|
||||
split_k_slices=problem_size.split_k_slices,
|
||||
split_k_mode=split_k_mode,
|
||||
)
|
||||
|
||||
if split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel:
|
||||
implicit_gemm_size = cutlass_bindings.conv.implicit_gemm_problem_size(
|
||||
self.operation.conv_kind, arguments.problem_size
|
||||
)
|
||||
reduction_arguments = ReductionArguments(
|
||||
self.reduction_operation,
|
||||
problem_size=[implicit_gemm_size.m(), implicit_gemm_size.n()],
|
||||
partitions=problem_size.split_k_slices,
|
||||
workspace=arguments.ptr_D,
|
||||
destination=tensor_D,
|
||||
source=tensor_C,
|
||||
output_op=self.reduction_operation.epilogue_type(alpha, beta),
|
||||
)
|
||||
|
||||
self.operation.run(arguments)
|
||||
if split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel:
|
||||
self.reduction_operation.run(reduction_arguments)
|
||||
|
||||
passed = True
|
||||
if self.verification:
|
||||
if split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel:
|
||||
reduction_arguments.sync()
|
||||
else:
|
||||
arguments.sync()
|
||||
|
||||
tensor_D_ref = self.host_reference(
|
||||
problem_size, tensor_A, tensor_B, tensor_C, alpha, beta
|
||||
)
|
||||
|
||||
passed = self.equal(tensor_D, tensor_D_ref, problem_size)
|
||||
|
||||
try:
|
||||
assert passed
|
||||
except AssertionError:
|
||||
self.print_problem_size(problem_size, split_k_mode)
|
||||
|
||||
if self.profiling:
|
||||
sleep(self.sleep_time)
|
||||
for _ in range(self.warmup_iterations):
|
||||
self.operation.run(arguments)
|
||||
if split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel:
|
||||
self.reduction_operation.run(reduction_arguments)
|
||||
|
||||
self.timer.start()
|
||||
for _ in range(self.warmup_iterations):
|
||||
self.operation.run(arguments)
|
||||
if split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel:
|
||||
self.reduction_operation.run(reduction_arguments)
|
||||
self.timer.stop_and_wait()
|
||||
runtime = self.timer.duration(self.iterations)
|
||||
|
||||
# free memory
|
||||
del arguments
|
||||
if split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel:
|
||||
del reduction_arguments
|
||||
|
||||
assert get_allocated_size() == 0, (
|
||||
"%d byte of pool memory is not released after current run"
|
||||
% get_allocated_size()
|
||||
)
|
||||
if self.profiling:
|
||||
return runtime
|
||||
return passed
|
||||
|
||||
|
||||
########################################################################################################
|
||||
# TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference
|
||||
# TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes
|
||||
# Additionaly, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes
|
||||
# (conv_blacklist_sizes)
|
||||
############################################################################################################
|
||||
|
||||
|
||||
def test_all_conv2d_from_compilation_mode(
|
||||
operation: Conv2dOperation,
|
||||
conv_test_sizes,
|
||||
interleaved,
|
||||
compilation_mode):
|
||||
|
||||
passed = True
|
||||
|
||||
testbed = Conv2dLauncher(operation, interleaved=interleaved, compilation_mode=compilation_mode)
|
||||
|
||||
#
|
||||
# Get conv problem sizes to run conv operator
|
||||
#
|
||||
|
||||
conv_problems = cutlass_bindings.test.conv.TestbedConv2dProblemSizes(64)
|
||||
|
||||
# Vector of conv2d problem sizes to avoid duplicate runs
|
||||
conv_tested_sizes = []
|
||||
|
||||
# Flatten 2D problem_vectors into a 1D problem sizes
|
||||
problem_sizes = conv_problems.conv2d_default_sizes
|
||||
|
||||
problem_sizes = [conv_problem for conv_problem in problem_sizes] + conv_test_sizes
|
||||
|
||||
# Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slices=1, alpha=1.0, beta=0.0)
|
||||
for conv_problem in problem_sizes:
|
||||
if conv_problem in conv_tested_sizes:
|
||||
continue
|
||||
|
||||
# skip channel dimension % 32 != 0 for interleaved case
|
||||
if interleaved:
|
||||
if conv_problem.K % 32 != 0 or conv_problem.C % 32 != 0:
|
||||
continue
|
||||
|
||||
#
|
||||
# Procedurally disable certain cases
|
||||
#
|
||||
|
||||
# CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1}
|
||||
if (
|
||||
operation.conv_kind == cutlass_bindings.conv.Operator.dgrad
|
||||
and operation.stride_support == StrideSupport.Unity
|
||||
):
|
||||
if not ((conv_problem.stride_h == 1) and (conv_problem.stride_w == 1)):
|
||||
continue
|
||||
|
||||
if not interleaved:
|
||||
# Fixed channels algorithm requires channel count to match access size
|
||||
if (
|
||||
operation.iterator_algorithm
|
||||
== cutlass_bindings.conv.IteratorAlgorithm.fixed_channels
|
||||
):
|
||||
if conv_problem.C != operation.A.alignment:
|
||||
continue
|
||||
|
||||
# Few channels algorithm requires channel count to match access size
|
||||
if (
|
||||
operation.iterator_algorithm
|
||||
== cutlass_bindings.conv.IteratorAlgorithm.few_channels
|
||||
):
|
||||
if conv_problem.C % operation.A.alignment:
|
||||
continue
|
||||
|
||||
# CUTLASS DGRAD's *strided* stride specialization supports all stride {stride_h, stride_w}
|
||||
# Although strided dgrad works for all stride combinations, we are only going
|
||||
# to run strided dgrad for non-unity strides
|
||||
|
||||
if (
|
||||
operation.conv_kind == cutlass_bindings.conv.Operator.dgrad
|
||||
and operation.stride_support == StrideSupport.Strided
|
||||
):
|
||||
if (conv_problem.stride_h == 1) and (conv_problem.stride_w == 1):
|
||||
continue
|
||||
|
||||
#
|
||||
# Test
|
||||
#
|
||||
|
||||
# push back tested problem size to avoid re-running duplicates
|
||||
conv_tested_sizes.append(conv_problem)
|
||||
|
||||
passed = testbed.run(conv_problem)
|
||||
|
||||
if not passed:
|
||||
return False
|
||||
|
||||
if interleaved:
|
||||
return True
|
||||
#
|
||||
# filter the cases for split K
|
||||
#
|
||||
|
||||
# Small-channels convolution can't run here.
|
||||
if operation.iterator_algorithm in [
|
||||
cutlass_bindings.conv.IteratorAlgorithm.fixed_channels,
|
||||
cutlass_bindings.conv.IteratorAlgorithm.few_channels,
|
||||
]:
|
||||
return True
|
||||
|
||||
# CUTLASS DGRAD's *stride* specialization does not support split-k mode
|
||||
if (
|
||||
operation.conv_kind == cutlass_bindings.conv.Operator.dgrad
|
||||
and operation.stride_support == StrideSupport.Strided
|
||||
):
|
||||
conv_problem = cutlass_bindings.conv.Conv2dProblemSize(
|
||||
cutlass_bindings.Tensor4DCoord(1, 56, 56, 8),
|
||||
cutlass_bindings.Tensor4DCoord(8, 1, 1, 8),
|
||||
cutlass_bindings.Tensor4DCoord(0, 0, 0, 0),
|
||||
cutlass_bindings.MatrixCoord(2, 2),
|
||||
cutlass_bindings.MatrixCoord(1, 1),
|
||||
cutlass_bindings.conv.Mode.cross_correlation,
|
||||
1,
|
||||
1,
|
||||
)
|
||||
passed = testbed.run(conv_problem)
|
||||
|
||||
return passed
|
||||
|
||||
# Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for
|
||||
# a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters
|
||||
# which are abolutely neccessary to catch functional bugs. The below code does provide option to sweep
|
||||
# alpha and beta for local testing, but only runs one value for alpha and beta.
|
||||
|
||||
conv2d_split_k_test_size = cutlass_bindings.conv.Conv2dProblemSize(
|
||||
cutlass_bindings.Tensor4DCoord(1, 17, 11, 288),
|
||||
cutlass_bindings.Tensor4DCoord(160, 3, 3, 288),
|
||||
cutlass_bindings.Tensor4DCoord(1, 1, 1, 1),
|
||||
cutlass_bindings.MatrixCoord(1, 1),
|
||||
cutlass_bindings.MatrixCoord(1, 1),
|
||||
cutlass_bindings.conv.Mode.cross_correlation,
|
||||
1,
|
||||
1,
|
||||
)
|
||||
|
||||
split_k_modes = [
|
||||
cutlass_bindings.conv.SplitKMode.Parallel,
|
||||
cutlass_bindings.conv.SplitKMode.Serial,
|
||||
]
|
||||
|
||||
split_k_slices = [1, 2, 3, 4, 201]
|
||||
problem_alpha = [
|
||||
2.0,
|
||||
]
|
||||
problem_beta = [
|
||||
2.0,
|
||||
]
|
||||
|
||||
for split_k_mode in split_k_modes:
|
||||
for split_k_slice in split_k_slices:
|
||||
for alpha in problem_alpha:
|
||||
for beta in problem_beta:
|
||||
passed = testbed.run(
|
||||
conv2d_split_k_test_size.reset_split_k_slices(split_k_slice),
|
||||
split_k_mode,
|
||||
alpha,
|
||||
beta,
|
||||
)
|
||||
|
||||
return passed
|
||||
|
||||
|
||||
def test_all_conv2d(
|
||||
operation: Conv2dOperation,
|
||||
conv_test_sizes=[],
|
||||
interleaved=False,
|
||||
compilation_modes=["nvcc", "nvrtc"]):
|
||||
|
||||
for compilation_mode in compilation_modes:
|
||||
passed = test_all_conv2d_from_compilation_mode(operation, conv_test_sizes, interleaved, compilation_mode)
|
||||
|
||||
if not passed:
|
||||
return False
|
||||
|
||||
return True
|
||||
@ -1,276 +0,0 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from bfloat16 import bfloat16
|
||||
import cutlass_bindings
|
||||
import numpy as np
|
||||
|
||||
from cutlass.backend import compiler
|
||||
from cutlass.backend.gemm_operation import GemmGroupedArguments, GemmOperationGrouped
|
||||
from cutlass.backend.library import DataTypeSize, ShortDataTypeNames
|
||||
from cutlass.backend.memory_manager import get_allocated_size
|
||||
from cutlass.backend.test.gemm_testbed import getTensorRef, getTensorView, transpose
|
||||
|
||||
|
||||
class TestbedGrouped:
|
||||
def __init__(self, operation: GemmOperationGrouped, seed: int = 2080) -> None:
|
||||
compiler.add_module([operation])
|
||||
|
||||
self.seed = seed
|
||||
|
||||
self.operation = operation
|
||||
|
||||
element_size = DataTypeSize[operation.A.element]
|
||||
|
||||
self.dtype_A = self.numpy_type(operation.A.element)
|
||||
self.dtype_B = self.numpy_type(operation.B.element)
|
||||
self.dtype_C = self.numpy_type(operation.C.element)
|
||||
self.dtype_D = self.numpy_type(operation.C.element)
|
||||
|
||||
if element_size == 1:
|
||||
self.scope_max = 1
|
||||
self.scope_min = 0
|
||||
elif element_size <= 8:
|
||||
self.scope_max = 1
|
||||
self.scope_min = -1
|
||||
elif element_size == 16:
|
||||
self.scope_max = 4
|
||||
self.scope_min = -4
|
||||
else:
|
||||
self.scope_max = 8
|
||||
self.scope_min = -8
|
||||
|
||||
#: compute type
|
||||
self.compute_type = operation.epilogue_functor.element_epilogue
|
||||
|
||||
self.accumulator_type = (
|
||||
operation.tile_description.math_instruction.element_accumulator
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def numpy_type(type):
|
||||
if type == cutlass_bindings.float64:
|
||||
return np.float64
|
||||
elif type == cutlass_bindings.float32:
|
||||
return np.float32
|
||||
elif type == cutlass_bindings.float16:
|
||||
return np.float16
|
||||
elif type == cutlass_bindings.bfloat16:
|
||||
return bfloat16
|
||||
elif type == cutlass_bindings.int32:
|
||||
return np.int32
|
||||
elif type == cutlass_bindings.int8:
|
||||
return np.int8
|
||||
else:
|
||||
raise ValueError("unsupported type: %s" % ShortDataTypeNames[type])
|
||||
|
||||
def uniform_init(self, size, dtype):
|
||||
if dtype in [np.float32, np.float16, bfloat16, np.float64]:
|
||||
return np.ceil(
|
||||
np.random.uniform(
|
||||
low=self.scope_min - 0.5, high=self.scope_max - 0.5, size=size
|
||||
).astype(dtype)
|
||||
)
|
||||
else:
|
||||
return np.random.uniform(
|
||||
low=self.scope_min - 1, high=self.scope_max + 1, size=size
|
||||
).astype(dtype)
|
||||
|
||||
def print_problem_size(self, p):
|
||||
problem_size = "problem: %d, %d, %d\n" % (p.m(), p.n(), p.k())
|
||||
print(problem_size)
|
||||
|
||||
def run(self, problem_count: int, alpha: float = 1.0, beta: float = 0.0) -> bool:
|
||||
assert get_allocated_size() == 0, (
|
||||
"%d byte of pool memory is not released in previous run"
|
||||
% get_allocated_size()
|
||||
)
|
||||
|
||||
# initialize
|
||||
passed = False
|
||||
np.random.seed(self.seed)
|
||||
|
||||
# generate the problem sizes
|
||||
problem_sizes = []
|
||||
tensor_As = []
|
||||
tensor_Bs = []
|
||||
tensor_Cs = []
|
||||
tensor_Ds = []
|
||||
tensor_D_refs = []
|
||||
|
||||
for i in range(problem_count):
|
||||
if self.dtype_A == np.int8:
|
||||
if i == 0:
|
||||
problem_size = cutlass_bindings.gemm.GemmCoord(48, 16, 32)
|
||||
else:
|
||||
problem_size = cutlass_bindings.gemm.GemmCoord(
|
||||
16 * np.random.randint(0, 64) + 48,
|
||||
16 * np.random.randint(0, 64) + 48,
|
||||
16 * np.random.randint(0, 64) + 48,
|
||||
)
|
||||
else:
|
||||
if i == 0:
|
||||
problem_size = cutlass_bindings.gemm.GemmCoord(48, 16, 8)
|
||||
else:
|
||||
problem_size = cutlass_bindings.gemm.GemmCoord(
|
||||
8 * np.random.randint(0, 64) + 24,
|
||||
8 * np.random.randint(0, 64) + 24,
|
||||
8 * np.random.randint(0, 64) + 24,
|
||||
)
|
||||
|
||||
tensor_As.append(
|
||||
self.uniform_init(
|
||||
size=(problem_size.m() * problem_size.k(),), dtype=self.dtype_A
|
||||
)
|
||||
)
|
||||
tensor_Bs.append(
|
||||
self.uniform_init(
|
||||
size=(problem_size.n() * problem_size.k(),), dtype=self.dtype_B
|
||||
)
|
||||
)
|
||||
tensor_Cs.append(
|
||||
self.uniform_init(
|
||||
size=(problem_size.m() * problem_size.n(),), dtype=self.dtype_C
|
||||
)
|
||||
)
|
||||
|
||||
tensor_Ds.append(
|
||||
np.zeros(
|
||||
shape=(problem_size.m() * problem_size.n(),), dtype=self.dtype_D
|
||||
)
|
||||
)
|
||||
|
||||
tensor_D_refs.append(
|
||||
np.ones(
|
||||
shape=(problem_size.m() * problem_size.n(),), dtype=self.dtype_D
|
||||
)
|
||||
)
|
||||
|
||||
problem_sizes.append(problem_size)
|
||||
|
||||
arguments = GemmGroupedArguments(
|
||||
operation=self.operation,
|
||||
problem_sizes=problem_sizes,
|
||||
A=tensor_As,
|
||||
B=tensor_Bs,
|
||||
C=tensor_Cs,
|
||||
D=tensor_Ds,
|
||||
output_op=self.operation.epilogue_type(alpha, beta),
|
||||
)
|
||||
|
||||
self.operation.run(arguments)
|
||||
|
||||
arguments.sync()
|
||||
|
||||
#
|
||||
# Reference check
|
||||
#
|
||||
alpha = self.compute_type(alpha).value()
|
||||
beta = self.compute_type(beta).value()
|
||||
init_acc = self.accumulator_type(0).value()
|
||||
|
||||
for idx, problem_size in enumerate(problem_sizes):
|
||||
if self.operation.switched:
|
||||
tensor_ref_A = getTensorRef(
|
||||
tensor_As[idx],
|
||||
problem_size,
|
||||
"a",
|
||||
transpose(self.operation.B.layout),
|
||||
)
|
||||
tensor_ref_B = getTensorRef(
|
||||
tensor_Bs[idx],
|
||||
problem_size,
|
||||
"b",
|
||||
transpose(self.operation.A.layout),
|
||||
)
|
||||
tensor_ref_C = getTensorRef(
|
||||
tensor_Cs[idx],
|
||||
problem_size,
|
||||
"c",
|
||||
transpose(self.operation.C.layout),
|
||||
)
|
||||
tensor_ref_D_ref = getTensorRef(
|
||||
tensor_D_refs[idx],
|
||||
problem_size,
|
||||
"d",
|
||||
transpose(self.operation.C.layout),
|
||||
)
|
||||
else:
|
||||
tensor_ref_A = getTensorRef(
|
||||
tensor_As[idx], problem_size, "a", self.operation.A.layout
|
||||
)
|
||||
tensor_ref_B = getTensorRef(
|
||||
tensor_Bs[idx], problem_size, "b", self.operation.B.layout
|
||||
)
|
||||
tensor_ref_C = getTensorRef(
|
||||
tensor_Cs[idx], problem_size, "c", self.operation.C.layout
|
||||
)
|
||||
tensor_ref_D_ref = getTensorRef(
|
||||
tensor_D_refs[idx], problem_size, "d", self.operation.C.layout
|
||||
)
|
||||
|
||||
tensor_view_D_ref = getTensorView(
|
||||
tensor_D_refs[idx], problem_size, "d", self.operation.C.layout
|
||||
)
|
||||
|
||||
cutlass_bindings.test.gemm.host.gemm(
|
||||
problem_size,
|
||||
alpha,
|
||||
tensor_ref_A,
|
||||
tensor_ref_B,
|
||||
beta,
|
||||
tensor_ref_C,
|
||||
tensor_ref_D_ref,
|
||||
init_acc,
|
||||
)
|
||||
|
||||
tensor_view_D = getTensorView(
|
||||
tensor_Ds[idx], problem_size, "d", self.operation.C.layout
|
||||
)
|
||||
|
||||
passed = cutlass_bindings.test.gemm.host.equals(
|
||||
tensor_view_D, tensor_view_D_ref
|
||||
)
|
||||
|
||||
try:
|
||||
assert passed
|
||||
except AssertionError:
|
||||
self.print_problem_size(problem_size)
|
||||
|
||||
del arguments
|
||||
|
||||
assert get_allocated_size() == 0, (
|
||||
"%d byte of pool memory is not released after current run"
|
||||
% get_allocated_size()
|
||||
)
|
||||
|
||||
return passed
|
||||
@ -1,765 +0,0 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from time import sleep
|
||||
|
||||
from bfloat16 import bfloat16
|
||||
from cuda import cuda, cudart
|
||||
import cutlass_bindings
|
||||
import numpy as np
|
||||
|
||||
from cutlass.backend import compiler
|
||||
from cutlass.backend.gemm_operation import GemmArguments, GemmOperationUniversal
|
||||
from cutlass.backend.library import (
|
||||
DataTypeSize,
|
||||
DataTypeSizeBytes,
|
||||
MathOperation,
|
||||
ShortDataTypeNames,
|
||||
)
|
||||
from cutlass.backend.memory_manager import get_allocated_size
|
||||
from cutlass.backend.reduction_operation import ReductionArguments, ReductionOperation
|
||||
from cutlass.backend.test.profiler import GpuTimer
|
||||
from cutlass.backend.utils.datatypes import to_cutlass
|
||||
from cutlass.backend.utils.software import SubstituteTemplate
|
||||
|
||||
|
||||
def transpose(layout):
|
||||
if layout == cutlass_bindings.RowMajor:
|
||||
return cutlass_bindings.ColumnMajor
|
||||
elif layout == cutlass_bindings.ColumnMajor:
|
||||
return cutlass_bindings.RowMajor
|
||||
elif layout == cutlass_bindings.ColumnMajorInterleaved32:
|
||||
return cutlass_bindings.RowMajorInterleaved32
|
||||
elif layout == cutlass_bindings.RowMajorInterleaved32:
|
||||
return cutlass_bindings.ColumnMajorInterleaved32
|
||||
|
||||
|
||||
def getTensorRef(
|
||||
tensor: np.ndarray,
|
||||
problem_size: cutlass_bindings.gemm.GemmCoord,
|
||||
operand: str,
|
||||
layout: cutlass_bindings.layout,
|
||||
batch_offset: int = 0,
|
||||
):
|
||||
ptr = tensor.__array_interface__["data"][0]
|
||||
if operand == "a":
|
||||
tensor_coord = problem_size.mk()
|
||||
batch_stride = problem_size.m() * problem_size.k()
|
||||
elif operand == "b":
|
||||
tensor_coord = problem_size.kn()
|
||||
batch_stride = problem_size.k() * problem_size.n()
|
||||
elif operand in ["c", "d"]:
|
||||
tensor_coord = problem_size.mn()
|
||||
batch_stride = problem_size.m() * problem_size.n()
|
||||
else:
|
||||
raise ValueError("Unknown operand: " + operand)
|
||||
|
||||
elt_size = DataTypeSizeBytes[to_cutlass(tensor.dtype)]
|
||||
ptr += batch_offset * batch_stride * elt_size
|
||||
|
||||
if layout == cutlass_bindings.RowMajor:
|
||||
layout = cutlass_bindings.RowMajor.packed(tensor_coord)
|
||||
layout_tag = "RowMajor"
|
||||
elif layout == cutlass_bindings.ColumnMajor:
|
||||
layout = cutlass_bindings.ColumnMajor.packed(tensor_coord)
|
||||
layout_tag = "ColumnMajor"
|
||||
elif layout == cutlass_bindings.ColumnMajorInterleaved32:
|
||||
layout = cutlass_bindings.ColumnMajorInterleaved32.packed(tensor_coord)
|
||||
layout_tag = "ColumnMajorInterleaved32"
|
||||
elif layout == cutlass_bindings.RowMajorInterleaved32:
|
||||
layout = cutlass_bindings.RowMajorInterleaved32.packed(tensor_coord)
|
||||
layout_tag = "RowMajorInterleaved32"
|
||||
else:
|
||||
raise ValueError("unsupported layout")
|
||||
if tensor.dtype == np.float32:
|
||||
ref_name = "TensorRefF32" + layout_tag
|
||||
elif tensor.dtype == np.float64:
|
||||
ref_name = "TensorRefF64" + layout_tag
|
||||
elif tensor.dtype == np.float16:
|
||||
ref_name = "TensorRefF16" + layout_tag
|
||||
elif tensor.dtype == bfloat16:
|
||||
ref_name = "TensorRefBF16" + layout_tag
|
||||
elif tensor.dtype == np.int8:
|
||||
ref_name = "TensorRefS8" + layout_tag
|
||||
elif tensor.dtype == np.int32:
|
||||
ref_name = "TensorRefS32" + layout_tag
|
||||
else:
|
||||
raise ValueError("unsupported datatype %s" % ShortDataTypeNames[tensor.dtype])
|
||||
|
||||
return getattr(cutlass_bindings, ref_name)(ptr, layout)
|
||||
|
||||
|
||||
def getTensorView(
|
||||
tensor: np.ndarray,
|
||||
problem_size: cutlass_bindings.gemm.GemmCoord,
|
||||
operand: str,
|
||||
layout: str,
|
||||
batch_offset: int = 0,
|
||||
):
|
||||
tensor_ref = getTensorRef(tensor, problem_size, operand, layout, batch_offset)
|
||||
|
||||
if operand == "a":
|
||||
tensor_coord = problem_size.mk()
|
||||
elif operand == "b":
|
||||
tensor_coord = problem_size.kn()
|
||||
elif operand in ["c", "d"]:
|
||||
tensor_coord = problem_size.mn()
|
||||
else:
|
||||
raise ValueError("Unknown operand: " + operand)
|
||||
|
||||
if layout == cutlass_bindings.RowMajor:
|
||||
layout_tag = "RowMajor"
|
||||
elif layout == cutlass_bindings.ColumnMajor:
|
||||
layout_tag = "ColumnMajor"
|
||||
elif layout == cutlass_bindings.ColumnMajorInterleaved32:
|
||||
layout_tag = "ColumnMajorInterleaved32"
|
||||
elif layout == cutlass_bindings.RowMajorInterleaved32:
|
||||
layout_tag = "RowMajorInterleaved32"
|
||||
else:
|
||||
raise ValueError("unsupported layout")
|
||||
if tensor.dtype == np.float32:
|
||||
ref_name = "TensorViewF32" + layout_tag
|
||||
elif tensor.dtype == np.float64:
|
||||
ref_name = "TensorViewF64" + layout_tag
|
||||
elif tensor.dtype == np.float16:
|
||||
ref_name = "TensorViewF16" + layout_tag
|
||||
elif tensor.dtype == bfloat16:
|
||||
ref_name = "TensorViewBF16" + layout_tag
|
||||
elif tensor.dtype == np.int32:
|
||||
ref_name = "TensorViewS32" + layout_tag
|
||||
elif tensor.dtype == np.int8:
|
||||
ref_name = "TensorViewS8" + layout_tag
|
||||
else:
|
||||
raise ValueError("unsupported datatype")
|
||||
|
||||
return getattr(cutlass_bindings, ref_name)(tensor_ref, tensor_coord)
|
||||
|
||||
|
||||
class GemmUniversalLauncher:
|
||||
def __init__(
|
||||
self,
|
||||
operation: "GemmOperationUniversal",
|
||||
seed: int = 2080,
|
||||
interleaved=False,
|
||||
verification=True,
|
||||
profiling=False,
|
||||
warmup_iterations=500,
|
||||
iterations=500,
|
||||
compiler_mode: str = "nvcc",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
# create the reduction kernel
|
||||
self.reduction_operation: ReductionOperation = ReductionOperation(
|
||||
shape=cutlass_bindings.MatrixCoord(4, 32 * operation.C.alignment),
|
||||
C=operation.C,
|
||||
element_accumulator=operation.tile_description.math_instruction.element_accumulator,
|
||||
element_compute=operation.epilogue_functor.element_epilogue,
|
||||
epilogue_functor=operation.epilogue_functor,
|
||||
count=operation.C.alignment,
|
||||
)
|
||||
|
||||
self.math_operation = operation.tile_description.math_instruction.math_operation
|
||||
|
||||
#: verify the output result
|
||||
self.verification = verification
|
||||
#: profile the kernel's runtime
|
||||
self.profiling = profiling
|
||||
|
||||
self.timer = GpuTimer()
|
||||
|
||||
self.warmup_iterations = warmup_iterations
|
||||
self.iterations = iterations
|
||||
|
||||
if "sleep" in kwargs.keys():
|
||||
self.sleep_time = kwargs["sleep"]
|
||||
else:
|
||||
self.sleep_time = 0
|
||||
|
||||
#
|
||||
# Compile the operator
|
||||
#
|
||||
if compiler_mode == "nvcc":
|
||||
compiler.nvcc()
|
||||
elif compiler_mode == "nvrtc":
|
||||
compiler.nvrtc()
|
||||
else:
|
||||
raise Exception(f"Unexpected compiler string {compiler_mode}")
|
||||
|
||||
op_list = [operation]
|
||||
if operation.arch < 90:
|
||||
# Split K via Python is currently only supported for pre-SM90 kernels
|
||||
op_list.append(self.reduction_operation)
|
||||
|
||||
compiler.add_module(op_list, bypass_cache=True)
|
||||
|
||||
self.operation = operation
|
||||
|
||||
self.dtype_A = GemmUniversalLauncher.numpy_type(operation.A.element)
|
||||
self.dtype_B = GemmUniversalLauncher.numpy_type(operation.B.element)
|
||||
self.dtype_C = GemmUniversalLauncher.numpy_type(operation.C.element)
|
||||
self.dtype_D = GemmUniversalLauncher.numpy_type(operation.C.element)
|
||||
|
||||
accumulator_size = DataTypeSize[
|
||||
operation.tile_description.math_instruction.element_accumulator
|
||||
]
|
||||
element_size = DataTypeSize[operation.A.element]
|
||||
|
||||
if element_size == 1:
|
||||
self.scope_max = 1
|
||||
self.scope_min = 0
|
||||
elif element_size <= 8:
|
||||
self.scope_max = 1
|
||||
self.scope_min = -1
|
||||
elif element_size == 16:
|
||||
self.scope_max = 4
|
||||
self.scope_min = -4
|
||||
else:
|
||||
self.scope_max = 8
|
||||
self.scope_min = -8
|
||||
|
||||
#: seed
|
||||
self.seed: int = seed
|
||||
|
||||
#: whether the layout is interleaved
|
||||
self.interleaved = interleaved
|
||||
|
||||
#: compute type
|
||||
self.compute_type = operation.epilogue_functor.element_epilogue
|
||||
self.accumulator_type = (
|
||||
operation.tile_description.math_instruction.element_accumulator
|
||||
)
|
||||
|
||||
def print_problem_size(self, p, mode, batch_count):
|
||||
if mode == cutlass_bindings.gemm.Mode.Gemm:
|
||||
mode = "Gemm"
|
||||
elif mode == cutlass_bindings.gemm.Mode.Batched:
|
||||
mode = "GemmBatched"
|
||||
elif mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel:
|
||||
mode = "GemmSplitKParallel"
|
||||
problem_size = "problem: %d, %d, %d\n batch_count: %d\n mode: %s" % (
|
||||
p.m(),
|
||||
p.n(),
|
||||
p.k(),
|
||||
batch_count,
|
||||
mode,
|
||||
)
|
||||
print(problem_size)
|
||||
|
||||
@staticmethod
|
||||
def numpy_type(type):
|
||||
if type == cutlass_bindings.float64:
|
||||
return np.float64
|
||||
elif type == cutlass_bindings.float32:
|
||||
return np.float32
|
||||
elif type == cutlass_bindings.float16:
|
||||
return np.float16
|
||||
elif type == cutlass_bindings.bfloat16:
|
||||
return bfloat16
|
||||
elif type == cutlass_bindings.int32:
|
||||
return np.int32
|
||||
elif type == cutlass_bindings.int8:
|
||||
return np.int8
|
||||
else:
|
||||
raise ValueError("unsupported type: %s" % ShortDataTypeNames[type])
|
||||
|
||||
def uniform_init(self, size, dtype):
|
||||
if dtype in [np.float32, np.float16, bfloat16, np.float64]:
|
||||
return np.ceil(
|
||||
np.random.uniform(
|
||||
low=self.scope_min - 0.5, high=self.scope_max - 0.5, size=size
|
||||
).astype(dtype)
|
||||
)
|
||||
else:
|
||||
return np.random.uniform(
|
||||
low=self.scope_min - 1, high=self.scope_max + 1, size=size
|
||||
).astype(dtype)
|
||||
|
||||
def reorder_tensor_B(self, tensor_B, problem_size):
|
||||
reordered_tensor_B = np.empty_like(tensor_B)
|
||||
tensor_ref_B = getTensorRef(
|
||||
tensor_B, problem_size, "b", self.operation.B.layout
|
||||
)
|
||||
reordered_tensor_ref_B = getTensorRef(
|
||||
reordered_tensor_B, problem_size, "b", self.operation.B.layout
|
||||
)
|
||||
cutlass_bindings.gemm.host.reorder_column(
|
||||
tensor_ref_B, reordered_tensor_ref_B, problem_size
|
||||
)
|
||||
return reordered_tensor_B
|
||||
|
||||
def host_reference(self, problem_size, batch_count, tensor_A, tensor_B, tensor_C, alpha, beta):
|
||||
tensor_D_ref = np.ones_like(tensor_C)
|
||||
alpha = self.numpy_type(self.compute_type)(alpha)
|
||||
beta = self.numpy_type(self.compute_type)(beta)
|
||||
init_acc = 0
|
||||
|
||||
alpha = self.compute_type(alpha).value()
|
||||
beta = self.compute_type(beta).value()
|
||||
init_acc = self.accumulator_type(init_acc).value()
|
||||
|
||||
for i in range(batch_count):
|
||||
if self.operation.switched:
|
||||
tensor_ref_A = getTensorRef(
|
||||
tensor_A,
|
||||
problem_size,
|
||||
"a",
|
||||
transpose(self.operation.B.layout),
|
||||
batch_offset=i,
|
||||
)
|
||||
tensor_ref_B = getTensorRef(
|
||||
tensor_B,
|
||||
problem_size,
|
||||
"b",
|
||||
transpose(self.operation.A.layout),
|
||||
batch_offset=i,
|
||||
)
|
||||
tensor_ref_C = getTensorRef(
|
||||
tensor_C,
|
||||
problem_size,
|
||||
"c",
|
||||
transpose(self.operation.C.layout),
|
||||
batch_offset=i,
|
||||
)
|
||||
tensor_ref_D_ref = getTensorRef(
|
||||
tensor_D_ref,
|
||||
problem_size,
|
||||
"d",
|
||||
transpose(self.operation.C.layout),
|
||||
batch_offset=i,
|
||||
)
|
||||
else:
|
||||
tensor_ref_A = getTensorRef(
|
||||
tensor_A, problem_size, "a", self.operation.A.layout, batch_offset=i
|
||||
)
|
||||
tensor_ref_B = getTensorRef(
|
||||
tensor_B, problem_size, "b", self.operation.B.layout, batch_offset=i
|
||||
)
|
||||
tensor_ref_C = getTensorRef(
|
||||
tensor_C, problem_size, "c", self.operation.C.layout, batch_offset=i
|
||||
)
|
||||
tensor_ref_D_ref = getTensorRef(
|
||||
tensor_D_ref,
|
||||
problem_size,
|
||||
"d",
|
||||
self.operation.C.layout,
|
||||
batch_offset=i,
|
||||
)
|
||||
|
||||
if self.math_operation in [MathOperation.multiply_add_saturate]:
|
||||
cutlass_bindings.test.gemm.host.gemm_saturate(
|
||||
problem_size,
|
||||
alpha,
|
||||
tensor_ref_A,
|
||||
tensor_ref_B,
|
||||
beta,
|
||||
tensor_ref_C,
|
||||
tensor_ref_D_ref,
|
||||
init_acc,
|
||||
)
|
||||
else:
|
||||
cutlass_bindings.test.gemm.host.gemm(
|
||||
problem_size,
|
||||
alpha,
|
||||
tensor_ref_A,
|
||||
tensor_ref_B,
|
||||
beta,
|
||||
tensor_ref_C,
|
||||
tensor_ref_D_ref,
|
||||
init_acc,
|
||||
)
|
||||
|
||||
return tensor_D_ref
|
||||
|
||||
def equal(self, tensor_D, tensor_D_ref, problem_size, batch_count):
|
||||
for i in range(batch_count):
|
||||
tensor_view_D = getTensorView(
|
||||
tensor_D, problem_size, "d", self.operation.C.layout, batch_offset=i
|
||||
)
|
||||
tensor_view_D_ref = getTensorView(
|
||||
tensor_D_ref, problem_size, "d", self.operation.C.layout, batch_offset=i
|
||||
)
|
||||
|
||||
if not cutlass_bindings.test.gemm.host.equals(
|
||||
tensor_view_D, tensor_view_D_ref
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def bytes(self, problem_size, batch_count=1, alpha=1.0, beta=0.0):
|
||||
m = problem_size.m()
|
||||
n = problem_size.n()
|
||||
k = problem_size.k()
|
||||
|
||||
bytes = (
|
||||
(DataTypeSize[self.operation.A.element] * m // 8) * k
|
||||
+ (DataTypeSize[self.operation.B.element] * n // 8) * k
|
||||
+ (DataTypeSize[self.operation.C.element] * m // 8) * n
|
||||
)
|
||||
|
||||
if beta != 0:
|
||||
bytes += (DataTypeSize[self.operation.C.element] * m // 8) * n
|
||||
|
||||
bytes *= batch_count
|
||||
|
||||
return bytes
|
||||
|
||||
def flops(self, problem_size, batch_count=1):
|
||||
m = problem_size.m()
|
||||
n = problem_size.n()
|
||||
k = problem_size.k()
|
||||
|
||||
flops_ = (m * n * k) * 2 * batch_count
|
||||
|
||||
return flops_
|
||||
|
||||
def run_cutlass_profiler(
|
||||
self, mode, problem_size, batch_count=1, alpha=1.0, beta=0.0
|
||||
):
|
||||
cutlass_path = os.getenv("CUTLASS_PATH")
|
||||
assert (
|
||||
cutlass_path is not None
|
||||
), "Environment variable 'CUTLASS_PATH' is not defined."
|
||||
|
||||
values = {
|
||||
"profiler_path": cutlass_path + "/build/tools/profiler/cutlass_profiler",
|
||||
"kernel_name": self.operation.procedural_name(),
|
||||
"verification_providers": "device",
|
||||
"provider": "cutlass",
|
||||
"m": str(problem_size.m()),
|
||||
"n": str(problem_size.n()),
|
||||
"k": str(problem_size.k()),
|
||||
"split_k_slices": str(batch_count),
|
||||
"alpha": str(alpha),
|
||||
"beta": str(beta),
|
||||
"warmup": str(self.warmup_iterations),
|
||||
"profile": str(self.iterations),
|
||||
}
|
||||
|
||||
cmd_template = (
|
||||
"${profiler_path} --kernels=${kernel_name} --verification-providers=${verification_providers}"
|
||||
" --providers=${provider} --m=${m} --n=${n} --k=${k}"
|
||||
)
|
||||
|
||||
cmd = SubstituteTemplate(cmd_template, values)
|
||||
result = subprocess.getoutput(cmd)
|
||||
|
||||
m = re.search(r"Runtime:\s+(?P<runtime>\d+.\d+)", result)
|
||||
runtime = float(m.group("runtime"))
|
||||
|
||||
m = re.search(r"Bytes:\s+(?P<bytes>\d+)", result)
|
||||
bytes = int(m.group("bytes"))
|
||||
|
||||
m = re.search(r"FLOPs:\s+(?P<flops>\d+)", result)
|
||||
flops = int(m.group("flops"))
|
||||
|
||||
# check if the problem size matches
|
||||
assert bytes == self.bytes(problem_size, alpha, beta)
|
||||
assert flops == self.flops(problem_size)
|
||||
|
||||
return runtime
|
||||
|
||||
def run(self, mode, problem_size, batch_count=1, split_k_slices=1, alpha=1.0, beta=0.0):
|
||||
assert get_allocated_size() == 0, (
|
||||
"%d byte of pool memory is not released in previous run"
|
||||
% get_allocated_size()
|
||||
)
|
||||
|
||||
np.random.seed(self.seed)
|
||||
|
||||
# Assign an actual batch count in cases where we are not running in batched mode.
|
||||
# This is to differentiate between the number of split K slices and the batch count,
|
||||
# which are overloaded within the single `batch_count` variable.
|
||||
true_batch_count = (
|
||||
batch_count if mode == cutlass_bindings.gemm.Mode.Batched else 1
|
||||
)
|
||||
|
||||
tensor_A = self.uniform_init(
|
||||
size=(problem_size.m() * problem_size.k() * true_batch_count,),
|
||||
dtype=self.dtype_A,
|
||||
)
|
||||
tensor_B = self.uniform_init(
|
||||
size=(problem_size.n() * problem_size.k() * true_batch_count,),
|
||||
dtype=self.dtype_B,
|
||||
)
|
||||
tensor_C = self.uniform_init(
|
||||
size=(problem_size.m() * problem_size.n() * true_batch_count,),
|
||||
dtype=self.dtype_C,
|
||||
)
|
||||
tensor_D = np.zeros(
|
||||
shape=(problem_size.m() * problem_size.n() * true_batch_count,),
|
||||
dtype=self.dtype_D,
|
||||
)
|
||||
|
||||
#
|
||||
# Launch kernel
|
||||
#
|
||||
|
||||
arguments = GemmArguments(
|
||||
operation=self.operation,
|
||||
problem_size=problem_size,
|
||||
A=tensor_A,
|
||||
B=tensor_B,
|
||||
C=tensor_C,
|
||||
D=tensor_D,
|
||||
output_op=self.operation.epilogue_type(alpha, beta),
|
||||
gemm_mode=mode,
|
||||
split_k_slices=split_k_slices,
|
||||
batch=batch_count,
|
||||
)
|
||||
|
||||
if mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel:
|
||||
reduction_arguments = ReductionArguments(
|
||||
self.reduction_operation,
|
||||
problem_size=[problem_size.m(), problem_size.n()],
|
||||
partitions=split_k_slices,
|
||||
workspace=arguments.ptr_D,
|
||||
destination=tensor_D,
|
||||
source=tensor_C,
|
||||
output_op=self.reduction_operation.epilogue_type(alpha, beta),
|
||||
)
|
||||
|
||||
self.operation.run(arguments)
|
||||
|
||||
if mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel:
|
||||
self.reduction_operation.run(reduction_arguments)
|
||||
|
||||
passed = True
|
||||
|
||||
if self.verification:
|
||||
if mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel:
|
||||
reduction_arguments.sync()
|
||||
else:
|
||||
arguments.sync()
|
||||
tensor_D_ref = self.host_reference(
|
||||
problem_size,
|
||||
true_batch_count,
|
||||
tensor_A,
|
||||
tensor_B,
|
||||
tensor_C,
|
||||
alpha,
|
||||
beta,
|
||||
)
|
||||
passed = self.equal(tensor_D, tensor_D_ref, problem_size, true_batch_count)
|
||||
|
||||
try:
|
||||
assert passed
|
||||
except AssertionError:
|
||||
self.print_problem_size(problem_size, mode, batch_count)
|
||||
|
||||
if self.profiling:
|
||||
sleep(self.sleep_time)
|
||||
for _ in range(self.warmup_iterations):
|
||||
self.operation.run(arguments)
|
||||
if mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel:
|
||||
self.reduction_operation.run(reduction_arguments)
|
||||
|
||||
self.timer.start()
|
||||
for _ in range(self.iterations):
|
||||
self.operation.run(arguments)
|
||||
if mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel:
|
||||
self.reduction_operation.run(reduction_arguments)
|
||||
self.timer.stop_and_wait()
|
||||
|
||||
runtime = self.timer.duration(self.iterations)
|
||||
|
||||
# free memory and clear buffers
|
||||
del arguments
|
||||
if mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel:
|
||||
del reduction_arguments
|
||||
|
||||
assert get_allocated_size() == 0, (
|
||||
"%d byte of pool memory is not released after current run"
|
||||
% get_allocated_size()
|
||||
)
|
||||
|
||||
if self.profiling:
|
||||
return runtime
|
||||
return passed
|
||||
|
||||
|
||||
def test_all_gemm(operation: "GemmOperationUniversal", testcase="universal", compilation_mode="nvcc"):
|
||||
passed = True
|
||||
|
||||
minimum_operand_element_size = min(
|
||||
DataTypeSize[operation.A.element], DataTypeSize[operation.B.element]
|
||||
)
|
||||
opcode_class = operation.tile_description.math_instruction.opcode_class
|
||||
|
||||
if opcode_class == cutlass_bindings.OpClass.Simt:
|
||||
alignment = 1
|
||||
else:
|
||||
alignment = 128 // minimum_operand_element_size
|
||||
|
||||
# int8_t gemm alignment constraints
|
||||
if opcode_class == cutlass_bindings.OpClass.Simt and operation.A.element == cutlass_bindings.int8 and operation.A.layout == cutlass_bindings.ColumnMajor:
|
||||
alignment_m = 4
|
||||
else:
|
||||
alignment_m = alignment
|
||||
|
||||
if (
|
||||
opcode_class == cutlass_bindings.OpClass.Simt
|
||||
and operation.B.element == cutlass_bindings.int8
|
||||
and operation.A.layout == cutlass_bindings.RowMajor
|
||||
):
|
||||
alignment_n = 4
|
||||
else:
|
||||
alignment_n = alignment
|
||||
|
||||
if (
|
||||
opcode_class == cutlass_bindings.OpClass.Simt
|
||||
and operation.A.element == cutlass_bindings.int8
|
||||
and operation.B.element == cutlass_bindings.int8
|
||||
and (
|
||||
operation.A.layout == cutlass_bindings.RowMajor
|
||||
or operation.B.layout == cutlass_bindings.ColumnMajor
|
||||
)
|
||||
):
|
||||
alignment_k = 4
|
||||
else:
|
||||
alignment_k = alignment
|
||||
|
||||
threadblock_k = operation.tile_description.threadblock_shape[2]
|
||||
|
||||
if testcase == "interleaved":
|
||||
if operation.A.layout in [
|
||||
cutlass_bindings.ColumnMajorInterleaved32,
|
||||
cutlass_bindings.RowMajorInterleaved32,
|
||||
]:
|
||||
interleavedk = 32
|
||||
else:
|
||||
raise ValueError("Unknown layout")
|
||||
|
||||
# Split K mode via Python is currently only supported pre-SM90, and when stream K is not used.
|
||||
# Stream K enables split-k functionality with mode `Gemm` and a non-unit batch count.
|
||||
supports_split_k = operation.arch < 90 and not isinstance(
|
||||
operation.swizzling_functor, cutlass_bindings.ThreadblockSwizzleStreamK
|
||||
)
|
||||
if testcase == "interleaved":
|
||||
modes = [
|
||||
cutlass_bindings.gemm.Mode.Gemm,
|
||||
]
|
||||
problem_size_m = [interleavedk, 512 + interleavedk]
|
||||
problem_size_n = [interleavedk, 512 + interleavedk]
|
||||
problem_size_k = [
|
||||
interleavedk,
|
||||
threadblock_k * operation.tile_description.stages + interleavedk,
|
||||
]
|
||||
problem_alpha = [1.0]
|
||||
problem_beta = [0.0]
|
||||
batch_counts = [
|
||||
1,
|
||||
]
|
||||
elif testcase == "multistage":
|
||||
modes = [
|
||||
cutlass_bindings.gemm.Mode.Gemm,
|
||||
]
|
||||
problem_size_m = [16, 528]
|
||||
problem_size_n = [16, 528]
|
||||
problem_size_k = [
|
||||
threadblock_k,
|
||||
threadblock_k * operation.tile_description.stages
|
||||
+ operation.tile_description.math_instruction.instruction_shape[2],
|
||||
]
|
||||
problem_alpha = [1.0]
|
||||
problem_beta = [0.0]
|
||||
batch_counts = [
|
||||
1,
|
||||
]
|
||||
else: # universal
|
||||
modes = [cutlass_bindings.gemm.Mode.Gemm]
|
||||
batch_counts = [1, 2, 3, 5, 7]
|
||||
if supports_split_k:
|
||||
modes.append(cutlass_bindings.gemm.Mode.GemmSplitKParallel)
|
||||
|
||||
problem_size_m = [alignment_m, 512 - 3 * alignment_m]
|
||||
problem_size_n = [alignment_n, 512 - 2 * alignment_n]
|
||||
if operation.tile_description.stages is None:
|
||||
stages_for_k_calc = 7
|
||||
else:
|
||||
stages_for_k_calc = operation.tile_description.stages
|
||||
problem_size_k = [
|
||||
alignment_k,
|
||||
threadblock_k * stages_for_k_calc - alignment_k,
|
||||
threadblock_k * stages_for_k_calc * 3 - alignment_k,
|
||||
]
|
||||
problem_alpha = [1.0]
|
||||
problem_beta = [2.0]
|
||||
|
||||
testbed = GemmUniversalLauncher(operation, interleaved=(testcase == "interleaved"), compiler_mode=compilation_mode)
|
||||
|
||||
for mode in modes:
|
||||
for m in problem_size_m:
|
||||
for n in problem_size_n:
|
||||
for k in problem_size_k:
|
||||
for batch_count in batch_counts:
|
||||
for alpha in problem_alpha:
|
||||
for beta in problem_beta:
|
||||
# skip very small K problems
|
||||
if testcase == "universal":
|
||||
if k // batch_count < 2 * threadblock_k:
|
||||
continue
|
||||
|
||||
problem_size = cutlass_bindings.gemm.GemmCoord(m, n, k)
|
||||
|
||||
if supports_split_k:
|
||||
split_k_slices = batch_count
|
||||
else:
|
||||
split_k_slices = 1
|
||||
|
||||
overridden_mode = mode
|
||||
if (
|
||||
mode == cutlass_bindings.gemm.Mode.Gemm
|
||||
and batch_count > 1
|
||||
):
|
||||
overridden_mode = cutlass_bindings.gemm.Mode.Batched
|
||||
|
||||
passed = testbed.run(
|
||||
overridden_mode,
|
||||
problem_size,
|
||||
batch_count,
|
||||
split_k_slices,
|
||||
alpha,
|
||||
beta,
|
||||
)
|
||||
|
||||
(err,) = cudart.cudaDeviceSynchronize()
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
|
||||
if not passed:
|
||||
return False
|
||||
|
||||
return passed
|
||||
@ -1,305 +0,0 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
import cutlass
|
||||
import cutlass_bindings
|
||||
|
||||
from cutlass import EpilogueScheduleSuffixes, KernelScheduleSuffixes
|
||||
from cutlass.utils.datatypes import binding_opclass, binding_type
|
||||
from cutlass.backend import library
|
||||
from cutlass.backend.test.gemm_testbed import test_all_gemm
|
||||
from cutlass.backend.utils.software import SubstituteTemplate
|
||||
|
||||
|
||||
class Layout:
|
||||
"""
|
||||
Utility class to map transpose and non-transpose terminology to row- and column-major terminology
|
||||
"""
|
||||
|
||||
T = cutlass_bindings.RowMajor
|
||||
N = cutlass_bindings.ColumnMajor
|
||||
|
||||
|
||||
class LayoutCombination:
|
||||
"""
|
||||
Utility class defining all combinations of row- and column-major layouts for operands to a GEMMs
|
||||
"""
|
||||
|
||||
NNN = (Layout.N, Layout.N, Layout.N)
|
||||
NNT = (Layout.N, Layout.N, Layout.T)
|
||||
NTN = (Layout.N, Layout.T, Layout.N)
|
||||
NTT = (Layout.N, Layout.T, Layout.T)
|
||||
TNN = (Layout.T, Layout.N, Layout.N)
|
||||
TNT = (Layout.T, Layout.N, Layout.T)
|
||||
TTN = (Layout.T, Layout.T, Layout.N)
|
||||
TTT = (Layout.T, Layout.T, Layout.T)
|
||||
|
||||
|
||||
def get_name(
|
||||
layouts,
|
||||
alignments,
|
||||
element_output,
|
||||
element_accumulator,
|
||||
element_epilogue,
|
||||
cluster_shape,
|
||||
threadblock_shape,
|
||||
stages,
|
||||
element_a,
|
||||
element_b,
|
||||
arch,
|
||||
opclass,
|
||||
kernel_schedule=None,
|
||||
epilogue_schedule=None,
|
||||
suffix="",
|
||||
):
|
||||
"""
|
||||
Generates a procedural name for a test case.
|
||||
|
||||
:param layouts: indexable container of layouts of A, B, and C operands
|
||||
:param alignments: indexable container of alignments of A, B, and C operands
|
||||
:param element_output: data type of the output element
|
||||
:param element_accumulator: data type used in accumulation
|
||||
:param element_epilogue: data type used in computing the epilogue
|
||||
:param cluster_shape: indexable container of dimensions of threadblock cluster to be launched
|
||||
:param threadblock_shape: indexable container of dimensions of threadblock tiles
|
||||
:param stages: number of pipeline stages to use in the kernel
|
||||
:type stages: int
|
||||
:param element_a: data type of operand A
|
||||
:param element_b: data type of operand B
|
||||
:param arch: compute capability of kernel being generated
|
||||
:type arch: int
|
||||
:param opclass: class of operation being performed (e.g., SIMT, Tensor Core)
|
||||
:type opclass: cutlass_bindings.OpClass
|
||||
:param kernel_schedule: kernel_schedule type
|
||||
:type kernel_schedule: cutlass.KernelScheduleType
|
||||
:param epilogue_schedule: epilogue_schedule type
|
||||
:type epilogue_schedule: cutlass.EpilogueScheduleType
|
||||
:param suffix: additional string to add to the suffix of the name
|
||||
:type suffix: str
|
||||
|
||||
:return: str
|
||||
"""
|
||||
name_format = "test_SM${arch}_Device_Gemm_${eA}${lA}_${eB}${lB}_${eC}${lC}_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${cM}x${cN}x${cK}_${stages}_align${aA}-${aB}-${aC}${k}${e}${suffix}"
|
||||
return SubstituteTemplate(
|
||||
name_format,
|
||||
{
|
||||
"arch": str(arch),
|
||||
"eA": library.DataTypeNames[binding_type(element_a)],
|
||||
"eB": library.DataTypeNames[binding_type(element_b)],
|
||||
"eC": library.DataTypeNames[binding_type(element_output)],
|
||||
"lA": library.ShortLayoutTypeNames[layouts[0]],
|
||||
"lB": library.ShortLayoutTypeNames[layouts[1]],
|
||||
"lC": library.ShortLayoutTypeNames[layouts[2]],
|
||||
"opclass": library.OpcodeClassNames[binding_opclass(opclass)],
|
||||
"acc": library.DataTypeNames[binding_type(element_accumulator)],
|
||||
"cM": str(cluster_shape[0]),
|
||||
"cN": str(cluster_shape[1]),
|
||||
"cK": str(cluster_shape[2]),
|
||||
"tbM": str(threadblock_shape[0]),
|
||||
"tbN": str(threadblock_shape[1]),
|
||||
"tbK": str(threadblock_shape[2]),
|
||||
"stages": str(stages) if stages is not None else "auto",
|
||||
"aA": str(alignments[0]),
|
||||
"aB": str(alignments[1]),
|
||||
"aC": str(alignments[2]),
|
||||
"k": "" if kernel_schedule is None else KernelScheduleSuffixes[kernel_schedule],
|
||||
"e": "" if epilogue_schedule is None else EpilogueScheduleSuffixes[epilogue_schedule],
|
||||
"suffix": "" if suffix is None else suffix,
|
||||
},
|
||||
)
|
||||
|
||||
def get_name_conv2d(
|
||||
arch,
|
||||
conv_kind,
|
||||
element,
|
||||
element_accumulator,
|
||||
element_output,
|
||||
opclass,
|
||||
threadblock_shape,
|
||||
warp_count,
|
||||
instruction_shape,
|
||||
stages,
|
||||
iterator_algorithm,
|
||||
swizzle,
|
||||
split_k_mode,
|
||||
split_k_slices,
|
||||
activation
|
||||
):
|
||||
"""
|
||||
Generates a procedural name for a test case for conv2d
|
||||
|
||||
:param arch: compute capability of kernel being generated
|
||||
:type arch: int
|
||||
:param conv_kind: the convolution type (i.e. fprop, dgrad, wgrad)
|
||||
:type conv_kind: str
|
||||
:param iterator_algorithm: the iterator algorithm applied
|
||||
:type iterator_algorithm: cutlass_bindings.conv.IteratorAlgorithm
|
||||
:param element_a: data type of operand A
|
||||
:param element_b: data type of operand B
|
||||
:param element_c: data type of operand C
|
||||
:param element_accumulator: data type used in accumulation
|
||||
:param opclass: class of operation being performed (e.g., SIMT, Tensor Core)
|
||||
:type opclass: cutlass_bindings.OpClass
|
||||
:param threadblock_shape: indexable container of dimensions of threadblock tiles
|
||||
:param stages: number of pipeline stages to use in the kernel
|
||||
:type stages: int
|
||||
:param stride_support: stride support of dgrad
|
||||
:param alignment: int
|
||||
:type alignment: int
|
||||
|
||||
:return: str
|
||||
"""
|
||||
if iterator_algorithm is None:
|
||||
iterator_algorithm = "AUTO"
|
||||
if swizzle is None:
|
||||
swizzle = 1
|
||||
name_format = "test_SM${arch}_Device_Conv2d_${conv_kind}_${iter_alg}_ImplicitGemm_${eA}nhwc_${eB}nhwc_${eC}nhwc_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${wM}x${wN}x${wK}_${IM}${IN}${IK}_stage${stages}_swizzle${swizzle}_${split_k_mode}${split_k_slices}_${activation}"
|
||||
|
||||
return SubstituteTemplate(
|
||||
name_format,
|
||||
{
|
||||
"arch": str(arch),
|
||||
"conv_kind": conv_kind,
|
||||
"iter_alg": iterator_algorithm,
|
||||
"eA": library.DataTypeNames[binding_type(element)],
|
||||
"eB": library.DataTypeNames[binding_type(element)],
|
||||
"eC": library.DataTypeNames[binding_type(element_output)],
|
||||
"opclass": opclass,
|
||||
"acc": library.DataTypeNames[binding_type(element_accumulator)],
|
||||
"tbM": str(threadblock_shape[0]),
|
||||
"tbN": str(threadblock_shape[1]),
|
||||
"tbK": str(threadblock_shape[2]),
|
||||
"wM": str(threadblock_shape[0] // warp_count[0]),
|
||||
"wN": str(threadblock_shape[1] // warp_count[1]),
|
||||
"wK": str(threadblock_shape[2] // warp_count[2]),
|
||||
"IM": str(instruction_shape[0]),
|
||||
"IN": str(instruction_shape[1]),
|
||||
"IK": str(instruction_shape[2]),
|
||||
"stages": str(stages),
|
||||
"swizzle": str(swizzle),
|
||||
"split_k_mode": split_k_mode,
|
||||
"split_k_slices": str(split_k_slices),
|
||||
"activation": activation
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def add_test_gemm(
|
||||
cls=None,
|
||||
cc=None,
|
||||
element=None,
|
||||
layouts=None,
|
||||
alignments=None,
|
||||
element_output=None,
|
||||
element_accumulator=None,
|
||||
cluster_shape=None,
|
||||
threadblock_shape=None,
|
||||
warp_count=None,
|
||||
stages=None,
|
||||
opclass=None,
|
||||
swizzle=None,
|
||||
kernel_schedule=None,
|
||||
epilogue_schedule=None,
|
||||
compilation_modes=['nvcc', 'nvrtc']):
|
||||
"""
|
||||
Create test-running functions with the given specification and set it as a method of ``cls``.
|
||||
|
||||
:param cls: class to which the generated method will be added
|
||||
:type cls: type
|
||||
:param cc: compute capability to compile for
|
||||
:type cc: int
|
||||
:param element: data type of A and B operands
|
||||
:type element: cutlass.DataType.f16
|
||||
:param layouts: layouts of A, B, and C operands
|
||||
:type layouts: list or tuple
|
||||
:param alignments: alingments of A, B, and C operands
|
||||
:type alignments: list or tuple
|
||||
:param element_output: data type of the output element
|
||||
:type element_output: cutlass.DataType
|
||||
:param element_accumulator: data type used in accumulation
|
||||
:type element_accumulator: cutlass.DataType
|
||||
:param cluster_shape: dimensions of clusters
|
||||
:type cluster_shape: list or tuple
|
||||
:param threadblock_shape: dimensions of threadblock tiles
|
||||
:type threadblock_shape: list or tuple
|
||||
:param warp_count: warps to be launched per threadblock dimension
|
||||
:type warp_count: list or tuple
|
||||
:param stages: number of pipeline stages to use in the kernel
|
||||
:type stages: int
|
||||
:param opclass: class of operation being performed (e.g., SIMT, Tensor Core)
|
||||
:type opclass: cutlass.OpClass
|
||||
:param swizzle: threadblock swizzling functor
|
||||
:param kernel_schedule: kernel schedule to use
|
||||
:type kernel_schedule: cutlass.KernelScheduleType
|
||||
:param epilogue_schedule: epilogue schedule to use
|
||||
:type epilogue_schedule: cutlass.EpilogueScheduleType
|
||||
:param compilation_modes: list of compilers to used in testing the kernel (options: 'nvrtc', 'nvcc')
|
||||
:type compilation_modes: list
|
||||
"""
|
||||
|
||||
for compilation_mode in compilation_modes:
|
||||
def run(self):
|
||||
"""
|
||||
Dynamically-generated function that constructs a GEMM operation and verifies it against
|
||||
multiple test cases.
|
||||
"""
|
||||
element_A = element
|
||||
element_B = element
|
||||
layout_A, layout_B, layout_C = layouts
|
||||
alignment_A, alignment_B, alignment_C = alignments
|
||||
|
||||
plan = cutlass.op.Gemm(element_A=element_A, element_B=element_B,
|
||||
element_C=element_output, element_D=element_output,
|
||||
layout_A=layout_A, layout_B=layout_B, layout_C=layout_C,
|
||||
element_accumulator=element_accumulator,
|
||||
kernel_cc=cc)
|
||||
|
||||
plan.opclass = opclass
|
||||
if swizzle is not None:
|
||||
plan.swizzling_functor = swizzle
|
||||
td = plan.tile_descriptions()[0]
|
||||
td.threadblock_shape = threadblock_shape
|
||||
td.stages = stages
|
||||
if warp_count is not None:
|
||||
td.warp_count = warp_count
|
||||
td.cluster_shape = cluster_shape
|
||||
op = plan.construct(tile_description=td, alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C)
|
||||
self.assertTrue(test_all_gemm(op, 'universal', compilation_mode=compilation_mode))
|
||||
|
||||
element_epilogue = element_accumulator
|
||||
name = get_name(
|
||||
layouts=layouts, alignments=alignments, element_output=element_output, element_accumulator=element_accumulator,
|
||||
element_epilogue=element_epilogue, cluster_shape=cluster_shape, threadblock_shape=threadblock_shape,
|
||||
stages=stages, element_a=element, element_b=element, arch=cc, opclass=opclass,
|
||||
kernel_schedule=kernel_schedule, epilogue_schedule=epilogue_schedule, suffix=f'_{compilation_mode}')
|
||||
|
||||
setattr(cls, name, run)
|
||||
@ -32,7 +32,6 @@
|
||||
|
||||
from cutlass.backend.utils.datatypes import *
|
||||
from cutlass.backend.utils.device import check_cuda_errors, device_cc
|
||||
from cutlass.backend.utils.reference_model import ReferenceModule
|
||||
from cutlass.backend.utils.software import (
|
||||
CheckPackages,
|
||||
SubstituteTemplate,
|
||||
|
||||
@ -34,8 +34,9 @@
|
||||
Utility functions for converting between frontend datatypes and CUTLASS datatypes
|
||||
"""
|
||||
|
||||
import cutlass_bindings
|
||||
from cuda import cuda
|
||||
|
||||
from cutlass import DataType
|
||||
from cutlass.backend.utils.software import CheckPackages
|
||||
|
||||
numpy_available = CheckPackages().check_numpy()
|
||||
@ -43,16 +44,16 @@ if numpy_available:
|
||||
import numpy as np
|
||||
|
||||
numpy_to_cutlass_dict = {
|
||||
np.float16: cutlass_bindings.float16,
|
||||
np.float32: cutlass_bindings.float32,
|
||||
np.float64: cutlass_bindings.float64,
|
||||
np.int8: cutlass_bindings.int8,
|
||||
np.int32: cutlass_bindings.int32,
|
||||
np.dtype('float16'): cutlass_bindings.float16,
|
||||
np.dtype('float32'): cutlass_bindings.float32,
|
||||
np.dtype('float64'): cutlass_bindings.float64,
|
||||
np.dtype('int8'): cutlass_bindings.int8,
|
||||
np.dtype('int32'): cutlass_bindings.int32,
|
||||
np.float16: DataType.f16,
|
||||
np.float32: DataType.f32,
|
||||
np.float64: DataType.f64,
|
||||
np.int8: DataType.s8,
|
||||
np.int32: DataType.s32,
|
||||
np.dtype('float16'): DataType.f16,
|
||||
np.dtype('float32'): DataType.f32,
|
||||
np.dtype('float64'): DataType.f64,
|
||||
np.dtype('int8'): DataType.s8,
|
||||
np.dtype('int32'): DataType.s32,
|
||||
}
|
||||
|
||||
|
||||
@ -67,9 +68,9 @@ if cupy_available:
|
||||
import cupy as cp
|
||||
|
||||
cupy_to_cutlass_dict = {
|
||||
cp.float16: cutlass_bindings.float16,
|
||||
cp.float32: cutlass_bindings.float32,
|
||||
cp.float64: cutlass_bindings.float64,
|
||||
cp.float16: DataType.f16,
|
||||
cp.float32: DataType.f32,
|
||||
cp.float64: DataType.f64,
|
||||
}
|
||||
|
||||
|
||||
@ -84,12 +85,12 @@ if torch_available:
|
||||
import torch
|
||||
|
||||
torch_to_cutlass_dict = {
|
||||
torch.half: cutlass_bindings.float16,
|
||||
torch.float16: cutlass_bindings.float16,
|
||||
torch.float: cutlass_bindings.float32,
|
||||
torch.float32: cutlass_bindings.float32,
|
||||
torch.double: cutlass_bindings.float64,
|
||||
torch.float64: cutlass_bindings.float64,
|
||||
torch.half: DataType.f16,
|
||||
torch.float16: DataType.f16,
|
||||
torch.float: DataType.f32,
|
||||
torch.float32: DataType.f32,
|
||||
torch.double: DataType.f64,
|
||||
torch.float64: DataType.f64,
|
||||
}
|
||||
|
||||
|
||||
@ -102,7 +103,7 @@ try:
|
||||
import bfloat16
|
||||
|
||||
bfloat16_available = True
|
||||
numpy_to_cutlass_dict[np.dtype(bfloat16.bfloat16)] = cutlass_bindings.bfloat16
|
||||
numpy_to_cutlass_dict[np.dtype(bfloat16.bfloat16)] = DataType.bf16
|
||||
except ImportError:
|
||||
bfloat16_available = False
|
||||
|
||||
@ -110,7 +111,7 @@ except ImportError:
|
||||
def bfloat16_to_cutlass(inp):
|
||||
if bfloat16_available:
|
||||
if inp == bfloat16.bfloat16:
|
||||
return cutlass_bindings.bfloat16
|
||||
return DataType.bf16
|
||||
|
||||
|
||||
def to_cutlass(inp):
|
||||
@ -127,3 +128,29 @@ def to_cutlass(inp):
|
||||
raise Exception(
|
||||
"No available conversion from type {} to a CUTLASS type.".format(inp)
|
||||
)
|
||||
|
||||
|
||||
def to_device_ptr(tensor) -> cuda.CUdeviceptr:
|
||||
"""
|
||||
Converts a tensor to a CUdeviceptr
|
||||
|
||||
:param tensor: tensor to convert
|
||||
:type tensor: np.ndarray | torch.Tensor | cp.ndarray | int
|
||||
|
||||
:return: device pointer
|
||||
:rtype: cuda.CUdeviceptr
|
||||
"""
|
||||
if isinstance(tensor, np.ndarray):
|
||||
ptr = cuda.CUdeviceptr(tensor.__array_interface__["data"][0])
|
||||
elif torch_available and isinstance(tensor, torch.Tensor):
|
||||
ptr = cuda.CUdeviceptr(tensor.data_ptr())
|
||||
elif cupy_available and isinstance(tensor, cp.ndarray):
|
||||
ptr = cuda.CUdeviceptr(int(tensor.data.ptr))
|
||||
elif isinstance(tensor, cuda.CUdeviceptr):
|
||||
ptr = tensor
|
||||
elif isinstance(tensor, int):
|
||||
ptr = cuda.CUdeviceptr(tensor)
|
||||
else:
|
||||
raise NotImplementedError(tensor)
|
||||
|
||||
return ptr
|
||||
|
||||
@ -1,317 +0,0 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from typing import Union
|
||||
|
||||
from bfloat16 import bfloat16
|
||||
import cutlass_bindings
|
||||
import numpy as np
|
||||
|
||||
from cutlass.backend.library import TensorDescription
|
||||
from cutlass.backend.utils.software import CheckPackages
|
||||
|
||||
torch_available = CheckPackages().check_torch()
|
||||
if torch_available:
|
||||
import torch
|
||||
|
||||
|
||||
class ReferenceModule:
|
||||
def __init__(
|
||||
self, A: TensorDescription, B: TensorDescription, C: TensorDescription
|
||||
) -> None:
|
||||
self.layout_A = A.layout
|
||||
self.layout_B = B.layout
|
||||
self.layout_C = C.layout
|
||||
|
||||
def run(
|
||||
self,
|
||||
A: np.ndarray,
|
||||
B: np.ndarray,
|
||||
C: np.ndarray,
|
||||
problem_size: cutlass_bindings.gemm.GemmCoord,
|
||||
alpha: float = 1.0,
|
||||
beta: float = 0.0,
|
||||
bias=False,
|
||||
batch=1,
|
||||
):
|
||||
"""
|
||||
Compute the reference result on CPU
|
||||
Args:
|
||||
A: dense operator with shape (M, K) in row-major and (K, M) in column-major
|
||||
B: dense operator with shape (K, N) in row-major and (N, K) in column-major
|
||||
C: dense operator with shape (M, N) in row-major and (N, M) in column-major
|
||||
"""
|
||||
M, N, K = problem_size.m(), problem_size.n(), problem_size.k()
|
||||
if isinstance(A, np.ndarray):
|
||||
if self.layout_A == cutlass_bindings.RowMajor:
|
||||
A_row = np.reshape(A, newshape=(batch, M, K))
|
||||
else:
|
||||
A_col = np.reshape(A, newshape=(batch, K, M))
|
||||
A_row = np.transpose(A_col, axes=(0, 2, 1))
|
||||
|
||||
if self.layout_B == cutlass_bindings.RowMajor:
|
||||
B_row = np.reshape(B, newshape=(batch, K, N))
|
||||
else:
|
||||
B_col = np.reshape(B, newshape=(batch, N, K))
|
||||
B_row = np.transpose(B_col, axes=(0, 2, 1))
|
||||
|
||||
if self.layout_C == cutlass_bindings.RowMajor:
|
||||
if bias:
|
||||
C_row = np.reshape(C, newshape=(batch, 1, N))
|
||||
else:
|
||||
C_row = np.reshape(C, newshape=(batch, M, N))
|
||||
else:
|
||||
if bias:
|
||||
C_row = np.reshape(C, newshape=(batch, M, 1))
|
||||
else:
|
||||
C_col = np.reshape(C, newshape=(batch, N, M))
|
||||
C_row = np.transpose(C_col, axes=(0, 2, 1))
|
||||
|
||||
if A_row.dtype == bfloat16:
|
||||
# numpy's einsum doesn't support bfloat16
|
||||
out_row = (
|
||||
np.einsum(
|
||||
"bik,bkj->bij",
|
||||
A_row.astype(np.float32),
|
||||
B_row.astype(np.float32),
|
||||
)
|
||||
* alpha
|
||||
+ C_row * beta
|
||||
)
|
||||
out_row = out_row.astype(C_row.dtype)
|
||||
else:
|
||||
out_row = np.einsum("bik,bkj->bij", A_row, B_row) * alpha + C_row * beta
|
||||
|
||||
if self.layout_C == cutlass_bindings.ColumnMajor:
|
||||
out = np.transpose(out_row, axes=(0, 2, 1))
|
||||
else:
|
||||
out = out_row
|
||||
|
||||
return out.ravel()
|
||||
|
||||
elif isinstance(A, torch.Tensor):
|
||||
if self.layout_A == cutlass_bindings.RowMajor:
|
||||
A_row = A.view((M, K))
|
||||
else:
|
||||
A_col = A.view((K, M))
|
||||
A_row = torch.permute(A_col, (1, 0))
|
||||
|
||||
if self.layout_B == cutlass_bindings.RowMajor:
|
||||
B_row = B.view((K, N))
|
||||
else:
|
||||
B_col = B.view((N, K))
|
||||
B_row = torch.permute(B_col, (1, 0))
|
||||
|
||||
if self.layout_C == cutlass_bindings.RowMajor:
|
||||
C_row = C.view((M, N))
|
||||
else:
|
||||
C_col = C.view((N, M))
|
||||
C_row = torch.permute(C_col, (1, 0))
|
||||
|
||||
out_row = torch.matmul(A_row, B_row) * alpha + C_row * beta
|
||||
|
||||
if self.layout_C == cutlass_bindings.ColumnMajor:
|
||||
out = torch.permute(out_row, (1, 0))
|
||||
else:
|
||||
out = out_row
|
||||
|
||||
return torch.flatten(out)
|
||||
|
||||
|
||||
#####################################################################################################
|
||||
# Conv2d
|
||||
#####################################################################################################
|
||||
|
||||
if torch_available:
|
||||
import torch
|
||||
|
||||
class Conv2dReferenceModule:
|
||||
def __init__(
|
||||
self,
|
||||
A: TensorDescription,
|
||||
B: TensorDescription,
|
||||
C: TensorDescription,
|
||||
kind: cutlass_bindings.conv.Operator.fprop,
|
||||
) -> None:
|
||||
self.layout_A = A.layout
|
||||
self.layout_B = B.layout
|
||||
self.layout_C = C.layout
|
||||
self.kind = kind
|
||||
|
||||
def run(
|
||||
self,
|
||||
A: Union[np.ndarray, torch.Tensor],
|
||||
B: Union[np.ndarray, torch.Tensor],
|
||||
C: Union[np.ndarray, torch.Tensor],
|
||||
problem_size,
|
||||
alpha=1.0,
|
||||
beta=0.0,
|
||||
bias=False,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute the reference result on CPU
|
||||
"""
|
||||
n = problem_size.N
|
||||
h = problem_size.H
|
||||
w = problem_size.W
|
||||
c = problem_size.C
|
||||
|
||||
k = problem_size.K
|
||||
r = problem_size.R
|
||||
s = problem_size.S
|
||||
|
||||
p = problem_size.P
|
||||
q = problem_size.Q
|
||||
|
||||
stride_h = problem_size.stride_h
|
||||
stride_w = problem_size.stride_w
|
||||
|
||||
pad_h = problem_size.pad_h
|
||||
pad_w = problem_size.pad_w
|
||||
|
||||
dilation_h = problem_size.dilation_h
|
||||
dilation_w = problem_size.dilation_w
|
||||
|
||||
groups = problem_size.groups
|
||||
|
||||
if isinstance(A, np.ndarray):
|
||||
# the pytorch activation layout is NCHW
|
||||
# weight layout is Cout Cin Kh Kw (also NCHW)
|
||||
if self.layout_A == cutlass_bindings.TensorNHWC:
|
||||
A_nhwc = np.reshape(A, newshape=(n, h, w, c))
|
||||
A_torch_nhwc = torch.from_numpy(A_nhwc).to("cuda")
|
||||
A_torch_nchw = torch.permute(A_torch_nhwc, (0, 3, 1, 2))
|
||||
|
||||
if self.layout_B == cutlass_bindings.TensorNHWC:
|
||||
B_nhwc = np.reshape(B, newshape=(k, r, s, c))
|
||||
B_torch_nhwc = torch.from_numpy(B_nhwc).to("cuda")
|
||||
B_torch_nchw = torch.permute(B_torch_nhwc, (0, 3, 1, 2))
|
||||
|
||||
if self.layout_C == cutlass_bindings.TensorNHWC:
|
||||
C_nhwc = np.reshape(C, newshape=(n, p, q, k))
|
||||
C_torch_nhwc = torch.from_numpy(C_nhwc).to("cuda")
|
||||
C_torch_nchw = torch.permute(C_torch_nhwc, (0, 3, 1, 2))
|
||||
|
||||
elif isinstance(A, torch.Tensor):
|
||||
if self.kind == cutlass_bindings.conv.Operator.wgrad:
|
||||
if self.layout_A == cutlass_bindings.TensorNHWC:
|
||||
A_nhwc = A.view((n, p, q, k))
|
||||
A_torch_nchw = torch.permute(A_nhwc, (0, 3, 1, 2))
|
||||
|
||||
if self.layout_B == cutlass_bindings.TensorNHWC:
|
||||
B_nhwc = B.view((n, h, w, c))
|
||||
B_torch_nchw = torch.permute(B_nhwc, (0, 3, 1, 2))
|
||||
|
||||
if self.layout_C == cutlass_bindings.TensorNHWC:
|
||||
if bias:
|
||||
C_nhwc = C.view((1, 1, 1, c))
|
||||
else:
|
||||
C_nhwc = C.view((k, r, s, c))
|
||||
C_torch_nchw = torch.permute(C_nhwc, (0, 3, 1, 2))
|
||||
elif self.kind == cutlass_bindings.conv.Operator.dgrad:
|
||||
if self.layout_A == cutlass_bindings.TensorNHWC:
|
||||
A_nhwc = A.view((n, p, q, k))
|
||||
A_torch_nchw = torch.permute(A_nhwc, (0, 3, 1, 2))
|
||||
|
||||
if self.layout_B == cutlass_bindings.TensorNHWC:
|
||||
B_nhwc = B.view((k, r, s, c))
|
||||
B_torch_nchw = torch.permute(B_nhwc, (0, 3, 1, 2))
|
||||
|
||||
if self.layout_C == cutlass_bindings.TensorNHWC:
|
||||
if bias:
|
||||
C_nhwc = C.view((1, 1, 1, c))
|
||||
else:
|
||||
C_nhwc = C.view((n, h, w, c))
|
||||
C_torch_nchw = torch.permute(C_nhwc, (0, 3, 1, 2))
|
||||
else:
|
||||
if self.layout_A == cutlass_bindings.TensorNHWC:
|
||||
A_nhwc = A.view((n, h, w, c))
|
||||
A_torch_nchw = torch.permute(A_nhwc, (0, 3, 1, 2))
|
||||
|
||||
if self.layout_B == cutlass_bindings.TensorNHWC:
|
||||
B_nhwc = B.view((k, r, s, c))
|
||||
B_torch_nchw = torch.permute(B_nhwc, (0, 3, 1, 2))
|
||||
|
||||
if self.layout_C == cutlass_bindings.TensorNHWC:
|
||||
if bias:
|
||||
C_nhwc = C.view((1, 1, 1, k))
|
||||
else:
|
||||
C_nhwc = C.view((n, p, q, k))
|
||||
C_torch_nchw = torch.permute(C_nhwc, (0, 3, 1, 2))
|
||||
|
||||
if self.kind == cutlass_bindings.conv.Operator.fprop:
|
||||
D_torch_nchw = (
|
||||
alpha
|
||||
* torch.nn.functional.conv2d(
|
||||
A_torch_nchw,
|
||||
B_torch_nchw,
|
||||
stride=(stride_h, stride_w),
|
||||
padding=(pad_h, pad_w),
|
||||
dilation=(dilation_h, dilation_w),
|
||||
groups=groups,
|
||||
)
|
||||
+ beta * C_torch_nchw
|
||||
)
|
||||
elif self.kind == cutlass_bindings.conv.Operator.dgrad:
|
||||
D_torch_nchw = (
|
||||
alpha
|
||||
* torch.nn.grad.conv2d_input(
|
||||
(n, c, h, w),
|
||||
B_torch_nchw,
|
||||
A_torch_nchw,
|
||||
padding=(pad_h, pad_w),
|
||||
stride=(stride_h, stride_w),
|
||||
).to(torch.float32)
|
||||
+ beta * C_torch_nchw
|
||||
)
|
||||
elif self.kind == cutlass_bindings.conv.Operator.wgrad:
|
||||
D_torch_nchw = (
|
||||
alpha
|
||||
* torch.nn.grad.conv2d_weight(
|
||||
B_torch_nchw,
|
||||
(k, c, r, s),
|
||||
A_torch_nchw,
|
||||
padding=(pad_h, pad_w),
|
||||
stride=(stride_h, stride_w),
|
||||
).to(torch.float32)
|
||||
+ beta * C_torch_nchw
|
||||
)
|
||||
|
||||
if self.layout_C == cutlass_bindings.TensorNHWC:
|
||||
if isinstance(A, np.ndarray):
|
||||
D_torch_out = (
|
||||
torch.permute(D_torch_nchw, (0, 2, 3, 1)).detach().cpu().numpy()
|
||||
)
|
||||
elif isinstance(A, torch.Tensor):
|
||||
D_torch_out = torch.permute(D_torch_nchw, (0, 2, 3, 1))
|
||||
|
||||
return D_torch_out.flatten()
|
||||
@ -1,3 +1,6 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
@ -85,10 +88,7 @@ def SubstituteTemplate(template, values):
|
||||
return text
|
||||
|
||||
|
||||
# this._device_sm_count = None
|
||||
def device_sm_count():
|
||||
# Query the number of SMs, if needed
|
||||
# if this._device_sm_count is None:
|
||||
from cuda import cuda
|
||||
|
||||
_device = 0
|
||||
|
||||
@ -1,75 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief In-memory compiled artifact cache
|
||||
*/
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
struct CompileCache {
|
||||
public:
|
||||
CompileCache() = default;
|
||||
~CompileCache() = default;
|
||||
|
||||
using Cache = std::unordered_map<std::string, py::object>;
|
||||
|
||||
/// Check if the kernel has already been compiled
|
||||
py::object at(const std::string &kernel) {
|
||||
auto item = cache_.find(kernel);
|
||||
|
||||
if (item != cache_.end()) {
|
||||
return item->second;
|
||||
}
|
||||
return py::none();
|
||||
}
|
||||
|
||||
/// Insert a new compiled kernel for new configuration
|
||||
void insert(const std::string &kernel, const py::object &compiled_kernel){
|
||||
cache_.emplace(kernel, compiled_kernel);
|
||||
}
|
||||
|
||||
const int64_t size() const { return cache_.size(); }
|
||||
|
||||
/// Clear the cache
|
||||
void clear() { cache_.clear(); }
|
||||
|
||||
private:
|
||||
Cache cache_;
|
||||
};
|
||||
|
||||
} // namespace cutlass
|
||||
@ -1,182 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief binding CUTLASS C++ APIs to Python
|
||||
*/
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "builtin_types.h"
|
||||
#include "device_launch_parameters.h"
|
||||
#include "stddef.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "include/conv/convolution.h"
|
||||
#include "include/gemm/gemm.h"
|
||||
#include "include/types.h"
|
||||
#include "include/layout/layout.h"
|
||||
#include "include/tensor_coord.h"
|
||||
#include "include/arch.h"
|
||||
#include "include/tensor_ref_view.h"
|
||||
#include "include/swizzling.h"
|
||||
#include "test/conv/convolution.h"
|
||||
#include "test/gemm/gemm.h"
|
||||
|
||||
|
||||
// Data Types
|
||||
#include "library.h"
|
||||
|
||||
// compiler
|
||||
#include "compiler.h"
|
||||
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
||||
PYBIND11_MODULE(cutlass_bindings, m) {
|
||||
|
||||
// module doc
|
||||
m.doc() = "CUTLASS C++ binding";
|
||||
|
||||
//
|
||||
// Bind data type
|
||||
//
|
||||
bind_cutlass_types(m);
|
||||
|
||||
//
|
||||
// Bind layout
|
||||
//
|
||||
bind_layout(m);
|
||||
|
||||
//
|
||||
// Bind tensor coord
|
||||
//
|
||||
bind_tensor_coord(m);
|
||||
|
||||
//
|
||||
// Bind tensor ref
|
||||
//
|
||||
bind_tensor_refs_and_views(m);
|
||||
|
||||
//
|
||||
// Bind opcode
|
||||
//
|
||||
bind_opcode(m);
|
||||
|
||||
//
|
||||
// Bind convolution
|
||||
//
|
||||
py::module_ conv_submodule = m.def_submodule("conv");
|
||||
bind_convolution(conv_submodule);
|
||||
|
||||
//
|
||||
// Bind gemm
|
||||
//
|
||||
py::module_ gemm_submodule = m.def_submodule("gemm");
|
||||
bind_gemm(gemm_submodule);
|
||||
|
||||
//
|
||||
// Bind swizzling
|
||||
//
|
||||
bind_threadblock_swizzle(m);
|
||||
|
||||
|
||||
//
|
||||
// Bind test units
|
||||
//
|
||||
py::module_ test = m.def_submodule("test");
|
||||
py::module_ test_conv = test.def_submodule("conv");
|
||||
bind_convolution_test(test_conv);
|
||||
|
||||
py::module_ test_gemm = test.def_submodule("gemm");
|
||||
bind_gemm_test(test_gemm);
|
||||
|
||||
// data types
|
||||
py::enum_<cutlass::DataType>(m, "dtype")
|
||||
.value("b1", cutlass::DataType::kB1)
|
||||
.value("u2", cutlass::DataType::kU2)
|
||||
.value("u4", cutlass::DataType::kU4)
|
||||
.value("u8", cutlass::DataType::kU8)
|
||||
.value("u16", cutlass::DataType::kU16)
|
||||
.value("u32", cutlass::DataType::kU32)
|
||||
.value("u64", cutlass::DataType::kU64)
|
||||
.value("s2", cutlass::DataType::kS2)
|
||||
.value("s4", cutlass::DataType::kS4)
|
||||
.value("s16", cutlass::DataType::kS16)
|
||||
.value("s64", cutlass::DataType::kS64)
|
||||
.value("cf16", cutlass::DataType::kCF16)
|
||||
.value("cbf16", cutlass::DataType::kCBF16)
|
||||
.value("cf32", cutlass::DataType::kCF32)
|
||||
.value("ctf32", cutlass::DataType::kCTF32)
|
||||
.value("cf64", cutlass::DataType::kCF64)
|
||||
.value("cs2", cutlass::DataType::kCS2)
|
||||
.value("cs4", cutlass::DataType::kCS4)
|
||||
.value("cs8", cutlass::DataType::kCS8)
|
||||
.value("cs16", cutlass::DataType::kCS16)
|
||||
.value("cs32", cutlass::DataType::kCS32)
|
||||
.value("cs64", cutlass::DataType::kCS64)
|
||||
.value("cu2", cutlass::DataType::kCU2)
|
||||
.value("cu4", cutlass::DataType::kCU4)
|
||||
.value("cu8", cutlass::DataType::kCU8)
|
||||
.value("cu16", cutlass::DataType::kCU16)
|
||||
.value("cu32", cutlass::DataType::kCU32)
|
||||
.value("cu64", cutlass::DataType::kCU64)
|
||||
.value("invalid", cutlass::DataType::kInvalid);
|
||||
|
||||
// layout types
|
||||
py::enum_<cutlass::LayoutType>(m, "layout")
|
||||
.value("ColumnMajorInterleaved2", cutlass::LayoutType::kColumnMajorInterleaved2)
|
||||
.value("RowMajorInterleaved2", cutlass::LayoutType::kRowMajorInterleaved2)
|
||||
.value("ColumnMajorInterleaved64", cutlass::LayoutType::kColumnMajorInterleaved64)
|
||||
.value("RowMajorInterleaved64", cutlass::LayoutType::kRowMajorInterleaved64)
|
||||
.value("TensorNDHWC", cutlass::LayoutType::kTensorNDHWC)
|
||||
.value("TensorNCHW", cutlass::LayoutType::kTensorNCHW)
|
||||
.value("TensorNGHWC", cutlass::LayoutType::kTensorNGHWC)
|
||||
.value("TensorNC64HW64", cutlass::LayoutType::kTensorNC64HW64)
|
||||
.value("TensorC64RSK64", cutlass::LayoutType::kTensorC64RSK64);
|
||||
|
||||
// transform types
|
||||
py::enum_<cutlass::ComplexTransform>(m, "complex_transform")
|
||||
.value("none", cutlass::ComplexTransform::kNone)
|
||||
.value("conj", cutlass::ComplexTransform::kConjugate);
|
||||
|
||||
//
|
||||
// Compiler
|
||||
//
|
||||
py::class_<cutlass::CompileCache>(m, "CompileCache")
|
||||
.def(py::init<>())
|
||||
.def("at", &cutlass::CompileCache::at)
|
||||
.def("insert", &cutlass::CompileCache::insert)
|
||||
.def("size", &cutlass::CompileCache::size)
|
||||
.def("clear", &cutlass::CompileCache::clear);
|
||||
|
||||
}
|
||||
@ -1,59 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind opcode classes to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "cutlass/arch/mma.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace cutlass {
|
||||
enum class OpcodeClass {
|
||||
kSimt, kTensorOp, kWmmaTensorOp, kSparseTensorOp
|
||||
};
|
||||
}
|
||||
|
||||
void bind_opcode(py::module &m) {
|
||||
py::enum_<cutlass::OpcodeClass>(m, "OpClass",
|
||||
R"pbdoc(classification of math operators)pbdoc")
|
||||
.value("Simt", cutlass::OpcodeClass::kSimt,
|
||||
R"pbdoc(Tag classifying math operators as thread-level operations)pbdoc")
|
||||
.value("TensorOp", cutlass::OpcodeClass::kTensorOp,
|
||||
R"pbdoc(Tag classifying operators as Tensor Core operations)pbdoc")
|
||||
.value("WmmaTensorOp", cutlass::OpcodeClass::kWmmaTensorOp,
|
||||
R"pbdoc(Tag classifying operators as WMMA Tensor Core operations)pbdoc")
|
||||
.value("SparseTensorOp", cutlass::OpcodeClass::kSparseTensorOp,
|
||||
R"pbdoc(Tag classifying operators as sparseTensor Core operations)pbdoc");
|
||||
}
|
||||
@ -1,102 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind Convolution problem sizes to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void bind_conv_problem_size(py::module &m) {
|
||||
//
|
||||
// Conv2d Problem Size:
|
||||
// include/cutlass/conv/conv2d_problem_size.h
|
||||
//
|
||||
py::class_<cutlass::conv::Conv2dProblemSize>(m, "Conv2dProblemSize")
|
||||
// constructors
|
||||
.def(py::init<int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, cutlass::conv::Mode, int, int>())
|
||||
.def(py::init<cutlass::Tensor4DCoord, cutlass::Tensor4DCoord, cutlass::Tensor4DCoord, cutlass::MatrixCoord, cutlass::MatrixCoord, cutlass::conv::Mode, int, int>())
|
||||
// attribute accessors
|
||||
.def_readwrite("N", &cutlass::conv::Conv2dProblemSize::N)
|
||||
.def_readwrite("H", &cutlass::conv::Conv2dProblemSize::H)
|
||||
.def_readwrite("W", &cutlass::conv::Conv2dProblemSize::W)
|
||||
.def_readwrite("C", &cutlass::conv::Conv2dProblemSize::C)
|
||||
.def_readwrite("P", &cutlass::conv::Conv2dProblemSize::P)
|
||||
.def_readwrite("Q", &cutlass::conv::Conv2dProblemSize::Q)
|
||||
.def_readwrite("K", &cutlass::conv::Conv2dProblemSize::K)
|
||||
.def_readwrite("R", &cutlass::conv::Conv2dProblemSize::R)
|
||||
.def_readwrite("S", &cutlass::conv::Conv2dProblemSize::S)
|
||||
.def_readwrite("pad_h", &cutlass::conv::Conv2dProblemSize::pad_h)
|
||||
.def_readwrite("pad_w", &cutlass::conv::Conv2dProblemSize::pad_w)
|
||||
.def_readwrite("stride_h", &cutlass::conv::Conv2dProblemSize::stride_h)
|
||||
.def_readwrite("stride_w", &cutlass::conv::Conv2dProblemSize::stride_w)
|
||||
.def_readwrite("dilation_h", &cutlass::conv::Conv2dProblemSize::dilation_h)
|
||||
.def_readwrite("dilation_w", &cutlass::conv::Conv2dProblemSize::dilation_w)
|
||||
.def_readwrite("mode", &cutlass::conv::Conv2dProblemSize::mode)
|
||||
.def_readwrite("split_k_slices", &cutlass::conv::Conv2dProblemSize::split_k_slices)
|
||||
.def_readwrite("groups", &cutlass::conv::Conv2dProblemSize::groups)
|
||||
// functions
|
||||
.def("reset_split_k_slices", &cutlass::conv::Conv2dProblemSize::reset_split_k_slices)
|
||||
.def("activation_extent", &cutlass::conv::Conv2dProblemSize::activation_extent)
|
||||
.def("filter_extent", &cutlass::conv::Conv2dProblemSize::filter_extent)
|
||||
.def("output_extent", &cutlass::conv::Conv2dProblemSize::output_extent)
|
||||
.def("activation_size", &cutlass::conv::Conv2dProblemSize::activation_size)
|
||||
.def("filter_size", &cutlass::conv::Conv2dProblemSize::filter_size)
|
||||
.def("output_size", &cutlass::conv::Conv2dProblemSize::output_size);
|
||||
|
||||
// Get tensor size
|
||||
m.def("implicit_gemm_tensor_a_size", py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&>(&cutlass::conv::implicit_gemm_tensor_a_size));
|
||||
m.def("implicit_gemm_tensor_b_size", py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&>(&cutlass::conv::implicit_gemm_tensor_b_size));
|
||||
m.def("implicit_gemm_tensor_c_size", py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&>(&cutlass::conv::implicit_gemm_tensor_c_size));
|
||||
|
||||
// Get tensor extent
|
||||
m.def("implicit_gemm_tensor_a_extent",
|
||||
py::overload_cast<
|
||||
cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&
|
||||
>(&cutlass::conv::implicit_gemm_tensor_a_extent));
|
||||
|
||||
m.def("implicit_gemm_tensor_b_extent",
|
||||
py::overload_cast<
|
||||
cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&
|
||||
>(&cutlass::conv::implicit_gemm_tensor_b_extent));
|
||||
|
||||
m.def("implicit_gemm_tensor_c_extent",
|
||||
py::overload_cast<
|
||||
cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&
|
||||
>(&cutlass::conv::implicit_gemm_tensor_c_extent));
|
||||
|
||||
m.def("implicit_gemm_problem_size", py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize &>(&cutlass::conv::implicit_gemm_problem_size));
|
||||
|
||||
}
|
||||
@ -1,91 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind convolution related enum types to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "conv_problem_size.h"
|
||||
#include "host.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void bind_convolution(py::module &m) {
|
||||
//
|
||||
// Enumerate types
|
||||
// cutlass/include/cutlass/conv/convolution.h
|
||||
//
|
||||
|
||||
/// Convolutional operator
|
||||
py::enum_<cutlass::conv::Operator>(m, "Operator", R"pbdoc(Convolutional operator)pbdoc")
|
||||
.value("fprop", cutlass::conv::Operator::kFprop, "Forward propagation")
|
||||
.value("dgrad", cutlass::conv::Operator::kDgrad, "Activation grad")
|
||||
.value("wgrad", cutlass::conv::Operator::kWgrad, "Weight grad");
|
||||
|
||||
/// Distinguishes convolution from cross correlation
|
||||
py::enum_<cutlass::conv::Mode>(m, "Mode")
|
||||
.value("cross_correlation", cutlass::conv::Mode::kCrossCorrelation)
|
||||
.value("convolution", cutlass::conv::Mode::kConvolution);
|
||||
|
||||
/// Selects among several implementation variants trading off performance with simplicity
|
||||
py::enum_<cutlass::conv::IteratorAlgorithm>(m, "IteratorAlgorithm",
|
||||
R"pbdoc(Selects among several implementation variants trading off performance with simplicity)pbdoc")
|
||||
.value("analytic", cutlass::conv::IteratorAlgorithm::kAnalytic, R"pbdoc(functionally correct in all cases but lower performance)pbdoc")
|
||||
.value("optimized", cutlass::conv::IteratorAlgorithm::kOptimized, R"pbdoc(optimized for R <= 32, S <= 32 and unity-stride dgrad)pbdoc")
|
||||
.value("fixed_channels", cutlass::conv::IteratorAlgorithm::kFixedChannels, R"pbdoc(Analytic algorithm optimized for fixed channel count (C == AccessSize))pbdoc")
|
||||
.value("few_channels", cutlass::conv::IteratorAlgorithm::kFewChannels, R"pbdoc(Analytic algorithm optimized for few channels (C divisible by AccessSize))pbdoc");
|
||||
|
||||
/// Distinguishes among partial specializations that accelerate certain problems where convolution
|
||||
/// stride is unit.
|
||||
py::enum_<cutlass::conv::StrideSupport>(m, "StrideSupport",
|
||||
R"pbdoc(Distinguishes among partial specializations that accelerate certain problems where convolution
|
||||
stride is unit.)pbdoc")
|
||||
.value("strided", cutlass::conv::StrideSupport::kStrided, R"pbdoc(arbitrary convolution stride)pbdoc")
|
||||
.value("unity", cutlass::conv::StrideSupport::kUnity, R"pbdoc(unit convolution stride)pbdoc");
|
||||
|
||||
/// Identifies split-K mode
|
||||
py::enum_<cutlass::conv::SplitKMode>(m, "SplitKMode")
|
||||
.value("None", cutlass::conv::SplitKMode::kNone)
|
||||
.value("Serial", cutlass::conv::SplitKMode::kSerial)
|
||||
.value("Parallel", cutlass::conv::SplitKMode::kParallel);
|
||||
|
||||
// Conv problem sizes
|
||||
bind_conv_problem_size(m);
|
||||
|
||||
//
|
||||
// host helper functions
|
||||
//
|
||||
py::module_ host_submodule = m.def_submodule("host");
|
||||
bind_conv_host_helper(host_submodule);
|
||||
}
|
||||
@ -1,54 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind conv host helpers to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "cutlass/util/host_reorder.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
||||
void bind_conv_host_helper(py::module &m) {
|
||||
|
||||
/// reorder operand B for interleaved layout
|
||||
m.def("reorder_convK", [](
|
||||
cutlass::TensorRef<int8_t, cutlass::layout::TensorCxRSKx<32>> dest,
|
||||
cutlass::TensorRef<int8_t, cutlass::layout::TensorCxRSKx<32>> src,
|
||||
cutlass::conv::Operator conv_op, const cutlass::conv::Conv2dProblemSize & problem_size) {
|
||||
cutlass::gemm::GemmCoord implicit_problem_size = cutlass::conv::implicit_gemm_problem_size(conv_op, problem_size);
|
||||
cutlass::reorder_convK<32>(dest, src, implicit_problem_size);
|
||||
});
|
||||
}
|
||||
@ -1,222 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A generic wrapper around an epilogue visitor operation
|
||||
*/
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_complex.h"
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h"
|
||||
|
||||
#include "epilogue_visitor_op/visitor_op_linear_combination.h"
|
||||
#include "epilogue_visitor_op/visitor_op_tensor_input.h"
|
||||
#include "epilogue_visitor_op/visitor_op_accumulator.h"
|
||||
#include "epilogue_visitor_op/visitor_op_row_broadcast.h"
|
||||
#include "epilogue_visitor_op/visitor_op_tensor_output.h"
|
||||
#include "epilogue_visitor_op/visitor_op_column_reduction.h"
|
||||
#include "epilogue_visitor_op/visitor_op_row_reduction.h"
|
||||
#include "epilogue_visitor_op/visitor_op_column_broadcast.h"
|
||||
#include "epilogue_visitor_op/visitor_op_unary.h"
|
||||
#include "epilogue_visitor_op/visitor_op_binary.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Generic Epilogue Visitor.
|
||||
template <
|
||||
typename OutputOp_
|
||||
>
|
||||
class EpilogueVisitorGeneric {
|
||||
public:
|
||||
|
||||
using OutputOp = OutputOp_;
|
||||
using AccumulatorAccessType = typename OutputOp::AccumulatorAccessType;
|
||||
static int const kElementsPerAccess = OutputOp::kElementsPerAccess;
|
||||
using ElementOutput = typename OutputOp::ElementOutput;
|
||||
using OutputTileIterator = typename OutputOp::OutputTileIterator;
|
||||
|
||||
static int const kIterations = OutputTileIterator::kIterations;
|
||||
|
||||
///
|
||||
/// End Epilogue Tree
|
||||
///
|
||||
|
||||
/// Additional SMEM bufer is not required in the broadcast epilogue visitor
|
||||
struct SharedStorage {
|
||||
|
||||
typename OutputOp::SharedStorage output_smem;
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() { }
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
typename OutputOp::Arguments output_op_args;
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
Arguments() { }
|
||||
|
||||
Arguments(
|
||||
typename OutputOp::Arguments output_op_args
|
||||
):
|
||||
output_op_args(output_op_args)
|
||||
{
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
struct Params {
|
||||
typename OutputOp::Params output_op_params;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
output_op_params(args.output_op_args)
|
||||
{
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
private:
|
||||
|
||||
OutputOp output_op;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
EpilogueVisitorGeneric(
|
||||
Params const ¶ms, ///< Parameters routed to the epilogue
|
||||
SharedStorage &shared_storage, ///< Shared storage needed by the functors here
|
||||
MatrixCoord threadblock_offset,
|
||||
gemm::GemmCoord threadblock_tile_offset,
|
||||
int thread_idx,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
output_op(params.output_op_params, shared_storage.output_smem, thread_idx, threadblock_offset, problem_size)
|
||||
{ }
|
||||
|
||||
/// Helper to indicate split-K behavior
|
||||
CUTLASS_DEVICE
|
||||
void set_k_partition(
|
||||
int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
|
||||
int split_k_slices) { ///< Total number of split-K slices
|
||||
|
||||
}
|
||||
|
||||
/// Called to set the batch index
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
output_op.set_batch_index(batch_idx);
|
||||
}
|
||||
|
||||
/// Called at the start of the epilogue just before iterating over accumulator slices
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
output_op.begin_epilogue();
|
||||
}
|
||||
|
||||
/// Called at the start of one step before starting accumulator exchange
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
output_op.begin_step(step_idx);
|
||||
}
|
||||
|
||||
/// Called at the start of a row
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
output_op.begin_row(row_idx);
|
||||
}
|
||||
|
||||
/// Called after accumulators have been exchanged for each accumulator vector
|
||||
CUTLASS_DEVICE
|
||||
void visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum) {
|
||||
output_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
}
|
||||
|
||||
/// Called at the start of a row
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
output_op.end_row(row_idx);
|
||||
|
||||
}
|
||||
|
||||
/// Called after all accumulator elements have been visited
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
output_op.end_step(step_idx);
|
||||
}
|
||||
|
||||
/// Called after all steps have been completed
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
output_op.end_epilogue();
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,84 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the binary ops
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
/// Scalar multiplication
|
||||
template <typename T, int N>
|
||||
struct VectorAdd {
|
||||
|
||||
struct Arguments {
|
||||
int tmp;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():tmp(0){ }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(int tmp): tmp(tmp) { }
|
||||
};
|
||||
|
||||
struct Params {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args) { }
|
||||
};
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VectorAdd(
|
||||
Params const ¶ms
|
||||
) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
|
||||
cutlass::plus<Array<T, N>> add_op;
|
||||
return add_op(lhs, rhs);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,233 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the unary ops
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
/// Scalar multiplication
|
||||
template <typename T, int N>
|
||||
struct Mult {
|
||||
|
||||
struct Arguments {
|
||||
T alpha;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():alpha(T(1.0)){ }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(T alpha): alpha(alpha) { }
|
||||
};
|
||||
|
||||
struct Params {
|
||||
T alpha; ///< scales accumulators
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():alpha(T(1.0)){ }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args): alpha(args.alpha) { }
|
||||
};
|
||||
|
||||
T alpha_;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Mult(
|
||||
Params const ¶ms
|
||||
):
|
||||
alpha_(params.alpha)
|
||||
{ }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &source) const {
|
||||
cutlass::multiplies<Array<T, N>> multiply_op;
|
||||
return multiply_op(source, alpha_);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool guard() {
|
||||
return alpha_ != T(0);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
/// ReLU
|
||||
template <typename T, int N>
|
||||
struct ReLUVisitor {
|
||||
struct Arguments {
|
||||
T threshold;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():threshold(T(0.0)) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(T threshold): threshold(threshold) { }
|
||||
};
|
||||
|
||||
struct Params {
|
||||
T threshold;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():threshold(T(0.0)) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args): threshold(args.threshold) { }
|
||||
};
|
||||
|
||||
T threshold_;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
ReLUVisitor(Params const ¶ms):
|
||||
threshold_(params.threshold) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &frag) const {
|
||||
maximum<Array<T, N>> mx;
|
||||
return mx(frag, threshold_);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool guard() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
/// leakyReLU
|
||||
template <typename T, int N>
|
||||
struct LeakyReLUVisitor {
|
||||
struct Arguments {
|
||||
T leaky_alpha;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():leaky_alpha(T(0.0)) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(T leaky_alpha): leaky_alpha(leaky_alpha) { }
|
||||
};
|
||||
|
||||
struct Params {
|
||||
T leaky_alpha;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():leaky_alpha(T(0.0)) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args): leaky_alpha(args.leaky_alpha) { }
|
||||
};
|
||||
|
||||
T leaky_alpha_;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
LeakyReLUVisitor(Params const ¶ms):
|
||||
leaky_alpha_(params.leaky_alpha) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &frag) const {
|
||||
cutlass::epilogue::thread::LeakyReLU<Array<T, N>> leaky_op;
|
||||
return leaky_op(frag, leaky_alpha_);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool guard() {
|
||||
return true;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/// Tanh
|
||||
template <typename T, int N>
|
||||
struct TanhVisitor {
|
||||
/// Argument
|
||||
struct Arguments {
|
||||
// a placeholder argument to ensure correctness of ctypes
|
||||
int tmp;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): tmp(0) { };
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(int tmp): tmp(tmp) { };
|
||||
};
|
||||
|
||||
/// Param
|
||||
struct Params {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(){ };
|
||||
Params(Arguments const &args) { }
|
||||
};
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TanhVisitor(Params const ¶ms) { }
|
||||
|
||||
// scalar operator
|
||||
CUTLASS_HOST_DEVICE
|
||||
T tanh_op(T const &scalar) const {
|
||||
return fast_tanh(scalar);
|
||||
}
|
||||
|
||||
/// vector operator
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &frag) const {
|
||||
Array<T, N> y;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i=0; i < N; ++i) {
|
||||
y[i] = tanh_op(frag[i]);
|
||||
}
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool guard() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,148 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with accumulator
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue Visitor operator for the following Computation
|
||||
///
|
||||
/// ElementAccumulator accum;
|
||||
/// return accum;
|
||||
///
|
||||
/// It can only be the leaf node of the epilogue tree
|
||||
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
int kElementsPerAccess_ ///< Number of elements computed per operation
|
||||
>
|
||||
class VisitorOpAccumulator{
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
static int const kElementsPerAccess = kElementsPerAccess_;
|
||||
|
||||
/// Fragment type for Accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type returned by this visitor
|
||||
using VisitAccessType = AccumulatorAccessType;
|
||||
|
||||
/// SMEM buffer class required in the epilogue visitor
|
||||
struct SharedStorage {
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() {}
|
||||
};
|
||||
|
||||
/// Host-constructable Arguments structure
|
||||
struct Arguments {
|
||||
// Note: it is strange that ctypes will return issue with empty arguments
|
||||
int tmp;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(int tmp): tmp(tmp) { }
|
||||
};
|
||||
|
||||
/// Parameter structure
|
||||
struct Params {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args) { }
|
||||
};
|
||||
|
||||
public:
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpAccumulator(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
return accum;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() { }
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,245 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with Binary op
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "binary_ops.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementCompute alpha;
|
||||
/// ElementCompute beta;
|
||||
/// ElementCompute C = BinaryOp(alpha * ElementCompute(Visitor_A), beta * ElementCompute(Visitor_B)
|
||||
/// Return C;
|
||||
///
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename ElementCompute_, ///< Data type used to compute linear combination
|
||||
int kElementsPerAccess_, ///< Number of elements computed per operation
|
||||
typename VisitorA_, ///< Child node A
|
||||
typename VisitorB_, ///< Child node B
|
||||
template<typename T, int N> typename BinaryOp_
|
||||
>
|
||||
class VisitorOpBinary{
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
static int const kElementsPerAccess = kElementsPerAccess_;
|
||||
|
||||
using VisitorA = VisitorA_;
|
||||
using VisitorB = VisitorB_;
|
||||
|
||||
/// Fragment type returned from VisitorA.visit
|
||||
using VisitAccessTypeA = typename VisitorA::VisitAccessType;
|
||||
using ElementA = typename VisitAccessTypeA::Element;
|
||||
|
||||
/// Fragment type returned from VisitorB.visit
|
||||
using VisitAccessTypeB = typename VisitorB::VisitAccessType;
|
||||
using ElementB = typename VisitAccessTypeB::Element;
|
||||
|
||||
/// Fragment type returned by this visitor
|
||||
using VisitAccessType = Array<ElementCompute, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
using BinaryOp = BinaryOp_<ElementCompute, kElementsPerAccess>;
|
||||
|
||||
static_assert(kElementsPerAccess==VisitAccessTypeA::kElements, "kElementsPerAccess mismatches with Visitor A");
|
||||
static_assert(kElementsPerAccess==VisitAccessTypeB::kElements, "kElementsPerAccess mismatches with Visitor B");
|
||||
|
||||
/// SMEM buffer class required in the epilogue visitor
|
||||
struct SharedStorage {
|
||||
typename VisitorA::SharedStorage storage_a;
|
||||
typename VisitorB::SharedStorage storage_b;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() {}
|
||||
};
|
||||
|
||||
|
||||
/// Host-constructable Arguments structure
|
||||
struct Arguments {
|
||||
typename BinaryOp::Arguments binary_arg;
|
||||
typename VisitorA::Arguments visitor_a_arg; ///< Argument type for visitor_a
|
||||
typename VisitorB::Arguments visitor_b_arg; ///< Argument type for visitor_b
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():binary_arg() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
typename BinaryOp::Arguments binary_arg,
|
||||
typename VisitorA::Arguments visitor_a_arg,
|
||||
typename VisitorB::Arguments visitor_b_arg
|
||||
):
|
||||
binary_arg(binary_arg),
|
||||
visitor_a_arg(visitor_a_arg),
|
||||
visitor_b_arg(visitor_b_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
/// Parameter structure
|
||||
struct Params {
|
||||
typename BinaryOp::Params binary_param;
|
||||
typename VisitorA::Params visitor_a_param; ///< Argument type for visitor_a
|
||||
typename VisitorB::Params visitor_b_param; ///< Argument type for visitor_b
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
binary_param(args.binary_arg),
|
||||
visitor_a_param(args.visitor_a_arg),
|
||||
visitor_b_param(args.visitor_b_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
BinaryOp binary_op;
|
||||
|
||||
VisitorA visitor_a_op;
|
||||
VisitorB visitor_b_op;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpBinary(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
binary_op(params.binary_param),
|
||||
visitor_a_op(params.visitor_a_param, shared_storage.storage_a, thread_idx, threadblock_offset, problem_size),
|
||||
visitor_b_op(params.visitor_b_param, shared_storage.storage_b, thread_idx, threadblock_offset, problem_size)
|
||||
{ }
|
||||
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
visitor_a_op.begin_epilogue();
|
||||
visitor_b_op.begin_epilogue();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
visitor_a_op.set_batch_index(batch_idx);
|
||||
visitor_b_op.set_batch_index(batch_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
visitor_a_op.begin_step(step_idx);
|
||||
visitor_b_op.begin_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
visitor_a_op.begin_row(row_idx);
|
||||
visitor_b_op.begin_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
/// Get result from visitor A and visitor B
|
||||
VisitAccessTypeA result_A = visitor_a_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
VisitAccessTypeB result_B = visitor_b_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
|
||||
/// Type conversion
|
||||
NumericArrayConverter<ElementCompute, ElementA, kElementsPerAccess> source_converter_A;
|
||||
NumericArrayConverter<ElementCompute, ElementB, kElementsPerAccess> source_converter_B;
|
||||
|
||||
return binary_op(
|
||||
source_converter_A(result_A),
|
||||
source_converter_B(result_B)
|
||||
);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
visitor_a_op.end_row(row_idx);
|
||||
visitor_b_op.end_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
visitor_a_op.end_step(step_idx);
|
||||
visitor_b_op.end_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
visitor_a_op.end_epilogue();
|
||||
visitor_b_op.end_epilogue();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,250 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with broadcasting vector to all columns
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementVector T[i][j] <- device-memory Td[i]
|
||||
///
|
||||
/// It can only be a leaf node in the epilogue tree
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename ElementFragment_, ///< Data type used to cache vector in register
|
||||
typename InputTileIterator_ ///< Tile iterator type to read the broadcasted tensor
|
||||
>
|
||||
class VisitorOpColumnBroadcast {
|
||||
public:
|
||||
using InputTileIterator = InputTileIterator_;
|
||||
|
||||
static int const kElementsPerAccess = InputTileIterator::kElementsPerAccess;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementVector = typename InputTileIterator::Element;
|
||||
using ElementFragment = ElementFragment_;
|
||||
|
||||
using VisitAccessType = Array<ElementFragment, kElementsPerAccess>;
|
||||
|
||||
/// Thread map used by input tile iterators
|
||||
using ThreadMap = typename InputTileIterator::ThreadMap;
|
||||
|
||||
/// Fragment object used to store the broadcast values
|
||||
using BroadcastFragment = Array<
|
||||
ElementFragment, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Used for the broadcast
|
||||
struct BroadcastDetail {
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = 32;
|
||||
|
||||
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar column indices handled by each thread
|
||||
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar row indices handled by each thread
|
||||
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
|
||||
|
||||
/// Number of threads per threadblock
|
||||
static int const kThreadCount = ThreadMap::kThreads;
|
||||
|
||||
/// Number of distinct threads per row of output tile
|
||||
static int const kThreadsPerRow = (InputTileIterator::Shape::kN / kColumnsPerThread);
|
||||
|
||||
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock.
|
||||
static int const kThreadRows = kThreadCount / kThreadsPerRow;
|
||||
|
||||
// /// Number of iterations (accesses) the threadblock takes to reduce a row
|
||||
// static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount);
|
||||
};
|
||||
|
||||
// using ComputeFragmentType = Array<ElementVector, BroadcastDetail::kElementsPerAccess>;
|
||||
|
||||
struct SharedStorage {
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() { }
|
||||
};
|
||||
|
||||
/// Host-constructable Argument structure
|
||||
struct Arguments {
|
||||
ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand
|
||||
int64_t batch_stride;
|
||||
|
||||
/// Methods
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():
|
||||
broadcast_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ElementVector *broadcast_ptr,
|
||||
int64_t batch_stride
|
||||
):
|
||||
broadcast_ptr(broadcast_ptr),
|
||||
batch_stride(batch_stride) { }
|
||||
};
|
||||
|
||||
/// Param structure
|
||||
struct Params {
|
||||
ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand
|
||||
int64_t batch_stride;
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
broadcast_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
broadcast_ptr(args.broadcast_ptr),
|
||||
batch_stride(args.batch_stride) { }
|
||||
};
|
||||
|
||||
private:
|
||||
ElementVector *broadcast_ptr;
|
||||
BroadcastFragment broadcast_fragment; ///< Array holds the loaded broadcast fragment
|
||||
MatrixCoord threadblock_offset_;
|
||||
int thread_idx_;
|
||||
MatrixCoord problem_size;
|
||||
|
||||
int thread_start_row_;
|
||||
int state_[3];
|
||||
int thread_offset_row_;
|
||||
|
||||
int64_t batch_stride_;
|
||||
|
||||
public:
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpColumnBroadcast(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
broadcast_ptr(params.broadcast_ptr),
|
||||
threadblock_offset_(threadblock_offset),
|
||||
thread_idx_(thread_idx),
|
||||
problem_size(problem_size),
|
||||
thread_start_row_(ThreadMap::initial_offset(thread_idx).row() + threadblock_offset.row()),
|
||||
batch_stride_(params.batch_stride)
|
||||
{
|
||||
state_[0] = state_[1] = state_[2] = 0;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
broadcast_ptr += batch_idx * batch_stride_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
// get pointer
|
||||
thread_offset_row_ = thread_start_row_ + ThreadMap::iteration_offset(frag_idx).row();
|
||||
|
||||
ElementFragment broadcast_data = ElementFragment(*(broadcast_ptr + thread_offset_row_));
|
||||
|
||||
broadcast_fragment.fill(broadcast_data);
|
||||
|
||||
return broadcast_fragment;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
// run operator ++
|
||||
++state_[0];
|
||||
|
||||
thread_start_row_ += ThreadMap::Shape::kRow;
|
||||
if (state_[0] == ThreadMap::Count::kRow) {
|
||||
state_[0] = 0;
|
||||
++state_[1];
|
||||
thread_start_row_ += (ThreadMap::Shape::kGroup - 1) *
|
||||
ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
|
||||
|
||||
if (state_[1] == ThreadMap::Count::kGroup) {
|
||||
state_[1] = 0;
|
||||
++state_[2];
|
||||
thread_start_row_ += ThreadMap::Count::kGroup *
|
||||
ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;
|
||||
|
||||
if (state_[2] == ThreadMap::Count::kCluster) {
|
||||
state_[2] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() { }
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,341 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with reduction over columns in CTA
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementReductionAccumulator R[j] = \sum_i ElementReductionAccumulator(T[i][j])
|
||||
/// device memory <- ElementReduction(R[j])
|
||||
///
|
||||
template <
|
||||
typename ThreadblockShape_, /// Threadblock shape
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename ElementReduction_, ///< Data type of the output reduction in device memory
|
||||
typename ElementReductionAccumulator_ , ///< Data type to accumulate reduction in smem and register
|
||||
typename OutputTileIterator_, ///< Tile Iterator type
|
||||
typename Visitor_ ///< preceding visitor op
|
||||
>
|
||||
class VisitorOpColumnReduction {
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementReductionAccumulator = ElementReductionAccumulator_;
|
||||
using ElementReduction = ElementReduction_;
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
using Visitor = Visitor_;
|
||||
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
using ReductionOp = cutlass::plus<Array<ElementReductionAccumulator, kElementsPerAccess>>;
|
||||
using ReductionOpScalar = cutlass::plus<ElementReductionAccumulator>;
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
|
||||
|
||||
|
||||
/// Fragment type returned from Visitor
|
||||
using VisitAccessTypeVisitor = typename Visitor::VisitAccessType;
|
||||
using ElementVisitor = typename VisitAccessTypeVisitor::Element;
|
||||
|
||||
using VisitAccessType = VisitAccessTypeVisitor;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of reduction
|
||||
using ReductionAccumulatorAccessType = Array<ElementReductionAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Thread map used by output tile iterators
|
||||
using ThreadMap = typename OutputTileIterator::ThreadMap;
|
||||
/// Used for the reduction
|
||||
struct ReductionDetail {
|
||||
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = 32;
|
||||
|
||||
/// Number of distinct scalar column indices handled by each thread
|
||||
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar row indices handled by each thread
|
||||
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
|
||||
|
||||
/// Number of threads per threadblock
|
||||
static int const kThreadCount = ThreadMap::kThreads;
|
||||
|
||||
/// Number of distinct threads per row of output tile
|
||||
static int const kThreadsPerRow = ThreadblockShape::kN / kColumnsPerThread;
|
||||
|
||||
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock
|
||||
static int const kThreadRows = kThreadCount / kThreadsPerRow;
|
||||
|
||||
/// Number of iterations (accesses) the threadblock takes to reduce a row
|
||||
static int const kThreadAccessesPerRow = const_max(1, (ThreadblockShape::kN + kThreadCount - 1) / kThreadCount);
|
||||
|
||||
using StorageShape = MatrixShape<
|
||||
kThreadRows,
|
||||
ThreadblockShape::kN
|
||||
>;
|
||||
};
|
||||
|
||||
using ReductionFragment = Array<ElementReductionAccumulator, ReductionDetail::kColumnsPerThread>;
|
||||
|
||||
/// Shared storage
|
||||
struct SharedStorage {
|
||||
typename Visitor::SharedStorage storage_visitor;
|
||||
AlignedArray<ElementReductionAccumulator, ReductionDetail::StorageShape::kCount, 16> reduction;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() {}
|
||||
};
|
||||
|
||||
/// Host-constructable Argument structure
|
||||
struct Arguments {
|
||||
ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory
|
||||
int64_t batch_stride;
|
||||
typename Visitor::Arguments visitor_arg; ///< Argument type of visitor
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): reduction_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ElementReduction *reduction_ptr,
|
||||
int64_t batch_stride,
|
||||
typename Visitor::Arguments visitor_arg
|
||||
):
|
||||
reduction_ptr(reduction_ptr),
|
||||
batch_stride(batch_stride),
|
||||
visitor_arg(visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
/// Param structure
|
||||
struct Params {
|
||||
ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory
|
||||
int64_t batch_stride;
|
||||
typename Visitor::Params visitor_param; ///< Argument type of visitor
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): reduction_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
reduction_ptr(args.reduction_ptr),
|
||||
batch_stride(args.batch_stride),
|
||||
visitor_param(args.visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
private:
|
||||
ElementReduction *reduction_output_ptr_; ///< Pointer to the reduction tensor in device memory
|
||||
ElementReductionAccumulator *reduction_smem_ptr_; ///< Pointer to the partial reductions in shared memory
|
||||
ReductionFragment reduction_fragment; ///< register fragments that hold the partial reduction
|
||||
Visitor visitor_; ///< visitor
|
||||
int thread_idx_;
|
||||
MatrixCoord threadblock_offset;
|
||||
MatrixCoord problem_size_;
|
||||
int64_t batch_stride_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpColumnReduction(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
visitor_(params.visitor_param, shared_storage.storage_visitor,
|
||||
thread_idx, threadblock_offset, problem_size),
|
||||
reduction_smem_ptr_(shared_storage.reduction.data()),
|
||||
reduction_output_ptr_(params.reduction_ptr),
|
||||
thread_idx_(thread_idx),
|
||||
threadblock_offset(threadblock_offset),
|
||||
problem_size_(problem_size),
|
||||
batch_stride_(params.batch_stride)
|
||||
{ }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
reduction_output_ptr_ += batch_idx * batch_stride_;
|
||||
visitor_.set_batch_index(batch_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
visitor_.begin_epilogue();
|
||||
|
||||
// clear the reduction fragment
|
||||
reduction_fragment.clear();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
visitor_.begin_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
visitor_.begin_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
/// Get result from visitor
|
||||
VisitAccessTypeVisitor result = visitor_.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
|
||||
NumericArrayConverter<ElementReductionAccumulator, ElementVisitor, kElementsPerAccess> reduction_converter;
|
||||
ReductionOp reduction_op;
|
||||
ReductionAccumulatorAccessType* reduction_fragment_ = reinterpret_cast<ReductionAccumulatorAccessType*>(&reduction_fragment);
|
||||
reduction_fragment_[column_idx] = reduction_op(reduction_fragment_[column_idx], reduction_converter(result));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
visitor_.end_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
visitor_.end_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
visitor_.end_epilogue();
|
||||
//
|
||||
// Store the partially reduced value to SMEM
|
||||
//
|
||||
|
||||
// Guard against uses of the existing SMEM tile
|
||||
__syncthreads();
|
||||
|
||||
using AccessType = AlignedArray<ElementReductionAccumulator, ThreadMap::kElementsPerAccess>;
|
||||
|
||||
//
|
||||
// Determine a compact thread arrangement to store to SMEM
|
||||
//
|
||||
|
||||
MatrixCoord thread_offset(
|
||||
thread_idx_ / ReductionDetail::kThreadsPerRow,
|
||||
(thread_idx_ % ReductionDetail::kThreadsPerRow) * ThreadMap::kElementsPerAccess
|
||||
);
|
||||
|
||||
//
|
||||
// Each thread store its fragment to a SMEM
|
||||
//
|
||||
AccessType *aligned_reduction_ptr = reinterpret_cast<AccessType *>(
|
||||
&reduction_smem_ptr_[thread_offset.row() * ThreadblockShape::kN + thread_offset.column()]
|
||||
);
|
||||
|
||||
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(
|
||||
&reduction_fragment
|
||||
);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
|
||||
int col_idx = column * ThreadMap::Delta::kColumn / ThreadMap::kElementsPerAccess;
|
||||
|
||||
aligned_reduction_ptr[col_idx] = frag_ptr[column];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
// Now, threads are assigned several columns of the output. The fetch over all rows from
|
||||
// the compacted SMEM tile and perform a reduction.
|
||||
//
|
||||
|
||||
NumericConverter<ElementReduction, ElementReductionAccumulator> output_converter;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < ReductionDetail::kThreadAccessesPerRow; ++j) {
|
||||
int column_idx = thread_idx_ + j * ReductionDetail::kThreadCount;
|
||||
|
||||
ReductionOpScalar reduction_op;
|
||||
ElementReductionAccumulator reduction_element = ElementReductionAccumulator();
|
||||
|
||||
int output_column_idx = threadblock_offset.column() + column_idx;
|
||||
|
||||
if (column_idx < ThreadblockShape::kN && output_column_idx < problem_size_.column()) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int row = 0; row < ReductionDetail::kThreadRows; ++row) {
|
||||
if (row) {
|
||||
auto frag = reduction_smem_ptr_[row * ThreadblockShape::kN + column_idx];
|
||||
reduction_element = reduction_op(reduction_element, frag);
|
||||
}
|
||||
else {
|
||||
|
||||
reduction_element = reduction_smem_ptr_[column_idx];
|
||||
}
|
||||
}
|
||||
|
||||
// Store
|
||||
reduction_output_ptr_[column_idx + threadblock_offset.column() + threadblock_offset.row() / ThreadblockShape::kM * problem_size_.column()] = output_converter(reduction_element);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,266 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with Linear Combination
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementCompute alpha;
|
||||
/// ElementCompute beta;
|
||||
/// ElementCompute C = BinaryOp(alpha * ElementCompute(Visitor_A), beta * ElementCompute(Visitor_B)
|
||||
/// Return C;
|
||||
///
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename ElementCompute_, ///< Data type used to compute linear combination
|
||||
int kElementsPerAccess_, ///< Number of elements computed per operation
|
||||
typename VisitorA_, ///< Child node A
|
||||
typename VisitorB_ ///< Child node B
|
||||
>
|
||||
class VisitorOpLinearCombination{
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
static int const kElementsPerAccess = kElementsPerAccess_;
|
||||
|
||||
using VisitorA = VisitorA_;
|
||||
using VisitorB = VisitorB_;
|
||||
|
||||
/// Fragment type returned from VisitorA.visit
|
||||
using VisitAccessTypeA = typename VisitorA::VisitAccessType;
|
||||
using ElementA = typename VisitAccessTypeA::Element;
|
||||
|
||||
/// Fragment type returned from VisitorB.visit
|
||||
using VisitAccessTypeB = typename VisitorB::VisitAccessType;
|
||||
using ElementB = typename VisitAccessTypeB::Element;
|
||||
|
||||
/// Fragment type returned by this visitor
|
||||
using VisitAccessType = Array<ElementCompute, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Combination Op
|
||||
using CombinationOp = cutlass::plus<VisitAccessType>;
|
||||
|
||||
static_assert(kElementsPerAccess==VisitAccessTypeA::kElements, "kElementsPerAccess mismatches with Visitor A");
|
||||
static_assert(kElementsPerAccess==VisitAccessTypeB::kElements, "kElementsPerAccess mismatches with Visitor B");
|
||||
|
||||
/// SMEM buffer class required in the epilogue visitor
|
||||
struct SharedStorage {
|
||||
typename VisitorA::SharedStorage storage_a;
|
||||
typename VisitorB::SharedStorage storage_b;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() {}
|
||||
};
|
||||
|
||||
|
||||
/// Host-constructable Arguments structure
|
||||
struct Arguments {
|
||||
ElementCompute alpha; ///< scales accumulators
|
||||
ElementCompute beta; ///< scales source tensor
|
||||
typename VisitorA::Arguments visitor_a_arg; ///< Argument type for visitor_a
|
||||
typename VisitorB::Arguments visitor_b_arg; ///< Argument type for visitor_b
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():
|
||||
alpha(ElementCompute(1)),
|
||||
beta(ElementCompute(0))
|
||||
{ }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta,
|
||||
typename VisitorA::Arguments visitor_a_arg,
|
||||
typename VisitorB::Arguments visitor_b_arg
|
||||
):
|
||||
alpha(alpha),
|
||||
beta(beta),
|
||||
visitor_a_arg(visitor_a_arg),
|
||||
visitor_b_arg(visitor_b_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
/// Parameter structure
|
||||
struct Params {
|
||||
ElementCompute alpha; ///< scales accumulators
|
||||
ElementCompute beta; ///< scales source tensor
|
||||
typename VisitorA::Params visitor_a_param; ///< Argument type for visitor_a
|
||||
typename VisitorB::Params visitor_b_param; ///< Argument type for visitor_b
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
alpha(args.alpha),
|
||||
beta(args.beta),
|
||||
visitor_a_param(args.visitor_a_arg),
|
||||
visitor_b_param(args.visitor_b_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
ElementCompute alpha_;
|
||||
ElementCompute beta_;
|
||||
|
||||
VisitorA visitor_a_op;
|
||||
VisitorB visitor_b_op;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpLinearCombination(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
alpha_(params.alpha),
|
||||
beta_(params.beta),
|
||||
visitor_a_op(params.visitor_a_param, shared_storage.storage_a, thread_idx, threadblock_offset, problem_size),
|
||||
visitor_b_op(params.visitor_b_param, shared_storage.storage_b, thread_idx, threadblock_offset, problem_size)
|
||||
{ }
|
||||
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
if (alpha_ != ElementCompute(0)) visitor_a_op.begin_epilogue();
|
||||
if (beta_ != ElementCompute(0)) visitor_b_op.begin_epilogue();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
if (alpha_ != ElementCompute(0)) visitor_a_op.begin_step(step_idx);
|
||||
if (beta_ != ElementCompute(0)) visitor_b_op.begin_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
if (alpha_ != ElementCompute(0)) visitor_a_op.begin_row(row_idx);
|
||||
if (beta_ != ElementCompute(0)) visitor_b_op.begin_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
/// Get result from visitor A and visitor B
|
||||
VisitAccessTypeA result_A;
|
||||
VisitAccessTypeB result_B;
|
||||
|
||||
if (alpha_ != ElementCompute(0)) {
|
||||
result_A = visitor_a_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
} else {
|
||||
// Fill the result A with zeros
|
||||
result_A.clear();
|
||||
}
|
||||
|
||||
if (beta_ != ElementCompute(0)) {
|
||||
result_B = visitor_b_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
} else {
|
||||
// Fill the result B with zeros
|
||||
result_B.clear();
|
||||
}
|
||||
|
||||
/// Type conversion
|
||||
NumericArrayConverter<ElementCompute, ElementA, kElementsPerAccess> source_converter_A;
|
||||
NumericArrayConverter<ElementCompute, ElementB, kElementsPerAccess> source_converter_B;
|
||||
|
||||
CombinationOp combination_op;
|
||||
|
||||
cutlass::multiplies<VisitAccessType> multiply_op;
|
||||
|
||||
return combination_op(
|
||||
multiply_op(alpha_, source_converter_A(result_A)),
|
||||
multiply_op(beta_, source_converter_B(result_B))
|
||||
);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
if (alpha_ != ElementCompute(0)) visitor_a_op.end_row(row_idx);
|
||||
if (beta_ != ElementCompute(0)) visitor_b_op.end_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
if (alpha_ != ElementCompute(0)) visitor_a_op.end_step(step_idx);
|
||||
if (beta_ != ElementCompute(0)) visitor_b_op.end_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
if (alpha_ != ElementCompute(0)) visitor_a_op.end_epilogue();
|
||||
if (beta_ != ElementCompute(0)) visitor_b_op.end_epilogue();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,258 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with broadcasting vector to all rows
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementVector T[i][j] <- device-memory Td[j]
|
||||
///
|
||||
/// It can only be a leaf node in the epilogue tree
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename ElementFragment_, ///< Data type used to cache vector in register
|
||||
typename InputTileIterator_ ///< Tile iterator type to read the broadcasted tensor
|
||||
>
|
||||
class VisitorOpRowBroadcast {
|
||||
public:
|
||||
using InputTileIterator = InputTileIterator_;
|
||||
|
||||
static int const kElementsPerAccess = InputTileIterator::kElementsPerAccess;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementVector = typename InputTileIterator::Element;
|
||||
using ElementFragment = ElementFragment_;
|
||||
|
||||
using VisitAccessType = Array<ElementFragment, kElementsPerAccess>;
|
||||
|
||||
/// Thread map used by input tile iterators
|
||||
using ThreadMap = typename InputTileIterator::ThreadMap;
|
||||
|
||||
/// Fragment object used to store the broadcast values
|
||||
using BroadcastFragment = Array<
|
||||
ElementFragment,
|
||||
ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Used for the broadcast
|
||||
struct BroadcastDetail {
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = 32;
|
||||
|
||||
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar column indices handled by each thread
|
||||
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar row indices handled by each thread
|
||||
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
|
||||
|
||||
/// Number of threads per threadblock
|
||||
static int const kThreadCount = ThreadMap::kThreads;
|
||||
|
||||
/// Number of distinct threads per row of output tile
|
||||
static int const kThreadsPerRow = (InputTileIterator::Shape::kN / kColumnsPerThread);
|
||||
|
||||
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock.
|
||||
static int const kThreadRows = kThreadCount / kThreadsPerRow;
|
||||
|
||||
// /// Number of iterations (accesses) the threadblock takes to reduce a row
|
||||
// static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount);
|
||||
};
|
||||
|
||||
// using ComputeFragmentType = Array<ElementVector, BroadcastDetail::kElementsPerAccess>;
|
||||
|
||||
struct SharedStorage {
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() { }
|
||||
};
|
||||
|
||||
/// Host-constructable Argument structure
|
||||
struct Arguments {
|
||||
ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand
|
||||
int64_t batch_stride;
|
||||
|
||||
/// Methods
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():
|
||||
broadcast_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ElementVector *broadcast_ptr,
|
||||
int64_t batch_stride
|
||||
):
|
||||
broadcast_ptr(broadcast_ptr),
|
||||
batch_stride(batch_stride) { }
|
||||
};
|
||||
|
||||
/// Param structure
|
||||
struct Params {
|
||||
ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand
|
||||
int64_t batch_stride;
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
broadcast_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
broadcast_ptr(args.broadcast_ptr),
|
||||
batch_stride(args.batch_stride) { }
|
||||
};
|
||||
|
||||
private:
|
||||
ElementVector *broadcast_ptr;
|
||||
BroadcastFragment broadcast_fragment; ///< Array holds the loaded broadcast fragment
|
||||
MatrixCoord threadblock_offset_;
|
||||
int thread_idx_;
|
||||
MatrixCoord problem_size;
|
||||
int64_t batch_stride_;
|
||||
|
||||
public:
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpRowBroadcast(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
broadcast_ptr(params.broadcast_ptr + threadblock_offset.column()),
|
||||
threadblock_offset_(threadblock_offset),
|
||||
thread_idx_(thread_idx),
|
||||
problem_size(problem_size),
|
||||
batch_stride_(params.batch_stride) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
broadcast_ptr += batch_idx * batch_stride_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
// load broadcast fragment
|
||||
load_broadcast_fragment_();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
VisitAccessType* broadcast_fragment_ = reinterpret_cast<VisitAccessType*>(&broadcast_fragment);
|
||||
return broadcast_fragment_[column_idx];
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() { }
|
||||
|
||||
private:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void load_broadcast_fragment_() {
|
||||
|
||||
broadcast_fragment.clear();
|
||||
|
||||
// If no pointer is supplied, set with all zeros and avoid memory accesses
|
||||
if (!broadcast_ptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column();
|
||||
|
||||
int thread_column_idx = threadblock_offset_.column() + thread_initial_column;
|
||||
broadcast_ptr += thread_initial_column;
|
||||
|
||||
NumericArrayConverter<ElementFragment, ElementVector, BroadcastDetail::kElementsPerAccess> converter;
|
||||
using AccessType = AlignedArray<ElementVector, BroadcastDetail::kElementsPerAccess>;
|
||||
using AccessFragmentType = Array<ElementFragment, BroadcastDetail::kElementsPerAccess>;
|
||||
|
||||
AccessFragmentType *frag_ptr = reinterpret_cast<AccessFragmentType *>(&broadcast_fragment);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) {
|
||||
|
||||
AccessType loaded;
|
||||
|
||||
loaded.clear();
|
||||
|
||||
if (thread_column_idx < problem_size.column()) {
|
||||
loaded = *reinterpret_cast<AccessType const *>(broadcast_ptr);
|
||||
}
|
||||
|
||||
AccessFragmentType cvt = converter(loaded);
|
||||
frag_ptr[j] = cvt;
|
||||
|
||||
thread_column_idx += ThreadMap::Delta::kColumn;
|
||||
broadcast_ptr += ThreadMap::Delta::kColumn;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,319 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with reduction over rows in CTA
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "stdio.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementReductionAccumulator R[i] = \sum_i ElementReductionAccumulator(T[i][j])
|
||||
/// device memory <- ElementReduction(R[i])
|
||||
///
|
||||
template <
|
||||
typename ThreadblockShape_, /// Threadblock shape
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename ElementReduction_, ///< Data type of the output reduction in device memory
|
||||
typename ElementReductionAccumulator_ , ///< Data type to accumulate reduction in smem and register
|
||||
typename OutputTileIterator_, ///< Tile Iterator type
|
||||
typename Visitor_ ///< preceding visitor op
|
||||
>
|
||||
class VisitorOpRowReduction {
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementReductionAccumulator = ElementReductionAccumulator_;
|
||||
using ElementReduction = ElementReduction_;
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
using Visitor = Visitor_;
|
||||
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
using ReductionOp = cutlass::plus<Array<ElementReductionAccumulator, kElementsPerAccess>>;
|
||||
using ReductionOpScalar = cutlass::plus<ElementReductionAccumulator>;
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
|
||||
/// Fragment type returned from Visitor
|
||||
using VisitAccessTypeVisitor = typename Visitor::VisitAccessType;
|
||||
using ElementVisitor = typename VisitAccessTypeVisitor::Element;
|
||||
|
||||
using VisitAccessType = VisitAccessTypeVisitor;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of reduction
|
||||
using ReductionAccumulatorAccessType = Array<ElementReductionAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Thread map used by output tile iterators
|
||||
using ThreadMap = typename OutputTileIterator::ThreadMap;
|
||||
/// Used for the reduction
|
||||
struct ReductionDetail {
|
||||
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = 32;
|
||||
|
||||
/// Number of distinct scalar column indices handled by each thread
|
||||
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar row indices handled by each thread
|
||||
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
|
||||
|
||||
/// Number of threads per threadblock
|
||||
static int const kThreadCount = ThreadMap::kThreads;
|
||||
|
||||
/// Number of distinct threads per row of output tile
|
||||
static int const kThreadsPerRow = ThreadblockShape::kN / kColumnsPerThread;
|
||||
|
||||
/// Half number of threads per row used for cross-thread reduction
|
||||
static int const kHalfThreadsPerRow = (kThreadsPerRow >> 1);
|
||||
|
||||
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock
|
||||
static int const kThreadRows = kThreadCount / kThreadsPerRow;
|
||||
};
|
||||
|
||||
/// Shared storage
|
||||
struct SharedStorage {
|
||||
typename Visitor::SharedStorage storage_visitor;
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() { }
|
||||
};
|
||||
|
||||
/// Host-constructable Argument structure
|
||||
struct Arguments {
|
||||
ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory
|
||||
int64_t batch_stride;
|
||||
typename Visitor::Arguments visitor_arg; ///< Argument type of visitor
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): reduction_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ElementReduction *reduction_ptr,
|
||||
int64_t batch_stride,
|
||||
typename Visitor::Arguments visitor_arg
|
||||
):
|
||||
reduction_ptr(reduction_ptr),
|
||||
batch_stride(batch_stride),
|
||||
visitor_arg(visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
/// Param structure
|
||||
struct Params {
|
||||
ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory
|
||||
int64_t batch_stride;
|
||||
typename Visitor::Params visitor_param; ///< Argument type of visitor
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): reduction_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
reduction_ptr(args.reduction_ptr),
|
||||
batch_stride(args.batch_stride),
|
||||
visitor_param(args.visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
private:
|
||||
ElementReduction *reduction_output_ptr_; ///< Pointer to the reduction tensor in device memory
|
||||
ElementReductionAccumulator reduction_accum;
|
||||
Visitor visitor_; ///< visitor
|
||||
int thread_idx_;
|
||||
MatrixCoord threadblock_offset;
|
||||
MatrixCoord problem_size_;
|
||||
|
||||
int thread_start_row_; /// used to identify
|
||||
int state_[3]; /// used to track row iterator
|
||||
int thread_offset_row_;
|
||||
int64_t batch_stride_;
|
||||
public:
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpRowReduction(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
visitor_(params.visitor_param, shared_storage.storage_visitor,
|
||||
thread_idx, threadblock_offset, problem_size),
|
||||
reduction_output_ptr_(params.reduction_ptr),
|
||||
thread_idx_(thread_idx),
|
||||
threadblock_offset(threadblock_offset),
|
||||
problem_size_(problem_size),
|
||||
thread_start_row_(ThreadMap::initial_offset(thread_idx).row() + threadblock_offset.row()),
|
||||
batch_stride_(params.batch_stride)
|
||||
{
|
||||
state_[0] = state_[1] = state_[2] = 0;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
reduction_output_ptr_ += batch_idx * batch_stride_;
|
||||
visitor_.set_batch_index(batch_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
visitor_.begin_epilogue();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
visitor_.begin_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
visitor_.begin_row(row_idx);
|
||||
|
||||
reduction_accum = ElementReductionAccumulator(0);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
/// Get result from visitor
|
||||
VisitAccessTypeVisitor result = visitor_.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
|
||||
thread_offset_row_ = thread_start_row_ + ThreadMap::iteration_offset(frag_idx).row();
|
||||
|
||||
ReductionOpScalar reduction_op;
|
||||
|
||||
ElementReductionAccumulator reduction_accum_ = reduction(result);
|
||||
|
||||
// After performing the in-thread reduction, we then perform cross-thread / in-warp reduction
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = ReductionDetail::kHalfThreadsPerRow; i > 0; i >>= 1) {
|
||||
reduction_accum_ = reduction_op(reduction_accum_, __shfl_xor_sync(0xFFFFFFFF, reduction_accum_, i));
|
||||
}
|
||||
reduction_accum = reduction_op(reduction_accum, reduction_accum_);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
visitor_.end_row(row_idx);
|
||||
NumericConverter<ElementReduction, ElementReductionAccumulator> output_converter;
|
||||
|
||||
bool is_write_thread = (thread_offset_row_ < problem_size_.row() && (thread_idx_ % ReductionDetail::kThreadsPerRow) == 0);
|
||||
int row_offset = thread_offset_row_ + threadblock_offset.column() / ThreadblockShape::kN * problem_size_.row();
|
||||
|
||||
ElementReduction *curr_ptr_reduction = reduction_output_ptr_ + row_offset;
|
||||
|
||||
arch::global_store<ElementReduction, sizeof(ElementReduction)>(
|
||||
output_converter(reduction_accum),
|
||||
(void *)curr_ptr_reduction,
|
||||
is_write_thread);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
visitor_.end_step(step_idx);
|
||||
|
||||
// run operator ++
|
||||
++state_[0];
|
||||
|
||||
thread_start_row_ += ThreadMap::Shape::kRow;
|
||||
if (state_[0] == ThreadMap::Count::kRow) {
|
||||
state_[0] = 0;
|
||||
++state_[1];
|
||||
thread_start_row_ += (ThreadMap::Shape::kGroup - 1) *
|
||||
ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
|
||||
|
||||
if (state_[1] == ThreadMap::Count::kGroup) {
|
||||
state_[1] = 0;
|
||||
++state_[2];
|
||||
thread_start_row_ += ThreadMap::Count::kGroup *
|
||||
ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;
|
||||
|
||||
if (state_[2] == ThreadMap::Count::kCluster) {
|
||||
state_[2] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
visitor_.end_epilogue();
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementReductionAccumulator reduction(VisitAccessTypeVisitor const& result) {
|
||||
ElementReductionAccumulator sum_ = ElementReductionAccumulator(0);
|
||||
|
||||
ReductionOpScalar reduction_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < VisitAccessTypeVisitor::kElements; ++i) {
|
||||
sum_ = reduction_op(sum_, result[i]);
|
||||
}
|
||||
|
||||
return sum_;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,188 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with Tensor Output
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementInput C <- device memory
|
||||
///
|
||||
/// It can only be a leaf node in the epilogue tree
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename InputTileIterator_ ///< Tile iterator type to read the tensor
|
||||
>
|
||||
class VisitorOpTensorInput {
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using InputTileIterator = InputTileIterator_;
|
||||
|
||||
static int const kElementsPerAccess = InputTileIterator::kElementsPerAccess;
|
||||
using ElementInput = typename InputTileIterator::Element;
|
||||
|
||||
using VisitAccessType = Array<ElementInput, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
struct SharedStorage {
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() { }
|
||||
};
|
||||
|
||||
/// Host-constructable Argument structure
|
||||
struct Arguments {
|
||||
ElementInput *input_ptr; ///< Pointer to the input tensor in device memory
|
||||
int ldt; ///< Leading dimension of the input tensor operand
|
||||
int64_t batch_stride; ///< batch stride for batched GEMM
|
||||
|
||||
/// Methods
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): input_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ElementInput *input_ptr,
|
||||
int ldt, int64_t batch_stride
|
||||
):
|
||||
input_ptr(input_ptr),
|
||||
ldt(ldt),
|
||||
batch_stride(batch_stride)
|
||||
{ }
|
||||
};
|
||||
|
||||
/// Param structure
|
||||
struct Params {
|
||||
typename InputTileIterator::Params params_input;
|
||||
ElementInput *input_ptr;
|
||||
int64_t batch_stride;
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
input_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
params_input(args.ldt),
|
||||
input_ptr(args.input_ptr),
|
||||
batch_stride(args.batch_stride)
|
||||
{ }
|
||||
};
|
||||
|
||||
private:
|
||||
InputTileIterator iterator_T_;
|
||||
typename InputTileIterator::Fragment fragment_T_;
|
||||
MatrixCoord problem_size;
|
||||
int64_t batch_stride_;
|
||||
|
||||
public:
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpTensorInput(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
iterator_T_(
|
||||
InputTileIterator(
|
||||
params.params_input,
|
||||
params.input_ptr,
|
||||
problem_size,
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
)
|
||||
),
|
||||
problem_size(problem_size),
|
||||
batch_stride_(params.batch_stride) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
iterator_T_.add_pointer_offset(batch_idx * batch_stride_);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
fragment_T_.clear();
|
||||
iterator_T_.load(fragment_T_);
|
||||
++iterator_T_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
VisitAccessType source = reinterpret_cast<VisitAccessType *>(&fragment_T_)[frag_idx];
|
||||
return source;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() { }
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,240 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with Tensor Output
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "stdio.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementOutput T = ElementOutput(Visitor)
|
||||
/// T-> device memory
|
||||
///
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename OutputTileIterator_, ///< Tile iterator type to write the tensor
|
||||
typename Visitor_ ///< Child visitor that produces the output tensor
|
||||
>
|
||||
class VisitorOpTensorOutput {
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
|
||||
using Visitor = Visitor_;
|
||||
|
||||
/// Fragment type returned from Visitor
|
||||
using VisitAccessTypeVisitor = typename Visitor::VisitAccessType;
|
||||
using ElementVisitor = typename VisitAccessTypeVisitor::Element;
|
||||
|
||||
using VisitAccessType = VisitAccessTypeVisitor;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of output
|
||||
using OutputAccessType = Array<ElementOutput, kElementsPerAccess>;
|
||||
|
||||
static_assert(kElementsPerAccess==VisitAccessTypeVisitor::kElements, "kElementsPerAccess mismatches with Visitor");
|
||||
|
||||
struct SharedStorage {
|
||||
typename Visitor::SharedStorage storage_visitor;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() { }
|
||||
};
|
||||
|
||||
/// Host-constructable Argument structure
|
||||
struct Arguments {
|
||||
ElementOutput *output_ptr; ///< Pointer to the output tensor in device memory
|
||||
int ldt; ///< Leading dimension of the output tensor operand
|
||||
int64_t batch_stride; ///< batch stride
|
||||
typename Visitor::Arguments visitor_arg; ///< Argument type of visitor
|
||||
|
||||
/// Methods
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): output_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ElementOutput *output_ptr,
|
||||
int ldt,
|
||||
int64_t batch_stride,
|
||||
typename Visitor::Arguments visitor_arg
|
||||
):
|
||||
output_ptr(output_ptr),
|
||||
ldt(ldt),
|
||||
batch_stride(batch_stride),
|
||||
visitor_arg(visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
/// Param structure
|
||||
struct Params {
|
||||
typename OutputTileIterator::Params params_output;
|
||||
ElementOutput *output_ptr;
|
||||
int64_t batch_stride;
|
||||
typename Visitor::Params visitor_param;
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
output_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
params_output(args.ldt),
|
||||
output_ptr(args.output_ptr),
|
||||
batch_stride(args.batch_stride),
|
||||
visitor_param(args.visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
private:
|
||||
OutputTileIterator iterator_T_;
|
||||
typename OutputTileIterator::Fragment fragment_T_;
|
||||
MatrixCoord problem_size;
|
||||
Visitor visitor_;
|
||||
int64_t batch_stride_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpTensorOutput(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
visitor_(params.visitor_param, shared_storage.storage_visitor, thread_idx, threadblock_offset, problem_size),
|
||||
iterator_T_(
|
||||
OutputTileIterator(
|
||||
params.params_output,
|
||||
params.output_ptr,
|
||||
problem_size,
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
)
|
||||
),
|
||||
problem_size(problem_size),
|
||||
batch_stride_(params.batch_stride) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
iterator_T_.add_pointer_offset(batch_idx * batch_stride_);
|
||||
visitor_.set_batch_index(batch_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
visitor_.begin_epilogue();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
fragment_T_.clear();
|
||||
visitor_.begin_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
visitor_.begin_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
/// Get result from visitor
|
||||
VisitAccessTypeVisitor result = visitor_.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
|
||||
// Column guard
|
||||
MatrixCoord thread_offset_ = iterator_T_.thread_start() + OutputTileIterator::ThreadMap::iteration_offset(frag_idx);
|
||||
bool column_guard = (thread_offset_.column() < problem_size.column());
|
||||
|
||||
if (column_guard) {
|
||||
NumericArrayConverter<ElementOutput, ElementVisitor, kElementsPerAccess> output_converter;
|
||||
OutputAccessType &output = reinterpret_cast<OutputAccessType *>(&fragment_T_)[frag_idx];
|
||||
output = output_converter(result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
visitor_.end_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
visitor_.end_step(step_idx);
|
||||
iterator_T_.store(fragment_T_);
|
||||
++iterator_T_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
visitor_.end_epilogue();
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,226 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with Unary operation
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "unary_ops.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementCompute alpha;
|
||||
/// ElementCompute beta;
|
||||
/// ElementCompute C = UnaryOp(ElementCompute(Visitor))
|
||||
/// Return C;
|
||||
///
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename ElementCompute_, ///< Data type used to compute linear combination
|
||||
int kElementsPerAccess_, ///< Number of elements computed per operation
|
||||
typename Visitor_, ///< Child node
|
||||
template<typename T, int N> typename UnaryOp_
|
||||
>
|
||||
class VisitorOpUnary{
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
static int const kElementsPerAccess = kElementsPerAccess_;
|
||||
|
||||
using Visitor = Visitor_;
|
||||
|
||||
/// Fragment type returned from Visitor.visit
|
||||
using VisitAccessTypeVisitor = typename Visitor::VisitAccessType;
|
||||
using ElementVisit = typename VisitAccessTypeVisitor::Element;
|
||||
|
||||
/// Fragment type returned by this visitor
|
||||
using VisitAccessType = Array<ElementCompute, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Combination Op
|
||||
using UnaryOp = UnaryOp_<ElementCompute, kElementsPerAccess>;
|
||||
|
||||
static_assert(kElementsPerAccess==VisitAccessTypeVisitor::kElements, "kElementsPerAccess mismatches with Visitor");
|
||||
|
||||
/// SMEM buffer class required in the epilogue visitor
|
||||
struct SharedStorage {
|
||||
typename Visitor::SharedStorage storage_visitor;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() {}
|
||||
};
|
||||
|
||||
|
||||
/// Host-constructable Arguments structure
|
||||
struct Arguments {
|
||||
typename UnaryOp::Arguments unary_arg;
|
||||
typename Visitor::Arguments visitor_arg; ///< Argument type for visitor
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():unary_arg() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
typename UnaryOp::Arguments unary_arg,
|
||||
typename Visitor::Arguments visitor_arg
|
||||
):
|
||||
unary_arg(unary_arg),
|
||||
visitor_arg(visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
/// Parameter structure
|
||||
struct Params {
|
||||
typename UnaryOp::Params unary_param;
|
||||
typename Visitor::Params visitor_param; ///< Argument type for visitor
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():unary_param() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
unary_param(args.unary_arg),
|
||||
visitor_param(args.visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
UnaryOp unary_op;
|
||||
|
||||
Visitor visitor_op;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpUnary(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
unary_op(params.unary_param),
|
||||
visitor_op(params.visitor_param, shared_storage.storage_visitor, thread_idx, threadblock_offset, problem_size)
|
||||
{ }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
visitor_op.set_batch_index(batch_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
if (unary_op.guard()) visitor_op.begin_epilogue();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
if (unary_op.guard()) visitor_op.begin_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
if (unary_op.guard()) visitor_op.begin_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
/// Get result from visitor A and visitor B
|
||||
VisitAccessTypeVisitor result;
|
||||
|
||||
if (unary_op.guard()){
|
||||
result = visitor_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
} else {
|
||||
result.clear();
|
||||
}
|
||||
|
||||
/// Type conversion
|
||||
NumericArrayConverter<ElementCompute, ElementVisit, kElementsPerAccess> source_converter;
|
||||
|
||||
cutlass::multiplies<VisitAccessType> multiply_op;
|
||||
|
||||
return unary_op(source_converter(result));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
if (unary_op.guard()) visitor_op.end_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
if (unary_op.guard()) visitor_op.end_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
if (unary_op.guard()) visitor_op.end_epilogue();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,480 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Epilogue visitor type used for partial computation of a layernorm operation
|
||||
|
||||
GemmLayernorm example = GEMM0 with partial reduction fused in epilogue (EpilogueVisitorLayerNorm)
|
||||
+ lightweight full reduction kernel (ApplyFinalReduction)
|
||||
+ GEMM1 with elementwise operations fused in mainloop (GemmLayernormMainloopFusion)
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_complex.h"
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename ThreadblockShape_,
|
||||
int ThreadCount,
|
||||
typename OutputTileIterator_,
|
||||
typename AccumulatorTile_,
|
||||
typename ElementAccumulator_,
|
||||
typename ElementVariance_,
|
||||
typename ElementMean_,
|
||||
typename ElementLayernormCompute_,
|
||||
typename ElementwiseFunctor_,
|
||||
bool IsShiftedVariance_ = false
|
||||
>
|
||||
class EpilogueVisitorLayerNorm {
|
||||
public:
|
||||
|
||||
using ElementVariance = ElementVariance_;
|
||||
using ElementMean = ElementMean_;
|
||||
using ElementLayernormCompute = ElementLayernormCompute_;
|
||||
|
||||
using AccumulatorTile = AccumulatorTile_;
|
||||
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
static int const kThreadCount = ThreadCount;
|
||||
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using ElementwiseFunctor = ElementwiseFunctor_;
|
||||
|
||||
static int const kIterations = OutputTileIterator::kIterations;
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
static int const kRowIterations = OutputTileIterator::ThreadMap::Iterations::kRow;
|
||||
|
||||
static int const kThreads = OutputTileIterator::ThreadMap::kThreads;
|
||||
|
||||
static bool const kIsShiftedVariance = IsShiftedVariance_;
|
||||
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
|
||||
static int const kDeltaRow = OutputTileIterator::ThreadMap::Delta::kRow;
|
||||
|
||||
/// Array type used in Shift-K Layernorm
|
||||
static int const kRowAccessCount = kIterations * kRowIterations;
|
||||
|
||||
using ConvertedShiftFragment = Array<ElementLayernormCompute, kRowAccessCount>;
|
||||
|
||||
// Conducts manual transpose externally (already supported) for column major
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
|
||||
using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
using LayernormFragment = Array<ElementLayernormCompute, kElementsPerAccess>;
|
||||
using OutputVector = Array<ElementOutput, kElementsPerAccess>;
|
||||
using TensorRefD = TensorRef<ElementOutput, LayoutOutput>;
|
||||
|
||||
static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth;
|
||||
static int const kThreadsInColumn = kThreads / kThreadsPerRow;
|
||||
static int const kHalfThreadsPerRow = (kThreadsPerRow >> 1);
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
typename ElementwiseFunctor::Params elementwise;
|
||||
ElementVariance *ptr_Variance;
|
||||
ElementMean *ptr_Mean;
|
||||
ElementOutput *ptr_Shifted_K;
|
||||
MatrixCoord extent;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
Arguments():
|
||||
ptr_Variance(nullptr),
|
||||
ptr_Mean(nullptr),
|
||||
ptr_Shifted_K(nullptr)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
Arguments(
|
||||
typename ElementwiseFunctor::Params elementwise_,
|
||||
ElementVariance *ptr_Variance,
|
||||
ElementMean *ptr_Mean_,
|
||||
ElementOutput *ptr_Shifted_K_ = nullptr,
|
||||
MatrixCoord extent = MatrixCoord(0, 0)
|
||||
):
|
||||
elementwise(elementwise_),
|
||||
ptr_Variance(ptr_Variance),
|
||||
ptr_Mean(ptr_Mean_),
|
||||
ptr_Shifted_K(ptr_Shifted_K_),
|
||||
extent(extent)
|
||||
{
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
struct Params {
|
||||
|
||||
typename ElementwiseFunctor::Params elementwise;
|
||||
ElementVariance *ptr_Variance;
|
||||
ElementMean *ptr_Mean;
|
||||
ElementOutput *ptr_Shifted_K;
|
||||
MatrixCoord extent;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
ptr_Variance(nullptr),
|
||||
ptr_Mean(nullptr)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
elementwise(args.elementwise),
|
||||
ptr_Variance(args.ptr_Variance),
|
||||
ptr_Mean(args.ptr_Mean),
|
||||
ptr_Shifted_K(args.ptr_Shifted_K),
|
||||
extent(args.extent)
|
||||
{
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared storage
|
||||
struct SharedStorage {
|
||||
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
Params const & params_;
|
||||
SharedStorage & shared_storage_;
|
||||
MatrixCoord extent_;
|
||||
ElementwiseFunctor elementwise_;
|
||||
|
||||
OutputTileIterator iterator_C_;
|
||||
OutputTileIterator iterator_D_;
|
||||
typename OutputTileIterator::Fragment fragment_C_;
|
||||
typename OutputTileIterator::Fragment fragment_D_;
|
||||
|
||||
ElementAccumulator alpha_;
|
||||
ElementAccumulator beta_;
|
||||
ConvertedShiftFragment shift_k_frag_;
|
||||
|
||||
ElementLayernormCompute accum_sum_square_;
|
||||
ElementLayernormCompute accum_sum_element_;
|
||||
int thread_idx_;
|
||||
|
||||
MatrixCoord thread_offset_;
|
||||
|
||||
gemm::GemmCoord threadblock_tile_offset_;
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
EpilogueVisitorLayerNorm(
|
||||
Params const ¶ms, ///< Parameters routed to the epilogue
|
||||
SharedStorage &shared_storage, ///< Shared storage needed by the functors here
|
||||
MatrixCoord threadblock_offset,
|
||||
gemm::GemmCoord threadblock_tile_offset,
|
||||
int thread_idx,
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
OutputTileIterator source_iterator ///< Threadblock tile coordinate in GEMMM
|
||||
):
|
||||
params_(params),
|
||||
shared_storage_(shared_storage),
|
||||
elementwise_(params.elementwise),
|
||||
extent_(params.extent),
|
||||
iterator_C_(source_iterator),
|
||||
iterator_D_(destination_iterator),
|
||||
threadblock_tile_offset_(threadblock_tile_offset),
|
||||
thread_idx_(thread_idx)
|
||||
{
|
||||
alpha_ = (params.elementwise.alpha_ptr ? *params.elementwise.alpha_ptr : params.elementwise.alpha);
|
||||
beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta);
|
||||
|
||||
if (beta_ == ElementAccumulator()) {
|
||||
iterator_C_.clear_mask();
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to indicate split-K behavior
|
||||
CUTLASS_DEVICE
|
||||
void set_k_partition(
|
||||
int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
|
||||
int split_k_slices) { ///< Total number of split-K slices
|
||||
|
||||
}
|
||||
|
||||
/// Called to set the batch index
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
|
||||
}
|
||||
|
||||
/// Called at the start of the epilogue just before iterating over accumulator slices
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
|
||||
// If shift-K feature is enabled, we load shift-k fragment
|
||||
// at the very beginning of an epilogue
|
||||
if (kIsShiftedVariance && params_.ptr_Shifted_K != nullptr) {
|
||||
shift_k_frag_.clear();
|
||||
int thread_offset_row_base = iterator_D_.thread_start_row();
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int iter_idx = 0; iter_idx < kIterations; ++iter_idx) {
|
||||
int step_offset = iter_idx * OutputTileIterator::Shape::kRow;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int rid = 0; rid < kRowIterations; ++rid) {
|
||||
int row_step_offset = rid * kDeltaRow;
|
||||
int row_offset = thread_offset_row_base + step_offset + row_step_offset;
|
||||
bool is_load = (row_offset < extent_.row());
|
||||
shift_k_frag_[iter_idx * kRowIterations + rid] = load_shift_k_(row_offset, is_load);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// Called at the start of one step before starting accumulator exchange
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
fragment_D_.clear();
|
||||
|
||||
if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
|
||||
fragment_C_.clear();
|
||||
iterator_C_.load(fragment_C_);
|
||||
++iterator_C_;
|
||||
}
|
||||
}
|
||||
|
||||
/// Called at the start of a row
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
/// set the accumulator to 0
|
||||
accum_sum_element_ = ElementLayernormCompute(0);
|
||||
accum_sum_square_ = ElementLayernormCompute(0);
|
||||
}
|
||||
|
||||
/// Called after accumulators have been exchanged for each accumulator vector
|
||||
CUTLASS_DEVICE
|
||||
void visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorFragment const &accum) {
|
||||
|
||||
using Mul = cutlass::multiplies<ElementLayernormCompute>;
|
||||
using Minus = cutlass::minus<ElementLayernormCompute>;
|
||||
using Exp = cutlass::fast_exp_op<ElementLayernormCompute>;
|
||||
|
||||
Minus minus;
|
||||
Mul mul;
|
||||
Exp exponential;
|
||||
|
||||
LayernormFragment result;
|
||||
|
||||
thread_offset_ =
|
||||
iterator_D_.thread_start() +
|
||||
OutputTileIterator::ThreadMap::iteration_offset(frag_idx);
|
||||
|
||||
NumericArrayConverter<ElementLayernormCompute, ElementOutput, kElementsPerAccess> source_converter;
|
||||
OutputVector &source_vector = reinterpret_cast<OutputVector *>(&fragment_C_)[frag_idx];
|
||||
|
||||
bool column_guard = (thread_offset_.column() < extent_.column());
|
||||
|
||||
if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
|
||||
result = source_converter(elementwise_(accum));
|
||||
}else{
|
||||
result = source_converter(elementwise_(accum, source_vector));
|
||||
}
|
||||
|
||||
|
||||
ElementLayernormCompute inv_scalar = cutlass::constants::one<ElementLayernormCompute>() / ElementLayernormCompute(extent_.column());
|
||||
|
||||
// Fragment is cleared for non-reachable columns so no need to check against column guard
|
||||
ElementLayernormCompute accum_sum_element_tmp = element_sum_accumulator_(result);
|
||||
|
||||
// Square sum is different. Non-reachable columns should've been computed for shift-k
|
||||
// Otherwise we will incorrectly have some extra k^2 added into square sum.
|
||||
ElementLayernormCompute accum_sum_square_tmp = ElementLayernormCompute(0);
|
||||
|
||||
if (column_guard) {
|
||||
accum_sum_square_tmp = (kIsShiftedVariance) ? \
|
||||
square_sum_accumulator_(result, shift_k_frag_[iter_idx * kRowIterations + row_idx]) : \
|
||||
square_sum_accumulator_(result);
|
||||
}
|
||||
|
||||
accum_sum_element_tmp *= inv_scalar;
|
||||
accum_sum_square_tmp *= inv_scalar;
|
||||
|
||||
// After performing the in-thread reduction, we then perform cross-thread / in-warp reduction
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = kHalfThreadsPerRow; i > 0; i >>= 1) {
|
||||
accum_sum_element_tmp += __shfl_xor_sync(0xFFFFFFFF, accum_sum_element_tmp, i);
|
||||
accum_sum_square_tmp += __shfl_xor_sync(0xFFFFFFFF, accum_sum_square_tmp, i);
|
||||
}
|
||||
accum_sum_element_ += accum_sum_element_tmp;
|
||||
accum_sum_square_ += accum_sum_square_tmp;
|
||||
|
||||
// Convert to the output
|
||||
NumericArrayConverter<ElementOutput, ElementLayernormCompute, kElementsPerAccess> output_converter;
|
||||
OutputVector &output = reinterpret_cast<OutputVector *>(&fragment_D_)[frag_idx];
|
||||
output = output_converter(result);
|
||||
}
|
||||
|
||||
/// Called at the start of a row
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
|
||||
using ConvertVarianceOutput = cutlass::NumericConverter<ElementVariance, ElementLayernormCompute>;
|
||||
using ConvertMeanOutput = cutlass::NumericConverter<ElementMean, ElementLayernormCompute>;
|
||||
|
||||
ConvertVarianceOutput convert_variance_output;
|
||||
ConvertMeanOutput convert_mean_output;
|
||||
|
||||
bool is_write_thread = (thread_offset_.row() < extent_.row() && (threadIdx.x % kThreadsPerRow) == 0);
|
||||
int row_offset = thread_offset_.row() + threadblock_tile_offset_.n() * extent_.row();
|
||||
|
||||
ElementVariance *curr_ptr_sum_square = params_.ptr_Variance + row_offset;
|
||||
ElementMean *curr_ptr_element_sum = params_.ptr_Mean + row_offset;
|
||||
|
||||
arch::global_store<ElementVariance, sizeof(ElementVariance)>(
|
||||
convert_variance_output(accum_sum_square_),
|
||||
(void *)curr_ptr_sum_square,
|
||||
is_write_thread);
|
||||
|
||||
arch::global_store<ElementMean, sizeof(ElementMean)>(
|
||||
convert_mean_output(accum_sum_element_),
|
||||
(void *)curr_ptr_element_sum,
|
||||
is_write_thread);
|
||||
}
|
||||
|
||||
/// Called after all accumulator elements have been visited
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
|
||||
iterator_D_.store(fragment_D_);
|
||||
++iterator_D_;
|
||||
}
|
||||
|
||||
/// Called after all steps have been completed
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementLayernormCompute load_shift_k_(int row_offset, bool is_load) {
|
||||
using ConvertShiftK = cutlass::NumericConverter<ElementLayernormCompute, ElementOutput>;
|
||||
ConvertShiftK convert_shift_k;
|
||||
ElementOutput shift_k_val;
|
||||
|
||||
// Computes the address to load shift_k element
|
||||
ElementOutput *curr_ptr_shift_k = params_.ptr_Shifted_K + row_offset;
|
||||
// Conditionally loads from global memory
|
||||
arch::global_load<ElementOutput, sizeof(ElementOutput)>(shift_k_val, (void *)curr_ptr_shift_k, is_load);
|
||||
// Converts data type to return
|
||||
ElementLayernormCompute converted_shift_k_val = convert_shift_k(shift_k_val);
|
||||
|
||||
return converted_shift_k_val;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum) {
|
||||
ElementLayernormCompute sum_ = ElementLayernormCompute(0);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < LayernormFragment::kElements; ++i) {
|
||||
auto accum_ = accum[i];
|
||||
sum_ += accum_ * accum_;
|
||||
}
|
||||
|
||||
return sum_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum, ElementLayernormCompute shift_k_val) {
|
||||
ElementLayernormCompute sum_ = ElementLayernormCompute(0);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < LayernormFragment::kElements; ++i) {
|
||||
auto accum_ = accum[i] - shift_k_val;
|
||||
sum_ += accum_ * accum_;
|
||||
}
|
||||
|
||||
return sum_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementLayernormCompute element_sum_accumulator_(LayernormFragment const &accum) {
|
||||
ElementLayernormCompute sum_ = ElementLayernormCompute(0);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < LayernormFragment::kElements; ++i) {
|
||||
sum_ += accum[i];
|
||||
}
|
||||
|
||||
return sum_;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,77 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind gemm related enum types to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "host.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void bind_gemm(py::module &m) {
|
||||
//
|
||||
// Enumerate types
|
||||
// cutlass/gemm/gemm.h
|
||||
|
||||
py::enum_<cutlass::gemm::GemmUniversalMode>(m, "Mode")
|
||||
.value("Gemm", cutlass::gemm::GemmUniversalMode::kGemm, "Ordinary GEMM & GEMM Split-K serial")
|
||||
.value("GemmSplitKParallel", cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, "GEMM Split-K parallel")
|
||||
.value("Batched", cutlass::gemm::GemmUniversalMode::kBatched, "Batched GEMM")
|
||||
.value("Array", cutlass::gemm::GemmUniversalMode::kArray)
|
||||
.value("Invalid", cutlass::gemm::GemmUniversalMode::kInvalid);
|
||||
|
||||
/// GemmCoord is a structure that specifies a location within the coordinate space of a GEMM problem
|
||||
py::class_<cutlass::gemm::GemmCoord>(m, "GemmCoord")
|
||||
.def(py::init<int, int, int>())
|
||||
.def("m", py::overload_cast<>(&cutlass::gemm::GemmCoord::m))
|
||||
.def("n", py::overload_cast<>(&cutlass::gemm::GemmCoord::n))
|
||||
.def("k", py::overload_cast<>(&cutlass::gemm::GemmCoord::k))
|
||||
// get tensor coords
|
||||
.def("mk",
|
||||
[](const cutlass::gemm::GemmCoord & problem_size) {
|
||||
return cutlass::MatrixCoord(problem_size.mk());
|
||||
})
|
||||
.def("kn",
|
||||
[](const cutlass::gemm::GemmCoord & problem_size) {
|
||||
return cutlass::MatrixCoord(problem_size.kn());
|
||||
})
|
||||
.def("mn",
|
||||
[](const cutlass::gemm::GemmCoord & problem_size) {
|
||||
return cutlass::MatrixCoord(problem_size.mn());
|
||||
});
|
||||
|
||||
py::module_ host_submodule = m.def_submodule("host");
|
||||
bind_gemm_host_helper(host_submodule);
|
||||
}
|
||||
@ -1,638 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/params_universal_base.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
|
||||
>
|
||||
struct GemmUniversalwithEpilogueVisitor {
|
||||
public:
|
||||
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueVisitor = typename Epilogue::Visitor;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename EpilogueVisitor::ElementOutput;
|
||||
using LayoutC = typename EpilogueVisitor::OutputTileIterator::Layout;
|
||||
|
||||
static ComplexTransform const kTransformA = Mma::kTransformA;
|
||||
static ComplexTransform const kTransformB = Mma::kTransformB;
|
||||
using Operator = typename Mma::Operator;
|
||||
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
/// Split-K preserves splits that are 128b aligned
|
||||
static int const kSplitKAlignment = const_max(
|
||||
128 / sizeof_bits<ElementA>::value,
|
||||
128 / sizeof_bits<ElementB>::value
|
||||
);
|
||||
|
||||
//
|
||||
// Structures
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments : UniversalArgumentsBase {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor;
|
||||
|
||||
void const * ptr_A;
|
||||
void const * ptr_B;
|
||||
void const * ptr_C;
|
||||
void * ptr_D;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_C;
|
||||
|
||||
typename LayoutA::Stride stride_a;
|
||||
typename LayoutB::Stride stride_b;
|
||||
typename LayoutC::Stride stride_c;
|
||||
typename LayoutC::Stride stride_d;
|
||||
|
||||
typename LayoutA::Stride::LongIndex lda;
|
||||
typename LayoutB::Stride::LongIndex ldb;
|
||||
typename LayoutC::Stride::LongIndex ldc;
|
||||
typename LayoutC::Stride::LongIndex ldd;
|
||||
|
||||
int const * ptr_gather_A_indices;
|
||||
int const * ptr_gather_B_indices;
|
||||
int const * ptr_scatter_D_indices;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Arguments():
|
||||
ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr),
|
||||
ptr_gather_A_indices(nullptr),
|
||||
ptr_gather_B_indices(nullptr),
|
||||
ptr_scatter_D_indices(nullptr) {}
|
||||
|
||||
/// constructs an arguments structure
|
||||
Arguments(
|
||||
GemmUniversalMode mode,
|
||||
GemmCoord problem_size,
|
||||
int batch_count,
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor,
|
||||
void const * ptr_A,
|
||||
void const * ptr_B,
|
||||
void const * ptr_C,
|
||||
void * ptr_D,
|
||||
int64_t batch_stride_A,
|
||||
int64_t batch_stride_B,
|
||||
int64_t batch_stride_C,
|
||||
int64_t batch_stride_D,
|
||||
typename LayoutA::Stride stride_a,
|
||||
typename LayoutB::Stride stride_b,
|
||||
typename LayoutC::Stride stride_c,
|
||||
typename LayoutC::Stride stride_d,
|
||||
int const *ptr_gather_A_indices = nullptr,
|
||||
int const *ptr_gather_B_indices = nullptr,
|
||||
int const *ptr_scatter_D_indices = nullptr
|
||||
):
|
||||
UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D),
|
||||
epilogue_visitor(epilogue_visitor),
|
||||
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
|
||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C),
|
||||
stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d),
|
||||
ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),
|
||||
ptr_scatter_D_indices(ptr_scatter_D_indices) {
|
||||
lda = 0;
|
||||
ldb = 0;
|
||||
ldc = 0;
|
||||
ldd = 0;
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
|
||||
}
|
||||
|
||||
/// constructs an arguments structure
|
||||
Arguments(
|
||||
GemmUniversalMode mode,
|
||||
GemmCoord problem_size,
|
||||
int batch_count,
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor,
|
||||
void const * ptr_A,
|
||||
void const * ptr_B,
|
||||
void const * ptr_C,
|
||||
void * ptr_D,
|
||||
int64_t batch_stride_A,
|
||||
int64_t batch_stride_B,
|
||||
int64_t batch_stride_C,
|
||||
int64_t batch_stride_D,
|
||||
typename LayoutA::Stride::LongIndex lda,
|
||||
typename LayoutB::Stride::LongIndex ldb,
|
||||
typename LayoutC::Stride::LongIndex ldc,
|
||||
typename LayoutC::Stride::LongIndex ldd,
|
||||
int const *ptr_gather_A_indices = nullptr,
|
||||
int const *ptr_gather_B_indices = nullptr,
|
||||
int const *ptr_scatter_D_indices = nullptr
|
||||
):
|
||||
UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D),
|
||||
epilogue_visitor(epilogue_visitor),
|
||||
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
|
||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C),
|
||||
lda(lda), ldb(ldb), ldc(ldc), ldd(ldd),
|
||||
ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),
|
||||
ptr_scatter_D_indices(ptr_scatter_D_indices) {
|
||||
stride_a = make_Coord(lda);
|
||||
stride_b = make_Coord(ldb);
|
||||
stride_c = make_Coord(ldc);
|
||||
stride_d = make_Coord(ldd);
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
|
||||
}
|
||||
|
||||
/// Returns arguments for the transposed problem
|
||||
Arguments transposed_problem() const {
|
||||
Arguments args(*this);
|
||||
|
||||
std::swap(args.problem_size.m(), args.problem_size.n());
|
||||
std::swap(args.ptr_A, args.ptr_B);
|
||||
std::swap(args.lda, args.ldb);
|
||||
std::swap(args.stride_a, args.stride_b);
|
||||
std::swap(args.batch_stride_A, args.batch_stride_B);
|
||||
std::swap(args.ptr_gather_A_indices, args.ptr_gather_B_indices);
|
||||
|
||||
return args;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params : UniversalParamsBase<
|
||||
ThreadblockSwizzle,
|
||||
ThreadblockShape,
|
||||
ElementA,
|
||||
ElementB,
|
||||
ElementC> {
|
||||
|
||||
using ParamsBase = UniversalParamsBase<
|
||||
ThreadblockSwizzle,
|
||||
ThreadblockShape,
|
||||
ElementA,
|
||||
ElementB,
|
||||
ElementC>;
|
||||
|
||||
typename Mma::IteratorA::Params params_A;
|
||||
typename Mma::IteratorB::Params params_B;
|
||||
typename EpilogueVisitor::OutputTileIterator::Params params_C;
|
||||
typename EpilogueVisitor::OutputTileIterator::Params params_D;
|
||||
|
||||
typename EpilogueVisitor::Params epilogue_visitor;
|
||||
|
||||
void * ptr_A;
|
||||
void * ptr_B;
|
||||
void * ptr_C;
|
||||
void * ptr_D;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_C;
|
||||
|
||||
int * ptr_gather_A_indices;
|
||||
int * ptr_gather_B_indices;
|
||||
int * ptr_scatter_D_indices;
|
||||
|
||||
int *semaphore;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default constructor
|
||||
Params() = default;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Arguments const &args,
|
||||
int device_sms,
|
||||
int sm_occupancy
|
||||
):
|
||||
ParamsBase(args, device_sms, sm_occupancy),
|
||||
params_A(args.lda ? make_Coord_with_padding<LayoutA::kStrideRank>(args.lda) : args.stride_a),
|
||||
params_B(args.ldb ? make_Coord_with_padding<LayoutB::kStrideRank>(args.ldb) : args.stride_b),
|
||||
params_C(args.ldc ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldc) : args.stride_c),
|
||||
params_D(args.ldd ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldd) : args.stride_d),
|
||||
epilogue_visitor(args.epilogue_visitor),
|
||||
ptr_A(const_cast<void *>(args.ptr_A)),
|
||||
ptr_B(const_cast<void *>(args.ptr_B)),
|
||||
ptr_C(const_cast<void *>(args.ptr_C)),
|
||||
ptr_D(args.ptr_D),
|
||||
batch_stride_A(args.batch_stride_A),
|
||||
batch_stride_B(args.batch_stride_B),
|
||||
batch_stride_C(args.batch_stride_C),
|
||||
ptr_gather_A_indices(const_cast<int *>(args.ptr_gather_A_indices)),
|
||||
ptr_gather_B_indices(const_cast<int *>(args.ptr_gather_B_indices)),
|
||||
ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices)) {
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void update(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr) {
|
||||
|
||||
ptr_A = const_cast<void *>(args.ptr_A);
|
||||
ptr_B = const_cast<void *>(args.ptr_B);
|
||||
ptr_C = const_cast<void *>(args.ptr_C);
|
||||
ptr_D = args.ptr_D;
|
||||
|
||||
ptr_gather_A_indices = const_cast<int *>(args.ptr_gather_A_indices);
|
||||
ptr_gather_B_indices = const_cast<int *>(args.ptr_gather_B_indices);
|
||||
ptr_scatter_D_indices = const_cast<int *>(args.ptr_scatter_D_indices);
|
||||
|
||||
batch_stride_A = args.batch_stride_A;
|
||||
batch_stride_B = args.batch_stride_B;
|
||||
batch_stride_C = args.batch_stride_C;
|
||||
|
||||
epilogue_visitor = args.epilogue_visitor;
|
||||
|
||||
semaphore = static_cast<int *>(workspace);
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::Params::update()");
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage {
|
||||
typename Mma::SharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
typename EpilogueVisitor::SharedStorage visitor;
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GemmUniversalwithEpilogueVisitor() { }
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(
|
||||
cutlass::gemm::GemmCoord const & problem_size) {
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalwithEpilogueVisitor::can_implement()");
|
||||
|
||||
static int const kAlignmentA = (platform::is_same<LayoutA,
|
||||
layout::ColumnMajorInterleaved<32>>::value)
|
||||
? 32
|
||||
: (platform::is_same<LayoutA,
|
||||
layout::ColumnMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = (platform::is_same<LayoutB,
|
||||
layout::RowMajorInterleaved<32>>::value)
|
||||
? 32
|
||||
: (platform::is_same<LayoutB,
|
||||
layout::RowMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = (platform::is_same<LayoutC,
|
||||
layout::ColumnMajorInterleaved<32>>::value)
|
||||
? 32
|
||||
: (platform::is_same<LayoutC,
|
||||
layout::ColumnMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
bool isAMisaligned = false;
|
||||
bool isBMisaligned = false;
|
||||
bool isCMisaligned = false;
|
||||
|
||||
if (platform::is_same<LayoutA, layout::RowMajor>::value) {
|
||||
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||
} else if (platform::is_same<LayoutA, layout::ColumnMajor>::value) {
|
||||
isAMisaligned = problem_size.m() % kAlignmentA;
|
||||
} else if (platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {
|
||||
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||
}
|
||||
|
||||
if (platform::is_same<LayoutB, layout::RowMajor>::value) {
|
||||
isBMisaligned = problem_size.n() % kAlignmentB;
|
||||
} else if (platform::is_same<LayoutB, layout::ColumnMajor>::value) {
|
||||
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||
} else if (platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {
|
||||
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||
}
|
||||
|
||||
if (platform::is_same<LayoutC, layout::RowMajor>::value) {
|
||||
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||
} else if (platform::is_same<LayoutC, layout::ColumnMajor>::value) {
|
||||
isCMisaligned = problem_size.m() % kAlignmentC;
|
||||
} else if (platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {
|
||||
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||
}
|
||||
|
||||
if (isAMisaligned) {
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (isBMisaligned) {
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (isCMisaligned) {
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" returning kSuccess");
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const &args) {
|
||||
return can_implement(args.problem_size);
|
||||
}
|
||||
|
||||
// Factory invocation
|
||||
CUTLASS_DEVICE
|
||||
static void invoke(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage)
|
||||
{
|
||||
GemmUniversalwithEpilogueVisitor op;
|
||||
op(params, shared_storage);
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
||||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
int offset_k = 0;
|
||||
int problem_size_k = params.problem_size.k();
|
||||
|
||||
ElementA *ptr_A = static_cast<ElementA *>(params.ptr_A);
|
||||
ElementB *ptr_B = static_cast<ElementB *>(params.ptr_B);
|
||||
|
||||
//
|
||||
// Fetch pointers based on mode.
|
||||
//
|
||||
if (params.mode == GemmUniversalMode::kGemm ||
|
||||
params.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
|
||||
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) {
|
||||
|
||||
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
|
||||
}
|
||||
|
||||
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kBatched) {
|
||||
ptr_A += threadblock_tile_offset.k() * params.batch_stride_A;
|
||||
ptr_B += threadblock_tile_offset.k() * params.batch_stride_B;
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kArray) {
|
||||
ptr_A = static_cast<ElementA * const *>(params.ptr_A)[threadblock_tile_offset.k()];
|
||||
ptr_B = static_cast<ElementB * const *>(params.ptr_B)[threadblock_tile_offset.k()];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
offset_k,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{
|
||||
offset_k,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN
|
||||
};
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
params.params_A,
|
||||
ptr_A,
|
||||
{params.problem_size.m(), problem_size_k},
|
||||
thread_idx,
|
||||
tb_offset_A,
|
||||
params.ptr_gather_A_indices);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
params.params_B,
|
||||
ptr_B,
|
||||
{problem_size_k, params.problem_size.n()},
|
||||
thread_idx,
|
||||
tb_offset_B,
|
||||
params.ptr_gather_B_indices);
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(
|
||||
gemm_k_iterations,
|
||||
accumulators,
|
||||
iterator_A,
|
||||
iterator_B,
|
||||
accumulators);
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
// EpilogueOutputOp output_op(params.output_op);
|
||||
|
||||
//
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
//assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN
|
||||
);
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
ElementC *ptr_C = static_cast<ElementC *>(params.ptr_C);
|
||||
ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);
|
||||
|
||||
//
|
||||
// Fetch pointers based on mode.
|
||||
//
|
||||
|
||||
// Construct the semaphore.
|
||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
|
||||
EpilogueVisitor epilogue_visitor(
|
||||
params.epilogue_visitor,
|
||||
shared_storage.visitor,
|
||||
threadblock_offset,
|
||||
threadblock_tile_offset,
|
||||
thread_idx,
|
||||
params.problem_size.mn()
|
||||
);
|
||||
|
||||
if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) {
|
||||
epilogue_visitor.set_batch_index(threadblock_tile_offset.k());
|
||||
}
|
||||
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
// Wait on the semaphore - this latency may have been covered by iterator construction
|
||||
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
||||
semaphore.wait(threadblock_tile_offset.k());
|
||||
}
|
||||
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(epilogue_visitor, accumulators);
|
||||
|
||||
//
|
||||
// Release the semaphore
|
||||
//
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
int lock = 0;
|
||||
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
|
||||
|
||||
// The final threadblock resets the semaphore for subsequent grids.
|
||||
lock = 0;
|
||||
}
|
||||
else {
|
||||
// Otherwise, the semaphore is incremented
|
||||
lock = threadblock_tile_offset.k() + 1;
|
||||
}
|
||||
|
||||
semaphore.release(lock);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,47 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind gemm host helpers to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "cutlass/util/host_reorder.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
||||
void bind_gemm_host_helper(py::module &m) {
|
||||
m.def("reorder_column", &cutlass::reorder_column<32, int8_t, cutlass::layout::RowMajorInterleaved<32>>);
|
||||
m.def("reorder_column", &cutlass::reorder_column<32, int8_t, cutlass::layout::ColumnMajorInterleaved<32>>);
|
||||
}
|
||||
@ -1,47 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind CUTLASS layouts to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "tensor.h"
|
||||
#include "matrix.h"
|
||||
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void bind_layout(py::module &m) {
|
||||
bind_tensor_layout(m);
|
||||
bind_matrix_layout(m);
|
||||
}
|
||||
@ -1,87 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind Matrix layouts to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void bind_matrix_layout(py::module &m) {
|
||||
//
|
||||
// Matrix layouts
|
||||
// cutlass/layout/matrix.h
|
||||
//
|
||||
|
||||
py::class_<cutlass::layout::RowMajor>(m, "RowMajor", R"pbdoc(
|
||||
Mapping function for row-major matrices.
|
||||
)pbdoc")
|
||||
.def_static("packed", &cutlass::layout::RowMajor::packed,
|
||||
py::arg("extent"),
|
||||
R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc")
|
||||
.def("stride", [](const cutlass::layout::RowMajor & layout){
|
||||
return layout.stride().at(0);
|
||||
}, R"pbdoc(Returns the stride of the layout)pbdoc");
|
||||
|
||||
py::class_<cutlass::layout::ColumnMajor>(m, "ColumnMajor", R"pbdoc(
|
||||
Mapping function for column-major matrices.
|
||||
)pbdoc")
|
||||
.def_static("packed", &cutlass::layout::ColumnMajor::packed,
|
||||
py::arg("extent"),
|
||||
R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc" )
|
||||
.def("stride", [](const cutlass::layout::ColumnMajor & layout){
|
||||
return layout.stride().at(0);
|
||||
}, R"pbdoc(Returns the stride of the layout)pbdoc");
|
||||
|
||||
py::class_<cutlass::layout::RowMajorInterleaved<32>>(m, "RowMajorInterleaved32",
|
||||
R"pbdoc(Mapping function for interleaved matrices. Matrix is structured
|
||||
as row-major arrangement of fixed-size columns 32)pbdoc")
|
||||
.def_static("packed", &cutlass::layout::RowMajorInterleaved<32>::packed,
|
||||
py::arg("extent"),
|
||||
R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc")
|
||||
.def("stride", [](const cutlass::layout::RowMajorInterleaved<32> & layout){
|
||||
return layout.stride().at(0);
|
||||
}, R"pbdoc(Returns the stride of the layout)pbdoc");
|
||||
|
||||
py::class_<cutlass::layout::ColumnMajorInterleaved<32>>(m, "ColumnMajorInterleaved32",
|
||||
R"pbdoc(Mapping function for interleaved matrices. Matrix is structured
|
||||
as column-major arrangement of fixed-size rows 32)pbdoc")
|
||||
.def_static("packed", &cutlass::layout::ColumnMajorInterleaved<32>::packed,
|
||||
py::arg("extent"),
|
||||
R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc")
|
||||
.def("stride", [](const cutlass::layout::ColumnMajorInterleaved<32> & layout){
|
||||
return layout.stride().at(0);
|
||||
}, R"pbdoc(Returns the stride of the layout)pbdoc");
|
||||
}
|
||||
@ -1,74 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind Tensor layouts to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "cutlass/layout/tensor.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void bind_tensor_layout(py::module &m) {
|
||||
//
|
||||
// Tensor layouts
|
||||
// cutlass/include/cutlass/layout/tensor.h
|
||||
//
|
||||
|
||||
/// Mapping function for 4-D NHWC tensors.
|
||||
py::class_<cutlass::layout::TensorNHWC>(m, "TensorNHWC",
|
||||
R"pbdoc(Mapping function for 4-D NHWC tensors)pbdoc")
|
||||
.def_static("packed", &cutlass::layout::TensorNHWC::packed,
|
||||
py::arg("extent"),
|
||||
R"pbdoc(Helper returns a layout to a tightly packed NHWC tensor)pbdoc")
|
||||
.def("stride", py::overload_cast<>(&cutlass::layout::TensorNHWC::stride),
|
||||
R"pbdoc(Returns the stride of the layout)pbdoc");
|
||||
|
||||
/// Mapping function for 4-D NC/xHWx tensors.
|
||||
py::class_<cutlass::layout::TensorNCxHWx<32>>(m, "TensorNC32HW32",
|
||||
R"pbdoc(Mapping function for 4-D NC/32HW32 tensors)pbdoc")
|
||||
.def_static("packed", &cutlass::layout::TensorNCxHWx<32>::packed,
|
||||
py::arg("extent"),
|
||||
R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc")
|
||||
.def("stride", py::overload_cast<>(&cutlass::layout::TensorNCxHWx<32>::stride),
|
||||
R"pbdoc(Returns the stride of the layout)pbdoc");
|
||||
|
||||
/// Mapping function for 4-D CxRSKx tensors.
|
||||
py::class_<cutlass::layout::TensorCxRSKx<32>>(m, "TensorC32RSK32",
|
||||
R"pbdoc(Mapping function for 4-D C32RSK32 tensors)pbdoc")
|
||||
.def_static("packed", &cutlass::layout::TensorCxRSKx<32>::packed,
|
||||
py::arg("extent"),
|
||||
R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc")
|
||||
.def("stride", py::overload_cast<>(&cutlass::layout::TensorCxRSKx<32>::stride),
|
||||
R"pbdoc(Returns the stride of the layout)pbdoc");
|
||||
}
|
||||
@ -1,169 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind threadblock swizzling to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
#include "cutlass/conv/threadblock/threadblock_swizzle.h"
|
||||
|
||||
#include <cxxabi.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
std::string demangle(const char* mangled_name) {
|
||||
std::size_t len = 0;
|
||||
int status = 0;
|
||||
std::unique_ptr<char> ptr(
|
||||
__cxxabiv1::__cxa_demangle(mangled_name, nullptr, &len, &status));
|
||||
return ptr.get();
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void bind_identity_swizzle(py::module & m, std::string name) {
|
||||
py::class_<T>(m, name.c_str(),
|
||||
R"pbdoc(Threadblock swizzling function for GEMMs)pbdoc")
|
||||
.def(py::init<>())
|
||||
.def("get_tiled_shape",
|
||||
py::overload_cast<cutlass::gemm::GemmCoord, cutlass::gemm::GemmCoord, int>(
|
||||
&T::get_tiled_shape, py::const_
|
||||
), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"),
|
||||
R"pbdoc(Returns the shape of the problem in units of logical tiles
|
||||
|
||||
:param problem_size: gemm(M, N, K)
|
||||
:type problem_size: :class:`cutlass.gemm.GemmCoord`
|
||||
)pbdoc")
|
||||
.def("get_tiled_shape",
|
||||
py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&, cutlass::gemm::GemmCoord, int>(
|
||||
&T::get_tiled_shape, py::const_
|
||||
), py::arg("conv_operator"), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"),
|
||||
R"pbdoc(Returns the shape of the problem in units of logical tiles
|
||||
|
||||
:param problem_size: Implicit gemm problem size conv_operator(NPQK, NHWC, KRSC)
|
||||
:type problem_size: :class:`cutlass.gemm.GemmCoord`)
|
||||
)pbdoc")
|
||||
.def("get_tiled_shape",
|
||||
py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv3dProblemSize&, cutlass::gemm::GemmCoord, int>(
|
||||
&T::get_tiled_shape, py::const_
|
||||
), py::arg("conv_operator"), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"),
|
||||
R"pbdoc(Returns the shape of the problem in units of logical tiles
|
||||
|
||||
:param problem_size: Implicit gemm problem size conv_operator(NZPQK, NDHWC, KTRSC)
|
||||
:type problem_size: :class:`cutlass.gemm.GemmCoord`)
|
||||
)pbdoc")
|
||||
.def("get_grid_shape", &T::get_grid_shape,
|
||||
py::arg("tiled_shape"),
|
||||
R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc")
|
||||
.def("tag", [](const T & swizzle){
|
||||
return demangle(typeid(T).name());
|
||||
}, R"pbdoc(Returns the c++ name of the swizzling for code emission)pbdoc");
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void bind_swizzle(py::module & m, std::string name, std::string doc) {
|
||||
py::class_<T>(m, name.c_str(), doc.c_str())
|
||||
.def(py::init<>())
|
||||
.def("get_tiled_shape",
|
||||
py::overload_cast<cutlass::gemm::GemmCoord, cutlass::gemm::GemmCoord, int>(
|
||||
&T::get_tiled_shape, py::const_
|
||||
), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"),
|
||||
R"pbdoc(Returns the shape of the problem in units of logical tiles
|
||||
|
||||
:param problem_size: gemm(M, N, K)
|
||||
:type problem_size: :class:`cutlass.gemm.GemmCoord`
|
||||
)pbdoc")
|
||||
.def("get_grid_shape", &T::get_grid_shape,
|
||||
py::arg("tiled_shape"),
|
||||
R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc")
|
||||
.def("tag", [](const T & swizzle){
|
||||
return demangle(typeid(T).name());
|
||||
}, R"pbdoc(Returns the c++ name of the swizzling for code emission)pbdoc");
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void bind_swizzle_streamk(py::module & m, std::string name, std::string doc) {
|
||||
py::class_<T>(m, name.c_str(), doc.c_str())
|
||||
.def(py::init<>())
|
||||
.def("tag", [](const T & swizzle){
|
||||
return demangle(typeid(T).name());
|
||||
}, R"pbdoc(Returns the c++ name of the swizzling for code emission)pbdoc");
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void bind_dgrad_swizzle(py::module & m, std::string name) {
|
||||
py::class_<T>(m, name.c_str(),
|
||||
R"pbdoc(Threadblock swizzling function for strided dgrad convolution)pbdoc")
|
||||
.def(py::init<>())
|
||||
.def("get_tiled_shape",
|
||||
py::overload_cast<cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize&, cutlass::gemm::GemmCoord, int>(
|
||||
&T::get_tiled_shape, py::const_
|
||||
), py::arg("conv_operator"), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"),
|
||||
R"pbdoc(Returns the shape of the problem in units of logical tiles
|
||||
|
||||
:param problem_size: Implicit gemm problem size conv_operator(NPQK, NHWC, KRSC)
|
||||
:type problem_size: :class:`cutlass.gemm.GemmCoord`)
|
||||
)pbdoc")
|
||||
.def("get_grid_shape", &T::get_grid_shape,
|
||||
py::arg("tiled_shape"),
|
||||
R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc")
|
||||
.def("tag", [](const T & swizzle){
|
||||
return demangle(typeid(T).name());
|
||||
}, R"pbdoc(Returns the c++ name of the swizzling for code emission)pbdoc");
|
||||
}
|
||||
|
||||
void bind_threadblock_swizzle(py::module &m) {
|
||||
|
||||
py::class_<dim3>(m, "dim3",
|
||||
R"pbdoc(A int3 type xyz contains three integers)pbdoc")
|
||||
.def(py::init<int, int, int>(),
|
||||
py::arg("x"), py::arg("y"), py::arg("z"))
|
||||
.def_readwrite("x", &dim3::x, R"pbdoc(get value x)pbdoc")
|
||||
.def_readwrite("y", &dim3::y, R"pbdoc(get value y)pbdoc")
|
||||
.def_readwrite("z", &dim3::z, R"pbdoc(get value z)pbdoc");
|
||||
|
||||
bind_identity_swizzle<cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>>(m, "IdentitySwizzle1");
|
||||
bind_identity_swizzle<cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>>(m, "IdentitySwizzle2");
|
||||
bind_identity_swizzle<cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>>(m, "IdentitySwizzle4");
|
||||
bind_identity_swizzle<cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>>(m, "IdentitySwizzle8");
|
||||
|
||||
bind_swizzle<cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle>(m, "HorizontalSwizzle", R"pbdoc(Threadblock swizzling function for GEMMs)pbdoc");
|
||||
bind_swizzle<cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle>(m, "BatchedIdentitySwizzle", R"pbdoc(Threadblock swizzling function for batched GEMMs)pbdoc");
|
||||
|
||||
bind_swizzle_streamk<cutlass::gemm::threadblock::ThreadblockSwizzleStreamK>(m, "ThreadblockSwizzleStreamK", R"pbdoc(Threadblock swizzling function using Stream K feature)pbdoc");
|
||||
|
||||
bind_dgrad_swizzle<cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>>(m, "StridedDgradIdentitySwizzle1");
|
||||
bind_dgrad_swizzle<cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>>(m, "StridedDgradIdentitySwizzle4");
|
||||
bind_dgrad_swizzle<cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle>(m, "StridedDgradHorizontalSwizzle");
|
||||
}
|
||||
@ -1,78 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind Tensor Coord to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "cutlass/tensor_coord.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void bind_tensor_coord(py::module &m) {
|
||||
//
|
||||
// Tensor Coords
|
||||
// cutlass/include/cutlass/tensor_coord.h
|
||||
//
|
||||
|
||||
/// Defines a canonical 4D coordinate used by tensor operations.
|
||||
py::class_<cutlass::Tensor4DCoord>(m, "Tensor4DCoord",
|
||||
R"pbdoc(Defines a canonical 4D coordinate used by tensor operations)pbdoc")
|
||||
.def(py::init<int, int, int, int>(),
|
||||
py::arg("n"), py::arg("h"), py::arg("w"), py::arg("c"),
|
||||
R"pbdoc(Helper to construct from N, H, W, and C)pbdoc")
|
||||
.def("at", py::overload_cast<int>(&cutlass::Tensor4DCoord::at),
|
||||
py::arg("dim"),
|
||||
R"pbdoc(Gets the index of a given Coord element)pbdoc")
|
||||
.def("size", [](const cutlass::Tensor4DCoord & coord) {
|
||||
return coord.at(0) * coord.at(1) * coord.at(2) * coord.at(3);},
|
||||
R"pbdoc(The size of the tensor coord)pbdoc");
|
||||
|
||||
py::class_<cutlass::Coord<3>>(m, "Tensor3DCoord",
|
||||
R"pbdoc(Defines a canonical 3D coordinate used by tensor operations)pbdoc")
|
||||
.def("at", py::overload_cast<int>(&cutlass::Coord<3>::at),
|
||||
py::arg("dim"),
|
||||
R"pbdoc(Gets the index of a given Coord element)pbdoc");
|
||||
|
||||
// Matrix Size
|
||||
py::class_<cutlass::MatrixCoord>(m, "MatrixCoord",
|
||||
R"pbdoc(MatrixCoord wraps Coord<2, int> to provide a helper for accessing named dimensions. Classes
|
||||
expecting a coordinate in the rank=2 index space of a matrix should use MatrixCoord.)pbdoc")
|
||||
.def(py::init<int, int>(),
|
||||
py::arg("row"), py::arg("column"), R"pbdoc(Helper to construct from a row and column)pbdoc")
|
||||
.def("row", py::overload_cast<>(&cutlass::MatrixCoord::row),
|
||||
R"pbdoc(Returns the row of the coordinate)pbdoc")
|
||||
.def("column", py::overload_cast<>(&cutlass::MatrixCoord::column),
|
||||
R"pbdoc(Returns the column of the coordinate)pbdoc");
|
||||
|
||||
}
|
||||
@ -1,102 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind TensorRef and View to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "types.h"
|
||||
|
||||
|
||||
template<typename T, typename L, typename TF>
|
||||
void bind_tensor_ref_view(py::module &m, std::string name) {
|
||||
py::class_<cutlass::TensorRef<T, L>>(m, ("TensorRef" + name).c_str())
|
||||
.def(py::init([](int64_t address, const L& layout_ ) {
|
||||
T* ptr = reinterpret_cast< T*>(address);
|
||||
return new cutlass::TensorRef<T, L>(ptr, layout_);
|
||||
}))
|
||||
.def("data", [](cutlass::TensorRef<T, L>& tensor_ref) {
|
||||
T* ptr = tensor_ref.data();
|
||||
return int64_t(ptr);
|
||||
})
|
||||
.def("layout", py::overload_cast<>(&cutlass::TensorRef<T, L>::layout));
|
||||
|
||||
m.def("get_tensor_ref", [](int64_t address, TF data, const L& layout_) {
|
||||
T* ptr = reinterpret_cast<T*>(address);
|
||||
cutlass::TensorRef<T, L> tensor_ref = cutlass::TensorRef<T, L>(ptr, layout_);
|
||||
return tensor_ref;
|
||||
});
|
||||
|
||||
py::class_<cutlass::TensorView<T, L>>(m, ("TensorView" + name).c_str())
|
||||
.def(py::init<const cutlass::TensorRef<T, L>&, const typename L::TensorCoord &>());
|
||||
}
|
||||
|
||||
|
||||
void bind_tensor_refs_and_views(py::module &m) {
|
||||
|
||||
/// float
|
||||
bind_tensor_ref_view<float, cutlass::layout::RowMajor, cutlass::float32>(m, "F32RowMajor");
|
||||
bind_tensor_ref_view<float, cutlass::layout::ColumnMajor, cutlass::float32>(m, "F32ColumnMajor");
|
||||
bind_tensor_ref_view<float, cutlass::layout::TensorNHWC, cutlass::float32>(m, "F32NHWC");
|
||||
|
||||
/// double
|
||||
bind_tensor_ref_view<double, cutlass::layout::RowMajor, cutlass::float64>(m, "F64RowMajor");
|
||||
bind_tensor_ref_view<double, cutlass::layout::ColumnMajor, cutlass::float64>(m, "F64ColumnMajor");
|
||||
bind_tensor_ref_view<double, cutlass::layout::TensorNHWC, cutlass::float64>(m, "F64NHWC");
|
||||
|
||||
// half_t
|
||||
bind_tensor_ref_view<cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t>(m, "F16RowMajor");
|
||||
bind_tensor_ref_view<cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t>(m, "F16ColumnMajor");
|
||||
bind_tensor_ref_view<cutlass::half_t, cutlass::layout::TensorNHWC, cutlass::half_t>(m, "F16NHWC");
|
||||
|
||||
// bfloat16
|
||||
bind_tensor_ref_view<cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t>(m, "BF16RowMajor");
|
||||
bind_tensor_ref_view<cutlass::bfloat16_t, cutlass::layout::ColumnMajor, cutlass::bfloat16_t>(m, "BF16ColumnMajor");
|
||||
bind_tensor_ref_view<cutlass::bfloat16_t, cutlass::layout::TensorNHWC, cutlass::bfloat16_t>(m, "BF16NHWC");
|
||||
|
||||
// int8_t
|
||||
bind_tensor_ref_view<int8_t, cutlass::layout::RowMajorInterleaved<32>, cutlass::int8>(m, "S8RowMajorInterleaved32");
|
||||
bind_tensor_ref_view<int8_t, cutlass::layout::ColumnMajorInterleaved<32>, cutlass::int8>(m, "S8ColumnMajorInterleaved32");
|
||||
bind_tensor_ref_view<int8_t, cutlass::layout::RowMajor, cutlass::int8>(m, "S8RowMajor");
|
||||
bind_tensor_ref_view<int8_t, cutlass::layout::ColumnMajor, cutlass::int8>(m, "S8ColumnMajor");
|
||||
bind_tensor_ref_view<int8_t, cutlass::layout::TensorNHWC, cutlass::int8>(m, "S8NHWC");
|
||||
bind_tensor_ref_view<int8_t, cutlass::layout::TensorNCxHWx<32>, cutlass::int8>(m, "S8NC32HW32");
|
||||
bind_tensor_ref_view<int8_t, cutlass::layout::TensorCxRSKx<32>, cutlass::int8>(m, "S8C32RSK32");
|
||||
|
||||
// int32_t
|
||||
bind_tensor_ref_view<int32_t, cutlass::layout::RowMajor, cutlass::int32>(m, "S32RowMajor");
|
||||
bind_tensor_ref_view<int32_t, cutlass::layout::ColumnMajor, cutlass::int32>(m, "S32ColumnMajor");
|
||||
bind_tensor_ref_view<int32_t, cutlass::layout::TensorNHWC, cutlass::int32>(m, "S32NHWC");
|
||||
}
|
||||
@ -1,146 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind CUTLASS types to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "cutlass/half.h"
|
||||
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/// IEEE 32-bit signed integer
|
||||
struct alignas(1) int8 {
|
||||
int8_t storage;
|
||||
explicit int8(int x) {
|
||||
storage = int8_t(x);
|
||||
}
|
||||
explicit int8(float x) {
|
||||
storage = int8_t(x);
|
||||
}
|
||||
|
||||
int8_t c_value(){return storage;}
|
||||
};
|
||||
|
||||
/// IEEE 32-bit signed integer
|
||||
struct alignas(4) int32 {
|
||||
int storage;
|
||||
explicit int32(int x) {
|
||||
storage = x;
|
||||
}
|
||||
explicit int32(float x) {
|
||||
storage = int(x);
|
||||
}
|
||||
|
||||
int c_value(){return storage;}
|
||||
};
|
||||
/// IEEE single-precision floating-point type
|
||||
struct alignas(4) float32 {
|
||||
float storage;
|
||||
explicit float32(float x) {
|
||||
storage = x;
|
||||
}
|
||||
explicit float32(int x) {
|
||||
storage = float(x);
|
||||
}
|
||||
float c_value(){return storage;}
|
||||
};
|
||||
/// IEEE double-precision floating-point type
|
||||
struct alignas(4) float64 {
|
||||
double storage;
|
||||
explicit float64(float x) {
|
||||
storage = double(x);
|
||||
}
|
||||
explicit float64(int x) {
|
||||
storage = double(x);
|
||||
}
|
||||
double c_value(){return storage;}
|
||||
};
|
||||
}
|
||||
|
||||
void bind_cutlass_types(py::module &m) {
|
||||
|
||||
// s8
|
||||
py::class_<cutlass::int8>(m, "int8")
|
||||
.def(py::init<float>())
|
||||
.def(py::init<int>())
|
||||
.def_readwrite("storage", &cutlass::int8::storage)
|
||||
.def("value", &cutlass::int8::c_value);
|
||||
|
||||
// s32
|
||||
py::class_<cutlass::int32>(m, "int32")
|
||||
.def(py::init<float>())
|
||||
.def(py::init<int>())
|
||||
.def_readwrite("storage", &cutlass::int32::storage)
|
||||
.def("value", &cutlass::int32::c_value);
|
||||
|
||||
// f16
|
||||
py::class_<cutlass::half_t>(m, "float16")
|
||||
.def(py::init<float>())
|
||||
.def(py::init<double>())
|
||||
.def(py::init<int>())
|
||||
.def(py::init<unsigned>())
|
||||
.def_readwrite("storage", &cutlass::half_t::storage)
|
||||
.def("value", [](const cutlass::half_t& value) {return value;});
|
||||
|
||||
// bf16
|
||||
py::class_<cutlass::bfloat16_t>(m, "bfloat16")
|
||||
.def(py::init<float>())
|
||||
.def(py::init<int>())
|
||||
.def_readwrite("storage", &cutlass::bfloat16_t::storage)
|
||||
.def("value", [](const cutlass::bfloat16_t& value) {return value;});
|
||||
|
||||
// f32
|
||||
py::class_<cutlass::float32>(m, "float32")
|
||||
.def(py::init<float>())
|
||||
.def(py::init<int>())
|
||||
.def_readwrite("storage", &cutlass::float32::storage)
|
||||
.def("value", &cutlass::float32::c_value);
|
||||
|
||||
// tf32
|
||||
py::class_<cutlass::tfloat32_t>(m, "tfloat32")
|
||||
.def(py::init<float>())
|
||||
.def(py::init<int>())
|
||||
.def_readwrite("storage", &cutlass::tfloat32_t::storage)
|
||||
.def("value", [](const cutlass::tfloat32_t& value) {return value;});
|
||||
|
||||
// f64
|
||||
py::class_<cutlass::float64>(m, "float64")
|
||||
.def(py::init<float>())
|
||||
.def(py::init<int>())
|
||||
.def_readwrite("storage", &cutlass::float64::storage)
|
||||
.def("value", &cutlass::float64::c_value);
|
||||
}
|
||||
@ -1,32 +0,0 @@
|
||||
#include <cutlass/complex.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/// ENUM class for datatypes
|
||||
enum class DataType {
|
||||
kB1, kU2, kU4, kU8,
|
||||
kU16, kU32, kU64, kS2,
|
||||
kS4, kS8, kS16, kS32,
|
||||
kS64, kF16, kBF16, kF32,
|
||||
kTF32, kF64, kCF16, kCBF16,
|
||||
kCF32, kCTF32, kCF64, kCS2,
|
||||
kCS4, kCS8, kCS16, kCS32,
|
||||
kCS64, kCU2, kCU4, kCU8,
|
||||
kCU16, kCU32, kCU64, kInvalid
|
||||
};
|
||||
|
||||
/// ENUM class for LayoutTypes
|
||||
enum class LayoutType {
|
||||
kColumnMajor, kRowMajor,
|
||||
kColumnMajorInterleaved2, kRowMajorInterleaved2,
|
||||
kColumnMajorInterleaved32, kRowMajorInterleaved32,
|
||||
kColumnMajorInterleaved64, kRowMajorInterleaved64,
|
||||
kTensorNHWC, kTensorNDHWC, kTensorNCHW, kTensorNGHWC,
|
||||
kTensorNC32HW32, kTensorNC64HW64, kTensorC32RSK32,
|
||||
kTensorC64RSK64
|
||||
};
|
||||
|
||||
/// ENUM class for opcode class
|
||||
|
||||
|
||||
} // namespace cutlass
|
||||
@ -1,54 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind convolution problems to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
|
||||
#include "unit/conv/device/conv2d_problems.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
PYBIND11_MAKE_OPAQUE(std::vector<cutlass::conv::Conv2dProblemSize>);
|
||||
|
||||
void bind_conv_problem_size_test(py::module &m) {
|
||||
|
||||
py::bind_vector<std::vector<cutlass::conv::Conv2dProblemSize>>(m, "Conv2dProblemVector")
|
||||
.def("size", &std::vector<cutlass::conv::Conv2dProblemSize>::size);
|
||||
// Get Conv2d problem sizes
|
||||
py::class_<test::conv::device::TestbedConv2dProblemSizes>(m, "TestbedConv2dProblemSizes")
|
||||
.def(py::init<int>())
|
||||
.def_readonly("conv2d_default_sizes", &test::conv::device::TestbedConv2dProblemSizes::conv2d_default_sizes);
|
||||
}
|
||||
@ -1,49 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind convolution related types to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "conv_problems.h"
|
||||
#include "host.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void bind_convolution_test(py::module &m) {
|
||||
// Conv problem sizes
|
||||
bind_conv_problem_size_test(m);
|
||||
|
||||
py::module_ host_submodule = m.def_submodule("host");
|
||||
bind_conv_host_references(host_submodule);
|
||||
}
|
||||
@ -1,181 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind Convolution host test helpers to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
#include "unit/conv/device/cache_testbed_output.h"
|
||||
|
||||
|
||||
#include "cutlass/util/reference/host/convolution.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
||||
template<typename Ta, typename La, typename Tb, typename Lb, typename Tc, typename Lc, typename Tacc, typename Te>
|
||||
void bind_conv2d_host(py::module &m) {
|
||||
m.def("conv2d", \
|
||||
&cutlass::reference::host::Conv2d< \
|
||||
Ta, La, Tb, Lb, Tc, Lc, Te, Tacc>);
|
||||
|
||||
m.def("CreateCachedConv2dTestKey", &test::conv::device::CreateCachedConv2dTestKey<Ta, La, Tb, Lb, Tc, Lc, Tacc, Te>);
|
||||
}
|
||||
|
||||
template<typename Ta, typename La, typename Tb, typename Lb, typename Tc, typename Lc, typename Tacc, typename Te>
|
||||
void bind_conv2d_host_sat(py::module &m) {
|
||||
m.def("conv2d", \
|
||||
&cutlass::reference::host::Conv2d< \
|
||||
Ta, La, Tb, Lb, Tc, Lc, Te, Tacc>);
|
||||
|
||||
m.def("CreateCachedConv2dTestKey", &test::conv::device::CreateCachedConv2dTestKey<Ta, La, Tb, Lb, Tc, Lc, Tacc, Te>);
|
||||
}
|
||||
|
||||
template<typename Ta, typename Tb, typename Tc, typename Tacc, typename Te>
|
||||
void bind_conv2d_host_nhwc(py::module &m) {
|
||||
bind_conv2d_host<
|
||||
Ta, cutlass::layout::TensorNHWC,
|
||||
Tb, cutlass::layout::TensorNHWC,
|
||||
Tc, cutlass::layout::TensorNHWC,
|
||||
Tacc, Te>(m);
|
||||
}
|
||||
|
||||
template<typename Ta, typename Tb, typename Tc, typename Tacc, typename Te>
|
||||
void bind_conv2d_host_nc32hw32(py::module &m) {
|
||||
bind_conv2d_host_sat<
|
||||
Ta, cutlass::layout::TensorNCxHWx<32>,
|
||||
Tb, cutlass::layout::TensorCxRSKx<32>,
|
||||
Tc, cutlass::layout::TensorNCxHWx<32>,
|
||||
Tacc, Te>(m);
|
||||
}
|
||||
|
||||
|
||||
template<typename T, typename Layout>
|
||||
void bind_tensor_equals(py::module &m) {
|
||||
m.def("equals", py::overload_cast<
|
||||
const cutlass::TensorView<T, Layout>&, const cutlass::TensorView<T, Layout>&>(
|
||||
&cutlass::reference::host::TensorEquals<T, Layout>
|
||||
));
|
||||
}
|
||||
|
||||
#define BIND_TENSOR_HASH(Element, Layout) { \
|
||||
m.def("TensorHash", &test::conv::device::TensorHash<Element, Layout>, py::arg("view"), py::arg("hash") = test::conv::device::CRC32(), py::arg("crc")=uint32_t()); \
|
||||
}
|
||||
|
||||
void bind_conv_host_references(py::module &m) {
|
||||
//
|
||||
// Conv2d reference on host
|
||||
// tools/util/include/cutlass/util/reference/host/convolution.h
|
||||
|
||||
/// double
|
||||
bind_conv2d_host_nhwc<double, double, double, double, double>(m);
|
||||
/// float
|
||||
bind_conv2d_host_nhwc<float, float, float, float, float>(m);
|
||||
/// half
|
||||
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, cutlass::half_t, cutlass::half_t, cutlass::half_t>(m);
|
||||
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, cutlass::half_t, float, cutlass::half_t>(m);
|
||||
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, cutlass::half_t, float, float>(m);
|
||||
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, cutlass::half_t, cutlass::half_t, float>(m);
|
||||
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, float, cutlass::half_t, cutlass::half_t>(m);
|
||||
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, float, float, cutlass::half_t>(m);
|
||||
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, float, float, float>(m);
|
||||
bind_conv2d_host_nhwc<cutlass::half_t, cutlass::half_t, float, cutlass::half_t, float>(m);
|
||||
/// bfloat16
|
||||
bind_conv2d_host_nhwc<cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass::bfloat16_t, float, cutlass::bfloat16_t>(m);
|
||||
bind_conv2d_host_nhwc<cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass::bfloat16_t, float, float>(m);
|
||||
bind_conv2d_host_nhwc<cutlass::bfloat16_t, cutlass::bfloat16_t, float, float, cutlass::bfloat16_t>(m);
|
||||
bind_conv2d_host_nhwc<cutlass::bfloat16_t, cutlass::bfloat16_t, float, float, float>(m);
|
||||
/// s8
|
||||
bind_conv2d_host_nhwc<int8_t, int8_t, int8_t, int32_t, int32_t>(m);
|
||||
bind_conv2d_host_nhwc<int8_t, int8_t, int8_t, int32_t, int8_t>(m);
|
||||
bind_conv2d_host_nhwc<int8_t, int8_t, int32_t, int32_t, int32_t>(m);
|
||||
bind_conv2d_host_nhwc<int8_t, int8_t, int32_t, int32_t, int8_t>(m);
|
||||
bind_conv2d_host_nhwc<int8_t, int8_t, int8_t, int32_t, float>(m);
|
||||
bind_conv2d_host_nhwc<int8_t, int8_t, int8_t, int32_t, float>(m);
|
||||
bind_conv2d_host_nhwc<int8_t, int8_t, int32_t, int32_t, float>(m);
|
||||
bind_conv2d_host_nhwc<int8_t, int8_t, int32_t, int32_t, float>(m);
|
||||
|
||||
bind_conv2d_host_nc32hw32<int8_t, int8_t, int8_t, int32_t, int32_t>(m);
|
||||
bind_conv2d_host_nc32hw32<int8_t, int8_t, int8_t, int32_t, int8_t>(m);
|
||||
bind_conv2d_host_nc32hw32<int8_t, int8_t, int32_t, int32_t, int32_t>(m);
|
||||
bind_conv2d_host_nc32hw32<int8_t, int8_t, int32_t, int32_t, int8_t>(m);
|
||||
bind_conv2d_host_nc32hw32<int8_t, int8_t, int8_t, int32_t, float>(m);
|
||||
bind_conv2d_host_nc32hw32<int8_t, int8_t, int8_t, int32_t, float>(m);
|
||||
bind_conv2d_host_nc32hw32<int8_t, int8_t, int32_t, int32_t, float>(m);
|
||||
bind_conv2d_host_nc32hw32<int8_t, int8_t, int32_t, int32_t, float>(m);
|
||||
|
||||
//
|
||||
// Compare whether two tensors are equal
|
||||
//
|
||||
/// double
|
||||
bind_tensor_equals<double, cutlass::layout::TensorNHWC>(m);
|
||||
/// float
|
||||
bind_tensor_equals<float, cutlass::layout::TensorNHWC>(m);
|
||||
/// half
|
||||
bind_tensor_equals<cutlass::half_t, cutlass::layout::TensorNHWC>(m);
|
||||
/// bfloat16
|
||||
bind_tensor_equals<cutlass::bfloat16_t, cutlass::layout::TensorNHWC>(m);
|
||||
/// s32
|
||||
bind_tensor_equals<int32_t, cutlass::layout::TensorNHWC>(m);
|
||||
bind_tensor_equals<int32_t, cutlass::layout::TensorNCxHWx<32>>(m);
|
||||
/// s8
|
||||
bind_tensor_equals<int8_t, cutlass::layout::TensorNHWC>(m);
|
||||
bind_tensor_equals<int8_t, cutlass::layout::TensorNCxHWx<32>>(m);
|
||||
|
||||
/// Cache
|
||||
py::class_<test::conv::device::CachedTestKey>(m, "CachedTestKey")
|
||||
.def(py::init<>())
|
||||
.def(py::init<std::string, std::string, std::string, uint32_t, uint32_t, uint32_t>())
|
||||
.def_readwrite("problem", &test::conv::device::CachedTestKey::problem);
|
||||
|
||||
py::class_<test::conv::device::CachedTestResult>(m, "CachedTestResult")
|
||||
.def(py::init<>())
|
||||
.def(py::init<uint32_t>())
|
||||
.def_readwrite("D", &test::conv::device::CachedTestResult::D);
|
||||
|
||||
py::class_<test::conv::device::CachedTestResultListing>(m, "CachedTestResultListing")
|
||||
.def(py::init<const std::string &>())
|
||||
.def("find", &test::conv::device::CachedTestResultListing::find)
|
||||
.def("append", &test::conv::device::CachedTestResultListing::append)
|
||||
.def("write", &test::conv::device::CachedTestResultListing::write);
|
||||
|
||||
py::class_<test::conv::device::CRC32>(m, "CRC32")
|
||||
.def(py::init<>());
|
||||
|
||||
BIND_TENSOR_HASH(double, cutlass::layout::TensorNHWC)
|
||||
BIND_TENSOR_HASH(float, cutlass::layout::TensorNHWC);
|
||||
BIND_TENSOR_HASH(cutlass::half_t, cutlass::layout::TensorNHWC);
|
||||
BIND_TENSOR_HASH(cutlass::bfloat16_t, cutlass::layout::TensorNHWC);
|
||||
BIND_TENSOR_HASH(int32_t, cutlass::layout::TensorNHWC);
|
||||
BIND_TENSOR_HASH(int8_t, cutlass::layout::TensorNCxHWx<32>);
|
||||
}
|
||||
@ -1,45 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind gemm test to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "host.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void bind_gemm_test(py::module &m) {
|
||||
py::module_ host_submodule = m.def_submodule("host");
|
||||
bind_gemm_host_reference(host_submodule);
|
||||
}
|
||||
@ -1,431 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Bind gemm test host functions to python
|
||||
*/
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/host_reorder.h"
|
||||
|
||||
#include "cutlass/functional.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
||||
template<
|
||||
typename ElementA, typename LayoutA,
|
||||
typename ElementB, typename LayoutB,
|
||||
typename ElementC, typename LayoutC,
|
||||
typename AccumulatorType, typename ComputeType,
|
||||
typename InnerProductOp>
|
||||
void bind_host_gemm_saturate(py::module &m) {
|
||||
m.def("gemm_saturate", py::overload_cast<
|
||||
cutlass::gemm::GemmCoord, ComputeType,
|
||||
cutlass::TensorRef<ElementA, LayoutA>,
|
||||
cutlass::TensorRef<ElementB, LayoutB>,
|
||||
ComputeType,
|
||||
cutlass::TensorRef<ElementC, LayoutC>,
|
||||
cutlass::TensorRef<ElementC, LayoutC>,
|
||||
AccumulatorType>(
|
||||
&cutlass::reference::host::compute_gemm<
|
||||
ElementA, LayoutA,
|
||||
ElementB, LayoutB,
|
||||
ElementC, LayoutC,
|
||||
ComputeType,
|
||||
AccumulatorType,
|
||||
InnerProductOp,
|
||||
cutlass::NumericConverterClamp<ElementC, AccumulatorType>>
|
||||
));
|
||||
}
|
||||
|
||||
template<
|
||||
typename ElementA, typename LayoutA,
|
||||
typename ElementB, typename LayoutB,
|
||||
typename ElementC, typename LayoutC,
|
||||
typename AccumulatorType, typename ComputeType,
|
||||
typename InnerProductOp>
|
||||
void bind_host_gemm(py::module &m) {
|
||||
m.def("gemm", py::overload_cast<
|
||||
cutlass::gemm::GemmCoord, ComputeType,
|
||||
cutlass::TensorRef<ElementA, LayoutA>,
|
||||
cutlass::TensorRef<ElementB, LayoutB>,
|
||||
ComputeType,
|
||||
cutlass::TensorRef<ElementC, LayoutC>,
|
||||
cutlass::TensorRef<ElementC, LayoutC>,
|
||||
AccumulatorType>(
|
||||
&cutlass::reference::host::compute_gemm<
|
||||
ElementA, LayoutA,
|
||||
ElementB, LayoutB,
|
||||
ElementC, LayoutC,
|
||||
ComputeType,
|
||||
AccumulatorType,
|
||||
InnerProductOp,
|
||||
cutlass::NumericConverter<ElementC, AccumulatorType>>
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
template<
|
||||
typename ElementA, typename ElementB, typename ElementC,
|
||||
typename AccumulatorType, typename ComputeType>
|
||||
void bind_host_gemm_multiply_add(py::module &m) {
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::RowMajor,
|
||||
ElementB, cutlass::layout::RowMajor,
|
||||
ElementC, cutlass::layout::RowMajor,
|
||||
ComputeType, AccumulatorType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::ColumnMajor,
|
||||
ElementB, cutlass::layout::RowMajor,
|
||||
ElementC, cutlass::layout::RowMajor,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::RowMajor,
|
||||
ElementB, cutlass::layout::ColumnMajor,
|
||||
ElementC, cutlass::layout::RowMajor,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::RowMajor,
|
||||
ElementB, cutlass::layout::RowMajor,
|
||||
ElementC, cutlass::layout::ColumnMajor,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::RowMajor,
|
||||
ElementB, cutlass::layout::ColumnMajor,
|
||||
ElementC, cutlass::layout::ColumnMajor,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::ColumnMajor,
|
||||
ElementB, cutlass::layout::RowMajor,
|
||||
ElementC, cutlass::layout::ColumnMajor,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::ColumnMajor,
|
||||
ElementB, cutlass::layout::ColumnMajor,
|
||||
ElementC, cutlass::layout::RowMajor,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::ColumnMajor,
|
||||
ElementB, cutlass::layout::ColumnMajor,
|
||||
ElementC, cutlass::layout::ColumnMajor,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
}
|
||||
|
||||
template<
|
||||
typename ElementA, typename ElementB, typename ElementC,
|
||||
typename AccumulatorType, typename ComputeType>
|
||||
void bind_host_gemm_multiply_add_saturate(py::module &m) {
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::RowMajor,
|
||||
ElementB, cutlass::layout::RowMajor,
|
||||
ElementC, cutlass::layout::RowMajor,
|
||||
ComputeType, AccumulatorType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::ColumnMajor,
|
||||
ElementB, cutlass::layout::RowMajor,
|
||||
ElementC, cutlass::layout::RowMajor,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::RowMajor,
|
||||
ElementB, cutlass::layout::ColumnMajor,
|
||||
ElementC, cutlass::layout::RowMajor,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::RowMajor,
|
||||
ElementB, cutlass::layout::RowMajor,
|
||||
ElementC, cutlass::layout::ColumnMajor,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::RowMajor,
|
||||
ElementB, cutlass::layout::ColumnMajor,
|
||||
ElementC, cutlass::layout::ColumnMajor,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::ColumnMajor,
|
||||
ElementB, cutlass::layout::RowMajor,
|
||||
ElementC, cutlass::layout::ColumnMajor,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::ColumnMajor,
|
||||
ElementB, cutlass::layout::ColumnMajor,
|
||||
ElementC, cutlass::layout::RowMajor,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::ColumnMajor,
|
||||
ElementB, cutlass::layout::ColumnMajor,
|
||||
ElementC, cutlass::layout::ColumnMajor,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
}
|
||||
|
||||
|
||||
template<
|
||||
typename ElementA, typename ElementB, typename ElementC,
|
||||
typename AccumulatorType, typename ComputeType>
|
||||
void bind_host_gemm_multiply_add_interleaved(py::module &m) {
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ComputeType, AccumulatorType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::RowMajorInterleaved<32>,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::RowMajorInterleaved<32>,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::RowMajorInterleaved<32>,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm<
|
||||
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
}
|
||||
|
||||
template<
|
||||
typename ElementA, typename ElementB, typename ElementC,
|
||||
typename AccumulatorType, typename ComputeType>
|
||||
void bind_host_gemm_multiply_add_saturate_interleaved(py::module &m) {
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ComputeType, AccumulatorType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::RowMajorInterleaved<32>,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::RowMajorInterleaved<32>,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::RowMajorInterleaved<32>,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
|
||||
bind_host_gemm_saturate<
|
||||
ElementA, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementB, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementC, cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
AccumulatorType, ComputeType,
|
||||
cutlass::multiply_add<AccumulatorType>>(m);
|
||||
}
|
||||
|
||||
#define BIND_TENSOR_EQUAL(Element, Layout) { \
|
||||
m.def("equals", py::overload_cast< \
|
||||
const cutlass::TensorView<Element, Layout>&, const cutlass::TensorView<Element, Layout>&>( \
|
||||
&cutlass::reference::host::TensorEquals<Element, Layout>)); \
|
||||
}
|
||||
|
||||
void bind_gemm_host_reference(py::module &m) {
|
||||
|
||||
/// double
|
||||
bind_host_gemm_multiply_add<double, double, double, double, double>(m);
|
||||
/// float
|
||||
bind_host_gemm_multiply_add<float, float, float, float, float>(m);
|
||||
/// half_t
|
||||
bind_host_gemm_multiply_add<cutlass::half_t, cutlass::half_t, cutlass::half_t, cutlass::half_t, cutlass::half_t>(m);
|
||||
bind_host_gemm_multiply_add<cutlass::half_t, cutlass::half_t, cutlass::half_t, float, float>(m);
|
||||
bind_host_gemm_multiply_add<cutlass::half_t, cutlass::half_t, float, cutlass::half_t, cutlass::half_t>(m);
|
||||
bind_host_gemm_multiply_add<cutlass::half_t, cutlass::half_t, float, float, float>(m);
|
||||
/// bfloat16
|
||||
bind_host_gemm_multiply_add<cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass::bfloat16_t, float, float>(m);
|
||||
bind_host_gemm_multiply_add<cutlass::bfloat16_t, cutlass::bfloat16_t, float, float, float>(m);
|
||||
|
||||
/// s8
|
||||
bind_host_gemm_multiply_add<int8_t, int8_t, int8_t, int32_t, int32_t>(m);
|
||||
bind_host_gemm_multiply_add<int8_t, int8_t, int8_t, int32_t, int8_t>(m);
|
||||
bind_host_gemm_multiply_add<int8_t, int8_t, int32_t, int32_t, int32_t>(m);
|
||||
bind_host_gemm_multiply_add<int8_t, int8_t, int32_t, int32_t, int8_t>(m);
|
||||
bind_host_gemm_multiply_add<int8_t, int8_t, int8_t, int32_t, float>(m);
|
||||
bind_host_gemm_multiply_add<int8_t, int8_t, int8_t, int32_t, float>(m);
|
||||
bind_host_gemm_multiply_add<int8_t, int8_t, int32_t, int32_t, float>(m);
|
||||
bind_host_gemm_multiply_add<int8_t, int8_t, int32_t, int32_t, float>(m);
|
||||
|
||||
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int8_t, int32_t, int32_t>(m);
|
||||
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int8_t, int32_t, int8_t>(m);
|
||||
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int32_t, int32_t, int32_t>(m);
|
||||
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int32_t, int32_t, int8_t>(m);
|
||||
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int8_t, int32_t, float>(m);
|
||||
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int8_t, int32_t, float>(m);
|
||||
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int32_t, int32_t, float>(m);
|
||||
bind_host_gemm_multiply_add_interleaved<int8_t, int8_t, int32_t, int32_t, float>(m);
|
||||
|
||||
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int8_t, int32_t, int32_t>(m);
|
||||
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int8_t, int32_t, int8_t>(m);
|
||||
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int32_t, int32_t, int32_t>(m);
|
||||
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int32_t, int32_t, int8_t>(m);
|
||||
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int8_t, int32_t, float>(m);
|
||||
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int8_t, int32_t, float>(m);
|
||||
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int32_t, int32_t, float>(m);
|
||||
bind_host_gemm_multiply_add_saturate<int8_t, int8_t, int32_t, int32_t, float>(m);
|
||||
|
||||
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int8_t, int32_t, int32_t>(m);
|
||||
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int8_t, int32_t, int8_t>(m);
|
||||
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int32_t, int32_t, int32_t>(m);
|
||||
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int32_t, int32_t, int8_t>(m);
|
||||
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int8_t, int32_t, float>(m);
|
||||
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int8_t, int32_t, float>(m);
|
||||
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int32_t, int32_t, float>(m);
|
||||
bind_host_gemm_multiply_add_saturate_interleaved<int8_t, int8_t, int32_t, int32_t, float>(m);
|
||||
|
||||
// float
|
||||
BIND_TENSOR_EQUAL(float, cutlass::layout::RowMajor);
|
||||
BIND_TENSOR_EQUAL(float, cutlass::layout::ColumnMajor);
|
||||
|
||||
// double
|
||||
BIND_TENSOR_EQUAL(double, cutlass::layout::RowMajor);
|
||||
BIND_TENSOR_EQUAL(double, cutlass::layout::ColumnMajor);
|
||||
|
||||
// half_t
|
||||
BIND_TENSOR_EQUAL(cutlass::half_t, cutlass::layout::RowMajor);
|
||||
BIND_TENSOR_EQUAL(cutlass::half_t, cutlass::layout::ColumnMajor);
|
||||
|
||||
// bfloat16
|
||||
BIND_TENSOR_EQUAL(cutlass::bfloat16_t, cutlass::layout::RowMajor);
|
||||
BIND_TENSOR_EQUAL(cutlass::bfloat16_t, cutlass::layout::ColumnMajor);
|
||||
|
||||
// int32_t
|
||||
BIND_TENSOR_EQUAL(int32_t, cutlass::layout::RowMajor);
|
||||
BIND_TENSOR_EQUAL(int32_t, cutlass::layout::ColumnMajor);
|
||||
|
||||
// int8_t
|
||||
BIND_TENSOR_EQUAL(int8_t, cutlass::layout::RowMajor);
|
||||
BIND_TENSOR_EQUAL(int8_t, cutlass::layout::ColumnMajor);
|
||||
BIND_TENSOR_EQUAL(int8_t, cutlass::layout::RowMajorInterleaved<32>);
|
||||
BIND_TENSOR_EQUAL(int8_t, cutlass::layout::ColumnMajorInterleaved<32>);
|
||||
|
||||
|
||||
}
|
||||
@ -214,12 +214,12 @@ cutlass::Status ${name}_kernel_run(cutlass::conv::Conv2dProblemSize* problem_siz
|
||||
cutlass::Tensor4DCoord tensor_coord_C = cutlass::conv::implicit_gemm_tensor_c_extent(
|
||||
cutlass::conv::Operator::k${conv_kind_name}, *problem_size
|
||||
);
|
||||
|
||||
|
||||
TensorRefA tensor_ref_A = get_tensor_ref<TensorRefA, UnderlyingKernel::ElementA>(tensor_coord_A, A);
|
||||
TensorRefB tensor_ref_B = get_tensor_ref<TensorRefB, UnderlyingKernel::ElementB>(tensor_coord_B, B);
|
||||
TensorRefC tensor_ref_C = get_tensor_ref<TensorRefC, UnderlyingKernel::ElementC>(tensor_coord_C, C);
|
||||
TensorRefC tensor_ref_D = get_tensor_ref<TensorRefC, UnderlyingKernel::ElementC>(tensor_coord_C, D);
|
||||
|
||||
|
||||
cutlass::conv::SplitKMode mode;
|
||||
if (split_k_mode == "serial") {
|
||||
mode = cutlass::conv::SplitKMode::kSerial;
|
||||
@ -228,7 +228,7 @@ cutlass::Status ${name}_kernel_run(cutlass::conv::Conv2dProblemSize* problem_siz
|
||||
} else {
|
||||
throw std::runtime_error("Invalid split_k_mode: " + split_k_mode);
|
||||
}
|
||||
|
||||
|
||||
typename DeviceKernel::Arguments arguments{
|
||||
*problem_size,
|
||||
tensor_ref_A,
|
||||
@ -238,18 +238,18 @@ cutlass::Status ${name}_kernel_run(cutlass::conv::Conv2dProblemSize* problem_siz
|
||||
{alpha, beta},
|
||||
mode
|
||||
};
|
||||
|
||||
|
||||
DeviceKernel implicit_gemm_op;
|
||||
|
||||
|
||||
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
|
||||
|
||||
|
||||
void* workspace_ptr = device_memory_allocation(workspace_size, device_id);
|
||||
|
||||
|
||||
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
|
||||
status = implicit_gemm_op.initialize(arguments, workspace_ptr, stream);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return status;
|
||||
@ -260,6 +260,6 @@ cutlass::Status ${name}_kernel_run(cutlass::conv::Conv2dProblemSize* problem_siz
|
||||
//
|
||||
status = implicit_gemm_op(stream);
|
||||
|
||||
return status;
|
||||
return status;
|
||||
}
|
||||
"""
|
||||
|
||||
@ -81,12 +81,10 @@ The module can later be used in Python via:
|
||||
import logging
|
||||
import os
|
||||
|
||||
import cutlass_bindings
|
||||
|
||||
from cutlass import CUTLASS_PATH, logger, swizzle
|
||||
from cutlass import CUTLASS_PATH, logger, swizzle, ConvKind, ConvKindNames, DataType
|
||||
from cutlass.backend.gemm_operation import GemmOperationGrouped, GemmOperationUniversal
|
||||
from cutlass.backend.conv2d_operation import Conv2dOperation
|
||||
from cutlass.backend.library import ApiVersion, ConvKindNames
|
||||
from cutlass.backend.library import ApiVersion
|
||||
from cutlass.backend.utils.software import CheckPackages, SubstituteTemplate
|
||||
from cutlass.emit import common
|
||||
|
||||
@ -165,26 +163,26 @@ _PYTORCH_CONV2D_FPROP_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
|
||||
|
||||
// CUDA forward declarations
|
||||
at::Tensor ${name}_kernel(
|
||||
const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
||||
float alpha=1.f, float beta=0.f,
|
||||
std::string split_k_mode="serial", int split_k_slices=1);
|
||||
|
||||
// C++ interface
|
||||
at::Tensor ${name}(
|
||||
const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
||||
float alpha=1.f, float beta=0.f,
|
||||
float alpha=1.f, float beta=0.f,
|
||||
std::string split_k_mode="serial", int split_k_slices=1) {
|
||||
return ${name}_kernel(A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices);
|
||||
return ${name}_kernel(A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices);
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("run",
|
||||
py::overload_cast<
|
||||
const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>,
|
||||
std::tuple<int, int>, std::tuple<int, int>, std::tuple<int, int>, float, float, std::string, int>(
|
||||
&${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr,
|
||||
&${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr,
|
||||
py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1),
|
||||
py::arg("alpha") = 1.f, py::arg("beta") = 0.f,
|
||||
py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1);
|
||||
@ -198,26 +196,26 @@ _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
|
||||
|
||||
// CUDA forward declarations
|
||||
at::Tensor ${name}_kernel(
|
||||
std::tuple<int, int, int, int> result_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int, int, int> result_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
||||
float alpha=1.f, float beta=0.f,
|
||||
float alpha=1.f, float beta=0.f,
|
||||
std::string split_k_mode="serial", int split_k_slices=1);
|
||||
|
||||
// C++ interface
|
||||
at::Tensor ${name}(
|
||||
std::tuple<int, int, int, int> result_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int, int, int> result_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
||||
float alpha=1.f, float beta=0.f,
|
||||
float alpha=1.f, float beta=0.f,
|
||||
std::string split_k_mode="serial", int split_k_slices=1) {
|
||||
return ${name}_kernel(result_size, A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices);
|
||||
return ${name}_kernel(result_size, A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("run",
|
||||
py::overload_cast<
|
||||
std::tuple<int, int, int, int>, const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>,
|
||||
std::tuple<int, int, int, int>, const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>,
|
||||
std::tuple<int, int>, std::tuple<int, int>, std::tuple<int, int>, float, float, std::string, int>(
|
||||
&${name}), py::arg("result_size"), py::arg("A"), py::arg("B"), py::arg("C") = nullptr,
|
||||
&${name}), py::arg("result_size"), py::arg("A"), py::arg("B"), py::arg("C") = nullptr,
|
||||
py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1),
|
||||
py::arg("alpha") = 1.f, py::arg("beta") = 0.f,
|
||||
py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1);
|
||||
@ -251,11 +249,11 @@ _PYTORCH_CONV2D_INCLUDES = """
|
||||
"""
|
||||
|
||||
_CUTLASS_TYPE_TO_TORCH_TYPE = {
|
||||
cutlass_bindings.float16: "torch::kF16",
|
||||
cutlass_bindings.float32: "torch::kF32",
|
||||
cutlass_bindings.float64: "torch::kF64",
|
||||
cutlass_bindings.int8: "torch::I8",
|
||||
cutlass_bindings.int32: "torch::I32",
|
||||
DataType.f16: "torch::kF16",
|
||||
DataType.f32: "torch::kF32",
|
||||
DataType.f64: "torch::kF64",
|
||||
DataType.s8: "torch::I8",
|
||||
DataType.s32: "torch::I32",
|
||||
}
|
||||
|
||||
_PYTORCH_GEMM_IMPL_TEMPLATE_2x = (
|
||||
@ -446,16 +444,16 @@ std::vector<at::Tensor> ${name}_kernel(const std::vector<at::Tensor>& A, const s
|
||||
|
||||
_PYTORCH_CONV2D_IMPL_TEMPLATE_2x = """
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
|
||||
cutlass::Status status = ${name}_kernel_run(
|
||||
&problem_size,
|
||||
&problem_size,
|
||||
reinterpret_cast<typename UnderlyingKernel::ElementA*>(A.data_ptr()),
|
||||
reinterpret_cast<typename UnderlyingKernel::ElementB*>(B.data_ptr()),
|
||||
ptrC,
|
||||
reinterpret_cast<typename UnderlyingKernel::ElementC*>(D.data_ptr()),
|
||||
alpha, beta,
|
||||
split_k_mode, stream, B.device().index());
|
||||
|
||||
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
|
||||
return D;
|
||||
}
|
||||
@ -464,19 +462,19 @@ _PYTORCH_CONV2D_IMPL_TEMPLATE_2x = """
|
||||
_PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x = (
|
||||
common._CUTLASS_KERNEL_RUN_CONV2D_2x
|
||||
+ """
|
||||
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
||||
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
||||
float alpha=1.f, float beta=0.f, std::string split_k_mode="serial", int split_k_slices=1) {
|
||||
int N, H, W, C_, K, R, S, P, Q;
|
||||
N = A.size(0);
|
||||
C_ = A.size(1);
|
||||
H = A.size(2);
|
||||
W = A.size(3);
|
||||
|
||||
|
||||
K = B.size(0);
|
||||
R = B.size(2);
|
||||
S = B.size(3);
|
||||
|
||||
|
||||
cutlass::conv::Conv2dProblemSize problem_size(
|
||||
cutlass::Tensor4DCoord(N, H, W, C_),
|
||||
cutlass::Tensor4DCoord(K, R, S, C_),
|
||||
@ -486,14 +484,14 @@ at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional
|
||||
cutlass::conv::Mode::kCrossCorrelation,
|
||||
split_k_slices
|
||||
);
|
||||
|
||||
|
||||
P = problem_size.P;
|
||||
Q = problem_size.Q;
|
||||
|
||||
|
||||
typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
|
||||
nullptr :
|
||||
reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
|
||||
|
||||
|
||||
torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
|
||||
at::Tensor D = torch::zeros({N, K, P, Q}, options);
|
||||
""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
|
||||
@ -503,7 +501,7 @@ at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional
|
||||
_PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x = (
|
||||
common._CUTLASS_KERNEL_RUN_CONV2D_2x
|
||||
+ """
|
||||
at::Tensor ${name}_kernel(std::tuple<int, int, int, int> input_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
at::Tensor ${name}_kernel(std::tuple<int, int, int, int> input_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1}, float alpha=1.f, float beta=0.f,
|
||||
std::string split_k_mode="serial", int split_k_slices=1) {
|
||||
int N, H, W, C_, K, R, S;
|
||||
@ -511,11 +509,11 @@ at::Tensor ${name}_kernel(std::tuple<int, int, int, int> input_size, const at::T
|
||||
C_ = std::get<1>(input_size);
|
||||
H = std::get<2>(input_size);
|
||||
W = std::get<3>(input_size);
|
||||
|
||||
|
||||
K = B.size(0);
|
||||
R = B.size(2);
|
||||
S = B.size(3);
|
||||
|
||||
|
||||
cutlass::conv::Conv2dProblemSize problem_size(
|
||||
cutlass::Tensor4DCoord(N, H, W, C_),
|
||||
cutlass::Tensor4DCoord(K, R, S, C_),
|
||||
@ -525,11 +523,11 @@ at::Tensor ${name}_kernel(std::tuple<int, int, int, int> input_size, const at::T
|
||||
cutlass::conv::Mode::kCrossCorrelation,
|
||||
split_k_slices
|
||||
);
|
||||
|
||||
|
||||
typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
|
||||
nullptr :
|
||||
reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
|
||||
|
||||
|
||||
torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
|
||||
at::Tensor D = torch::empty({N, C_, H, W}, options);
|
||||
""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
|
||||
@ -539,19 +537,19 @@ at::Tensor ${name}_kernel(std::tuple<int, int, int, int> input_size, const at::T
|
||||
_PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x = (
|
||||
common._CUTLASS_KERNEL_RUN_CONV2D_2x
|
||||
+ """
|
||||
at::Tensor ${name}_kernel(std::tuple<int, int, int, int> weight_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1}, float alpha=1.f, float beta=0.f,
|
||||
at::Tensor ${name}_kernel(std::tuple<int, int, int, int> weight_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1}, float alpha=1.f, float beta=0.f,
|
||||
std::string split_k_mode="serial", int split_k_slices=1) {
|
||||
int N, H, W, C_, K, R, S;
|
||||
K = std::get<0>(weight_size);
|
||||
C_ = std::get<1>(weight_size);
|
||||
R = std::get<2>(weight_size);
|
||||
S = std::get<3>(weight_size);
|
||||
|
||||
|
||||
N = B.size(0);
|
||||
H = B.size(2);
|
||||
W = B.size(3);
|
||||
|
||||
|
||||
cutlass::conv::Conv2dProblemSize problem_size(
|
||||
cutlass::Tensor4DCoord(N, H, W, C_),
|
||||
cutlass::Tensor4DCoord(K, R, S, C_),
|
||||
@ -561,11 +559,11 @@ at::Tensor ${name}_kernel(std::tuple<int, int, int, int> weight_size, const at::
|
||||
cutlass::conv::Mode::kCrossCorrelation,
|
||||
split_k_slices
|
||||
);
|
||||
|
||||
|
||||
typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
|
||||
nullptr :
|
||||
reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
|
||||
|
||||
|
||||
torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
|
||||
at::Tensor D = torch::empty({K, C_, R, S}, options);
|
||||
""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
|
||||
@ -726,7 +724,7 @@ def _pytorch_gemm(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""
|
||||
impl_template = _PYTORCH_GEMM_IMPL_TEMPLATE_3x
|
||||
else:
|
||||
impl_template = _PYTORCH_GEMM_IMPL_TEMPLATE_2x
|
||||
if isinstance(op.swizzling_functor, swizzle.ThreadblockSwizzleStreamK):
|
||||
if op.swizzling_functor == swizzle.ThreadblockSwizzleStreamK:
|
||||
extra_kw["args"] = common._CUTLASS_KERNEL_ARGS_2x_STREAM_K
|
||||
else:
|
||||
extra_kw["args"] = common._CUTLASS_KERNEL_ARGS_2x
|
||||
@ -837,9 +835,9 @@ def _pytorch_conv2d(op, name: str, cc: int, jit: bool = False, sourcedir: str =
|
||||
:type jit: bool
|
||||
:param sourcedir: directory to which generated source files should be written
|
||||
:type sourcedir: str
|
||||
|
||||
|
||||
Note that the when conv kind is `dgrad` or `wgrad`, the size of the input `(N, C, H, W)` or
|
||||
weight `(K, C, R, S)` should be provided. This is because there are multiple valid solutions
|
||||
weight `(K, C, R, S)` should be provided. This is because there are multiple valid solutions
|
||||
for H/W/R/S given the same P/Q.
|
||||
|
||||
:return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise
|
||||
@ -848,13 +846,13 @@ def _pytorch_conv2d(op, name: str, cc: int, jit: bool = False, sourcedir: str =
|
||||
os.makedirs(sourcedir)
|
||||
cuda_file = os.path.join(sourcedir, name + "_kernel.cu")
|
||||
extra_kw = {}
|
||||
if op.conv_kind == cutlass_bindings.conv.Operator.fprop:
|
||||
if op.conv_kind == ConvKind.Fprop:
|
||||
impl_template = _PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x
|
||||
cpp_template = _PYTORCH_CONV2D_FPROP_CPP_TEMPLATE
|
||||
elif op.conv_kind == cutlass_bindings.conv.Operator.dgrad:
|
||||
elif op.conv_kind == ConvKind.Dgrad:
|
||||
impl_template = _PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x
|
||||
cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE
|
||||
elif op.conv_kind == cutlass_bindings.conv.Operator.wgrad:
|
||||
elif op.conv_kind == ConvKind.Wgrad:
|
||||
impl_template = _PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x
|
||||
cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE
|
||||
extra_kw["conv_kind_name"] = ConvKindNames[op.conv_kind].capitalize()
|
||||
|
||||
53
python/cutlass/epilogue/__init__.py
Normal file
53
python/cutlass/epilogue/__init__.py
Normal file
@ -0,0 +1,53 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass.epilogue.epilogue import (
|
||||
get_activations,
|
||||
get_activation_epilogue,
|
||||
gelu,
|
||||
hardswish,
|
||||
identity,
|
||||
leaky_relu,
|
||||
relu,
|
||||
sigmoid,
|
||||
silu,
|
||||
tanh,
|
||||
trace
|
||||
)
|
||||
|
||||
from cutlass.epilogue.evt_ops import (
|
||||
max,
|
||||
multiply_add,
|
||||
sum,
|
||||
permute,
|
||||
reshape
|
||||
)
|
||||
@ -45,6 +45,7 @@ code like the following for GEMM:
|
||||
|
||||
from cutlass.backend import epilogue
|
||||
|
||||
|
||||
gelu = epilogue.gelu
|
||||
hardswish = epilogue.hardswish
|
||||
identity = epilogue.identity
|
||||
@ -99,9 +100,59 @@ def get_activation_epilogue(
|
||||
)
|
||||
else:
|
||||
return epilogue.LinearCombinationGeneric(
|
||||
activation(element_compute),
|
||||
activation,
|
||||
element_output,
|
||||
elements_per_access,
|
||||
element_accumulator,
|
||||
element_compute,
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
Frontend for EVT that generates epilogue functor through tracing the input function
|
||||
"""
|
||||
from cutlass.backend.evt.frontend import PythonASTFrontend
|
||||
|
||||
|
||||
def trace(fn, example_tensors, **kwargs):
|
||||
"""
|
||||
Trace `fn(**example_tensors)` and generates epilogue visitor
|
||||
|
||||
:param fn: Python callables
|
||||
:param example_tensors: example inputs for fn
|
||||
:type example_tensors: dict
|
||||
|
||||
.. hightlight:: python
|
||||
.. code-block:: python
|
||||
import cutlass.backend.evt
|
||||
|
||||
# Define epilogue function as Python callable
|
||||
def example_fn(accum, C, alpha, beta, gamma):
|
||||
D = ((accum + C) * alpha - gamma) / beta
|
||||
return D
|
||||
|
||||
# Define the example tensors
|
||||
example_inputs = {
|
||||
"accum": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda"),
|
||||
"C": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda"),
|
||||
"alpha": 1.5,
|
||||
"beta": 0.5,
|
||||
"gamma": 2.5,
|
||||
"D": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda")
|
||||
}
|
||||
|
||||
# Generate the epilogue functor
|
||||
epilogue_visitor = cutlass.epilogue.trace(example_fn, example_inputs)
|
||||
"""
|
||||
if callable(fn):
|
||||
class EpilogueFunctor(PythonASTFrontend):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
pass
|
||||
setattr(EpilogueFunctor, "__call__", staticmethod(fn))
|
||||
|
||||
epilogue_functor = EpilogueFunctor(**kwargs)
|
||||
epilogue_functor.trace(example_tensors)
|
||||
return epilogue_functor
|
||||
else:
|
||||
raise NotImplementedError("Expect a callable Python function")
|
||||
@ -1,6 +1,6 @@
|
||||
################################################################################
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
@ -28,42 +28,52 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Collection of builtin functions used for host reference in EVT
|
||||
"""
|
||||
|
||||
from cuda import cuda
|
||||
import cutlass_bindings
|
||||
import numpy as np
|
||||
|
||||
from cutlass.backend.utils.software import CheckPackages
|
||||
|
||||
cupy_available = CheckPackages().check_cupy()
|
||||
if cupy_available:
|
||||
import cupy as cp
|
||||
|
||||
torch_available = CheckPackages().check_torch()
|
||||
if torch_available:
|
||||
import torch
|
||||
|
||||
|
||||
class TensorRef:
|
||||
"""
|
||||
Python Wrapper for cutlass_bindings.TensorRef
|
||||
"""
|
||||
def multiply_add(x, y, z):
|
||||
return x * y + z
|
||||
|
||||
def __init__(self, tensor, dtype, layout) -> None:
|
||||
if isinstance(tensor, np.ndarray):
|
||||
ptr = cuda.CUdeviceptr(tensor.__array_interface__["data"][0])
|
||||
elif torch_available and isinstance(tensor, torch.Tensor):
|
||||
ptr = cuda.CUdeviceptr(tensor.data_ptr())
|
||||
elif torch_available and isinstance(tensor, cp.ndarray):
|
||||
ptr = cuda.CUdeviceptr(int(tensor.data.ptr))
|
||||
elif isinstance(tensor, cuda.CUdeviceptr):
|
||||
ptr = tensor
|
||||
elif isinstance(tensor, int):
|
||||
ptr = cuda.CUdeviceptr(tensor)
|
||||
else:
|
||||
raise NotImplementedError(tensor)
|
||||
|
||||
# the dtype(0) is used to overload between different data types
|
||||
# with the same layout
|
||||
self.tensor_ref = cutlass_bindings.get_tensor_ref(int(ptr), dtype(0), layout)
|
||||
def sum(x, dim):
|
||||
if isinstance(x, np.ndarray):
|
||||
return x.sum(axis=tuple(dim))
|
||||
elif torch_available and isinstance(x, torch.Tensor):
|
||||
return torch.sum(x, dim)
|
||||
|
||||
|
||||
def max(x, dim):
|
||||
if isinstance(x, np.ndarray):
|
||||
return x.max(axis=tuple(dim))
|
||||
elif torch_available and isinstance(x, torch.Tensor):
|
||||
return torch.amax(x, dim)
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Layout manipulate nodes
|
||||
##############################################################################
|
||||
|
||||
def permute(x, indices: tuple):
|
||||
if isinstance(x, np.ndarray):
|
||||
return np.transpose(x, axes=indices)
|
||||
elif torch_available and isinstance(x, torch.Tensor):
|
||||
return x.permute(*indices)
|
||||
|
||||
|
||||
def reshape(x, new_shape: tuple):
|
||||
if isinstance(x, np.ndarray):
|
||||
return np.reshape(x, newshape=new_shape)
|
||||
elif torch_available and isinstance(x, torch.Tensor):
|
||||
return x.view(new_shape)
|
||||
@ -35,25 +35,21 @@ Classes containing valid operations for a given compute capability and data type
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from cuda import __version__
|
||||
|
||||
# Strip any additional information from the CUDA version
|
||||
_cuda_version = __version__.split("rc")[0]
|
||||
|
||||
# Imports from CUTLASS profiler generator and manifest scripts
|
||||
import generator as prof_generator
|
||||
import manifest as prof_manifest
|
||||
from library import (
|
||||
ConvKind, IteratorAlgorithm, StrideSupport, GroupMode
|
||||
)
|
||||
import cutlass_library
|
||||
from cutlass_library.library import ConvKind, IteratorAlgorithm, StrideSupport, GroupMode
|
||||
|
||||
import cutlass
|
||||
from cutlass.utils.check import valid_stage_count
|
||||
from cutlass.utils.datatypes import td_from_profiler_td, td_from_profiler_op, has_binding_type
|
||||
from cutlass.utils.datatypes import td_from_profiler_td, td_from_profiler_op
|
||||
|
||||
|
||||
_generator_ccs = [50, 60, 61, 70, 75, 80, 90]
|
||||
|
||||
# Strip any additional information from the CUDA version
|
||||
_cuda_version = __version__.split("rc")[0]
|
||||
|
||||
|
||||
class KernelsForDataType:
|
||||
"""
|
||||
@ -202,10 +198,10 @@ class ArchOptions:
|
||||
# Identify the method within CUTLASS generator script that generates kernel
|
||||
# descriptions for the target CC
|
||||
generate_function_name = "GenerateSM" + str(kernel_cc)
|
||||
if not hasattr(prof_generator, generate_function_name):
|
||||
if not hasattr(cutlass_library.generator, generate_function_name):
|
||||
cutlass.logger.warning(f"No generator found for architecture {kernel_cc}")
|
||||
return
|
||||
generate_function = getattr(prof_generator, generate_function_name)
|
||||
generate_function = getattr(cutlass_library.generator, generate_function_name)
|
||||
|
||||
# Initialize a default manifest and populate it with valid kernel descriptions
|
||||
# for the target CC
|
||||
@ -213,8 +209,8 @@ class ArchOptions:
|
||||
"--kernels=all",
|
||||
f"--log-level={logging.getLevelName(cutlass.logger.level)}"
|
||||
]
|
||||
manifest_args = prof_generator.define_parser().parse_args(args)
|
||||
manifest = prof_manifest.Manifest(manifest_args)
|
||||
manifest_args = cutlass_library.generator.define_parser().parse_args(args)
|
||||
manifest = cutlass_library.manifest.Manifest(manifest_args)
|
||||
generate_function(manifest, _cuda_version)
|
||||
|
||||
if operation_kind not in manifest.operations:
|
||||
@ -223,9 +219,15 @@ class ArchOptions:
|
||||
cutlass.logger.warning(f"No operations of type {operation_kind} found for CC {kernel_cc}")
|
||||
return
|
||||
|
||||
# Only one CC should be returned, given the setup above of calling only the generation scripts
|
||||
# for a given CC
|
||||
if len(manifest.operations[operation_kind].keys()) != 1 or kernel_cc not in manifest.operations[operation_kind]:
|
||||
raise Exception(f"Error finding kernels for SM{kernel_cc}. Check that your CUDA toolkit version "
|
||||
"is sufficient for the architecture in question.")
|
||||
|
||||
# Iterate through the available operations for this operation kind and
|
||||
# find available opclasses and data types
|
||||
for name, op_list in manifest.operations[operation_kind].items():
|
||||
for name, op_list in manifest.operations[operation_kind][kernel_cc].items():
|
||||
for op in op_list:
|
||||
if operation_kind == cutlass.OperationKind.Gemm:
|
||||
if op.gemm_kind not in gemm_kinds:
|
||||
@ -235,15 +237,15 @@ class ArchOptions:
|
||||
if mi.math_operation not in self.allowed_math_operations:
|
||||
continue
|
||||
|
||||
datatype_comb = (mi.element_a, mi.element_b, mi.element_accumulator)
|
||||
|
||||
# Skip any data types that do not currently have conversions via cutlass_bindings
|
||||
if False in [has_binding_type(elt) for elt in datatype_comb]:
|
||||
if op.C.element == cutlass.DataType.void:
|
||||
# The CUTLASS Python interface currently does not support void-C kernels
|
||||
continue
|
||||
|
||||
datatype_comb = (mi.element_a, mi.element_b, mi.element_accumulator)
|
||||
|
||||
# Prune operations that don't fit in shared memory
|
||||
td = td_from_profiler_op(op)
|
||||
if not valid_stage_count(target_cc, td)[0]:
|
||||
if not valid_stage_count(target_cc, kernel_cc, td)[0]:
|
||||
continue
|
||||
|
||||
if mi.opcode_class not in self.operations_by_opclass:
|
||||
@ -337,19 +339,19 @@ class ArchOptions:
|
||||
[128, 128, 8], 2, [4, 2, 1], math_inst, 50, 1024)
|
||||
|
||||
# Prune operations that don't fit in shared memory
|
||||
if not valid_stage_count(target_cc, td_from_profiler_td(td))[0]:
|
||||
if not valid_stage_count(target_cc, kernel_cc, td_from_profiler_td(td))[0]:
|
||||
continue
|
||||
|
||||
new_kernels = KernelsForDataType(type_comb, layout_comb)
|
||||
|
||||
if operation_kind == cutlass.OperationKind.Gemm:
|
||||
new_operation = prof_manifest.GemmOperation(
|
||||
new_operation = cutlass_library.manifest.GemmOperation(
|
||||
cutlass.GemmKind.Universal, td.minimum_compute_capability,
|
||||
td, A, B, C, type_comb[2], epilogue_functor, swizzling_functor)
|
||||
new_kernels.add(new_operation)
|
||||
elif operation_kind == cutlass.OperationKind.Conv2d:
|
||||
for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]:
|
||||
new_operation = prof_manifest.Conv2dOperation(
|
||||
new_operation = cutlass_library.manifest.Conv2dOperation(
|
||||
conv_kind, IteratorAlgorithm.Analytic, td.minimum_compute_capability, td,
|
||||
A, B, C, type_comb[2], StrideSupport.Strided, epilogue_functor, swizzling_functor,
|
||||
group_mode=GroupMode.SingleGroup
|
||||
|
||||
@ -30,7 +30,7 @@
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass.op.gemm import Gemm
|
||||
from cutlass.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad
|
||||
from cutlass.op.gemm import Gemm
|
||||
from cutlass.op.gemm_grouped import GroupedGemm
|
||||
from cutlass.op.op import OperationBase
|
||||
|
||||
@ -49,7 +49,7 @@
|
||||
# A, B, C, and D are torch/numpy/cupy tensor objects
|
||||
plan = cutlass.op.Conv(A, B, C, D)
|
||||
plan.run(stride=(1, 1), padding=(0, 0), dilation=(1, 1))
|
||||
|
||||
|
||||
One can also use the interface by specifying data types of operands at construction
|
||||
and using different tensor objects with these data types at runtime:
|
||||
|
||||
@ -112,21 +112,29 @@
|
||||
args.sync()
|
||||
"""
|
||||
|
||||
import cutlass_bindings
|
||||
import cutlass
|
||||
from cutlass import epilogue
|
||||
from cutlass import (
|
||||
ConvKind,
|
||||
ConvMode,
|
||||
IteratorAlgorithm,
|
||||
SplitKMode,
|
||||
StrideSupport,
|
||||
)
|
||||
from cutlass.backend import compiler
|
||||
from cutlass.backend.conv2d_operation import Conv2dArguments, Conv2dOperation
|
||||
from cutlass.backend.reduction_operation import ReductionOperation, ReductionArguments
|
||||
from cutlass.backend.library import TensorDescription, TileDescription
|
||||
from cutlass.op.op import OperationBase
|
||||
from cutlass.shape import Conv2DProblemSize, MatrixCoord
|
||||
from cutlass.utils import check, datatypes
|
||||
|
||||
|
||||
class Conv2d(OperationBase):
|
||||
"""
|
||||
Constructs a ``Conv2d`` object.
|
||||
Constructs a ``Conv2d`` object.
|
||||
|
||||
The convolution kind (fprop, wgrad, degrad), the data types of operands A, B, and C,
|
||||
The convolution kind (fprop, wgrad, degrad), the data types of operands A, B, and C,
|
||||
along with the data type of output D and that used for accumulation, are bound to the ``Conv``
|
||||
object throughout its lifetime -- these are not to be changed after a ``Conv2d`` has been constructed.
|
||||
|
||||
@ -142,7 +150,7 @@ class Conv2d(OperationBase):
|
||||
Conv2d(kind="fprop", element=cutlass.DataType.f32)
|
||||
|
||||
# Explicitly specify the data types to use for A, B, C, and D.
|
||||
Conv2d(kind="fprop", element_A=cutlass.DataType.f32, element_B=cutlass.DataType.f32,
|
||||
Conv2d(kind="fprop", element_A=cutlass.DataType.f32, element_B=cutlass.DataType.f32,
|
||||
element_C=cutlass.DataType.f32, element_D=cutlass.DataType.f32)
|
||||
|
||||
# Set the data types and elements from existing tensors. Note that one can use different tensors when
|
||||
@ -151,7 +159,7 @@ class Conv2d(OperationBase):
|
||||
# A, B, C, and D are torch.Tensor objects of type torch.float32 under the channel-last layout
|
||||
Conv2d(kind="fprop", A=A, B=B, C=C, D=D)
|
||||
|
||||
# Explicitly specify the data type for only some of A, B, C, and D. Unspecified data types will inherit
|
||||
# Explicitly specify the data type for only some of A, B, C, and D. Unspecified data types will inherit
|
||||
# those passed in via the generic ``element``
|
||||
Conv2d(kind="fprop", element_A=cutlass.DataType.f32, element_accumulator=cutlass.DataType.f32,
|
||||
element=cutlass.DataType.f32)
|
||||
@ -187,9 +195,9 @@ class Conv2d(OperationBase):
|
||||
:type kernel_cc: int
|
||||
"""
|
||||
def __init__(
|
||||
self, kind="fprop",
|
||||
self, kind="fprop",
|
||||
A=None, B=None, C=None, D=None, alpha=1.0, beta=0.0,
|
||||
element=None,
|
||||
element=None,
|
||||
element_A=None, element_B=None, element_C=None, element_D=None,
|
||||
element_accumulator=None,
|
||||
cc: int = None, kernel_cc: int = None
|
||||
@ -202,18 +210,18 @@ class Conv2d(OperationBase):
|
||||
cutlass.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
|
||||
self.specified_kernel_cc = 80
|
||||
self._reset_options(80)
|
||||
|
||||
|
||||
# The arch is used in testing
|
||||
self.arch = self.current_cc
|
||||
self.name = "conv2d" + kind
|
||||
|
||||
# The convolution kind. (concept: cutlass_bindings.conv.Operator)
|
||||
self.conv_kind = getattr(cutlass_bindings.conv.Operator, kind)
|
||||
|
||||
# The convolution kind. (concept: cutlass_library.library.ConvKind)
|
||||
self.conv_kind = datatypes.getattr_enum(ConvKind, kind)
|
||||
|
||||
# The element types (concept: cutlass library types) of A, B, C, and D
|
||||
elements = []
|
||||
layouts = []
|
||||
|
||||
|
||||
# Complete the data types based on user-provided arguments
|
||||
for elt, tens, name in zip([element_A, element_B, element_C, element_D],
|
||||
[A, B, C, D],
|
||||
@ -222,27 +230,27 @@ class Conv2d(OperationBase):
|
||||
raise Exception(f'Must not specify both element_{name} and tensor {name}')
|
||||
if elt is None and tens is None and element is None:
|
||||
raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.')
|
||||
|
||||
|
||||
elt_to_set = None
|
||||
lay_to_set = None
|
||||
|
||||
|
||||
if tens is not None:
|
||||
elt_to_set, _ = datatypes.get_datatype_and_layout(tens)
|
||||
else:
|
||||
elt_to_set = elt if elt is not None else element
|
||||
|
||||
|
||||
assert elt_to_set is not None
|
||||
|
||||
|
||||
# Currently we only support layout TensorNHWC
|
||||
lay_to_set = cutlass.LayoutType.TensorNHWC
|
||||
lay_to_set = cutlass.LayoutType.TensorNHWC
|
||||
elements.append(datatypes.library_type(elt_to_set))
|
||||
layouts.append(lay_to_set)
|
||||
|
||||
|
||||
self._element_a, self._element_b, self._element_c, self._element_d = elements
|
||||
self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts
|
||||
|
||||
|
||||
self.A, self.B, self.C, self.D, self.alpha, self.beta = A, B, C, D, alpha, beta
|
||||
|
||||
|
||||
if element_accumulator is None:
|
||||
self._element_accumulator = self._element_c
|
||||
else:
|
||||
@ -253,38 +261,38 @@ class Conv2d(OperationBase):
|
||||
self.B = B
|
||||
self.C = C
|
||||
self.D = D
|
||||
|
||||
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
|
||||
self.beta = beta
|
||||
|
||||
# We only specify the stride of the swizzling functor here
|
||||
# The actual swizzling functor is determined in run based on conv_kind and stride
|
||||
self._swizzling_stride = 1
|
||||
|
||||
|
||||
# Arguments that will be set to default value in _reset_operations
|
||||
# The default tile_description and op_class are fetched from manifest of cutlass library
|
||||
self._tile_description = None
|
||||
self.op_class = None
|
||||
# The default identity epilogue will be created
|
||||
self.epilogue_functor = None
|
||||
|
||||
|
||||
self._reset_operations()
|
||||
|
||||
|
||||
# Arguments that will be determined online based on arguments of "run"
|
||||
# based on stride, input/output channels, alignment, and conv_kind
|
||||
self._iterator_algorithm = None
|
||||
self._stride_support = None
|
||||
|
||||
|
||||
def _reset_operations(self, reset_epilogue: bool = True):
|
||||
# Set the default op class
|
||||
datatype_comb = (self._element_a, self._element_b, self._element_accumulator)
|
||||
layout_comb = (self._layout_a, self._layout_b)
|
||||
|
||||
|
||||
self.possible_op_classes = self.options.supporting_opclasses(
|
||||
self._element_a, self._element_b, self._element_accumulator,
|
||||
self._layout_a, self._layout_b
|
||||
)
|
||||
|
||||
|
||||
if cutlass.OpcodeClass.TensorOp in self.possible_op_classes:
|
||||
self.opclass = cutlass.OpcodeClass.TensorOp
|
||||
elif cutlass.OpcodeClass.Simt in self.possible_op_classes:
|
||||
@ -292,34 +300,34 @@ class Conv2d(OperationBase):
|
||||
else:
|
||||
raise Exception(f'No kernel configuration found for supported data type and layout '
|
||||
f'combination {datatype_comb}x{layout_comb}')
|
||||
|
||||
|
||||
if reset_epilogue:
|
||||
self._reset_epilogue_functor_activation(epilogue.identity)
|
||||
|
||||
|
||||
self.alignment_pref_A = min(
|
||||
128 // cutlass.DataTypeSize[self._element_a], max(self.possible_operations.alignments))
|
||||
self.alignment_pref_B = min(
|
||||
128 // cutlass.DataTypeSize[self._element_b], max(self.possible_operations.alignments))
|
||||
self.alignment_pref_C = min(
|
||||
128 // cutlass.DataTypeSize[self._element_c], max(self.possible_operations.alignments))
|
||||
128 // cutlass.DataTypeSize[self._element_c], max(self.possible_operations.alignments))
|
||||
|
||||
#
|
||||
# Tile description Related
|
||||
#
|
||||
|
||||
|
||||
@property
|
||||
def tile_description(self) -> TileDescription:
|
||||
"""
|
||||
Returns the tile description
|
||||
"""
|
||||
return self._tile_description
|
||||
|
||||
|
||||
@tile_description.setter
|
||||
def tile_description(
|
||||
self, td=None):
|
||||
"""
|
||||
Set the tile description
|
||||
|
||||
|
||||
:param td: tile description
|
||||
:type td: cutlass.backend.TileDescription, or a dict with keys
|
||||
{
|
||||
@ -340,7 +348,7 @@ class Conv2d(OperationBase):
|
||||
if "cluster_shape" in td.keys():
|
||||
if td["cluster_shape"] != [1, 1, 1]:
|
||||
cutlass.logger.warning("Conv2d currently only support 'cluster_shape'=[1, 1, 1]'.")
|
||||
td["cluster_shape"] = [1, 1, 1]
|
||||
td["cluster_shape"] = [1, 1, 1]
|
||||
td = self._tile_description.clone_and_update(td)
|
||||
|
||||
valid, msg = self._valid_tile_description(td)
|
||||
@ -348,7 +356,7 @@ class Conv2d(OperationBase):
|
||||
self._tile_description = td
|
||||
else:
|
||||
raise Exception(msg)
|
||||
|
||||
|
||||
def _valid_tile_description(self, td: TileDescription) -> tuple:
|
||||
"""
|
||||
Checks whether the provided tile description is valid for the given compute capability. At present,
|
||||
@ -366,9 +374,7 @@ class Conv2d(OperationBase):
|
||||
and the second element is a string providing an optional error message.
|
||||
:rtype: tuple
|
||||
"""
|
||||
# Check stage count based on the CC to which we are compiling (self.cc), rather
|
||||
# than the CC from which we find kernels (self.current_cc)
|
||||
valid, msg = check.valid_stage_count(self.cc, td)
|
||||
valid, msg = check.valid_stage_count(self.cc, self.current_cc, td)
|
||||
if not valid:
|
||||
return (valid, msg)
|
||||
|
||||
@ -393,11 +399,11 @@ class Conv2d(OperationBase):
|
||||
description_str.append(str(td))
|
||||
descriptions.append(td)
|
||||
return descriptions
|
||||
|
||||
|
||||
#
|
||||
# Swizzling functor Related
|
||||
#
|
||||
|
||||
|
||||
@property
|
||||
def swizzling_stride(self):
|
||||
"""
|
||||
@ -420,108 +426,110 @@ class Conv2d(OperationBase):
|
||||
"""
|
||||
Automatically propose the swizzling functor based on the stride
|
||||
"""
|
||||
if self.conv_kind == cutlass_bindings.conv.Operator.dgrad:
|
||||
if self.conv_kind == ConvKind.Dgrad:
|
||||
if stride[0] != 1 or stride[1] != 1:
|
||||
return getattr(cutlass.swizzle, f"StridedDgradIdentitySwizzle{self._swizzling_stride}")
|
||||
|
||||
|
||||
return getattr(cutlass.swizzle, f"IdentitySwizzle{self._swizzling_stride}")
|
||||
|
||||
#
|
||||
# Iterator Algorithm Related
|
||||
#
|
||||
|
||||
|
||||
@property
|
||||
def iterator_algorithm(self) -> cutlass_bindings.conv.IteratorAlgorithm:
|
||||
def iterator_algorithm(self) -> IteratorAlgorithm:
|
||||
"""
|
||||
Returns the iterator algorithm
|
||||
"""
|
||||
return self._iterator_algorithm
|
||||
|
||||
|
||||
@iterator_algorithm.setter
|
||||
def iterator_algorithm(self, alg: str):
|
||||
"""
|
||||
Sets the iterator algorithm
|
||||
|
||||
|
||||
:param alg: The iterator algorithm
|
||||
:type td: string, options: "analytic", "optimized", "few_channels", and "fixed_channels"
|
||||
"""
|
||||
# Check if the iterator algorithm is valid
|
||||
if alg in ["few_channels", "fixed_channels"] and self.conv_kind != cutlass_bindings.conv.Operator.fprop:
|
||||
raise Exception(f"{self.conv_kind} does not support iterator algorithm {alg}.")
|
||||
|
||||
self._iterator_algorithm = getattr(cutlass_bindings.conv.IteratorAlgorithm, alg)
|
||||
iterator_alg = datatypes.getattr_enum(IteratorAlgorithm, alg)
|
||||
|
||||
def _propose_iterator_algorithm(self, problem_size, alignment_a, alignment_b) -> cutlass_bindings.conv.IteratorAlgorithm:
|
||||
# Check if the iterator algorithm is valid
|
||||
if iterator_alg in [IteratorAlgorithm.FewChannels, IteratorAlgorithm.FixedChannels] and self.conv_kind != ConvKind.Fprop:
|
||||
raise Exception(f"{self.conv_kind} does not support iterator algorithm {alg}.")
|
||||
|
||||
self._iterator_algorithm = iterator_alg
|
||||
|
||||
def _propose_iterator_algorithm(self, problem_size, alignment_a, alignment_b) -> IteratorAlgorithm:
|
||||
"""
|
||||
Propose a valid iterator algorithm based on problem size and alignment
|
||||
"""
|
||||
if self.conv_kind == cutlass_bindings.conv.Operator.fprop:
|
||||
if self.conv_kind == ConvKind.Fprop:
|
||||
# Check whether the fixed channel is applicable
|
||||
if problem_size.C == alignment_a:
|
||||
return cutlass_bindings.conv.IteratorAlgorithm.fixed_channels
|
||||
return IteratorAlgorithm.FixedChannels
|
||||
elif (problem_size.C % alignment_a == 0 and
|
||||
problem_size.R <= 32 and problem_size.S <= 32):
|
||||
return cutlass_bindings.conv.IteratorAlgorithm.optimized
|
||||
return IteratorAlgorithm.Optimized
|
||||
else:
|
||||
return cutlass_bindings.conv.IteratorAlgorithm.analytic
|
||||
elif self.conv_kind == cutlass_bindings.conv.Operator.dgrad:
|
||||
return IteratorAlgorithm.Analytic
|
||||
elif self.conv_kind == ConvKind.Dgrad:
|
||||
if (problem_size.K % alignment_a == 0 and
|
||||
problem_size.R <= 32 and problem_size.S <= 32 and
|
||||
problem_size.C % alignment_b == 0):
|
||||
return cutlass_bindings.conv.IteratorAlgorithm.optimized
|
||||
return IteratorAlgorithm.Optimized
|
||||
else:
|
||||
return cutlass_bindings.conv.IteratorAlgorithm.analytic
|
||||
elif self.conv_kind == cutlass_bindings.conv.Operator.wgrad:
|
||||
return IteratorAlgorithm.Analytic
|
||||
elif self.conv_kind == ConvKind.Wgrad:
|
||||
if (problem_size.K % alignment_a == 0 and
|
||||
problem_size.C % alignment_b == 0):
|
||||
return cutlass_bindings.conv.IteratorAlgorithm.optimized
|
||||
return IteratorAlgorithm.Optimized
|
||||
else:
|
||||
return cutlass_bindings.conv.IteratorAlgorithm.analytic
|
||||
|
||||
return IteratorAlgorithm.Analytic
|
||||
|
||||
def _validate_iterator_algorithm(self, iterator_algorithm, problem_size, alignment_a, alignment_b) -> bool:
|
||||
"""
|
||||
Validate whether the user provide iterator algorithm works for the given problem size
|
||||
"""
|
||||
if self.conv_kind == cutlass_bindings.conv.Operator.fprop:
|
||||
if iterator_algorithm == cutlass_bindings.conv.IteratorAlgorithm.fixed_channels:
|
||||
if self.conv_kind == ConvKind.Fprop:
|
||||
if iterator_algorithm == IteratorAlgorithm.FixedChannels:
|
||||
return problem_size.C == alignment_a
|
||||
elif iterator_algorithm == cutlass_bindings.conv.IteratorAlgorithm.optimized:
|
||||
elif iterator_algorithm == IteratorAlgorithm.Optimized:
|
||||
return (problem_size.C % alignment_a == 0 and
|
||||
problem_size.R <= 32 and problem_size.S <= 32)
|
||||
elif iterator_algorithm == cutlass_bindings.conv.IteratorAlgorithm.few_channels:
|
||||
elif iterator_algorithm == IteratorAlgorithm.FewChannels:
|
||||
return problem_size.C % alignment_a == 0
|
||||
elif self.conv_kind == cutlass_bindings.conv.Operator.dgrad:
|
||||
if iterator_algorithm == cutlass_bindings.conv.IteratorAlgorithm.optimized:
|
||||
elif self.conv_kind == ConvKind.Dgrad:
|
||||
if iterator_algorithm == IteratorAlgorithm.Optimized:
|
||||
return (problem_size.K % alignment_a == 0 and
|
||||
problem_size.R <= 32 and problem_size.S <= 32 and
|
||||
problem_size.C % alignment_b == 0)
|
||||
elif self.conv_kind == cutlass_bindings.conv.Operator.wgrad:
|
||||
if iterator_algorithm == cutlass_bindings.conv.IteratorAlgorithm.optimized:
|
||||
elif self.conv_kind == ConvKind.Wgrad:
|
||||
if iterator_algorithm == IteratorAlgorithm.Optimized:
|
||||
return (problem_size.K % alignment_a == 0 and
|
||||
problem_size.C % alignment_b == 0)
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
#
|
||||
# Stride Support Related
|
||||
#
|
||||
|
||||
|
||||
def _propose_stride_support(self, stride):
|
||||
if self.conv_kind == cutlass_bindings.conv.Operator.dgrad:
|
||||
if self.conv_kind == ConvKind.Dgrad:
|
||||
if stride[0] == 1 and stride[1] == 1:
|
||||
return cutlass.backend.library.StrideSupport.Unity
|
||||
|
||||
return cutlass.backend.library.StrideSupport.Strided
|
||||
|
||||
return StrideSupport.Unity
|
||||
|
||||
return StrideSupport.Strided
|
||||
|
||||
#
|
||||
# Construct and Compilation
|
||||
#
|
||||
|
||||
|
||||
def construct(
|
||||
self, tile_description: TileDescription = None,
|
||||
self, tile_description: TileDescription = None,
|
||||
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
|
||||
iterator_algorithm: cutlass_bindings.conv.IteratorAlgorithm = None,
|
||||
stride_support = None, swizzling_functor: cutlass.swizzle = None,
|
||||
iterator_algorithm: IteratorAlgorithm = None,
|
||||
stride_support = None, swizzling_functor: cutlass.swizzle = None,
|
||||
epilogue_functor=None) -> cutlass.backend.Conv2dOperation:
|
||||
"""
|
||||
Constructs a ``cutlass.backend.Conv2dOperation`` based on the input parameters and current
|
||||
@ -536,9 +544,9 @@ class Conv2d(OperationBase):
|
||||
:param alignment_C: alignment of operand C
|
||||
:type alignment_C: int
|
||||
:param iterator_algorithm: the iterator algorithm used
|
||||
:type iterator_algorithm: cutlass_bindings.conv.IteratorAlgorithm
|
||||
:type iterator_algorithm: cutlass_library.library.IteratorAlgorithm
|
||||
:param stride_support: the stride support of dgrad
|
||||
:type stride_support: cutlass.backend.library.StrideSupport
|
||||
:type stride_support: cutlass_library.library.StrideSupport
|
||||
:param swizzling_functor: the swizzling functor
|
||||
:type swizzling_functor: cutlass.swizzle
|
||||
:param epilogue_functor: the epilogue functor
|
||||
@ -550,66 +558,55 @@ class Conv2d(OperationBase):
|
||||
alignment_A = check.alignment_or_default(alignment_A, self.alignment_pref_A)
|
||||
alignment_B = check.alignment_or_default(alignment_B, self.alignment_pref_B)
|
||||
alignment_C = check.alignment_or_default(alignment_C, self.alignment_pref_C)
|
||||
|
||||
tensor_A = TensorDescription(
|
||||
datatypes.binding_type(self._element_a),
|
||||
datatypes.binding_layout(self._layout_b),
|
||||
alignment_A
|
||||
)
|
||||
tensor_B = TensorDescription(
|
||||
datatypes.binding_type(self._element_b),
|
||||
datatypes.binding_layout(self._layout_b),
|
||||
alignment_B
|
||||
)
|
||||
tensor_C = TensorDescription(
|
||||
datatypes.binding_type(self._element_c),
|
||||
datatypes.binding_layout(self._layout_c),
|
||||
alignment_C
|
||||
)
|
||||
|
||||
|
||||
tensor_A = TensorDescription(self._element_a, self._layout_b, alignment_A)
|
||||
tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
|
||||
tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
|
||||
|
||||
if tile_description is None:
|
||||
if self.tile_description is not None:
|
||||
tile_description = self.tile_description
|
||||
else:
|
||||
op = self.possible_operations.operations(alignment_A)[0]
|
||||
min_alignment = min([alignment_A, alignment_B, alignment_C])
|
||||
op = self.possible_operations.operations(min_alignment)[0]
|
||||
tile_description = datatypes.td_from_profiler_op(op)
|
||||
else:
|
||||
valid, err_str = self._valid_tile_description(tile_description)
|
||||
if not valid:
|
||||
raise Exception(f"Invalid tile description. {err_str}")
|
||||
self.tile_description = tile_description
|
||||
|
||||
|
||||
if iterator_algorithm is None:
|
||||
# If the iterator algorithm is already set
|
||||
if self.iterator_algorithm is not None:
|
||||
if self.iterator_algorithm is not None:
|
||||
iterator_algorithm = self.iterator_algorithm
|
||||
else:
|
||||
# Otherwise, we conservatively use the analytic iterator for correctness
|
||||
iterator_algorithm = cutlass_bindings.conv.IteratorAlgorithm.analytic
|
||||
|
||||
iterator_algorithm = IteratorAlgorithm.Analytic
|
||||
|
||||
if stride_support is None:
|
||||
# If the stride support is already set
|
||||
if self._stride_support is not None:
|
||||
stride_support = self._stride_support
|
||||
else:
|
||||
# Otherwise, we assume strided
|
||||
stride_support = cutlass.backend.library.StrideSupport.Strided
|
||||
|
||||
stride_support = StrideSupport.Strided
|
||||
|
||||
if swizzling_functor is None:
|
||||
# If the swizzling functor is already set
|
||||
swizzling_functor = self._propose_swizzling_functor(stride=(2, 2))
|
||||
|
||||
|
||||
if epilogue_functor is None:
|
||||
if self.epilogue_functor is not None:
|
||||
epilogue_functor = self.epilogue_functor
|
||||
else:
|
||||
epilogue_functor = self._create_epilogue_functor_activation(self._activation)
|
||||
|
||||
|
||||
# Reset the alignment of the epilogue functor
|
||||
epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, epilogue_functor)
|
||||
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=self.conv_kind,
|
||||
conv_kind=self.conv_kind,
|
||||
iterator_algorithm=iterator_algorithm,
|
||||
arch=self.current_cc,
|
||||
tile_description=tile_description,
|
||||
@ -618,13 +615,13 @@ class Conv2d(OperationBase):
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=swizzling_functor,
|
||||
)
|
||||
|
||||
|
||||
return operation
|
||||
|
||||
def compile(self, tile_description: TileDescription = None,
|
||||
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
|
||||
iterator_algorithm: cutlass_bindings.conv.IteratorAlgorithm = None,
|
||||
stride_support = None, swizzling_functor: cutlass.swizzle = None,
|
||||
iterator_algorithm: IteratorAlgorithm = None,
|
||||
stride_support = None, swizzling_functor: cutlass.swizzle = None,
|
||||
epilogue_functor = None, print_module: bool = False) -> cutlass.backend.Conv2dOperation:
|
||||
"""
|
||||
Emits and compiles the kernel currently specified. If ``tile_description`` and any
|
||||
@ -641,9 +638,9 @@ class Conv2d(OperationBase):
|
||||
:param alignment_C: alignment of operand C
|
||||
:type alignment_C: int
|
||||
:param iterator_algorithm: the iterator algorithm used
|
||||
:type iterator_algorithm: cutlass_bindings.conv.IteratorAlgorithm
|
||||
:type iterator_algorithm: cutlass_library.library.IteratorAlgorithm
|
||||
:param stride_support: the stride support of dgrad
|
||||
:type stride_support: cutlass.backend.library.StrideSupport
|
||||
:type stride_support: cutlass_library.library.StrideSupport
|
||||
:param swizzling_functor: the swizzling functor
|
||||
:type swizzling_functor: cutlass.swizzle
|
||||
:param epilogue_functor: the epilogue functor
|
||||
@ -651,17 +648,17 @@ class Conv2d(OperationBase):
|
||||
:return: operation that was compiled
|
||||
:rtype: cutlass.backend.Conv2dOperation
|
||||
"""
|
||||
|
||||
|
||||
self.operation = self.construct(
|
||||
tile_description, alignment_A, alignment_B, alignment_C,
|
||||
tile_description, alignment_A, alignment_B, alignment_C,
|
||||
iterator_algorithm, stride_support, swizzling_functor, epilogue_functor)
|
||||
|
||||
|
||||
if print_module:
|
||||
print(self.operation.rt_module.emit())
|
||||
|
||||
|
||||
compiler.add_module([self.operation,])
|
||||
return self.operation
|
||||
|
||||
|
||||
#
|
||||
# Run Related
|
||||
#
|
||||
@ -681,21 +678,19 @@ class Conv2d(OperationBase):
|
||||
if dtype != ref_type:
|
||||
raise Exception(f'Tensor {name} with type and layout {dtype} '
|
||||
f'does not match the expected type of {ref_type}.')
|
||||
|
||||
|
||||
|
||||
def _get_and_verify_conv_problem_size(self, A, B, C, stride, padding, dilation):
|
||||
if self.conv_kind == cutlass_bindings.conv.Operator.fprop:
|
||||
if self.conv_kind == ConvKind.Fprop:
|
||||
input = A
|
||||
weight = B
|
||||
output = C
|
||||
output_tensor = "C"
|
||||
elif self.conv_kind == cutlass_bindings.conv.Operator.dgrad:
|
||||
elif self.conv_kind == ConvKind.Dgrad:
|
||||
output = A
|
||||
weight = B
|
||||
input = C
|
||||
output_tensor = "A"
|
||||
elif self.conv_kind == cutlass_bindings.conv.Operator.wgrad:
|
||||
elif self.conv_kind == ConvKind.Wgrad:
|
||||
output = A
|
||||
input = B
|
||||
weight = C
|
||||
@ -703,27 +698,27 @@ class Conv2d(OperationBase):
|
||||
else:
|
||||
raise Exception(f"Convolution kind {self.conv_kind} is not supported")
|
||||
|
||||
N_, H_, W_, C_ = datatypes.get_tensor_shape(input)
|
||||
K_, R_, S_, _ = datatypes.get_tensor_shape(weight)
|
||||
_, P_, Q_, _ = datatypes.get_tensor_shape(output)
|
||||
|
||||
problem_size = cutlass_bindings.conv.Conv2dProblemSize(
|
||||
cutlass_bindings.Tensor4DCoord(N_, H_, W_, C_),
|
||||
cutlass_bindings.Tensor4DCoord(K_, R_, S_, C_),
|
||||
cutlass_bindings.Tensor4DCoord(padding[0], padding[0], padding[1], padding[1]),
|
||||
cutlass_bindings.MatrixCoord(stride[0], stride[1]),
|
||||
cutlass_bindings.MatrixCoord(dilation[0], dilation[1]),
|
||||
cutlass_bindings.conv.Mode.cross_correlation,
|
||||
N_, H_, W_, C_ = datatypes.get_tensor_shape(input, op="CONV")
|
||||
K_, R_, S_, _ = datatypes.get_tensor_shape(weight, op="CONV")
|
||||
_, P_, Q_, _ = datatypes.get_tensor_shape(output, op="CONV")
|
||||
|
||||
problem_size = Conv2DProblemSize(
|
||||
N_, H_, W_, C_,
|
||||
K_, R_, S_, C_,
|
||||
padding[0], padding[1],
|
||||
stride[0], stride[1],
|
||||
dilation[0], dilation[1],
|
||||
ConvMode.CrossCorrelation,
|
||||
1, 1
|
||||
)
|
||||
|
||||
|
||||
if P_ != problem_size.P or Q_ != problem_size.Q:
|
||||
raise Exception(
|
||||
f"Tensor {output_tensor} size should be ({N_}, {problem_size.P}, {problem_size.Q}, {K_}), got ({N_}, {P_}, {Q_}, {K_})")
|
||||
|
||||
|
||||
return problem_size
|
||||
|
||||
def run(self, A=None, B=None, C=None, D=None,
|
||||
|
||||
def run(self, A=None, B=None, C=None, D=None,
|
||||
stride=(1, 1), padding=(0, 0), dilation=(1, 1),
|
||||
alpha=None, beta=None,
|
||||
split_k=("serial", 1), sync: bool = True,
|
||||
@ -732,9 +727,9 @@ class Conv2d(OperationBase):
|
||||
Runs the kernel currently specified. If it has not already been, the kernel is emitted and
|
||||
compiled. Tensors holding operands and outputs of the kernel are sourced either from the
|
||||
``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta``
|
||||
parameters provided in the call, or from those
|
||||
parameters provided in the call, or from those
|
||||
passed in on the construction of this object -- one of the two must be specified.
|
||||
|
||||
|
||||
By default, this call returns only once the kernel has completed. To launch the kernel
|
||||
and immediately return, set ``sync=False``. In this case, it is the responsibility of the
|
||||
caller to syncrhonize the results of the kernel before attempting to access outputs
|
||||
@ -754,7 +749,7 @@ class Conv2d(OperationBase):
|
||||
:type sync: bool
|
||||
:param print_module: whether to print the emitted C++ code
|
||||
:type print_module: bool
|
||||
|
||||
|
||||
:return: arguments passed in to the kernel
|
||||
:rtype: cutlass.backend.Conv2dArguments
|
||||
"""
|
||||
@ -764,45 +759,49 @@ class Conv2d(OperationBase):
|
||||
D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D")
|
||||
alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
|
||||
beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
|
||||
|
||||
|
||||
# handle the case when there is no C
|
||||
if C is None:
|
||||
if beta != 0:
|
||||
raise Exception(f"With beta {beta} != 0, C has to be provided.")
|
||||
else:
|
||||
C = D
|
||||
|
||||
|
||||
# Construct problem size based on input
|
||||
# It also verifies whether the A, B, C, D, stride, padding, and dilation are matching
|
||||
problem_size = self._get_and_verify_conv_problem_size(A, B, C, stride, padding, dilation)
|
||||
|
||||
|
||||
# Propose stride support based on input
|
||||
stride_support = self._propose_stride_support(stride)
|
||||
|
||||
|
||||
# Propose swizzling functor
|
||||
swizzling_functor = self._propose_swizzling_functor(stride)
|
||||
|
||||
|
||||
shape_a = datatypes.get_tensor_shape(A, op="CONV")
|
||||
shape_b = datatypes.get_tensor_shape(B, op="CONV")
|
||||
shape_c = datatypes.get_tensor_shape(C, op="CONV")
|
||||
|
||||
# Get the alignment
|
||||
alignment_a = self.possible_operations.find_alignment(datatypes.get_tensor_shape(A), self._layout_a)
|
||||
alignment_b = self.possible_operations.find_alignment(datatypes.get_tensor_shape(B), self._layout_b)
|
||||
alignment_c = self.possible_operations.find_alignment(datatypes.get_tensor_shape(C), self._layout_c)
|
||||
|
||||
alignment_a = self.possible_operations.find_alignment(shape_a, self._layout_a)
|
||||
alignment_b = self.possible_operations.find_alignment(shape_b, self._layout_b)
|
||||
alignment_c = self.possible_operations.find_alignment(shape_c, self._layout_c)
|
||||
|
||||
alignment_a = check.update_alignment(alignment_a, self.alignment_pref_A)
|
||||
alignment_b = check.update_alignment(alignment_b, self.alignment_pref_B)
|
||||
alignment_c = check.update_alignment(alignment_c, self.alignment_pref_C)
|
||||
|
||||
|
||||
# Propose iterator algorithm based on input
|
||||
if self._iterator_algorithm is None:
|
||||
# Propose a default itertaor algorithm based on the problem size
|
||||
# Propose a default iterator algorithm based on the problem size
|
||||
iterator_algorithm = self._propose_iterator_algorithm(problem_size, alignment_a, alignment_b)
|
||||
else:
|
||||
if (self._validate_iterator_algorithm(self._iterator_algorithm, problem_size, alignment_a, alignment_b)):
|
||||
iterator_algorithm = self._iterator_algorithm
|
||||
else:
|
||||
raise Exception(f"Iterator algorithm {self._iterator_algorithm} is invalid for current problem.")
|
||||
|
||||
|
||||
epilogue_args = [alpha, beta]
|
||||
|
||||
|
||||
if hasattr(self, "_activation_args"):
|
||||
if isinstance(self._activation_args, list):
|
||||
epilogue_args += self._activation_args
|
||||
@ -813,43 +812,41 @@ class Conv2d(OperationBase):
|
||||
epilogue_functor = self._create_epilogue_functor_activation(epilogue.identity)
|
||||
else:
|
||||
epilogue_functor = self.epilogue_functor
|
||||
|
||||
|
||||
# The alignment is determined by the iterator function (I believe)
|
||||
self.compile(tile_description=self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
|
||||
self.compile(tile_description=self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
|
||||
alignment_C=alignment_c, iterator_algorithm=iterator_algorithm, stride_support=stride_support,
|
||||
swizzling_functor=swizzling_functor, epilogue_functor=epilogue_functor, print_module=print_module)
|
||||
|
||||
|
||||
# Create reduction operation for parallel split-k
|
||||
if split_k[0] == "parallel" and split_k[1] > 1:
|
||||
epilogue_functor_reduction = self._reset_epilogue_functor_alignment(alignment_c, self.epilogue_functor)
|
||||
self.reduction_operation = ReductionOperation(
|
||||
shape=cutlass_bindings.MatrixCoord(4, 32 * alignment_c), C=self.operation.C,
|
||||
element_accumulator=datatypes.binding_type(self._element_accumulator),
|
||||
element_compute=datatypes.binding_type(self._element_accumulator),
|
||||
shape=MatrixCoord(4, 32 * alignment_c), C=self.operation.C,
|
||||
element_accumulator=self._element_accumulator,
|
||||
element_compute=self._element_accumulator,
|
||||
epilogue_functor=epilogue_functor_reduction,
|
||||
count=alignment_c
|
||||
)
|
||||
if print_module:
|
||||
print(self.reduction_operation.rt_module.emit())
|
||||
compiler.add_module([self.reduction_operation,])
|
||||
|
||||
|
||||
arguments = Conv2dArguments(
|
||||
operation=self.operation, problem_size=problem_size,
|
||||
A=A, B=B, C=C, D=D,
|
||||
output_op=self.operation.epilogue_type(*epilogue_args),
|
||||
split_k_mode=datatypes.getattr_enum(cutlass_bindings.conv.SplitKMode, split_k[0]),
|
||||
split_k_mode=datatypes.getattr_enum(SplitKMode, split_k[0]),
|
||||
split_k_slices=split_k[1]
|
||||
)
|
||||
|
||||
|
||||
self.operation.run(arguments)
|
||||
|
||||
|
||||
if split_k[0] == "parallel" and split_k[1] > 1:
|
||||
implicit_gemm_size = cutlass_bindings.conv.implicit_gemm_problem_size(
|
||||
self.conv_kind, arguments.problem_size
|
||||
)
|
||||
implicit_gemm_size = arguments.problem_size.implicit_gemm_size(self.conv_kind)
|
||||
reduction_arguments = ReductionArguments(
|
||||
self.reduction_operation,
|
||||
problem_size=[implicit_gemm_size.m(), implicit_gemm_size.n()],
|
||||
problem_size=[implicit_gemm_size.m, implicit_gemm_size.n],
|
||||
partitions=split_k[1],
|
||||
workspace=arguments.ptr_D,
|
||||
destination=D,
|
||||
@ -857,7 +854,7 @@ class Conv2d(OperationBase):
|
||||
output_op=self.reduction_operation.epilogue_type(*epilogue_args)
|
||||
)
|
||||
self.reduction_operation.run(reduction_arguments)
|
||||
|
||||
|
||||
if sync:
|
||||
if split_k[0] == "parallel" and split_k[1] > 1:
|
||||
reduction_arguments.sync()
|
||||
@ -865,23 +862,23 @@ class Conv2d(OperationBase):
|
||||
arguments.sync()
|
||||
|
||||
return arguments
|
||||
|
||||
|
||||
#
|
||||
# Helper functions
|
||||
#
|
||||
@staticmethod
|
||||
def output_size(input_size, weight_size, padding, stride, dilation):
|
||||
problem_size = cutlass_bindings.conv.Conv2dProblemSize(
|
||||
cutlass_bindings.Tensor4DCoord(*input_size),
|
||||
cutlass_bindings.Tensor4DCoord(*weight_size),
|
||||
cutlass_bindings.Tensor4DCoord(padding[0], padding[0], padding[1], padding[1]),
|
||||
cutlass_bindings.MatrixCoord(stride[0], stride[1]),
|
||||
cutlass_bindings.MatrixCoord(dilation[0], dilation[1]),
|
||||
cutlass_bindings.conv.Mode.cross_correlation,
|
||||
problem_size = Conv2DProblemSize(
|
||||
*input_size,
|
||||
*weight_size,
|
||||
padding[0], padding[1],
|
||||
stride[0], stride[1],
|
||||
dilation[0], dilation[1],
|
||||
ConvMode.CrossCorrelation,
|
||||
1, 1
|
||||
)
|
||||
return (problem_size.N, problem_size.P, problem_size.Q, problem_size.K)
|
||||
|
||||
|
||||
|
||||
#
|
||||
# Easy to use interfaces for fprop, wgrad, and dgrad
|
||||
@ -890,23 +887,23 @@ class Conv2d(OperationBase):
|
||||
class Conv2dFprop(Conv2d):
|
||||
def __init__(
|
||||
self,
|
||||
input=None, weight=None, C=None, output=None, alpha=1, beta=0,
|
||||
element=None,
|
||||
element_input=None, element_weight=None, element_C=None, element_output=None,
|
||||
element_accumulator=None,
|
||||
input=None, weight=None, C=None, output=None, alpha=1, beta=0,
|
||||
element=None,
|
||||
element_input=None, element_weight=None, element_C=None, element_output=None,
|
||||
element_accumulator=None,
|
||||
cc: int = None, kernel_cc: int = None):
|
||||
A, B, D = input, weight, output
|
||||
element_A, element_B, element_D = element_input, element_weight, element_output
|
||||
super().__init__(
|
||||
"fprop", A, B, C, D, alpha, beta, element,
|
||||
element_A, element_B, element_C, element_D,
|
||||
"fprop", A, B, C, D, alpha, beta, element,
|
||||
element_A, element_B, element_C, element_D,
|
||||
element_accumulator, cc, kernel_cc)
|
||||
|
||||
|
||||
def run(
|
||||
self, input=None, weight=None, C=None, output=None, alpha=None, beta=None,
|
||||
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
||||
self, input=None, weight=None, C=None, output=None, alpha=None, beta=None,
|
||||
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
||||
sync: bool = True, print_module: bool = False) -> Conv2dArguments:
|
||||
|
||||
|
||||
A, B, D = input, weight, output
|
||||
return super().run(
|
||||
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module)
|
||||
@ -915,20 +912,20 @@ class Conv2dFprop(Conv2d):
|
||||
class Conv2dDgrad(Conv2d):
|
||||
def __init__(
|
||||
self,
|
||||
grad_output=None, weight=None, C=None, grad_input=None, alpha=1, beta=0,
|
||||
element=None,
|
||||
element_grad_output=None, element_weight=None, element_C=None, element_grad_input=None,
|
||||
element_accumulator=None,
|
||||
grad_output=None, weight=None, C=None, grad_input=None, alpha=1, beta=0,
|
||||
element=None,
|
||||
element_grad_output=None, element_weight=None, element_C=None, element_grad_input=None,
|
||||
element_accumulator=None,
|
||||
cc: int = None, kernel_cc: int = None):
|
||||
A, B, D = grad_output, weight, grad_input
|
||||
element_A, element_B, element_D = element_grad_output, element_weight, element_grad_input
|
||||
super().__init__(
|
||||
"dgrad", A, B, C, D, alpha, beta, element,
|
||||
element_A, element_B, element_C, element_D,
|
||||
"dgrad", A, B, C, D, alpha, beta, element,
|
||||
element_A, element_B, element_C, element_D,
|
||||
element_accumulator, cc, kernel_cc)
|
||||
|
||||
def run(self, grad_output=None, weight=None, C=None, grad_input=None, alpha=None, beta=None,
|
||||
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
||||
|
||||
def run(self, grad_output=None, weight=None, C=None, grad_input=None, alpha=None, beta=None,
|
||||
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
||||
sync: bool = True, print_module: bool = False) -> Conv2dArguments:
|
||||
#
|
||||
A, B, D = grad_output, weight, grad_input
|
||||
@ -939,20 +936,20 @@ class Conv2dDgrad(Conv2d):
|
||||
class Conv2dWgrad(Conv2d):
|
||||
def __init__(
|
||||
self,
|
||||
grad_output=None, input=None, C=None, grad_weight=None, alpha=1, beta=0,
|
||||
element=None,
|
||||
element_grad_output=None, element_input=None, element_C=None, element_grad_weight=None,
|
||||
element_accumulator=None,
|
||||
grad_output=None, input=None, C=None, grad_weight=None, alpha=1, beta=0,
|
||||
element=None,
|
||||
element_grad_output=None, element_input=None, element_C=None, element_grad_weight=None,
|
||||
element_accumulator=None,
|
||||
cc: int = None, kernel_cc: int = None):
|
||||
A, B, D = grad_output, input, grad_weight
|
||||
element_A, element_B, element_D = element_grad_output, element_input, element_grad_weight
|
||||
super().__init__(
|
||||
"wgrad", A, B, C, D, alpha, beta, element,
|
||||
element_A, element_B, element_C, element_D,
|
||||
"wgrad", A, B, C, D, alpha, beta, element,
|
||||
element_A, element_B, element_C, element_D,
|
||||
element_accumulator, cc, kernel_cc)
|
||||
|
||||
def run(self, grad_output=None, input=None, C=None, grad_weight=None, alpha=None, beta=None,
|
||||
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
||||
|
||||
def run(self, grad_output=None, input=None, C=None, grad_weight=None, alpha=None, beta=None,
|
||||
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
||||
sync: bool = True, print_module: bool = False) -> Conv2dArguments:
|
||||
#
|
||||
A, B, D = grad_output, input, grad_weight
|
||||
|
||||
@ -116,14 +116,18 @@
|
||||
|
||||
from math import prod
|
||||
|
||||
import cutlass_bindings
|
||||
|
||||
import cutlass
|
||||
from cutlass import epilogue, swizzle
|
||||
from cutlass import (
|
||||
epilogue,
|
||||
swizzle,
|
||||
GemmUniversalMode,
|
||||
)
|
||||
from cutlass.backend import compiler
|
||||
from cutlass.backend.evt import EpilogueFunctorVisitor
|
||||
from cutlass.backend.gemm_operation import GemmArguments, GemmOperationUniversal
|
||||
from cutlass.backend.library import TensorDescription, TileDescription
|
||||
from cutlass.op.op import OperationBase
|
||||
from cutlass.shape import GemmCoord
|
||||
from cutlass.utils import check, datatypes
|
||||
|
||||
|
||||
@ -245,7 +249,7 @@ class Gemm(OperationBase):
|
||||
lay_to_set = lay if lay is not None else layout
|
||||
|
||||
elements.append(datatypes.library_type(elt_to_set))
|
||||
layouts.append(datatypes.library_layout(lay_to_set))
|
||||
layouts.append(lay_to_set)
|
||||
|
||||
self._element_a, self._element_b, self._element_c, self._element_d = elements
|
||||
self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts
|
||||
@ -265,6 +269,7 @@ class Gemm(OperationBase):
|
||||
|
||||
self.epilogue_functor = None
|
||||
self.op_class = None
|
||||
self._tile_description = None
|
||||
|
||||
self._reset_operations()
|
||||
|
||||
@ -311,6 +316,48 @@ class Gemm(OperationBase):
|
||||
raise Exception('ThreadblockSwizzleStreamK is currently unsupported on SM90')
|
||||
self._swizzling_functor = swizzling_functor
|
||||
|
||||
#
|
||||
# Tile description Related
|
||||
#
|
||||
|
||||
@property
|
||||
def tile_description(self) -> TileDescription:
|
||||
"""
|
||||
Returns the tile description
|
||||
"""
|
||||
return self._tile_description
|
||||
|
||||
@tile_description.setter
|
||||
def tile_description(
|
||||
self, td=None):
|
||||
"""
|
||||
Set the tile description
|
||||
|
||||
:param td: tile description
|
||||
:type td: cutlass.backend.TileDescription, or a dict with keys
|
||||
{
|
||||
"threadblock_shape": [int, int, int],
|
||||
"warp_count": [int, int, int],
|
||||
"stages": int,
|
||||
"instruction_shape": [int, int, int] (optional),
|
||||
"cluster_shape": [int, int, int] (optional)
|
||||
}
|
||||
"""
|
||||
if td is None:
|
||||
return
|
||||
if isinstance(td, dict):
|
||||
if self._tile_description is None:
|
||||
alignment = list(self.possible_operations.kernels_by_alignment.keys())[0]
|
||||
op = self.possible_operations.operations(alignment)[0]
|
||||
self._tile_description = datatypes.td_from_profiler_op(op)
|
||||
td = self._tile_description.clone_and_update(td)
|
||||
|
||||
valid, msg = self._valid_tile_description(td)
|
||||
if valid:
|
||||
self._tile_description = td
|
||||
else:
|
||||
raise Exception(msg)
|
||||
|
||||
def _valid_tile_description(self, td: TileDescription) -> tuple:
|
||||
"""
|
||||
Checks whether the provided tile description is valid for the given compute capability. At present,
|
||||
@ -328,9 +375,7 @@ class Gemm(OperationBase):
|
||||
and the second element is a string providing an optional error message.
|
||||
:rtype: tuple
|
||||
"""
|
||||
# Check stage count based on the CC to which we are compiling (self.cc), rather
|
||||
# than the CC from which we find kernels (self.current_cc)
|
||||
valid, msg = check.valid_stage_count(self.cc, td, self._element_c, self._element_d)
|
||||
valid, msg = check.valid_stage_count(self.cc, self.current_cc, td, self._element_c, self._element_d)
|
||||
if not valid:
|
||||
return (valid, msg)
|
||||
|
||||
@ -378,30 +423,21 @@ class Gemm(OperationBase):
|
||||
|
||||
self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor)
|
||||
|
||||
tensor_A = TensorDescription(
|
||||
datatypes.binding_type(self._element_a),
|
||||
datatypes.binding_layout(self._layout_a),
|
||||
alignment_A
|
||||
)
|
||||
tensor_B = TensorDescription(
|
||||
datatypes.binding_type(self._element_b),
|
||||
datatypes.binding_layout(self._layout_b),
|
||||
alignment_B
|
||||
)
|
||||
tensor_C = TensorDescription(
|
||||
datatypes.binding_type(self._element_c),
|
||||
datatypes.binding_layout(self._layout_c),
|
||||
alignment_C
|
||||
)
|
||||
tensor_A = TensorDescription(self._element_a, self._layout_a, alignment_A)
|
||||
tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
|
||||
tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
|
||||
|
||||
if tile_description is None:
|
||||
op = self.possible_operations.operations(alignment_A)[0]
|
||||
tile_description = datatypes.td_from_profiler_op(op)
|
||||
if self._tile_description is None:
|
||||
op = self.possible_operations.operations(alignment_A)[0]
|
||||
tile_description = datatypes.td_from_profiler_op(op)
|
||||
else:
|
||||
tile_description = self._tile_description
|
||||
else:
|
||||
valid, err_str = self._valid_tile_description(tile_description)
|
||||
if not valid:
|
||||
raise Exception(f"Invalid tile description. {err_str}")
|
||||
self.tile_description = tile_description
|
||||
self._tile_description = tile_description
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=self.current_cc,
|
||||
@ -473,21 +509,13 @@ class Gemm(OperationBase):
|
||||
:return: tuple of batch count dimensions
|
||||
:rtype: tuple
|
||||
"""
|
||||
A_batch = A.shape[:-2] if len(A.shape) > 2 else tuple()
|
||||
B_batch = B.shape[:-2] if len(B.shape) > 2 else tuple()
|
||||
C_batch = C.shape[:-2] if len(C.shape) > 2 else tuple()
|
||||
D_batch = D.shape[:-2] if len(D.shape) > 2 else tuple()
|
||||
A_batch = prod(A.shape[:-2]) if len(A.shape) > 2 else 1
|
||||
B_batch = prod(B.shape[:-2]) if len(B.shape) > 2 else 1
|
||||
|
||||
if len(D_batch) > 0 and D_batch not in [A_batch, B_batch, C_batch]:
|
||||
raise Exception(f"Batch count in D must be present in one of operands A, B, and C. "
|
||||
f"Batch counts are: A={A_batch}, B={B_batch}, C={C_batch}, D={D_batch}")
|
||||
|
||||
for batch_shape in [A_batch, B_batch, C_batch]:
|
||||
if len(batch_shape) > 0 and batch_shape != D_batch:
|
||||
raise Exception(f"Batch count for all other operands must either match that of D or be zero."
|
||||
f"Received batch shape of {batch_shape}, which does not match that of D of {D_batch}.")
|
||||
|
||||
return D_batch
|
||||
if 1 not in [A_batch, B_batch]:
|
||||
if A_batch != B_batch:
|
||||
raise Exception(f"Get invalid batch counts: A={A_batch}, B={B_batch}")
|
||||
return max(A_batch, B_batch)
|
||||
|
||||
def _get_batch_stride(self, tensor) -> int:
|
||||
"""
|
||||
@ -518,38 +546,38 @@ class Gemm(OperationBase):
|
||||
:param D: tensor D
|
||||
:type D: numpy/cupy/torch array/tensor object
|
||||
|
||||
:return: tuple containing the problem size (cutlass_bindings.gemm.GemmCoord), the GEMM mode (cutlass_bindings.gemm.Mode), and the batch count (int)
|
||||
:return: tuple containing the problem size (cutlass.shape.GemmCoord), the GEMM mode (cutlass.GemmUniversalMode), and the batch count (int)
|
||||
:rtype: tuple
|
||||
"""
|
||||
M, K = A.shape[-2:]
|
||||
N = B.shape[-1]
|
||||
mode = cutlass_bindings.gemm.Mode.Gemm
|
||||
mode = GemmUniversalMode.Gemm
|
||||
|
||||
batch_count = self._get_batch_count(A, B, C, D)
|
||||
returned_batch_count = prod(batch_count) if len(batch_count) > 0 else 1
|
||||
returned_batch_count = batch_count
|
||||
|
||||
# If we are running a batched GEMM in which there is a nonzero batch stride
|
||||
# only for A, then we can fold the batched dimension of A into the M dimension
|
||||
# (i.e., (b, m, k) x (k, n) -> (m*b, k) x (k, n)). This works only if both A
|
||||
# and C are row major. A similar operation can be performed if only B has a nonzero
|
||||
# batch dimension
|
||||
if len(batch_count) > 0:
|
||||
if batch_count > 1:
|
||||
A_row = self._layout_a == cutlass.LayoutType.RowMajor
|
||||
B_row = self._layout_b == cutlass.LayoutType.RowMajor
|
||||
C_row = self._layout_c == cutlass.LayoutType.RowMajor
|
||||
|
||||
batched = lambda x : len(x.shape) == 2 + len(batch_count)
|
||||
batched = lambda x : len(x.shape) > 2 and prod(x.shape[:-2]) == batch_count
|
||||
|
||||
if batched(A) and not batched(B) and batched(C) and A_row and C_row:
|
||||
M *= prod(batch_count)
|
||||
M *= batch_count
|
||||
returned_batch_count = 1
|
||||
elif not batched(A) and batched(B) and batched(C) and not B_row and not C_row:
|
||||
N *= prod(batch_count)
|
||||
N *= batch_count
|
||||
returned_batch_count = 1
|
||||
else:
|
||||
mode = cutlass_bindings.gemm.Mode.Batched
|
||||
mode = GemmUniversalMode.Batched
|
||||
|
||||
return cutlass_bindings.gemm.GemmCoord(M, N, K), mode, returned_batch_count
|
||||
return GemmCoord(M, N, K), mode, returned_batch_count
|
||||
|
||||
def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name):
|
||||
"""
|
||||
@ -570,7 +598,7 @@ class Gemm(OperationBase):
|
||||
f'layout of ({ref_type}, {ref_layout}).')
|
||||
|
||||
def run(self, A=None, B=None, C=None, D=None,
|
||||
alpha=None, beta=None, sync: bool = True, print_module: bool = False) -> GemmArguments:
|
||||
alpha=None, beta=None, sync: bool = True, print_module: bool = False, visitor_args: dict = None) -> GemmArguments:
|
||||
"""
|
||||
Runs the kernel currently specified. If it has not already been, the kernel is emitted and
|
||||
compiled. Tensors holding operands and outputs of the kernel are sourced either from the
|
||||
@ -612,12 +640,12 @@ class Gemm(OperationBase):
|
||||
alignment_a = self.possible_operations.find_alignment(A.shape, self._layout_a)
|
||||
alignment_b = self.possible_operations.find_alignment(B.shape, self._layout_b)
|
||||
alignment_c = self.possible_operations.find_alignment(C.shape, self._layout_c)
|
||||
self.compile(self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
|
||||
self.compile(self._tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
|
||||
alignment_C=alignment_c, print_module=print_module)
|
||||
|
||||
problem_size, mode, batch_count = self._get_problem_args(A, B, C, D)
|
||||
|
||||
if mode == cutlass_bindings.gemm.Mode.Gemm or batch_count == 1:
|
||||
if mode == GemmUniversalMode.Gemm or batch_count == 1:
|
||||
kwargs = {'split_k_slices': 1}
|
||||
else:
|
||||
kwargs = {
|
||||
@ -630,10 +658,15 @@ class Gemm(OperationBase):
|
||||
}
|
||||
}
|
||||
|
||||
if isinstance(self.epilogue_functor, EpilogueFunctorVisitor):
|
||||
output_op = self.operation.epilogue_type(visitor_args)
|
||||
else:
|
||||
output_op = self.operation.epilogue_type(alpha, beta)
|
||||
|
||||
arguments = GemmArguments(
|
||||
operation=self.operation, problem_size=problem_size,
|
||||
A=A, B=B, C=C, D=D,
|
||||
output_op=self.operation.epilogue_type(alpha, beta),
|
||||
output_op=output_op,
|
||||
gemm_mode=mode,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@ -51,19 +51,18 @@
|
||||
plan.run([A0, A1], [B0, B1], [C0, C1], [D0, D1])
|
||||
"""
|
||||
|
||||
import cutlass_bindings
|
||||
|
||||
from cutlass import DataTypeSize
|
||||
from cutlass.backend.gemm_operation import (
|
||||
GemmGroupedArguments,
|
||||
GemmOperationGrouped,
|
||||
)
|
||||
from cutlass.backend.library import (
|
||||
DataTypeSize,
|
||||
SchedulerMode,
|
||||
TensorDescription,
|
||||
TileDescription,
|
||||
)
|
||||
from cutlass.op.gemm import Gemm
|
||||
from cutlass.shape import GemmCoord
|
||||
from cutlass.utils import check, datatypes
|
||||
|
||||
|
||||
@ -170,21 +169,9 @@ class GroupedGemm(Gemm):
|
||||
|
||||
self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor)
|
||||
|
||||
tensor_A = TensorDescription(
|
||||
datatypes.binding_type(self._element_a),
|
||||
datatypes.binding_layout(self._layout_a),
|
||||
alignment_A
|
||||
)
|
||||
tensor_B = TensorDescription(
|
||||
datatypes.binding_type(self._element_b),
|
||||
datatypes.binding_layout(self._layout_b),
|
||||
alignment_B
|
||||
)
|
||||
tensor_C = TensorDescription(
|
||||
datatypes.binding_type(self._element_c),
|
||||
datatypes.binding_layout(self._layout_c),
|
||||
alignment_C
|
||||
)
|
||||
tensor_A = TensorDescription(self._element_a, self._layout_b, alignment_A)
|
||||
tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
|
||||
tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
|
||||
|
||||
if tile_description is None:
|
||||
op = self.possible_operations.operations(alignment_A)[0]
|
||||
@ -244,7 +231,7 @@ class GroupedGemm(Gemm):
|
||||
Bs[i] = self._verify_tensor(B[i], self.B, self._element_b, self._layout_b, "B")
|
||||
Cs[i] = self._verify_tensor(C[i], self.C, self._element_c, self._layout_c, "C")
|
||||
Ds[i] = self._verify_tensor(D[i], self.D, self._element_d, self._layout_d, "D")
|
||||
problem_sizes.append(cutlass_bindings.gemm.GemmCoord(A[i].shape[0], B[i].shape[1], A[i].shape[1]))
|
||||
problem_sizes.append(GemmCoord(A[i].shape[0], B[i].shape[1], A[i].shape[1]))
|
||||
|
||||
alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
|
||||
beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
|
||||
|
||||
@ -38,11 +38,12 @@ from bisect import bisect_left
|
||||
|
||||
import cutlass
|
||||
from cutlass import option_registry, epilogue
|
||||
from cutlass.backend.evt import EpilogueFunctorVisitor
|
||||
from cutlass.backend.utils.device import device_cc
|
||||
from cutlass.epilogue import get_activations
|
||||
from cutlass.library_defaults import _generator_ccs
|
||||
from cutlass.library_defaults import KernelsForDataType, _generator_ccs
|
||||
from cutlass.swizzle import get_swizzling_functors
|
||||
from cutlass.utils import datatypes
|
||||
from cutlass.utils import datatypes, check
|
||||
|
||||
|
||||
class OperationBase:
|
||||
@ -67,7 +68,7 @@ class OperationBase:
|
||||
|
||||
if self.options is None:
|
||||
raise Exception(f"Invalid or unsupported compute capability: {self.current_cc}")
|
||||
|
||||
|
||||
# Default activation function: identity
|
||||
self._activation = epilogue.identity
|
||||
|
||||
@ -120,7 +121,7 @@ class OperationBase:
|
||||
raise Exception(f'Invalid CC for CUTLASS kernels: {cc}.')
|
||||
self.current_cc = cc
|
||||
self.options = option_registry.options_for_cc(self.current_cc, self.operation_kind)
|
||||
|
||||
|
||||
def _verify_scalar(self, scalar, ref_scalar, ref_dtype, name):
|
||||
"""
|
||||
Verifies the following properties:
|
||||
@ -153,7 +154,7 @@ class OperationBase:
|
||||
f"Tensor {name} with type {dtype} does not match expected type {ref_dtype}."
|
||||
)
|
||||
return scalar
|
||||
|
||||
|
||||
def _verify_tensor(self, tensor, ref_tensor, ref_dtype, ref_layout, name):
|
||||
"""
|
||||
Verifies the following properties:
|
||||
@ -183,11 +184,11 @@ class OperationBase:
|
||||
|
||||
self._verify_type_and_layout(tensor, ref_dtype, ref_layout, name)
|
||||
return tensor
|
||||
|
||||
|
||||
#
|
||||
# Opcode Related
|
||||
#
|
||||
|
||||
|
||||
@property
|
||||
def opclass(self) -> cutlass.OpcodeClass:
|
||||
"""
|
||||
@ -197,7 +198,7 @@ class OperationBase:
|
||||
:rtype: cutlass.OpcodeClass
|
||||
"""
|
||||
return self.op_class
|
||||
|
||||
|
||||
@opclass.setter
|
||||
def opclass(self, oc: cutlass.OpcodeClass):
|
||||
if isinstance(oc, str):
|
||||
@ -223,11 +224,11 @@ class OperationBase:
|
||||
self.possible_operations = self.options.operations(
|
||||
self.op_class, self._element_a, self._element_b,
|
||||
self._element_accumulator, self._layout_a, self._layout_b)
|
||||
|
||||
|
||||
#
|
||||
# Epilogue
|
||||
#
|
||||
|
||||
|
||||
def _create_epilogue_functor_activation(self, activation):
|
||||
"""
|
||||
Returns the epilogue functor with given activation function
|
||||
@ -261,44 +262,47 @@ class OperationBase:
|
||||
|
||||
return epilogue.get_activation_epilogue(
|
||||
activation,
|
||||
datatypes.binding_type(self._element_c),
|
||||
self._element_c,
|
||||
elements_per_access,
|
||||
datatypes.binding_type(self._element_accumulator),
|
||||
datatypes.binding_type(self._element_accumulator),
|
||||
self._element_accumulator,
|
||||
self._element_accumulator,
|
||||
)
|
||||
|
||||
|
||||
def _reset_epilogue_functor_activation(self, activation):
|
||||
"""
|
||||
Set the epilogue functor based on the provided activation function
|
||||
"""
|
||||
self.epilogue_functor = self._create_epilogue_functor_activation(activation)
|
||||
|
||||
|
||||
def _reset_epilogue_functor_alignment(self, alignment, epilogue_functor):
|
||||
"""
|
||||
Reset the alignment of the current epilogue functor based on alignment C
|
||||
"""
|
||||
if isinstance(epilogue_functor, EpilogueFunctorVisitor):
|
||||
return epilogue_functor
|
||||
|
||||
if epilogue_functor is None or not hasattr(epilogue_functor, 'activation_functor'):
|
||||
# Identity epilogue does not have 'activation_functor'
|
||||
activation = epilogue.identity
|
||||
else:
|
||||
activation = type(epilogue_functor.activation_functor)
|
||||
activation = epilogue_functor.activation_functor
|
||||
|
||||
epilogue_functor = epilogue.get_activation_epilogue(
|
||||
activation,
|
||||
datatypes.binding_type(self._element_c),
|
||||
self._element_c,
|
||||
alignment,
|
||||
datatypes.binding_type(self._element_accumulator),
|
||||
datatypes.binding_type(self._element_accumulator),
|
||||
self._element_accumulator,
|
||||
self._element_accumulator,
|
||||
)
|
||||
return epilogue_functor
|
||||
|
||||
|
||||
@property
|
||||
def activation(self):
|
||||
"""
|
||||
Returns the type of the current activation function used
|
||||
"""
|
||||
if hasattr(self.epilogue_functor, "activation_functor"):
|
||||
return type(self.epilogue_functor.activation_functor)
|
||||
return self.epilogue_functor.activation_functor
|
||||
else:
|
||||
return epilogue.identity
|
||||
|
||||
@ -307,7 +311,7 @@ class OperationBase:
|
||||
"""
|
||||
Sets the type of the activation function to use
|
||||
Activation can come with a set of arguments
|
||||
|
||||
|
||||
:param act: type of activation function to use
|
||||
:type act: str or tuple. e.g. "relu", ("leaky_relu", 0.01)
|
||||
|
||||
@ -325,4 +329,50 @@ class OperationBase:
|
||||
act = getattr(cutlass.backend.epilogue, act)
|
||||
self._reset_epilogue_functor_activation(act)
|
||||
self._activation = act
|
||||
|
||||
|
||||
@property
|
||||
def epilogue_visitor(self):
|
||||
"""
|
||||
Return the epilogue functor
|
||||
"""
|
||||
return self.epilogue_functor
|
||||
|
||||
@epilogue_visitor.setter
|
||||
def epilogue_visitor(self, visitor):
|
||||
"""
|
||||
Create the epilogue visitor
|
||||
"""
|
||||
self.epilogue_functor = EpilogueFunctorVisitor(self.cc, visitor)
|
||||
|
||||
# The epilogue_functor may consume too much shared memory
|
||||
# Reset the possible operations
|
||||
if self.cc != 90:
|
||||
# The shared memory is only a concern for sm90 epilogue
|
||||
# In sm80, the epilogue and mainloop share the shared memory
|
||||
return
|
||||
datatype_comb = self.possible_operations.datatype_comb
|
||||
layout_comb = self.possible_operations.layout_comb
|
||||
new_possible_operations = KernelsForDataType(datatype_comb, layout_comb)
|
||||
for operation in self.possible_operations.all_operations:
|
||||
td = datatypes.td_from_profiler_op(operation)
|
||||
# Filter invalid epilogue schedules
|
||||
if td.epilogue_schedule not in [
|
||||
cutlass.EpilogueScheduleType.TmaWarpSpecialized,
|
||||
cutlass.EpilogueScheduleType.TmaWarpSpecializedCooperative]:
|
||||
continue
|
||||
epilogue_smem_bytes = self.epilogue_functor.get_smem_size(td)
|
||||
|
||||
# Verify the maximum number of mainloop stages
|
||||
mainloop_smem_per_stage = check.calculate_smem_usage_per_stage(td, cutlass.OperationKind.Gemm)
|
||||
smem_capacity_bytes = cutlass.SharedMemPerCC[self.cc] << 10
|
||||
mainloop_stages = (smem_capacity_bytes - epilogue_smem_bytes) // mainloop_smem_per_stage
|
||||
if mainloop_stages < 2:
|
||||
# Mainloop stages must >= 2
|
||||
continue
|
||||
|
||||
new_possible_operations.add(operation)
|
||||
if len(new_possible_operations.all_operations) == 0:
|
||||
raise RuntimeError(
|
||||
"The epilogue consumes too much shared memory. "
|
||||
"No valid tile description is found in the generator.")
|
||||
self.possible_operations = new_possible_operations
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user