CUTLASS 2.10 updates (#622)
Co-authored-by: Aniket Shivam <ashivam@nvidia.com>
This commit is contained in:
@ -16,7 +16,7 @@ math_inst = MathInstruction(
|
||||
|
||||
tile_description = TileDescription(
|
||||
[128, 128, 8], 4, [2, 4, 1],
|
||||
math_inst, 80, 80
|
||||
math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -31,10 +31,12 @@ C = TensorDescription(
|
||||
cutlass.float32, cutlass.RowMajor, 1
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(cutlass.float32, 1, cutlass.float32, cutlass.float32)
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=cutlass.float32,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -54,7 +56,7 @@ beta = 0.0
|
||||
arguments = GemmArguments(
|
||||
operation=operation, problem_size=problem_size,
|
||||
A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D,
|
||||
output_op=LinearCombinationFunctorArguments(alpha, beta),
|
||||
output_op=operation.epilogue_type(alpha, beta),
|
||||
gemm_mode=cutlass.gemm.Mode.Gemm, split_k_splices=1
|
||||
)
|
||||
|
||||
@ -68,6 +70,14 @@ assert torch.equal(tensor_D, tensor_D_ref)
|
||||
```
|
||||
PyCUTLASS also provides infrastructures for profiling, compiled artifact management, and pool memory manager
|
||||
|
||||
## Supported Features
|
||||
PyCUTLASS currently supports following operations:
|
||||
* GEMM with mode {Serial, Parallel Split K, Batched GEMM, Array GEMM}, op class {SIMT, TensorCore}, data type {int8, f16, bf16, f32, f64}, layout {RowMajor, ColumnMajor, Row/ColumnMajorInterleaved<32> for int8}, math operation {MultiplyAdd, MultiplyAddFastF16, MultiplyAddFastBF16, MultiplyAddFastF32}, swizzling functions {IdentitySwizzle<1,2,4,8>, HorizontalSwizzle, BatchedIdentitySwizzle}, and epilogue {LinearCombination, LinearCombinationClamp}
|
||||
* GEMM grouped with op class {SIMT, TensorCore}, data type {int8, f16, bf16, f32, f64}, layout {RowMajor, ColumnMajor}, math operation {MultiplyAdd, MultiplyAddFastF16, MultiplyAddFastBF16, MultiplyAddFastF32}, scheduling mode {Host, Device}, and epilogue {LinearCombination, LinearCombinationClamp}.
|
||||
* Conv2d with {Fprop, Dgrad, Wgrad}, op class {SIMT, TensorCore}, data type {int8, f16, bf16, f32, f64}, layout {Tensor NHWC, TensorNC32HW32 and TensorC32RSK for int8}, math operation {MultiplyAdd, MultiplyAddFastF16, MultiplyAddFastBF16, MultiplyAddFastF32}, split-k mode {Parallel, Serial}, and epilogue {LinearCombination, LinearCombinationClamp}
|
||||
|
||||
The tiling size of above operations can also be customized.
|
||||
|
||||
## Installation
|
||||
|
||||
### Using Docker
|
||||
@ -94,12 +104,19 @@ cd $CUTLASS_PATH/tools/library/scripts/pycutlass && bash build.sh
|
||||
```
|
||||
|
||||
## Examples
|
||||
Examples can be found in `$CUTLASS_PATH/examples/40_cutlass_py`
|
||||
Examples can be found in [$CUTLASS_PATH/examples/40_cutlass_py](examples/40_cutlass_py)
|
||||
|
||||
## Test
|
||||
The test cases are listed in `$CUTLASS_PATH//tools/library/scripts/pycutlass/test`. The unit test can be run with
|
||||
```shell
|
||||
cd $CUTLASS_PATH/tools/library/scripts/pycutlass/test/unit && python test_sm80.py
|
||||
cd $CUTLASS_PATH/tools/library/scripts/pycutlass/test/example && run_all_example.sh
|
||||
```
|
||||
|
||||
## build documentation
|
||||
Run
|
||||
```shell
|
||||
bash build_doc.sh
|
||||
```
|
||||
|
||||
|
||||
|
||||
@ -1,2 +1,4 @@
|
||||
python setup.py develop
|
||||
pip install enum-tools
|
||||
pip install sphinx-toolbox
|
||||
pip install m2r2
|
||||
sphinx-build -b html docs/source/ docs/build/html
|
||||
|
||||
@ -50,7 +50,7 @@
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = 'PyCutlass'
|
||||
copyright = '2022, Andrew Kerr; Zhaodong Chen; Haicheng Wu; Szymon Migacz; Graham Markall'
|
||||
copyright = '2022, Zhaodong Chen; Andrew Kerr; Haicheng Wu; Szymon Migacz; Graham Markall'
|
||||
author = 'Zhaodong Chen; Andrew Kerr; Haicheng Wu; Szymon Migacz; Graham Markall'
|
||||
|
||||
|
||||
@ -65,9 +65,12 @@ extensions = [
|
||||
'sphinx.ext.autodoc',
|
||||
'sphinx.ext.intersphinx',
|
||||
'enum_tools.autoenum',
|
||||
'sphinx.ext.autosummary'
|
||||
'sphinx.ext.autosummary',
|
||||
'm2r2'
|
||||
]
|
||||
|
||||
source_suffix = [".rst", ".md"]
|
||||
|
||||
autosummary_generate = True
|
||||
autosummary_imported_members = True
|
||||
|
||||
@ -85,7 +88,7 @@ exclude_patterns = []
|
||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||
# a list of builtin themes.
|
||||
#
|
||||
html_theme = 'classic'
|
||||
html_theme = 'bizstyle'
|
||||
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
|
||||
@ -1,2 +1,100 @@
|
||||
cutlass
|
||||
=======
|
||||
|
||||
.. rubric:: Operator Classification
|
||||
|
||||
.. autoclass:: cutlass.OpClass
|
||||
:members:
|
||||
|
||||
.. rubric:: GEMM Layout
|
||||
|
||||
.. autoclass:: cutlass.RowMajor
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.ColumnMajor
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.RowMajorInterleaved32
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.ColumnMajorInterleaved32
|
||||
:members:
|
||||
|
||||
.. rubric:: Conv Layout
|
||||
|
||||
.. autoclass:: cutlass.TensorNHWC
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.TensorNC32HW32
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.TensorC32RSK32
|
||||
:members:
|
||||
|
||||
.. rubric:: Threadblock Swizzle
|
||||
|
||||
.. autoclass:: cutlass.dim3
|
||||
:special-members:
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.IdentitySwizzle1
|
||||
:special-members:
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.IdentitySwizzle2
|
||||
:special-members:
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.IdentitySwizzle4
|
||||
:special-members:
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.IdentitySwizzle8
|
||||
:special-members:
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.HorizontalSwizzle
|
||||
:special-members:
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.BatchedIdentitySwizzle
|
||||
:special-members:
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.StridedDgradIdentitySwizzle1
|
||||
:special-members:
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.StridedDgradIdentitySwizzle4
|
||||
:special-members:
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.StridedDgradHorizontalSwizzle
|
||||
:special-members:
|
||||
:members:
|
||||
|
||||
.. rubric:: Coordinates
|
||||
|
||||
.. autoclass:: cutlass.Tensor4DCoord
|
||||
:special-members:
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.Tensor3DCoord
|
||||
:special-members:
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.MatrixCoord
|
||||
:special-members:
|
||||
:members:
|
||||
|
||||
|
||||
.. rubric:: Convolution
|
||||
|
||||
.. autoclass:: cutlass.conv.Operator
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.conv.IteratorAlgorithm
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.conv.StrideSupport
|
||||
:members:
|
||||
|
||||
@ -1,6 +0,0 @@
|
||||
Descriptions
|
||||
==============
|
||||
|
||||
.. autoclass:: pycutlass.TileDescription
|
||||
:special-members:
|
||||
:members:
|
||||
@ -1,5 +0,0 @@
|
||||
Frontend
|
||||
==============
|
||||
|
||||
.. autoclass:: pycutlass.NumpyFrontend
|
||||
:members:
|
||||
@ -3,27 +3,29 @@
|
||||
You can adapt this file completely to your liking, but it should at least
|
||||
contain the root `toctree` directive.
|
||||
|
||||
Welcome to PyCutlass's documentation!
|
||||
CUTLASS Python Project Documentation
|
||||
=====================================
|
||||
|
||||
.. mdinclude:: ../../README.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: Contents:
|
||||
|
||||
|
||||
|
||||
Indices and tables
|
||||
.. Indices and tables
|
||||
.. ==================
|
||||
|
||||
.. * :ref:`genindex`
|
||||
.. * :ref:`modindex`
|
||||
.. * :ref:`search`
|
||||
|
||||
|
||||
Indices
|
||||
==================
|
||||
|
||||
* :ref:`genindex`
|
||||
* :ref:`modindex`
|
||||
* :ref:`search`
|
||||
|
||||
|
||||
.. toctree::
|
||||
types
|
||||
cutlass
|
||||
descriptor
|
||||
frontend
|
||||
user_guide
|
||||
visitor_tree
|
||||
gemm_op
|
||||
conv2d_op
|
||||
cutlass
|
||||
|
||||
@ -0,0 +1,225 @@
|
||||
# Epilogue Visitor Tree
|
||||
The Epilogue Visitor Tree is an experimental feature that directly generates epilogues from user-provide python functions.
|
||||
|
||||
## Usage
|
||||
|
||||
The Epilogue Visitor tree support many different operations.
|
||||
|
||||
### Unary functions
|
||||
Epilogue Visitor Tree supports unary functions like activation functions. For example,
|
||||
```python
|
||||
class UnaryEpilogue_(EpilogueVisitTree):
|
||||
def __call__(
|
||||
self, accum: 'tensor', c: 'tensor',
|
||||
alpha: 'scalar', beta: 'scalar'):
|
||||
#
|
||||
T = leaky_relu.numpy(accum, 0.2)
|
||||
Z = alpha * T + beta * c
|
||||
return Z
|
||||
epilogue_functor = UnaryEpilogue_(
|
||||
epilogue_functor, tile_description, math_inst.element_accumulator,
|
||||
C.alignment, element_epilogue, C.element)
|
||||
```
|
||||
|
||||
### Broadcast Operation
|
||||
Epilogue Visitor Tree supports broadcasting row and column vectors to the whole output matrix. To use broadcast, you just need to specify whether the source vector is a `row` vector or a `column` vector. Here is an example.
|
||||
```python
|
||||
class ColumnBroadcast_(EpilogueVisitTree):
|
||||
def __call__(
|
||||
self, accum: 'tensor', c: 'tensor',
|
||||
vector: 'column', alpha: 'scalar', beta: 'scalar'):
|
||||
#
|
||||
T = accum + vector
|
||||
scale_T = leaky_relu.numpy(alpha * T, 0.2)
|
||||
Z = scale_T + beta * c
|
||||
return Z, T
|
||||
epilogue_functor = ColumnBroadcast_(
|
||||
epilogue_functor, tile_description, math_inst.element_accumulator,
|
||||
C.alignment, element_epilogue, C.element)
|
||||
```
|
||||
|
||||
### Reduction Operation
|
||||
|
||||
Epilogue Visitor Tree also supports row and column-wise reduction in each threadblock tile. The syntax for reduction is
|
||||
```python
|
||||
{reduction_output} = reduction_op({input_tensor}, {row|column}, {Add}, {threadblock_shape.n|threadblock_shape.m})
|
||||
```
|
||||
The `{row|column}` indicates whether the `row` vectors are reduced or the `column` vectors are reduction. The `{Add}` specifies the reduction operation. The `{threadblock_shape.n|threadblock_shape.m}` are the reduction lengths.
|
||||
|
||||
**Constraint**
|
||||
* The `{input_tensor}` can only be the name of source or intermediate result. `reduction_op(A + B, ...)` will not work, please use `C = A + B`, `reduction_op(C, ...)` instead.
|
||||
* The `{reduction_output}` cannot be used in the epilogue. It will be directly written to global memory after the reduction is done.
|
||||
```python
|
||||
class RowReduction_(EpilogueVisitTree):
|
||||
def __call__(
|
||||
self, accum: 'tensor', c: 'tensor',
|
||||
alpha: 'scalar', beta: 'scalar'):
|
||||
#
|
||||
D = alpha * accum + tanh.numpy(beta * c)
|
||||
reduction = reduction_op(D, "row", "Add", args.threadblock_shape[1])
|
||||
return D, reduction
|
||||
epilogue_functor = RowReduction_(
|
||||
epilogue_functor, tile_description, math_inst.element_accumulator,
|
||||
C.alignment, element_epilogue, C.element)
|
||||
epilogue_functor.initialize()
|
||||
```
|
||||
|
||||
## Get output_op
|
||||
|
||||
As shown in the user guide, an `output_op` is required by the argument wrapper. We will take the `RowReduction_` as an example to show how to get `output_op`.
|
||||
```python
|
||||
class RowReduction_(EpilogueVisitTree):
|
||||
def __call__(
|
||||
self, accum: 'tensor', c: 'tensor',
|
||||
alpha: 'scalar', beta: 'scalar'):
|
||||
#
|
||||
D = alpha * accum + tanh.numpy(beta * c)
|
||||
reduction = reduction_op(D, "row", "Add", args.threadblock_shape[1])
|
||||
return D, reduction
|
||||
epilogue_functor = RowReduction_(
|
||||
epilogue_functor, tile_description, math_inst.element_accumulator,
|
||||
C.alignment, element_epilogue, C.element)
|
||||
epilogue_functor.initialize()
|
||||
|
||||
cta_n = args.threadblock_shape[1]
|
||||
num_cta_n = (problem_size.n() + cta_n - 1) // cta_n
|
||||
reduction = np.zeros(shape=(args.batch * problem_size.m() * num_cta_n,), dtype=getattr(np, element_c))
|
||||
# get output op
|
||||
output_op = operation.epilogue_type(
|
||||
D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()]
|
||||
)
|
||||
```
|
||||
Like other epilogue functors such as `LinearCombination`, the output op for EpilogueVisitorTree is also created with `operation.epilogue_type(*)`. However, there are two differences:
|
||||
* The arguments need to be passed as keyword-arguments. The keywords are the argument names in `def __call__`.
|
||||
* An additional `problem_size=[problem_size.m(), problem_size.n()]` is required.
|
||||
|
||||
|
||||
## Add new Unary Operation (e.g. Activation Function)
|
||||
To add additional unary operation into epilogue visitor tree, a new unary op
|
||||
should be created for `VisitorOpUnary`. We will take `tanh` as an example.
|
||||
|
||||
### Step 1: define TanhVisitor
|
||||
|
||||
The visitor defines the parameters and computation required by the unary option.
|
||||
The unary operations are registered in [pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/unary_ops.h](tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/unary_ops.h). But you can define it in any header file and include the header file in [pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_unary.h](tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_unary.h).
|
||||
|
||||
|
||||
* Two template arguments are required:
|
||||
* `T`: data type used to compute the unary operation
|
||||
* `N`: compute fragment length
|
||||
* We also need to provide the `Arguments` and `Params` structures. The `Arguments` will be assembled by [ctypes](https://docs.python.org/3/library/ctypes.html), the `Params` will be generated from `Arguments` automatically. If the unary function takes no argument, an integer like `int tmp` can be provide to ensure the correctness of ctypes.
|
||||
* The constructor can only take the `params` as the single argument.
|
||||
* The operation is defined in `Array<T, N> operator()(Array<T, N> const &frag) const `. On common way to do that is first define a scalar computation, and them use it for the fragment computation with an unrolled for-loop.
|
||||
* A guard function is required. If it returns `true`, it will disable all the children nodes of the unary node and return zeros to parent node. This is very helpful for multiplication with scalar while the scalar is `0`. For general cases, you can just return `true`.
|
||||
```c++
|
||||
// T: data type used to compute the unary operation
|
||||
// N: compute fragment length
|
||||
template <typename T, int N>
|
||||
struct TanhVisitor {
|
||||
/// Argument
|
||||
struct Arguments {
|
||||
// a placeholder argument to ensure correctness of ctypes
|
||||
int tmp;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): tmp(0) { };
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(int tmp): tmp(tmp) { };
|
||||
};
|
||||
|
||||
/// Param
|
||||
struct Params {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(){ };
|
||||
Params(Arguments const &args) { }
|
||||
};
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TanhVisitor(Params const ¶ms) { }
|
||||
|
||||
// scalar operator
|
||||
CUTLASS_HOST_DEVICE
|
||||
T tanh_op(T const &scalar) const {
|
||||
return fast_tanh(scalar);
|
||||
}
|
||||
|
||||
/// vector operator
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &frag) const {
|
||||
Array<T, N> y;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i=0; i < N; ++i) {
|
||||
y[i] = tanh_op(frag[i]);
|
||||
}
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
// Guard
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool guard() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
### Step 2: register Tanh function
|
||||
After defining the function in C++, we need to register it in python. The class below gives an example.
|
||||
* The init function takes the data type `element_compute`, which will be the `T` in the C++ template.
|
||||
In the init function, we also generate the `_Arguments` class as a `ctypes.Structure`. It includes all the data members in the `TanhVisitor::Arguments`.
|
||||
* The `_Arguments` need to be registered as `self.argument_type` of `tanh` class.
|
||||
* A `emit` function is required to emit the namespace and typename of `TanhVisitor`.
|
||||
* A staticmethod as numpy reference is required to implement the python code to parse.
|
||||
|
||||
The built-in functions are defined in [pycutlass/src/pycutlass/epilogue.py](tools/library/scripts/pycutlass/src/pycutlass/epilogue.py). You can defined yours in any file as long as it can be found by [/pycutlass/src/pycutlass/parser.py](tools/library/scripts/pycutlass/src/pycutlass/parser.py).
|
||||
```python
|
||||
class tanh(ActivationFunctor):
|
||||
def __init__(self, element_compute) -> None:
|
||||
super().__init__()
|
||||
class _Arguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("tmp", ctypes.c_int)
|
||||
]
|
||||
def __init__(self, *args) -> None:
|
||||
self.tmp = 0
|
||||
self.argument_type = _Arguments
|
||||
|
||||
def emit(self):
|
||||
return "cutlass::TanhVisitor"
|
||||
|
||||
@staticmethod
|
||||
def numpy(x: np.ndarray):
|
||||
return np.tanh(x)
|
||||
```
|
||||
|
||||
### Step 3: Run the function
|
||||
Now the new unary op is ready to use. An epilogue visitor tree can be built with
|
||||
```python
|
||||
class RowReduction_(EpilogueVisitTree):
|
||||
def __call__(
|
||||
self, accum: NDArray['tensor', 'float32'], c: NDArray['tensor', 'float32'],
|
||||
alpha: 'float32', beta: 'float32'):
|
||||
#
|
||||
D = alpha * accum + tanh.numpy(beta * c)
|
||||
reduction = reduction_op(D, "row", "Add", args.threadblock_shape[1])
|
||||
return D, reduction
|
||||
epilogue_functor = RowReduction_(
|
||||
epilogue_functor, tile_description, math_inst.element_accumulator,
|
||||
C.alignment, element_epilogue, C.element)
|
||||
epilogue_functor.initialize()
|
||||
```
|
||||
|
||||
## Limitations and Future work
|
||||
|
||||
Although the Epilogue Visitor Tree brings great flexibility to epilogue construction, as the epilogue is formulated as a single tree, there are several limitations.
|
||||
* [Future Work] Serial and Parallel Split-K GEMM are not supported yet.
|
||||
* To support serial split-k, additional tree transformation pass is required to inject a `binaryOpNode(Add)` + `TensorInputNode` before each `TensorOutputNode` to fetch the partial sum back. The `semaphore` also needs to be passed into epilogue.
|
||||
* To support parallel split-k, an Reduction with visitor kernel is required.
|
||||
* [Future Work] Convolution and GEMM Grouped are not supported yet.
|
||||
* To support Conv2d and GEMM Grouped, corresponding *_with_visitor kernels are required.
|
||||
|
||||
* [Limitation] If the same node is used by two operations (except that one of them is reduction), the node and all its offsprings will be executed twice.
|
||||
* [Limitation] The result of reduction can only be used as the return value.
|
||||
283
tools/library/scripts/pycutlass/docs/source/md/basic_idea.md
Normal file
283
tools/library/scripts/pycutlass/docs/source/md/basic_idea.md
Normal file
@ -0,0 +1,283 @@
|
||||
# Basics of PyCUTLASS
|
||||
|
||||
PyCUTLASS handles the following things when launch the CUTLASS kernels
|
||||
* Memory management
|
||||
* Operation Description
|
||||
* Code emission and compilation
|
||||
* Arguments preprocessing
|
||||
* Kernel launching
|
||||
* Result Synchronization
|
||||
|
||||
## Memory management
|
||||
|
||||
PyCUTLASS uses [RMM](https://github.com/rapidsai/rmm) to manage device memory. At the begining of the program, call
|
||||
```python
|
||||
pycutlass.get_memory_pool({init_pool_size_in_bytes}, {max_pool_size_in_bytes})
|
||||
```
|
||||
We also provide functions to query the allocated size.
|
||||
```python
|
||||
bytes = get_allocated_size()
|
||||
```
|
||||
|
||||
|
||||
## Operation Description
|
||||
PyCUTLASS provides operation description for GEMM, GEMM Grouped and Conv2d operations. These operation descriptions are assembled from four foundamental concepts
|
||||
* Math Instruction: math instruction executed in GPU cores
|
||||
* Tile Description: tiling sizes and pipeline stages
|
||||
* Operand Description: data type, layout, memory alignment
|
||||
* Epilogue Functor: epilogue function
|
||||
|
||||
### Math Instruction
|
||||
|
||||
The math instruction is defined as follows:
|
||||
```python
|
||||
math_inst = MathInstruction(
|
||||
{instruction_shape}, {element_a}, {element_b},
|
||||
{element_acc}, {opclass}, {math_operation}
|
||||
)
|
||||
```
|
||||
The `{instruction_shape}` and `{opclass}` defines the instruction size and type. The table below lists valid combinations. `{element_a}`, `{element_b}` define the source operand data type for each instructions, and `{element_acc}` defines the accumulator type. The `{math_operation}` defines the math operation applied.
|
||||
|
||||
|Opclass | element_a/element_b | element_acc | instruction_shape | math_operation |
|
||||
| -- | -- | -- | -- | -- |
|
||||
| cutlass.OpClass.TensorOp | cutlass.float64 | cutlass.float64 | [8, 8, 4] | MathOperation.multiply_add|
|
||||
| | cutass.float32 cutlass.tfloat32, cutlass.float16 cutlass.bfloat16 | cutlass.float32 | [16, 8, 8] | MathOperation.multiply_add MathOperation.multiply_add_fast_f32 MathOperation.multiply_add_fast_f16 MathOperation.multiply_add_fast_bf16 |
|
||||
| | cutlass.float16 | cutlass.float16/cutlass.float32|[16, 8, 16]| MathOperation.multiply_add |
|
||||
| | cutlass.bfloat_16 | cutlass.float32 | [16, 8, 16]|MathOperation.multiply_add |
|
||||
| | cutlass.int8 | cutlass.int32 | [16, 8, 32] | MathOperation.multiply_add_saturate|
|
||||
|cutlass.OpClass.Simt| cutlass.float64 | cutlass.float64 | [1, 1, 1] | MathOperation.multiply_add |
|
||||
| | cutlass.float32 | cutlass.float32 | [1, 1, 1] | MathOperation.multiply_add |
|
||||
|
||||
The `cutlass.OpClass.TensorOp` indicates that the tensor core is used, while `cutlass.OpClass.Simt` uses the SIMT Core.
|
||||
|
||||
The `multiply_add_fast_f32` emulates fast accurate SGEMM kernel which is accelerated
|
||||
using Ampere Tensor Cores. More details can be found in [examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm](examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm).
|
||||
|
||||
### Tile Description
|
||||
The tile description describes the threadblock and warp tiling sizes, as well as the pipeline stages.
|
||||
```python
|
||||
tile_description = TileDescription(
|
||||
{threadblock_shape}, {stages}, {warp_count},
|
||||
math_inst
|
||||
)
|
||||
```
|
||||
The `{threadblock_shape}` is a list of 3 integers `[Tile_M, Tile_N, Tile_K]` that defines the threadblock tiling size. `{stages}` defines the number of software pipeline stages ([detail](https://developer.nvidia.com/blog/controlling-data-movement-to-boost-performance-on-ampere-architecture/)). `{warp_count}` defines the number of warps along `M`, `N`, and `K` dimension. I.e., with `{threadblock_shape}=[Tile_M, Tile_N, Tile_K]` and `{warp_count}=[W_M, W_N, W_K]`, the warp tile size would be `[Tile_M / W_M, Tile_N / W_N, Tile_K / W_K]`.
|
||||
|
||||
### Operand Description
|
||||
The Operand Description defines the data type, layout, and memory alignment of input tensor A, B, and C. The output D shares the same attributes with C. The description is as follows:
|
||||
```python
|
||||
A = TensorDescription(
|
||||
{element_a}, {layout_a}, {alignment_a}
|
||||
)
|
||||
|
||||
B = TensorDescription(
|
||||
{element_b}, {layout_b}, {alignment_b}
|
||||
)
|
||||
|
||||
C = TensorDescription(
|
||||
{element_c}, {layout_c}, {alignment_c}
|
||||
)
|
||||
```
|
||||
The table below lists the supported layout and data types for each operation
|
||||
| Operation | data type | layout |
|
||||
| -- | -- | -- |
|
||||
| GEMM, GEMM Grouped | cutlass.float64, cutlass.float32, cutlass.float16, cutlass.bfloat16 | cutlass.RowMajor, cutlass.ColumnMajor |
|
||||
| | cutlass.int8 | cutlass.RowMajor, cutlass.ColumnMajor, cutlass.RowMajorInterleaved32, cutlass.ColumnMajorInterleaved32|
|
||||
| Conv2d Fprop, Dgrad, Wgrad | cutlass.float64, cutlass.float32, cutlass.float16, cutlass.bfloat16 | cutlass.TensorNHWC |
|
||||
| Conv2d Fprop | cutlass.int8 | cutlass.TensorNHWC, cutlass.TensorNC32HW32, cutlass.TensorC32RSK32|
|
||||
|
||||
### Epilogue Functor
|
||||
The epilogue functor defines the epilogue executed after mainloop.
|
||||
We expose the following epilogue functors.
|
||||
| Epilogue Functor | Remark |
|
||||
| -- | -- |
|
||||
| LinearCombination | $D=\alpha \times Accm + \beta \times C$ |
|
||||
| LinearCombinationClamp | $D=\alpha \times Accm + \beta \times C$, Output is clamped to the maximum value of the data type output |
|
||||
| FastLinearCombinationClamp | $D=\alpha \times Accm + \beta \times C$, only used for problem size $K\le 256$ for cutlass.int8, with accumulator data type `cutlass.int32` and epilogue compute data type `cutlass.float32` |
|
||||
| LinearCombinationGeneric | $D = activation(\alpha \times Accm + \beta \times C)$, available activations include `relu`, `leaky_relu`, `tanh`, `sigmoid`, `silu`, `hardswish`, and `gelu` |
|
||||
|
||||
The epilogue functors can be created as follows
|
||||
```python
|
||||
# LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
element_C, alignment_c, element_acc, element_epilogue_compute
|
||||
)
|
||||
|
||||
# LinearCombinationClamp
|
||||
epilogue_functor = LinearCombinationClamp(
|
||||
element_C, alignment_c, element_acc, element_epilogue_compute
|
||||
)
|
||||
|
||||
# FastLinearCombinationClamp
|
||||
epilogue_functor = FastLinearCombinationClamp(
|
||||
element_C, alignment_c
|
||||
)
|
||||
|
||||
# LinearCombinationGeneric
|
||||
epilogue_functor = LinearCombinationGeneric(
|
||||
relu(element_epilogue_compute), element_C, alignment_c,
|
||||
element_acc, element_epilogue_compute
|
||||
)
|
||||
```
|
||||
|
||||
We also provides an experimental feature "Epilogue Visitor Tree" for GEMM operation. The details can be found in [EpilogueVisitorTree](tools/library/scripts/pycutlass/docs/source/md/EpilogueVisitorTree.md).
|
||||
|
||||
|
||||
### GEMM Operation
|
||||
|
||||
The GEMM Operation description can be created with
|
||||
```python
|
||||
operation = GemmOperationUniversal(
|
||||
{compute_capability}, tile_description,
|
||||
A, B, C, epilogue_functor,
|
||||
{swizzling_functor}, {visitor}
|
||||
)
|
||||
```
|
||||
* `{compute_capability}` is an integer indicates the compute capability of the GPU. For A100, it is 80.
|
||||
* `{swizzling_functor}` describes how threadblocks are scheduled on GPU. This is used to improve the L2 Locality ([detail](https://developer.nvidia.com/blog/optimizing-compute-shaders-for-l2-locality-using-thread-group-id-swizzling/)). Currently we support `cutlass.{IdentitySwizzle1|IdentitySwizzle2|IdentitySwizzle4|IdentitySwizzle8|BatchedIdentitySwizzle}`. The last one is used for batched or array GEMM.
|
||||
* `{visitor}`: a bool variable indicates whether the epilogue visitor tree is used.
|
||||
|
||||
### GEMM Grouped Operation
|
||||
The GEMM Grouped Operation description can be created with
|
||||
```python
|
||||
operation = GemmOperationGrouped(
|
||||
compute_capability, tile_description,
|
||||
A, B, C, epilogue_functor,
|
||||
swizzling_functor, {precompute_mode}
|
||||
)
|
||||
```
|
||||
* `{precompute_mode}`: It could be `SchedulerMode.Host` or `SchedulerMode.Device`. See [examples/24_gemm_grouped](examples/24_gemm_grouped) for more details.
|
||||
|
||||
|
||||
### Conv2d Operation
|
||||
The Conv2d Operation description can be created with
|
||||
```python
|
||||
operation = Conv2dOperation(
|
||||
{conv_kind}, {iterator_algorithm},
|
||||
compute_capability, tile_description,
|
||||
A, B, C, {stride_support},
|
||||
epilogue_functor, swizzling_functor
|
||||
)
|
||||
```
|
||||
* `{conv_kind}` defines which convolution is executed. Available options include `fprop`, `dgrad`, and `wgrad`.
|
||||
* `{iterator_algorithm}` specifies the iterator algorithm used by the implicit GEMM in convolution. The options are as follows:
|
||||
* `analytic`: functionally correct in all cases but lower performance
|
||||
* `optimized`: optimized for R <= 32, S <= 32 and unity-stride dgrad
|
||||
* `fixed_channels`: analytic algorithm optimized for fixed channel count (C == AccessSize)
|
||||
* `few_channels`: Analytic algorithm optimized for few channels (C divisible by AccessSize)
|
||||
* `{stride_support}`: distinguishes among partial specializations that accelerate certain problems where convolution
|
||||
stride is unit.
|
||||
* `strided`: arbitrary convolution stride
|
||||
* `unity`: unit convolution stride
|
||||
|
||||
***
|
||||
## Code Emission and Compilation
|
||||
After implementing the operation description, the related host and device code can be compiled with
|
||||
```python
|
||||
import pycutlass
|
||||
|
||||
pycutlass.compiler.add_module([operation,])
|
||||
```
|
||||
Several operations can be compiled togather. The `nvcc` at `$CUDA_INSTALL_PATH/bin` is used by default as the compiler backend. But you can also switch to [CUDA Python](https://nvidia.github.io/cuda-python/overview.html)'s `nvrtc` with
|
||||
```python
|
||||
pycutlass.compiler.nvrtc()
|
||||
```
|
||||
We also have an internal compiled artifact manager that caches the compiled kernel in both memory and disk. The `compiled_cache.db` at your workspace is the database that contains the binary files. You can delete the file if you want to recompile the kernels.
|
||||
***
|
||||
## Argument Processing
|
||||
We provide argument wrapper to convert python tensors to the kernel parameters. Currently it supports [torch.Tensor](https://pytorch.org/), [numpy.ndarray](https://numpy.org/), and [cupy.ndarray](https://cupy.dev/).
|
||||
### GEMM Arguments
|
||||
The Gemm arguments can be created with
|
||||
```python
|
||||
arguments = GemmArguments(
|
||||
operation=operation, problem_size={problem_size},
|
||||
A={tensor_A}, B={tensor_B}, C={tensor_C}, D={tensor_D},
|
||||
output_op={output_op},
|
||||
gemm_mode={gemm_mode},
|
||||
split_k_slices={split_k_slices}, batch={batch}
|
||||
)
|
||||
```
|
||||
* `problem_size` is a `cutlass.gemm.GemmCoord(M, N, K)` object that defines $M\times N\times K$ matrix multiplication.
|
||||
* `tensor_X`: user-provide tensors.
|
||||
* `output_op`: the params for the epilogue functor.
|
||||
* `gemm_mode`, `split_k_slices`, and `batch`:
|
||||
|
||||
|gemm_mode| split_k_slices | batch | remark|
|
||||
|--|--|--|--|
|
||||
|cutlass.gemm.Mode.Gemm | number of split-K slices | - | the ordinary GEMM or GEMM with serial split-K|
|
||||
|cutlass.gemm.Mode.GemmSplitKParallel | number of split-K slices | - | GEMM Split-K Parallel|
|
||||
|cutlass.gemm.Mode.Batched | - | batch size | Batched GEMM |
|
||||
|cutlass.gemm.Mode.Array | - | batch size | Array GEMM |
|
||||
|
||||
### GEMM Grouped Arguments
|
||||
The GEMM grouped arguments can be created with
|
||||
```python
|
||||
arguments = GemmGroupedArguments(
|
||||
operation, {problem_sizes_coord}, {tensor_As}, {tensor_Bs}, {tensor_Cs}, {tensor_Ds},
|
||||
output_op=output_op)
|
||||
)
|
||||
```
|
||||
* `problem_size_coord` is a list of `cutlass.gemm.GemmCoord(M, N, K)` for each problem size.
|
||||
* `tensor_Xs` is a list of user-provide tensors.
|
||||
* `output_op`: the params of the epilogue functor
|
||||
|
||||
### Conv2d Arguments
|
||||
The Conv2d arguments can be created with
|
||||
```python
|
||||
arguments = Conv2dArguments(
|
||||
operation, {problem_size}, {tensor_A},
|
||||
{tensor_B}, {tensor_C}, {tensor_D},
|
||||
{output_op},
|
||||
{split_k_mode},
|
||||
{split_k_slices}
|
||||
)
|
||||
```
|
||||
* `problem_size`: it can be constructed with
|
||||
```python
|
||||
problem_size = cutlass.conv.Conv2dProblemSize(
|
||||
cutlass.Tensor4DCoord(N, H, W, C),
|
||||
cutlass.Tensor4DCoord(K, R, S, C),
|
||||
cutlass.Tensor4DCoord(pad[0], pad[1], pad[2], pad[3]),
|
||||
cutlass.MatrixCoord(stride[0], stride[1]),
|
||||
cutlass.MatrixCoord(dilation[0], dilation[1]),
|
||||
cutlass.conv.Mode.cross_correlation,
|
||||
split_k_slices, 1
|
||||
)
|
||||
```
|
||||
* `tensor_X` are user-provide tensors
|
||||
* `output_op`: the params of the epilogue functor
|
||||
* `split_k_mode`: currently we support `cutlass.conv.SplitKMode.Serial` and `cutlass.conv.SplitKMode.Parallel`.
|
||||
* `split_k_slice`: number of split-k slices
|
||||
|
||||
For ordianry conv2d, just use `cutlass.conv.SplitKMode.Serial` with `split_k_slice=1`.
|
||||
|
||||
### Getting output_op
|
||||
The way to create output_op is listed below
|
||||
```python
|
||||
output_op = operation.epilogue_type(*([alpha, beta] + args.activation_args)),
|
||||
```
|
||||
It is a list of arguments start with the scaling factor `alpha` and `beta`.
|
||||
The `output_op` of EpilogueVisitorTree is slightly different. Please check [EpilogueVisitorTree](tools/library/scripts/pycutlass/docs/source/md/EpilogueVisitorTree.md) for details.
|
||||
|
||||
|
||||
## Kernel Launching
|
||||
|
||||
With the arguments and operations, the kernel can be launched simply with
|
||||
```python
|
||||
operation.run(arguments)
|
||||
```
|
||||
|
||||
## Sync results
|
||||
|
||||
We also provide function to synchronize the kernel execution. If you use `numpy`, it will also copy the result back to host. To do that, run
|
||||
```python
|
||||
arguments.sync()
|
||||
```
|
||||
If you use EpilogueVisitorTree, please call
|
||||
```python
|
||||
output_op.sync()
|
||||
```
|
||||
|
||||
## Reduction Kernel behind Parallel Split-K
|
||||
|
||||
If you use parallel-split-K in GEMM or Conv2d, an additional reduction kernel is required. Please check [examples/40_cutlass_py](examples/40_cutlass_py) for detail.
|
||||
@ -1,6 +0,0 @@
|
||||
Types
|
||||
========
|
||||
|
||||
|
||||
.. autoenum:: pycutlass.OperationKind
|
||||
:members:
|
||||
@ -0,0 +1,4 @@
|
||||
User Guide
|
||||
=====================================
|
||||
|
||||
.. mdinclude:: ./md/basic_idea.md
|
||||
@ -0,0 +1,4 @@
|
||||
User Guide
|
||||
=====================================
|
||||
|
||||
.. mdinclude:: ./md/EpilogueVisitorTree.md
|
||||
@ -32,6 +32,7 @@
|
||||
|
||||
from pycutlass import *
|
||||
import pycutlass
|
||||
from pycutlass.epilogue import LinearCombination
|
||||
from pycutlass.test.conv2d_testbed import Conv2dLauncher
|
||||
|
||||
|
||||
@ -62,15 +63,16 @@ if __name__ == "__main__":
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=4,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(cutlass.float32, 4, cutlass.float32, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -49,7 +49,7 @@ if __name__ == '__main__':
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[256, 128, 32],
|
||||
stages=3, warp_count=[4, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -67,7 +67,7 @@ if __name__ == '__main__':
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(cutlass.float32, 4, cutlass.float32, cutlass.float32)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
|
||||
@ -43,7 +43,7 @@ try:
|
||||
Pybind11Extension("cutlass",
|
||||
["src/cpp/cutlass.cpp"],
|
||||
include_dirs=include_dirs,
|
||||
extra_compile_args=["-fpermissive"])
|
||||
extra_compile_args=["-fpermissive", "-w"])
|
||||
]
|
||||
except ImportError:
|
||||
pass
|
||||
@ -69,7 +69,8 @@ setup(
|
||||
'typeguard',
|
||||
'bfloat16',
|
||||
'typing',
|
||||
'scikit-build'
|
||||
'scikit-build',
|
||||
'treelib'
|
||||
],
|
||||
cmdclass={
|
||||
'rmm': BuildRMM
|
||||
|
||||
@ -0,0 +1,225 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor with CTA row-wise broadcast
|
||||
|
||||
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
||||
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_complex.h"
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h"
|
||||
|
||||
#include "epilogue_visitor_op/visitor_op_linear_combination.h"
|
||||
#include "epilogue_visitor_op/visitor_op_tensor_input.h"
|
||||
#include "epilogue_visitor_op/visitor_op_accumulator.h"
|
||||
#include "epilogue_visitor_op/visitor_op_row_broadcast.h"
|
||||
#include "epilogue_visitor_op/visitor_op_tensor_output.h"
|
||||
#include "epilogue_visitor_op/visitor_op_column_reduction.h"
|
||||
#include "epilogue_visitor_op/visitor_op_row_reduction.h"
|
||||
#include "epilogue_visitor_op/visitor_op_column_broadcast.h"
|
||||
#include "epilogue_visitor_op/visitor_op_unary.h"
|
||||
#include "epilogue_visitor_op/visitor_op_binary.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Generic Epilogue Visitor.
|
||||
template <
|
||||
typename OutputOp_
|
||||
>
|
||||
class EpilogueVisitorGeneric {
|
||||
public:
|
||||
|
||||
using OutputOp = OutputOp_;
|
||||
using AccumulatorAccessType = typename OutputOp::AccumulatorAccessType;
|
||||
static int const kElementsPerAccess = OutputOp::kElementsPerAccess;
|
||||
using ElementOutput = typename OutputOp::ElementOutput;
|
||||
using OutputTileIterator = typename OutputOp::OutputTileIterator;
|
||||
|
||||
static int const kIterations = OutputTileIterator::kIterations;
|
||||
|
||||
///
|
||||
/// End Epilogue Tree
|
||||
///
|
||||
|
||||
/// Additional SMEM bufer is not required in the broadcast epilogue visitor
|
||||
struct SharedStorage {
|
||||
|
||||
typename OutputOp::SharedStorage output_smem;
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() { }
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
typename OutputOp::Arguments output_op_args;
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
Arguments() { }
|
||||
|
||||
Arguments(
|
||||
typename OutputOp::Arguments output_op_args
|
||||
):
|
||||
output_op_args(output_op_args)
|
||||
{
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
struct Params {
|
||||
typename OutputOp::Params output_op_params;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
output_op_params(args.output_op_args)
|
||||
{
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
private:
|
||||
|
||||
OutputOp output_op;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
EpilogueVisitorGeneric(
|
||||
Params const ¶ms, ///< Parameters routed to the epilogue
|
||||
SharedStorage &shared_storage, ///< Shared storage needed by the functors here
|
||||
MatrixCoord threadblock_offset,
|
||||
gemm::GemmCoord threadblock_tile_offset,
|
||||
int thread_idx,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
output_op(params.output_op_params, shared_storage.output_smem, thread_idx, threadblock_offset, problem_size)
|
||||
{ }
|
||||
|
||||
/// Helper to indicate split-K behavior
|
||||
CUTLASS_DEVICE
|
||||
void set_k_partition(
|
||||
int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
|
||||
int split_k_slices) { ///< Total number of split-K slices
|
||||
|
||||
}
|
||||
|
||||
/// Called to set the batch index
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
output_op.set_batch_index(batch_idx);
|
||||
}
|
||||
|
||||
/// Called at the start of the epilogue just before iterating over accumulator slices
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
output_op.begin_epilogue();
|
||||
}
|
||||
|
||||
/// Called at the start of one step before starting accumulator exchange
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
output_op.begin_step(step_idx);
|
||||
}
|
||||
|
||||
/// Called at the start of a row
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
output_op.begin_row(row_idx);
|
||||
}
|
||||
|
||||
/// Called after accumulators have been exchanged for each accumulator vector
|
||||
CUTLASS_DEVICE
|
||||
void visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum) {
|
||||
output_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
}
|
||||
|
||||
/// Called at the start of a row
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
output_op.end_row(row_idx);
|
||||
|
||||
}
|
||||
|
||||
/// Called after all accumulator elements have been visited
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
output_op.end_step(step_idx);
|
||||
}
|
||||
|
||||
/// Called after all steps have been completed
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
output_op.end_epilogue();
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,84 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the binary ops
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
/// Scalar multiplication
|
||||
template <typename T, int N>
|
||||
struct VectorAdd {
|
||||
|
||||
struct Arguments {
|
||||
int tmp;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():tmp(0){ }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(int tmp): tmp(tmp) { }
|
||||
};
|
||||
|
||||
struct Params {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args) { }
|
||||
};
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VectorAdd(
|
||||
Params const ¶ms
|
||||
) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
|
||||
cutlass::plus<Array<T, N>> add_op;
|
||||
return add_op(lhs, rhs);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,233 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the unary ops
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
/// Scalar multiplication
|
||||
template <typename T, int N>
|
||||
struct Mult {
|
||||
|
||||
struct Arguments {
|
||||
T alpha;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():alpha(T(1.0)){ }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(T alpha): alpha(alpha) { }
|
||||
};
|
||||
|
||||
struct Params {
|
||||
T alpha; ///< scales accumulators
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():alpha(T(1.0)){ }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args): alpha(args.alpha) { }
|
||||
};
|
||||
|
||||
T alpha_;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Mult(
|
||||
Params const ¶ms
|
||||
):
|
||||
alpha_(params.alpha)
|
||||
{ }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &source) const {
|
||||
cutlass::multiplies<Array<T, N>> multiply_op;
|
||||
return multiply_op(source, alpha_);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool guard() {
|
||||
return alpha_ != T(0);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
/// ReLU
|
||||
template <typename T, int N>
|
||||
struct ReLUVisitor {
|
||||
struct Arguments {
|
||||
T threshold;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():threshold(T(0.0)) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(T threshold): threshold(threshold) { }
|
||||
};
|
||||
|
||||
struct Params {
|
||||
T threshold;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():threshold(T(0.0)) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args): threshold(args.threshold) { }
|
||||
};
|
||||
|
||||
T threshold_;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
ReLUVisitor(Params const ¶ms):
|
||||
threshold_(params.threshold) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &frag) const {
|
||||
maximum<Array<T, N>> mx;
|
||||
return mx(frag, threshold_);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool guard() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
/// leakyReLU
|
||||
template <typename T, int N>
|
||||
struct LeakyReLUVisitor {
|
||||
struct Arguments {
|
||||
T leaky_alpha;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():leaky_alpha(T(0.0)) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(T leaky_alpha): leaky_alpha(leaky_alpha) { }
|
||||
};
|
||||
|
||||
struct Params {
|
||||
T leaky_alpha;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():leaky_alpha(T(0.0)) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args): leaky_alpha(args.leaky_alpha) { }
|
||||
};
|
||||
|
||||
T leaky_alpha_;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
LeakyReLUVisitor(Params const ¶ms):
|
||||
leaky_alpha_(params.leaky_alpha) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &frag) const {
|
||||
cutlass::epilogue::thread::LeakyReLU<Array<T, N>> leaky_op;
|
||||
return leaky_op(frag, leaky_alpha_);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool guard() {
|
||||
return true;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/// Tanh
|
||||
template <typename T, int N>
|
||||
struct TanhVisitor {
|
||||
/// Argument
|
||||
struct Arguments {
|
||||
// a placeholder argument to ensure correctness of ctypes
|
||||
int tmp;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): tmp(0) { };
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(int tmp): tmp(tmp) { };
|
||||
};
|
||||
|
||||
/// Param
|
||||
struct Params {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(){ };
|
||||
Params(Arguments const &args) { }
|
||||
};
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TanhVisitor(Params const ¶ms) { }
|
||||
|
||||
// scalar operator
|
||||
CUTLASS_HOST_DEVICE
|
||||
T tanh_op(T const &scalar) const {
|
||||
return fast_tanh(scalar);
|
||||
}
|
||||
|
||||
/// vector operator
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &frag) const {
|
||||
Array<T, N> y;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i=0; i < N; ++i) {
|
||||
y[i] = tanh_op(frag[i]);
|
||||
}
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool guard() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,148 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with accumulator
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue Visitor operator for the following Computation
|
||||
///
|
||||
/// ElementAccumulator accum;
|
||||
/// return accum;
|
||||
///
|
||||
/// It can only be the leaf node of the epilogue tree
|
||||
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
int kElementsPerAccess_ ///< Number of elements computed per operation
|
||||
>
|
||||
class VisitorOpAccumulator{
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
static int const kElementsPerAccess = kElementsPerAccess_;
|
||||
|
||||
/// Fragment type for Accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type returned by this visitor
|
||||
using VisitAccessType = AccumulatorAccessType;
|
||||
|
||||
/// SMEM buffer class required in the epilogue visitor
|
||||
struct SharedStorage {
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() {}
|
||||
};
|
||||
|
||||
/// Host-constructable Arguments structure
|
||||
struct Arguments {
|
||||
// Note: it is strange that ctypes will return issue with empty arguments
|
||||
int tmp;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(int tmp): tmp(tmp) { }
|
||||
};
|
||||
|
||||
/// Parameter structure
|
||||
struct Params {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args) { }
|
||||
};
|
||||
|
||||
public:
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpAccumulator(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
return accum;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() { }
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,246 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with Binary op
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "binary_ops.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementCompute alpha;
|
||||
/// ElementCompute beta;
|
||||
/// ElementCompute C = BinaryOp(alpha * ElementCompute(Visitor_A), beta * ElementCompute(Visitor_B)
|
||||
/// Return C;
|
||||
///
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename ElementCompute_, ///< Data type used to compute linear combination
|
||||
int kElementsPerAccess_, ///< Number of elements computed per operation
|
||||
typename VisitorA_, ///< Child node A
|
||||
typename VisitorB_, ///< Child node B
|
||||
template<typename T, int N> typename BinaryOp_
|
||||
>
|
||||
class VisitorOpBinary{
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
static int const kElementsPerAccess = kElementsPerAccess_;
|
||||
|
||||
using VisitorA = VisitorA_;
|
||||
using VisitorB = VisitorB_;
|
||||
|
||||
/// Fragment type returned from VisitorA.visit
|
||||
using VisitAccessTypeA = typename VisitorA::VisitAccessType;
|
||||
using ElementA = typename VisitAccessTypeA::Element;
|
||||
|
||||
/// Fragment type returned from VisitorB.visit
|
||||
using VisitAccessTypeB = typename VisitorB::VisitAccessType;
|
||||
using ElementB = typename VisitAccessTypeB::Element;
|
||||
|
||||
/// Fragment type returned by this visitor
|
||||
using VisitAccessType = Array<ElementCompute, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Combination Op TODO: generalize this
|
||||
using BinaryOp = BinaryOp_<ElementCompute, kElementsPerAccess>;
|
||||
|
||||
static_assert(kElementsPerAccess==VisitAccessTypeA::kElements, "kElementsPerAccess mismatches with Visitor A");
|
||||
static_assert(kElementsPerAccess==VisitAccessTypeB::kElements, "kElementsPerAccess misnatches with Visitor B");
|
||||
|
||||
/// SMEM buffer class required in the epilogue visitor
|
||||
struct SharedStorage {
|
||||
typename VisitorA::SharedStorage storage_a;
|
||||
typename VisitorB::SharedStorage storage_b;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() {}
|
||||
};
|
||||
|
||||
|
||||
/// Host-constructable Arguments structure
|
||||
struct Arguments {
|
||||
typename BinaryOp::Arguments binary_arg;
|
||||
typename VisitorA::Arguments visitor_a_arg; ///< Argument type for visitor_a
|
||||
typename VisitorB::Arguments visitor_b_arg; ///< Argument type for visitor_b
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():binary_arg() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
typename BinaryOp::Arguments binary_arg,
|
||||
typename VisitorA::Arguments visitor_a_arg,
|
||||
typename VisitorB::Arguments visitor_b_arg
|
||||
):
|
||||
binary_arg(binary_arg),
|
||||
visitor_a_arg(visitor_a_arg),
|
||||
visitor_b_arg(visitor_b_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
/// Parameter structure
|
||||
struct Params {
|
||||
typename BinaryOp::Params binary_param;
|
||||
typename VisitorA::Params visitor_a_param; ///< Argument type for visitor_a
|
||||
typename VisitorB::Params visitor_b_param; ///< Argument type for visitor_b
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
binary_param(args.binary_arg),
|
||||
visitor_a_param(args.visitor_a_arg),
|
||||
visitor_b_param(args.visitor_b_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
BinaryOp binary_op;
|
||||
|
||||
VisitorA visitor_a_op;
|
||||
VisitorB visitor_b_op;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpBinary(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
binary_op(params.binary_param),
|
||||
visitor_a_op(params.visitor_a_param, shared_storage.storage_a, thread_idx, threadblock_offset, problem_size),
|
||||
visitor_b_op(params.visitor_b_param, shared_storage.storage_b, thread_idx, threadblock_offset, problem_size)
|
||||
{ }
|
||||
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
visitor_a_op.begin_epilogue();
|
||||
visitor_b_op.begin_epilogue();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
visitor_a_op.set_batch_index(batch_idx);
|
||||
visitor_b_op.set_batch_index(batch_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
visitor_a_op.begin_step(step_idx);
|
||||
visitor_b_op.begin_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
visitor_a_op.begin_row(row_idx);
|
||||
visitor_b_op.begin_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
/// Get result from visitor A and visitor B
|
||||
VisitAccessTypeA result_A = visitor_a_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
VisitAccessTypeB result_B = visitor_b_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
|
||||
/// Type conversion
|
||||
NumericArrayConverter<ElementCompute, ElementA, kElementsPerAccess> source_converter_A;
|
||||
NumericArrayConverter<ElementCompute, ElementB, kElementsPerAccess> source_converter_B;
|
||||
|
||||
return binary_op(
|
||||
source_converter_A(result_A),
|
||||
source_converter_B(result_B)
|
||||
);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
visitor_a_op.end_row(row_idx);
|
||||
visitor_b_op.end_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
visitor_a_op.end_step(step_idx);
|
||||
visitor_b_op.end_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
visitor_a_op.end_epilogue();
|
||||
visitor_b_op.end_epilogue();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,250 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with broadcasting vector to all columns
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementVector T[i][j] <- device-memory Td[i]
|
||||
///
|
||||
/// It can only be a leaf node in the epilogue tree
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename ElementFragment_, ///< Data type used to cache vector in register
|
||||
typename InputTileIterator_ ///< Tile iterator type to read the broadcasted tensor
|
||||
>
|
||||
class VisitorOpColumnBroadcast {
|
||||
public:
|
||||
using InputTileIterator = InputTileIterator_;
|
||||
|
||||
static int const kElementsPerAccess = InputTileIterator::kElementsPerAccess;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementVector = typename InputTileIterator::Element;
|
||||
using ElementFragment = ElementFragment_;
|
||||
|
||||
using VisitAccessType = Array<ElementFragment, kElementsPerAccess>;
|
||||
|
||||
/// Thread map used by input tile iterators
|
||||
using ThreadMap = typename InputTileIterator::ThreadMap;
|
||||
|
||||
/// Fragment object used to store the broadcast values
|
||||
using BroadcastFragment = Array<
|
||||
ElementFragment, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Used for the broadcast
|
||||
struct BroadcastDetail {
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = 32;
|
||||
|
||||
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar column indices handled by each thread
|
||||
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar row indices handled by each thread
|
||||
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
|
||||
|
||||
/// Number of threads per threadblock
|
||||
static int const kThreadCount = ThreadMap::kThreads;
|
||||
|
||||
/// Number of distinct threads per row of output tile
|
||||
static int const kThreadsPerRow = (InputTileIterator::Shape::kN / kColumnsPerThread);
|
||||
|
||||
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock.
|
||||
static int const kThreadRows = kThreadCount / kThreadsPerRow;
|
||||
|
||||
// /// Number of iterations (accesses) the threadblock takes to reduce a row
|
||||
// static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount);
|
||||
};
|
||||
|
||||
// using ComputeFragmentType = Array<ElementVector, BroadcastDetail::kElementsPerAccess>;
|
||||
|
||||
struct SharedStorage {
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() { }
|
||||
};
|
||||
|
||||
/// Host-constructable Argument structure
|
||||
struct Arguments {
|
||||
ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand
|
||||
int64_t batch_stride;
|
||||
|
||||
/// Methods
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():
|
||||
broadcast_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ElementVector *broadcast_ptr,
|
||||
int64_t batch_stride
|
||||
):
|
||||
broadcast_ptr(broadcast_ptr),
|
||||
batch_stride(batch_stride) { }
|
||||
};
|
||||
|
||||
/// Param structure
|
||||
struct Params {
|
||||
ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand
|
||||
int64_t batch_stride;
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
broadcast_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
broadcast_ptr(args.broadcast_ptr),
|
||||
batch_stride(args.batch_stride) { }
|
||||
};
|
||||
|
||||
private:
|
||||
ElementVector *broadcast_ptr;
|
||||
BroadcastFragment broadcast_fragment; ///< Array holds the loaded broadcast fragment
|
||||
MatrixCoord threadblock_offset_;
|
||||
int thread_idx_;
|
||||
MatrixCoord problem_size;
|
||||
|
||||
int thread_start_row_;
|
||||
int state_[3];
|
||||
int thread_offset_row_;
|
||||
|
||||
int64_t batch_stride_;
|
||||
|
||||
public:
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpColumnBroadcast(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
broadcast_ptr(params.broadcast_ptr),
|
||||
threadblock_offset_(threadblock_offset),
|
||||
thread_idx_(thread_idx),
|
||||
problem_size(problem_size),
|
||||
thread_start_row_(ThreadMap::initial_offset(thread_idx).row() + threadblock_offset.row()),
|
||||
batch_stride_(params.batch_stride)
|
||||
{
|
||||
state_[0] = state_[1] = state_[2] = 0;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
broadcast_ptr += batch_idx * batch_stride_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
// get pointer
|
||||
thread_offset_row_ = thread_start_row_ + ThreadMap::iteration_offset(frag_idx).row();
|
||||
|
||||
ElementFragment broadcast_data = ElementFragment(*(broadcast_ptr + thread_offset_row_));
|
||||
|
||||
broadcast_fragment.fill(broadcast_data);
|
||||
|
||||
return broadcast_fragment;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
// run operator ++
|
||||
++state_[0];
|
||||
|
||||
thread_start_row_ += ThreadMap::Shape::kRow;
|
||||
if (state_[0] == ThreadMap::Count::kRow) {
|
||||
state_[0] = 0;
|
||||
++state_[1];
|
||||
thread_start_row_ += (ThreadMap::Shape::kGroup - 1) *
|
||||
ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
|
||||
|
||||
if (state_[1] == ThreadMap::Count::kGroup) {
|
||||
state_[1] = 0;
|
||||
++state_[2];
|
||||
thread_start_row_ += ThreadMap::Count::kGroup *
|
||||
ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;
|
||||
|
||||
if (state_[2] == ThreadMap::Count::kCluster) {
|
||||
state_[2] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() { }
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,342 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with reduction over columns in CTA
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementReductionAccumulator R[j] = \sum_i ElementReductionAccumulator(T[i][j])
|
||||
/// device memory <- ElementReduction(R[j])
|
||||
///
|
||||
template <
|
||||
typename ThreadblockShape_, /// Threadblock shape
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename ElementReduction_, ///< Data type of the output reduction in device memory
|
||||
typename ElementReductionAccumulator_ , ///< Data type to accumulate reduction in smem and register
|
||||
typename OutputTileIterator_, ///< Tile Iterator type
|
||||
typename Visitor_ ///< preceeding visitor op
|
||||
>
|
||||
class VisitorOpColumnReduction {
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementReductionAccumulator = ElementReductionAccumulator_;
|
||||
using ElementReduction = ElementReduction_;
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
using Visitor = Visitor_;
|
||||
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
// TODO: generalize the reduction op
|
||||
using ReductionOp = cutlass::plus<Array<ElementReductionAccumulator, kElementsPerAccess>>;
|
||||
using ReductionOpScalar = cutlass::plus<ElementReductionAccumulator>;
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
|
||||
|
||||
|
||||
/// Fragment type returned from Visitor
|
||||
using VisitAccessTypeVisitor = typename Visitor::VisitAccessType;
|
||||
using ElementVisitor = typename VisitAccessTypeVisitor::Element;
|
||||
|
||||
using VisitAccessType = VisitAccessTypeVisitor;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of redcution
|
||||
using ReductionAccumulatorAccessType = Array<ElementReductionAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Thread map used by output tile iterators
|
||||
using ThreadMap = typename OutputTileIterator::ThreadMap;
|
||||
/// Used for the reduction
|
||||
struct ReductionDetail {
|
||||
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = 32;
|
||||
|
||||
/// Number of distinct scalar column indices handled by each thread
|
||||
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar row indices handled by each thread
|
||||
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
|
||||
|
||||
/// Number of threads per threadblock
|
||||
static int const kThreadCount = ThreadMap::kThreads;
|
||||
|
||||
/// Number of distinct threads per row of output tile
|
||||
static int const kThreadsPerRow = ThreadblockShape::kN / kColumnsPerThread;
|
||||
|
||||
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock
|
||||
static int const kThreadRows = kThreadCount / kThreadsPerRow;
|
||||
|
||||
/// Number of iterations (accesses) the threadblock takes to reduce a row
|
||||
static int const kThreadAccessesPerRow = const_max(1, (ThreadblockShape::kN + kThreadCount - 1) / kThreadCount);
|
||||
|
||||
using StorageShape = MatrixShape<
|
||||
kThreadRows,
|
||||
ThreadblockShape::kN
|
||||
>;
|
||||
};
|
||||
|
||||
using ReductionFragment = Array<ElementReductionAccumulator, ReductionDetail::kColumnsPerThread>;
|
||||
|
||||
/// Shared storage
|
||||
struct SharedStorage {
|
||||
typename Visitor::SharedStorage storage_visitor;
|
||||
AlignedArray<ElementReductionAccumulator, ReductionDetail::StorageShape::kCount, 16> reduction;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() {}
|
||||
};
|
||||
|
||||
/// Host-constructable Argument structure
|
||||
struct Arguments {
|
||||
ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory
|
||||
int64_t batch_stride;
|
||||
typename Visitor::Arguments visitor_arg; ///< Argument type of visitor
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): reduction_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ElementReduction *reduction_ptr,
|
||||
int64_t batch_stride,
|
||||
typename Visitor::Arguments visitor_arg
|
||||
):
|
||||
reduction_ptr(reduction_ptr),
|
||||
batch_stride(batch_stride),
|
||||
visitor_arg(visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
/// Param structure
|
||||
struct Params {
|
||||
ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory
|
||||
int64_t batch_stride;
|
||||
typename Visitor::Params visitor_param; ///< Argument type of visitor
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): reduction_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
reduction_ptr(args.reduction_ptr),
|
||||
batch_stride(args.batch_stride),
|
||||
visitor_param(args.visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
private:
|
||||
ElementReduction *reduction_output_ptr_; ///< Pointer to the reduction tensor in device memory
|
||||
ElementReductionAccumulator *reduction_smem_ptr_; ///< Pointer to the partial reductions in shared memory
|
||||
ReductionFragment reduction_fragment; ///< register fragments that hold the partial reduction
|
||||
Visitor visitor_; ///< visitor
|
||||
int thread_idx_;
|
||||
MatrixCoord threadblock_offset;
|
||||
MatrixCoord problem_size_;
|
||||
int64_t batch_stride_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpColumnReduction(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
visitor_(params.visitor_param, shared_storage.storage_visitor,
|
||||
thread_idx, threadblock_offset, problem_size),
|
||||
reduction_smem_ptr_(shared_storage.reduction.data()),
|
||||
reduction_output_ptr_(params.reduction_ptr),
|
||||
thread_idx_(thread_idx),
|
||||
threadblock_offset(threadblock_offset),
|
||||
problem_size_(problem_size),
|
||||
batch_stride_(params.batch_stride)
|
||||
{ }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
reduction_output_ptr_ += batch_idx * batch_stride_;
|
||||
visitor_.set_batch_index(batch_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
visitor_.begin_epilogue();
|
||||
|
||||
// clear the reduction fragment
|
||||
reduction_fragment.clear();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
visitor_.begin_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
visitor_.begin_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
/// Get result from visitor
|
||||
VisitAccessTypeVisitor result = visitor_.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
|
||||
NumericArrayConverter<ElementReductionAccumulator, ElementVisitor, kElementsPerAccess> reduction_converter;
|
||||
ReductionOp reduction_op;
|
||||
ReductionAccumulatorAccessType* reduction_fragment_ = reinterpret_cast<ReductionAccumulatorAccessType*>(&reduction_fragment);
|
||||
reduction_fragment_[column_idx] = reduction_op(reduction_fragment_[column_idx], reduction_converter(result));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
visitor_.end_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
visitor_.end_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
visitor_.end_epilogue();
|
||||
//
|
||||
// Store the partially reduced value to SMEM
|
||||
//
|
||||
|
||||
// Guard against uses of the existing SMEM tile
|
||||
__syncthreads();
|
||||
|
||||
using AccessType = AlignedArray<ElementReductionAccumulator, ThreadMap::kElementsPerAccess>;
|
||||
|
||||
//
|
||||
// Determine a compact thread arrangement to store to SMEM
|
||||
//
|
||||
|
||||
MatrixCoord thread_offset(
|
||||
thread_idx_ / ReductionDetail::kThreadsPerRow,
|
||||
(thread_idx_ % ReductionDetail::kThreadsPerRow) * ThreadMap::kElementsPerAccess
|
||||
);
|
||||
|
||||
//
|
||||
// Each thread store its fragment to a SMEM
|
||||
//
|
||||
AccessType *aligned_reduction_ptr = reinterpret_cast<AccessType *>(
|
||||
&reduction_smem_ptr_[thread_offset.row() * ThreadblockShape::kN + thread_offset.column()]
|
||||
);
|
||||
|
||||
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(
|
||||
&reduction_fragment
|
||||
);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
|
||||
int col_idx = column * ThreadMap::Delta::kColumn / ThreadMap::kElementsPerAccess;
|
||||
|
||||
aligned_reduction_ptr[col_idx] = frag_ptr[column];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
// Now, threads are assigned several columns of the output. The fetch over all rows from
|
||||
// the compacted SMEM tile and perform a reduction.
|
||||
//
|
||||
|
||||
NumericConverter<ElementReduction, ElementReductionAccumulator> output_converter;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < ReductionDetail::kThreadAccessesPerRow; ++j) {
|
||||
int column_idx = thread_idx_ + j * ReductionDetail::kThreadCount;
|
||||
|
||||
ReductionOpScalar reduction_op;
|
||||
ElementReductionAccumulator reduction_element = ElementReductionAccumulator();
|
||||
|
||||
int output_column_idx = threadblock_offset.column() + column_idx;
|
||||
|
||||
if (column_idx < ThreadblockShape::kN && output_column_idx < problem_size_.column()) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int row = 0; row < ReductionDetail::kThreadRows; ++row) {
|
||||
if (row) {
|
||||
auto frag = reduction_smem_ptr_[row * ThreadblockShape::kN + column_idx];
|
||||
reduction_element = reduction_op(reduction_element, frag);
|
||||
}
|
||||
else {
|
||||
|
||||
reduction_element = reduction_smem_ptr_[column_idx];
|
||||
}
|
||||
}
|
||||
|
||||
// Store
|
||||
reduction_output_ptr_[column_idx + threadblock_offset.column() + threadblock_offset.row() / ThreadblockShape::kM * problem_size_.column()] = output_converter(reduction_element);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,266 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with Linear Combination
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementCompute alpha;
|
||||
/// ElementCompute beta;
|
||||
/// ElementCompute C = BinaryOp(alpha * ElementCompute(Visitor_A), beta * ElementCompute(Visitor_B)
|
||||
/// Return C;
|
||||
///
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename ElementCompute_, ///< Data type used to compute linear combination
|
||||
int kElementsPerAccess_, ///< Number of elements computed per operation
|
||||
typename VisitorA_, ///< Child node A
|
||||
typename VisitorB_ ///< Child node B
|
||||
>
|
||||
class VisitorOpLinearCombination{
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
static int const kElementsPerAccess = kElementsPerAccess_;
|
||||
|
||||
using VisitorA = VisitorA_;
|
||||
using VisitorB = VisitorB_;
|
||||
|
||||
/// Fragment type returned from VisitorA.visit
|
||||
using VisitAccessTypeA = typename VisitorA::VisitAccessType;
|
||||
using ElementA = typename VisitAccessTypeA::Element;
|
||||
|
||||
/// Fragment type returned from VisitorB.visit
|
||||
using VisitAccessTypeB = typename VisitorB::VisitAccessType;
|
||||
using ElementB = typename VisitAccessTypeB::Element;
|
||||
|
||||
/// Fragment type returned by this visitor
|
||||
using VisitAccessType = Array<ElementCompute, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Combination Op TODO: generalize this
|
||||
using CombinationOp = cutlass::plus<VisitAccessType>;
|
||||
|
||||
static_assert(kElementsPerAccess==VisitAccessTypeA::kElements, "kElementsPerAccess mismatches with Visitor A");
|
||||
static_assert(kElementsPerAccess==VisitAccessTypeB::kElements, "kElementsPerAccess misnatches with Visitor B");
|
||||
|
||||
/// SMEM buffer class required in the epilogue visitor
|
||||
struct SharedStorage {
|
||||
typename VisitorA::SharedStorage storage_a;
|
||||
typename VisitorB::SharedStorage storage_b;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() {}
|
||||
};
|
||||
|
||||
|
||||
/// Host-constructable Arguments structure
|
||||
struct Arguments {
|
||||
ElementCompute alpha; ///< scales accumulators
|
||||
ElementCompute beta; ///< scales source tensor
|
||||
typename VisitorA::Arguments visitor_a_arg; ///< Argument type for visitor_a
|
||||
typename VisitorB::Arguments visitor_b_arg; ///< Argument type for visitor_b
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():
|
||||
alpha(ElementCompute(1)),
|
||||
beta(ElementCompute(0))
|
||||
{ }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta,
|
||||
typename VisitorA::Arguments visitor_a_arg,
|
||||
typename VisitorB::Arguments visitor_b_arg
|
||||
):
|
||||
alpha(alpha),
|
||||
beta(beta),
|
||||
visitor_a_arg(visitor_a_arg),
|
||||
visitor_b_arg(visitor_b_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
/// Parameter structure
|
||||
struct Params {
|
||||
ElementCompute alpha; ///< scales accumulators
|
||||
ElementCompute beta; ///< scales source tensor
|
||||
typename VisitorA::Params visitor_a_param; ///< Argument type for visitor_a
|
||||
typename VisitorB::Params visitor_b_param; ///< Argument type for visitor_b
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
alpha(args.alpha),
|
||||
beta(args.beta),
|
||||
visitor_a_param(args.visitor_a_arg),
|
||||
visitor_b_param(args.visitor_b_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
ElementCompute alpha_;
|
||||
ElementCompute beta_;
|
||||
|
||||
VisitorA visitor_a_op;
|
||||
VisitorB visitor_b_op;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpLinearCombination(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
alpha_(params.alpha),
|
||||
beta_(params.beta),
|
||||
visitor_a_op(params.visitor_a_param, shared_storage.storage_a, thread_idx, threadblock_offset, problem_size),
|
||||
visitor_b_op(params.visitor_b_param, shared_storage.storage_b, thread_idx, threadblock_offset, problem_size)
|
||||
{ }
|
||||
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
if (alpha_ != ElementCompute(0)) visitor_a_op.begin_epilogue();
|
||||
if (beta_ != ElementCompute(0)) visitor_b_op.begin_epilogue();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
if (alpha_ != ElementCompute(0)) visitor_a_op.begin_step(step_idx);
|
||||
if (beta_ != ElementCompute(0)) visitor_b_op.begin_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
if (alpha_ != ElementCompute(0)) visitor_a_op.begin_row(row_idx);
|
||||
if (beta_ != ElementCompute(0)) visitor_b_op.begin_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
/// Get result from visitor A and visitor B
|
||||
VisitAccessTypeA result_A;
|
||||
VisitAccessTypeB result_B;
|
||||
|
||||
if (alpha_ != ElementCompute(0)) {
|
||||
result_A = visitor_a_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
} else {
|
||||
// Fill the result A with zeros
|
||||
result_A.clear();
|
||||
}
|
||||
|
||||
if (beta_ != ElementCompute(0)) {
|
||||
result_B = visitor_b_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
} else {
|
||||
// Fill the result B with zeros
|
||||
result_B.clear();
|
||||
}
|
||||
|
||||
/// Type conversion
|
||||
NumericArrayConverter<ElementCompute, ElementA, kElementsPerAccess> source_converter_A;
|
||||
NumericArrayConverter<ElementCompute, ElementB, kElementsPerAccess> source_converter_B;
|
||||
|
||||
CombinationOp combination_op;
|
||||
|
||||
cutlass::multiplies<VisitAccessType> multiply_op;
|
||||
|
||||
return combination_op(
|
||||
multiply_op(alpha_, source_converter_A(result_A)),
|
||||
multiply_op(beta_, source_converter_B(result_B))
|
||||
);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
if (alpha_ != ElementCompute(0)) visitor_a_op.end_row(row_idx);
|
||||
if (beta_ != ElementCompute(0)) visitor_b_op.end_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
if (alpha_ != ElementCompute(0)) visitor_a_op.end_step(step_idx);
|
||||
if (beta_ != ElementCompute(0)) visitor_b_op.end_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
if (alpha_ != ElementCompute(0)) visitor_a_op.end_epilogue();
|
||||
if (beta_ != ElementCompute(0)) visitor_b_op.end_epilogue();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,258 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with broadcasting vector to all rows
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementVector T[i][j] <- device-memory Td[j]
|
||||
///
|
||||
/// It can only be a leaf node in the epilogue tree
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename ElementFragment_, ///< Data type used to cache vector in register
|
||||
typename InputTileIterator_ ///< Tile iterator type to read the broadcasted tensor
|
||||
>
|
||||
class VisitorOpRowBroadcast {
|
||||
public:
|
||||
using InputTileIterator = InputTileIterator_;
|
||||
|
||||
static int const kElementsPerAccess = InputTileIterator::kElementsPerAccess;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementVector = typename InputTileIterator::Element;
|
||||
using ElementFragment = ElementFragment_;
|
||||
|
||||
using VisitAccessType = Array<ElementFragment, kElementsPerAccess>;
|
||||
|
||||
/// Thread map used by input tile iterators
|
||||
using ThreadMap = typename InputTileIterator::ThreadMap;
|
||||
|
||||
/// Fragment object used to store the broadcast values
|
||||
using BroadcastFragment = Array<
|
||||
ElementFragment,
|
||||
ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Used for the broadcast
|
||||
struct BroadcastDetail {
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = 32;
|
||||
|
||||
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar column indices handled by each thread
|
||||
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar row indices handled by each thread
|
||||
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
|
||||
|
||||
/// Number of threads per threadblock
|
||||
static int const kThreadCount = ThreadMap::kThreads;
|
||||
|
||||
/// Number of distinct threads per row of output tile
|
||||
static int const kThreadsPerRow = (InputTileIterator::Shape::kN / kColumnsPerThread);
|
||||
|
||||
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock.
|
||||
static int const kThreadRows = kThreadCount / kThreadsPerRow;
|
||||
|
||||
// /// Number of iterations (accesses) the threadblock takes to reduce a row
|
||||
// static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount);
|
||||
};
|
||||
|
||||
// using ComputeFragmentType = Array<ElementVector, BroadcastDetail::kElementsPerAccess>;
|
||||
|
||||
struct SharedStorage {
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() { }
|
||||
};
|
||||
|
||||
/// Host-constructable Argument structure
|
||||
struct Arguments {
|
||||
ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand
|
||||
int64_t batch_stride;
|
||||
|
||||
/// Methods
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():
|
||||
broadcast_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ElementVector *broadcast_ptr,
|
||||
int64_t batch_stride
|
||||
):
|
||||
broadcast_ptr(broadcast_ptr),
|
||||
batch_stride(batch_stride) { }
|
||||
};
|
||||
|
||||
/// Param structure
|
||||
struct Params {
|
||||
ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand
|
||||
int64_t batch_stride;
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
broadcast_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
broadcast_ptr(args.broadcast_ptr),
|
||||
batch_stride(args.batch_stride) { }
|
||||
};
|
||||
|
||||
private:
|
||||
ElementVector *broadcast_ptr;
|
||||
BroadcastFragment broadcast_fragment; ///< Array holds the loaded broadcast fragment
|
||||
MatrixCoord threadblock_offset_;
|
||||
int thread_idx_;
|
||||
MatrixCoord problem_size;
|
||||
int64_t batch_stride_;
|
||||
|
||||
public:
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpRowBroadcast(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
broadcast_ptr(params.broadcast_ptr + threadblock_offset.column()),
|
||||
threadblock_offset_(threadblock_offset),
|
||||
thread_idx_(thread_idx),
|
||||
problem_size(problem_size),
|
||||
batch_stride_(params.batch_stride) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
broadcast_ptr += batch_idx * batch_stride_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
// load broadcast fragment
|
||||
load_broadcast_fragment_();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
VisitAccessType* broadcast_fragment_ = reinterpret_cast<VisitAccessType*>(&broadcast_fragment);
|
||||
return broadcast_fragment_[column_idx];
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() { }
|
||||
|
||||
private:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void load_broadcast_fragment_() {
|
||||
|
||||
broadcast_fragment.clear();
|
||||
|
||||
// If no pointer is supplied, set with all zeros and avoid memory accesses
|
||||
if (!broadcast_ptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column();
|
||||
|
||||
int thread_column_idx = threadblock_offset_.column() + thread_initial_column;
|
||||
broadcast_ptr += thread_initial_column;
|
||||
|
||||
NumericArrayConverter<ElementFragment, ElementVector, BroadcastDetail::kElementsPerAccess> converter;
|
||||
using AccessType = AlignedArray<ElementVector, BroadcastDetail::kElementsPerAccess>;
|
||||
using AccessFragmentType = Array<ElementFragment, BroadcastDetail::kElementsPerAccess>;
|
||||
|
||||
AccessFragmentType *frag_ptr = reinterpret_cast<AccessFragmentType *>(&broadcast_fragment);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) {
|
||||
|
||||
AccessType loaded;
|
||||
|
||||
loaded.clear();
|
||||
|
||||
if (thread_column_idx < problem_size.column()) {
|
||||
loaded = *reinterpret_cast<AccessType const *>(broadcast_ptr);
|
||||
}
|
||||
|
||||
AccessFragmentType cvt = converter(loaded);
|
||||
frag_ptr[j] = cvt;
|
||||
|
||||
thread_column_idx += ThreadMap::Delta::kColumn;
|
||||
broadcast_ptr += ThreadMap::Delta::kColumn;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,320 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with reduction over rows in CTA
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "stdio.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementReductionAccumulator R[i] = \sum_i ElementReductionAccumulator(T[i][j])
|
||||
/// device memory <- ElementReduction(R[i])
|
||||
///
|
||||
template <
|
||||
typename ThreadblockShape_, /// Threadblock shape
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename ElementReduction_, ///< Data type of the output reduction in device memory
|
||||
typename ElementReductionAccumulator_ , ///< Data type to accumulate reduction in smem and register
|
||||
typename OutputTileIterator_, ///< Tile Iterator type
|
||||
typename Visitor_ ///< preceeding visitor op
|
||||
>
|
||||
class VisitorOpRowReduction {
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementReductionAccumulator = ElementReductionAccumulator_;
|
||||
using ElementReduction = ElementReduction_;
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
using Visitor = Visitor_;
|
||||
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
// TODO: generalize the reduction op
|
||||
using ReductionOp = cutlass::plus<Array<ElementReductionAccumulator, kElementsPerAccess>>;
|
||||
using ReductionOpScalar = cutlass::plus<ElementReductionAccumulator>;
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
|
||||
/// Fragment type returned from Visitor
|
||||
using VisitAccessTypeVisitor = typename Visitor::VisitAccessType;
|
||||
using ElementVisitor = typename VisitAccessTypeVisitor::Element;
|
||||
|
||||
using VisitAccessType = VisitAccessTypeVisitor;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of redcution
|
||||
using ReductionAccumulatorAccessType = Array<ElementReductionAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Thread map used by output tile iterators
|
||||
using ThreadMap = typename OutputTileIterator::ThreadMap;
|
||||
/// Used for the reduction
|
||||
struct ReductionDetail {
|
||||
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = 32;
|
||||
|
||||
/// Number of distinct scalar column indices handled by each thread
|
||||
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar row indices handled by each thread
|
||||
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
|
||||
|
||||
/// Number of threads per threadblock
|
||||
static int const kThreadCount = ThreadMap::kThreads;
|
||||
|
||||
/// Number of distinct threads per row of output tile
|
||||
static int const kThreadsPerRow = ThreadblockShape::kN / kColumnsPerThread;
|
||||
|
||||
/// Half number of threads per row used for cross-thread reduction
|
||||
static int const kHalfThreadsPerRow = (kThreadsPerRow >> 1);
|
||||
|
||||
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock
|
||||
static int const kThreadRows = kThreadCount / kThreadsPerRow;
|
||||
};
|
||||
|
||||
/// Shared storage
|
||||
struct SharedStorage {
|
||||
typename Visitor::SharedStorage storage_visitor;
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() { }
|
||||
};
|
||||
|
||||
/// Host-constructable Argument structure
|
||||
struct Arguments {
|
||||
ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory
|
||||
int64_t batch_stride;
|
||||
typename Visitor::Arguments visitor_arg; ///< Argument type of visitor
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): reduction_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ElementReduction *reduction_ptr,
|
||||
int64_t batch_stride,
|
||||
typename Visitor::Arguments visitor_arg
|
||||
):
|
||||
reduction_ptr(reduction_ptr),
|
||||
batch_stride(batch_stride),
|
||||
visitor_arg(visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
/// Param structure
|
||||
struct Params {
|
||||
ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory
|
||||
int64_t batch_stride;
|
||||
typename Visitor::Params visitor_param; ///< Argument type of visitor
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): reduction_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
reduction_ptr(args.reduction_ptr),
|
||||
batch_stride(args.batch_stride),
|
||||
visitor_param(args.visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
private:
|
||||
ElementReduction *reduction_output_ptr_; ///< Pointer to the reduction tensor in device memory
|
||||
ElementReductionAccumulator reduction_accum;
|
||||
Visitor visitor_; ///< visitor
|
||||
int thread_idx_;
|
||||
MatrixCoord threadblock_offset;
|
||||
MatrixCoord problem_size_;
|
||||
|
||||
int thread_start_row_; /// used to identify
|
||||
int state_[3]; /// used to track row iterator
|
||||
int thread_offset_row_;
|
||||
int64_t batch_stride_;
|
||||
public:
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpRowReduction(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
visitor_(params.visitor_param, shared_storage.storage_visitor,
|
||||
thread_idx, threadblock_offset, problem_size),
|
||||
reduction_output_ptr_(params.reduction_ptr),
|
||||
thread_idx_(thread_idx),
|
||||
threadblock_offset(threadblock_offset),
|
||||
problem_size_(problem_size),
|
||||
thread_start_row_(ThreadMap::initial_offset(thread_idx).row() + threadblock_offset.row()),
|
||||
batch_stride_(params.batch_stride)
|
||||
{
|
||||
state_[0] = state_[1] = state_[2] = 0;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
reduction_output_ptr_ += batch_idx * batch_stride_;
|
||||
visitor_.set_batch_index(batch_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
visitor_.begin_epilogue();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
visitor_.begin_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
visitor_.begin_row(row_idx);
|
||||
|
||||
reduction_accum = ElementReductionAccumulator(0);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
/// Get result from visitor
|
||||
VisitAccessTypeVisitor result = visitor_.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
|
||||
thread_offset_row_ = thread_start_row_ + ThreadMap::iteration_offset(frag_idx).row();
|
||||
|
||||
ReductionOpScalar reduction_op;
|
||||
|
||||
ElementReductionAccumulator reduction_accum_ = reduction(result);
|
||||
|
||||
// After performing the in-thread reduction, we then perform cross-thread / in-warp reduction
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = ReductionDetail::kHalfThreadsPerRow; i > 0; i >>= 1) {
|
||||
reduction_accum_ = reduction_op(reduction_accum_, __shfl_xor_sync(0xFFFFFFFF, reduction_accum_, i));
|
||||
}
|
||||
reduction_accum = reduction_op(reduction_accum, reduction_accum_);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
visitor_.end_row(row_idx);
|
||||
NumericConverter<ElementReduction, ElementReductionAccumulator> output_converter;
|
||||
|
||||
bool is_write_thread = (thread_offset_row_ < problem_size_.row() && (thread_idx_ % ReductionDetail::kThreadsPerRow) == 0);
|
||||
int row_offset = thread_offset_row_ + threadblock_offset.column() / ThreadblockShape::kN * problem_size_.row();
|
||||
|
||||
ElementReduction *curr_ptr_reduction = reduction_output_ptr_ + row_offset;
|
||||
|
||||
arch::global_store<ElementReduction, sizeof(ElementReduction)>(
|
||||
output_converter(reduction_accum),
|
||||
(void *)curr_ptr_reduction,
|
||||
is_write_thread);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
visitor_.end_step(step_idx);
|
||||
|
||||
// run operator ++
|
||||
++state_[0];
|
||||
|
||||
thread_start_row_ += ThreadMap::Shape::kRow;
|
||||
if (state_[0] == ThreadMap::Count::kRow) {
|
||||
state_[0] = 0;
|
||||
++state_[1];
|
||||
thread_start_row_ += (ThreadMap::Shape::kGroup - 1) *
|
||||
ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
|
||||
|
||||
if (state_[1] == ThreadMap::Count::kGroup) {
|
||||
state_[1] = 0;
|
||||
++state_[2];
|
||||
thread_start_row_ += ThreadMap::Count::kGroup *
|
||||
ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;
|
||||
|
||||
if (state_[2] == ThreadMap::Count::kCluster) {
|
||||
state_[2] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
visitor_.end_epilogue();
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementReductionAccumulator reduction(VisitAccessTypeVisitor const& result) {
|
||||
ElementReductionAccumulator sum_ = ElementReductionAccumulator(0);
|
||||
|
||||
ReductionOpScalar reduction_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < VisitAccessTypeVisitor::kElements; ++i) {
|
||||
sum_ = reduction_op(sum_, result[i]);
|
||||
}
|
||||
|
||||
return sum_;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,188 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with Tensor Output
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementInput C <- device memory
|
||||
///
|
||||
/// It can only be a leaf node in the epilogue tree
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename InputTileIterator_ ///< Tile iterator type to read the tensor
|
||||
>
|
||||
class VisitorOpTensorInput {
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using InputTileIterator = InputTileIterator_;
|
||||
|
||||
static int const kElementsPerAccess = InputTileIterator::kElementsPerAccess;
|
||||
using ElementInput = typename InputTileIterator::Element;
|
||||
|
||||
using VisitAccessType = Array<ElementInput, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
struct SharedStorage {
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() { }
|
||||
};
|
||||
|
||||
/// Host-constructable Argument structure
|
||||
struct Arguments {
|
||||
ElementInput *input_ptr; ///< Pointer to the input tensor in device memory
|
||||
int ldt; ///< Leading dimension of the input tensor operand
|
||||
int64_t batch_stride; ///< batch stride for batched GEMM
|
||||
|
||||
/// Methods
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): input_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ElementInput *input_ptr,
|
||||
int ldt, int64_t batch_stride
|
||||
):
|
||||
input_ptr(input_ptr),
|
||||
ldt(ldt),
|
||||
batch_stride(batch_stride)
|
||||
{ }
|
||||
};
|
||||
|
||||
/// Param structure
|
||||
struct Params {
|
||||
typename InputTileIterator::Params params_input;
|
||||
ElementInput *input_ptr;
|
||||
int64_t batch_stride;
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
input_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
params_input(args.ldt),
|
||||
input_ptr(args.input_ptr),
|
||||
batch_stride(args.batch_stride)
|
||||
{ }
|
||||
};
|
||||
|
||||
private:
|
||||
InputTileIterator iterator_T_;
|
||||
typename InputTileIterator::Fragment fragment_T_;
|
||||
MatrixCoord problem_size;
|
||||
int64_t batch_stride_;
|
||||
|
||||
public:
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpTensorInput(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
iterator_T_(
|
||||
InputTileIterator(
|
||||
params.params_input,
|
||||
params.input_ptr,
|
||||
problem_size,
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
)
|
||||
),
|
||||
problem_size(problem_size),
|
||||
batch_stride_(params.batch_stride) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
iterator_T_.add_pointer_offset(batch_idx * batch_stride_);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
fragment_T_.clear();
|
||||
iterator_T_.load(fragment_T_);
|
||||
++iterator_T_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
VisitAccessType source = reinterpret_cast<VisitAccessType *>(&fragment_T_)[frag_idx];
|
||||
return source;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() { }
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,240 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with Tensor Output
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "stdio.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementOutput T = ElementOutput(Visitor)
|
||||
/// T-> device memory
|
||||
///
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename OutputTileIterator_, ///< Tile iterator type to write the tensor
|
||||
typename Visitor_ ///< Child visitor that produces the output tensor
|
||||
>
|
||||
class VisitorOpTensorOutput {
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
|
||||
using Visitor = Visitor_;
|
||||
|
||||
/// Fragment type returned from Visitor
|
||||
using VisitAccessTypeVisitor = typename Visitor::VisitAccessType;
|
||||
using ElementVisitor = typename VisitAccessTypeVisitor::Element;
|
||||
|
||||
using VisitAccessType = VisitAccessTypeVisitor;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of output
|
||||
using OutputAccessType = Array<ElementOutput, kElementsPerAccess>;
|
||||
|
||||
static_assert(kElementsPerAccess==VisitAccessTypeVisitor::kElements, "kElementsPerAccess mismatches with Visitor");
|
||||
|
||||
struct SharedStorage {
|
||||
typename Visitor::SharedStorage storage_visitor;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() { }
|
||||
};
|
||||
|
||||
/// Host-constructable Argument structure
|
||||
struct Arguments {
|
||||
ElementOutput *output_ptr; ///< Pointer to the output tensor in device memory
|
||||
int ldt; ///< Leading dimension of the output tensor operand
|
||||
int64_t batch_stride; ///< batch stride
|
||||
typename Visitor::Arguments visitor_arg; ///< Argument type of visitor
|
||||
|
||||
/// Methods
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): output_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ElementOutput *output_ptr,
|
||||
int ldt,
|
||||
int64_t batch_stride,
|
||||
typename Visitor::Arguments visitor_arg
|
||||
):
|
||||
output_ptr(output_ptr),
|
||||
ldt(ldt),
|
||||
batch_stride(batch_stride),
|
||||
visitor_arg(visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
/// Param structure
|
||||
struct Params {
|
||||
typename OutputTileIterator::Params params_output;
|
||||
ElementOutput *output_ptr;
|
||||
int64_t batch_stride;
|
||||
typename Visitor::Params visitor_param;
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
output_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
params_output(args.ldt),
|
||||
output_ptr(args.output_ptr),
|
||||
batch_stride(args.batch_stride),
|
||||
visitor_param(args.visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
private:
|
||||
OutputTileIterator iterator_T_;
|
||||
typename OutputTileIterator::Fragment fragment_T_;
|
||||
MatrixCoord problem_size;
|
||||
Visitor visitor_;
|
||||
int64_t batch_stride_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpTensorOutput(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
visitor_(params.visitor_param, shared_storage.storage_visitor, thread_idx, threadblock_offset, problem_size),
|
||||
iterator_T_(
|
||||
OutputTileIterator(
|
||||
params.params_output,
|
||||
params.output_ptr,
|
||||
problem_size,
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
)
|
||||
),
|
||||
problem_size(problem_size),
|
||||
batch_stride_(params.batch_stride) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
iterator_T_.add_pointer_offset(batch_idx * batch_stride_);
|
||||
visitor_.set_batch_index(batch_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
visitor_.begin_epilogue();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
fragment_T_.clear();
|
||||
visitor_.begin_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
visitor_.begin_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
/// Get result from visitor
|
||||
VisitAccessTypeVisitor result = visitor_.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
|
||||
// Column guard
|
||||
MatrixCoord thread_offset_ = iterator_T_.thread_start() + OutputTileIterator::ThreadMap::iteration_offset(frag_idx);
|
||||
bool column_guard = (thread_offset_.column() < problem_size.column());
|
||||
|
||||
if (column_guard) {
|
||||
NumericArrayConverter<ElementOutput, ElementVisitor, kElementsPerAccess> output_converter;
|
||||
OutputAccessType &output = reinterpret_cast<OutputAccessType *>(&fragment_T_)[frag_idx];
|
||||
output = output_converter(result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
visitor_.end_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
visitor_.end_step(step_idx);
|
||||
iterator_T_.store(fragment_T_);
|
||||
++iterator_T_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
visitor_.end_epilogue();
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,226 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with Unary operation
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "unary_ops.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementCompute alpha;
|
||||
/// ElementCompute beta;
|
||||
/// ElementCompute C = UnaryOp(ElementCompute(Visitor))
|
||||
/// Return C;
|
||||
///
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename ElementCompute_, ///< Data type used to compute linear combination
|
||||
int kElementsPerAccess_, ///< Number of elements computed per operation
|
||||
typename Visitor_, ///< Child node
|
||||
template<typename T, int N> typename UnaryOp_
|
||||
>
|
||||
class VisitorOpUnary{
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
static int const kElementsPerAccess = kElementsPerAccess_;
|
||||
|
||||
using Visitor = Visitor_;
|
||||
|
||||
/// Fragment type returned from Visitor.visit
|
||||
using VisitAccessTypeVisitor = typename Visitor::VisitAccessType;
|
||||
using ElementVisit = typename VisitAccessTypeVisitor::Element;
|
||||
|
||||
/// Fragment type returned by this visitor
|
||||
using VisitAccessType = Array<ElementCompute, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Combination Op TODO: generalize this
|
||||
using UnaryOp = UnaryOp_<ElementCompute, kElementsPerAccess>;
|
||||
|
||||
static_assert(kElementsPerAccess==VisitAccessTypeVisitor::kElements, "kElementsPerAccess mismatches with Visitor");
|
||||
|
||||
/// SMEM buffer class required in the epilogue visitor
|
||||
struct SharedStorage {
|
||||
typename Visitor::SharedStorage storage_visitor;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() {}
|
||||
};
|
||||
|
||||
|
||||
/// Host-constructable Arguments structure
|
||||
struct Arguments {
|
||||
typename UnaryOp::Arguments unary_arg;
|
||||
typename Visitor::Arguments visitor_arg; ///< Argument type for visitor
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():unary_arg() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
typename UnaryOp::Arguments unary_arg,
|
||||
typename Visitor::Arguments visitor_arg
|
||||
):
|
||||
unary_arg(unary_arg),
|
||||
visitor_arg(visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
/// Parameter structure
|
||||
struct Params {
|
||||
typename UnaryOp::Params unary_param;
|
||||
typename Visitor::Params visitor_param; ///< Argument type for visitor
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():unary_param() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
unary_param(args.unary_arg),
|
||||
visitor_param(args.visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
UnaryOp unary_op;
|
||||
|
||||
Visitor visitor_op;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpUnary(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
unary_op(params.unary_param),
|
||||
visitor_op(params.visitor_param, shared_storage.storage_visitor, thread_idx, threadblock_offset, problem_size)
|
||||
{ }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
visitor_op.set_batch_index(batch_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
if (unary_op.guard()) visitor_op.begin_epilogue();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
if (unary_op.guard()) visitor_op.begin_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
if (unary_op.guard()) visitor_op.begin_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
/// Get result from visitor A and visitor B
|
||||
VisitAccessTypeVisitor result;
|
||||
|
||||
if (unary_op.guard()){
|
||||
result = visitor_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
} else {
|
||||
result.clear();
|
||||
}
|
||||
|
||||
/// Type conversion
|
||||
NumericArrayConverter<ElementCompute, ElementVisit, kElementsPerAccess> source_converter;
|
||||
|
||||
cutlass::multiplies<VisitAccessType> multiply_op;
|
||||
|
||||
return unary_op(source_converter(result));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
if (unary_op.guard()) visitor_op.end_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
if (unary_op.guard()) visitor_op.end_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
if (unary_op.guard()) visitor_op.end_epilogue();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,481 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief A file contains all functioning classes needed by GemmLayernorm.
|
||||
|
||||
GemmLayernorm example = GEMM0 with partial reduction fused in epilogue (EpilogueVisitorLayerNorm)
|
||||
+ lightweight full reduction kernel (ApplyFinalReduction)
|
||||
+ GEMM1 with elemenwise operations fused in mainloop (GemmLayernormMainloopFusion)
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_complex.h"
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename ThreadblockShape_,
|
||||
int ThreadCount,
|
||||
typename OutputTileIterator_,
|
||||
typename AccumulatorTile_,
|
||||
typename ElementAccumulator_,
|
||||
typename ElementVariance_,
|
||||
typename ElementMean_,
|
||||
typename ElementLayernormCompute_,
|
||||
typename ElementwiseFunctor_,
|
||||
bool IsShiftedVariance_ = false
|
||||
>
|
||||
class EpilogueVisitorLayerNorm {
|
||||
public:
|
||||
|
||||
using ElementVariance = ElementVariance_;
|
||||
using ElementMean = ElementMean_;
|
||||
using ElementLayernormCompute = ElementLayernormCompute_;
|
||||
|
||||
using AccumulatorTile = AccumulatorTile_;
|
||||
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
static int const kThreadCount = ThreadCount;
|
||||
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using ElementwiseFunctor = ElementwiseFunctor_;
|
||||
|
||||
static int const kIterations = OutputTileIterator::kIterations;
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
static int const kRowIterations = OutputTileIterator::ThreadMap::Iterations::kRow;
|
||||
|
||||
static int const kThreads = OutputTileIterator::ThreadMap::kThreads;
|
||||
|
||||
static bool const kIsShiftedVariance = IsShiftedVariance_;
|
||||
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
|
||||
static int const kDeltaRow = OutputTileIterator::ThreadMap::Delta::kRow;
|
||||
|
||||
/// Array type used in Shift-K Layernorm
|
||||
static int const kRowAccessCount = kIterations * kRowIterations;
|
||||
|
||||
using ConvertedShiftFragment = Array<ElementLayernormCompute, kRowAccessCount>;
|
||||
|
||||
// Conducts manual transpose externally (already supported) for column major
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
|
||||
using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
using LayernormFragment = Array<ElementLayernormCompute, kElementsPerAccess>;
|
||||
using OutputVector = Array<ElementOutput, kElementsPerAccess>;
|
||||
using TensorRefD = TensorRef<ElementOutput, LayoutOutput>;
|
||||
|
||||
static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth;
|
||||
static int const kThreadsInColumn = kThreads / kThreadsPerRow;
|
||||
static int const kHalfThreadsPerRow = (kThreadsPerRow >> 1);
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
typename ElementwiseFunctor::Params elementwise;
|
||||
ElementVariance *ptr_Variance;
|
||||
ElementMean *ptr_Mean;
|
||||
ElementOutput *ptr_Shifted_K;
|
||||
MatrixCoord extent;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
Arguments():
|
||||
ptr_Variance(nullptr),
|
||||
ptr_Mean(nullptr),
|
||||
ptr_Shifted_K(nullptr)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
Arguments(
|
||||
typename ElementwiseFunctor::Params elementwise_,
|
||||
ElementVariance *ptr_Variance,
|
||||
ElementMean *ptr_Mean_,
|
||||
ElementOutput *ptr_Shifted_K_ = nullptr,
|
||||
MatrixCoord extent = MatrixCoord(0, 0)
|
||||
):
|
||||
elementwise(elementwise_),
|
||||
ptr_Variance(ptr_Variance),
|
||||
ptr_Mean(ptr_Mean_),
|
||||
ptr_Shifted_K(ptr_Shifted_K_),
|
||||
extent(extent)
|
||||
{
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
struct Params {
|
||||
|
||||
typename ElementwiseFunctor::Params elementwise;
|
||||
ElementVariance *ptr_Variance;
|
||||
ElementMean *ptr_Mean;
|
||||
ElementOutput *ptr_Shifted_K;
|
||||
MatrixCoord extent;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
ptr_Variance(nullptr),
|
||||
ptr_Mean(nullptr)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
elementwise(args.elementwise),
|
||||
ptr_Variance(args.ptr_Variance),
|
||||
ptr_Mean(args.ptr_Mean),
|
||||
ptr_Shifted_K(args.ptr_Shifted_K),
|
||||
extent(args.extent)
|
||||
{
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared storage
|
||||
struct SharedStorage {
|
||||
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
Params const & params_;
|
||||
SharedStorage & shared_storage_;
|
||||
MatrixCoord extent_;
|
||||
ElementwiseFunctor elementwise_;
|
||||
|
||||
OutputTileIterator iterator_C_;
|
||||
OutputTileIterator iterator_D_;
|
||||
typename OutputTileIterator::Fragment fragment_C_;
|
||||
typename OutputTileIterator::Fragment fragment_D_;
|
||||
|
||||
ElementAccumulator alpha_;
|
||||
ElementAccumulator beta_;
|
||||
ConvertedShiftFragment shift_k_frag_;
|
||||
|
||||
ElementLayernormCompute accum_sum_square_;
|
||||
ElementLayernormCompute accum_sum_element_;
|
||||
int thread_idx_;
|
||||
|
||||
MatrixCoord thread_offset_;
|
||||
|
||||
gemm::GemmCoord threadblock_tile_offset_;
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
EpilogueVisitorLayerNorm(
|
||||
Params const ¶ms, ///< Parameters routed to the epilogue
|
||||
SharedStorage &shared_storage, ///< Shared storage needed by the functors here
|
||||
MatrixCoord threadblock_offset,
|
||||
gemm::GemmCoord threadblock_tile_offset,
|
||||
int thread_idx,
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
OutputTileIterator source_iterator ///< Threadblock tile coordinate in GEMMM
|
||||
):
|
||||
params_(params),
|
||||
shared_storage_(shared_storage),
|
||||
elementwise_(params.elementwise),
|
||||
extent_(params.extent),
|
||||
iterator_C_(source_iterator),
|
||||
iterator_D_(destination_iterator),
|
||||
threadblock_tile_offset_(threadblock_tile_offset),
|
||||
thread_idx_(thread_idx)
|
||||
{
|
||||
alpha_ = (params.elementwise.alpha_ptr ? *params.elementwise.alpha_ptr : params.elementwise.alpha);
|
||||
beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta);
|
||||
|
||||
if (beta_ == ElementAccumulator()) {
|
||||
iterator_C_.clear_mask();
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to indicate split-K behavior
|
||||
CUTLASS_DEVICE
|
||||
void set_k_partition(
|
||||
int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
|
||||
int split_k_slices) { ///< Total number of split-K slices
|
||||
|
||||
}
|
||||
|
||||
/// Called to set the batch index
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
|
||||
}
|
||||
|
||||
/// Called at the start of the epilogue just before iterating over accumulator slices
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
|
||||
// If shift-K feature is enabled, we load shift-k fragment
|
||||
// at the very beginning of an epilogue
|
||||
if (kIsShiftedVariance && params_.ptr_Shifted_K != nullptr) {
|
||||
shift_k_frag_.clear();
|
||||
int thread_offset_row_base = iterator_D_.thread_start_row();
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int iter_idx = 0; iter_idx < kIterations; ++iter_idx) {
|
||||
int step_offset = iter_idx * OutputTileIterator::Shape::kRow;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int rid = 0; rid < kRowIterations; ++rid) {
|
||||
int row_step_offset = rid * kDeltaRow;
|
||||
int row_offset = thread_offset_row_base + step_offset + row_step_offset;
|
||||
bool is_load = (row_offset < extent_.row());
|
||||
shift_k_frag_[iter_idx * kRowIterations + rid] = load_shift_k_(row_offset, is_load);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// Called at the start of one step before starting accumulator exchange
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
fragment_D_.clear();
|
||||
|
||||
if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
|
||||
fragment_C_.clear();
|
||||
iterator_C_.load(fragment_C_);
|
||||
++iterator_C_;
|
||||
}
|
||||
}
|
||||
|
||||
/// Called at the start of a row
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
/// set the accumulator to 0
|
||||
accum_sum_element_ = ElementLayernormCompute(0);
|
||||
accum_sum_square_ = ElementLayernormCompute(0);
|
||||
}
|
||||
|
||||
/// Called after accumulators have been exchanged for each accumulator vector
|
||||
CUTLASS_DEVICE
|
||||
void visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorFragment const &accum) {
|
||||
|
||||
using Mul = cutlass::multiplies<ElementLayernormCompute>;
|
||||
using Minus = cutlass::minus<ElementLayernormCompute>;
|
||||
using Exp = cutlass::fast_exp_op<ElementLayernormCompute>;
|
||||
|
||||
Minus minus;
|
||||
Mul mul;
|
||||
Exp exponential;
|
||||
|
||||
LayernormFragment result;
|
||||
|
||||
thread_offset_ =
|
||||
iterator_D_.thread_start() +
|
||||
OutputTileIterator::ThreadMap::iteration_offset(frag_idx);
|
||||
|
||||
NumericArrayConverter<ElementLayernormCompute, ElementOutput, kElementsPerAccess> source_converter;
|
||||
OutputVector &source_vector = reinterpret_cast<OutputVector *>(&fragment_C_)[frag_idx];
|
||||
|
||||
bool column_guard = (thread_offset_.column() < extent_.column());
|
||||
|
||||
if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
|
||||
result = source_converter(elementwise_(accum));
|
||||
}else{
|
||||
result = source_converter(elementwise_(accum, source_vector));
|
||||
}
|
||||
|
||||
|
||||
ElementLayernormCompute inv_scalar = cutlass::constants::one<ElementLayernormCompute>() / ElementLayernormCompute(extent_.column());
|
||||
|
||||
// Fragment is cleared for non-reachable columns so no need to check against column guard
|
||||
ElementLayernormCompute accum_sum_element_tmp = element_sum_accumulator_(result);
|
||||
|
||||
// Square sum is different. Non-reachable columns should've been computed for shift-k
|
||||
// Otherwise we will incorrectly have some extra k^2 added into square sum.
|
||||
ElementLayernormCompute accum_sum_square_tmp = ElementLayernormCompute(0);
|
||||
|
||||
if (column_guard) {
|
||||
accum_sum_square_tmp = (kIsShiftedVariance) ? \
|
||||
square_sum_accumulator_(result, shift_k_frag_[iter_idx * kRowIterations + row_idx]) : \
|
||||
square_sum_accumulator_(result);
|
||||
}
|
||||
|
||||
accum_sum_element_tmp *= inv_scalar;
|
||||
accum_sum_square_tmp *= inv_scalar;
|
||||
|
||||
// After performing the in-thread reduction, we then perform cross-thread / in-warp reduction
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = kHalfThreadsPerRow; i > 0; i >>= 1) {
|
||||
accum_sum_element_tmp += __shfl_xor_sync(0xFFFFFFFF, accum_sum_element_tmp, i);
|
||||
accum_sum_square_tmp += __shfl_xor_sync(0xFFFFFFFF, accum_sum_square_tmp, i);
|
||||
}
|
||||
accum_sum_element_ += accum_sum_element_tmp;
|
||||
accum_sum_square_ += accum_sum_square_tmp;
|
||||
|
||||
// Convert to the output
|
||||
NumericArrayConverter<ElementOutput, ElementLayernormCompute, kElementsPerAccess> output_converter;
|
||||
OutputVector &output = reinterpret_cast<OutputVector *>(&fragment_D_)[frag_idx];
|
||||
output = output_converter(result);
|
||||
}
|
||||
|
||||
/// Called at the start of a row
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
|
||||
using ConvertVarianceOutput = cutlass::NumericConverter<ElementVariance, ElementLayernormCompute>;
|
||||
using ConvertMeanOutput = cutlass::NumericConverter<ElementMean, ElementLayernormCompute>;
|
||||
|
||||
ConvertVarianceOutput convert_variance_output;
|
||||
ConvertMeanOutput convert_mean_output;
|
||||
|
||||
bool is_write_thread = (thread_offset_.row() < extent_.row() && (threadIdx.x % kThreadsPerRow) == 0);
|
||||
int row_offset = thread_offset_.row() + threadblock_tile_offset_.n() * extent_.row();
|
||||
|
||||
ElementVariance *curr_ptr_sum_square = params_.ptr_Variance + row_offset;
|
||||
ElementMean *curr_ptr_element_sum = params_.ptr_Mean + row_offset;
|
||||
|
||||
arch::global_store<ElementVariance, sizeof(ElementVariance)>(
|
||||
convert_variance_output(accum_sum_square_),
|
||||
(void *)curr_ptr_sum_square,
|
||||
is_write_thread);
|
||||
|
||||
arch::global_store<ElementMean, sizeof(ElementMean)>(
|
||||
convert_mean_output(accum_sum_element_),
|
||||
(void *)curr_ptr_element_sum,
|
||||
is_write_thread);
|
||||
}
|
||||
|
||||
/// Called after all accumulator elements have been visited
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
|
||||
iterator_D_.store(fragment_D_);
|
||||
++iterator_D_;
|
||||
}
|
||||
|
||||
/// Called after all steps have been completed
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementLayernormCompute load_shift_k_(int row_offset, bool is_load) {
|
||||
using ConvertShiftK = cutlass::NumericConverter<ElementLayernormCompute, ElementOutput>;
|
||||
ConvertShiftK convert_shift_k;
|
||||
ElementOutput shift_k_val;
|
||||
|
||||
// Computes the address to load shift_k element
|
||||
ElementOutput *curr_ptr_shift_k = params_.ptr_Shifted_K + row_offset;
|
||||
// Conditionally loads from global memory
|
||||
arch::global_load<ElementOutput, sizeof(ElementOutput)>(shift_k_val, (void *)curr_ptr_shift_k, is_load);
|
||||
// Converts data type to return
|
||||
ElementLayernormCompute converted_shift_k_val = convert_shift_k(shift_k_val);
|
||||
|
||||
return converted_shift_k_val;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum) {
|
||||
ElementLayernormCompute sum_ = ElementLayernormCompute(0);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < LayernormFragment::kElements; ++i) {
|
||||
auto accum_ = accum[i];
|
||||
sum_ += accum_ * accum_;
|
||||
}
|
||||
|
||||
return sum_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum, ElementLayernormCompute shift_k_val) {
|
||||
ElementLayernormCompute sum_ = ElementLayernormCompute(0);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < LayernormFragment::kElements; ++i) {
|
||||
auto accum_ = accum[i] - shift_k_val;
|
||||
sum_ += accum_ * accum_;
|
||||
}
|
||||
|
||||
return sum_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementLayernormCompute element_sum_accumulator_(LayernormFragment const &accum) {
|
||||
ElementLayernormCompute sum_ = ElementLayernormCompute(0);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < LayernormFragment::kElements; ++i) {
|
||||
sum_ += accum[i];
|
||||
}
|
||||
|
||||
return sum_;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,692 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
|
||||
>
|
||||
struct GemmUniversalwithEpilogueVisitor {
|
||||
public:
|
||||
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueVisitor = typename Epilogue::Visitor;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename EpilogueVisitor::ElementOutput;
|
||||
using LayoutC = typename EpilogueVisitor::OutputTileIterator::Layout;
|
||||
|
||||
static ComplexTransform const kTransformA = Mma::kTransformA;
|
||||
static ComplexTransform const kTransformB = Mma::kTransformB;
|
||||
using Operator = typename Mma::Operator;
|
||||
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
/// Split-K preserves splits that are 128b aligned
|
||||
static int const kSplitKAlignment = const_max(
|
||||
128 / sizeof_bits<ElementA>::value,
|
||||
128 / sizeof_bits<ElementB>::value
|
||||
);
|
||||
|
||||
//
|
||||
// Structures
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmUniversalMode mode;
|
||||
GemmCoord problem_size;
|
||||
int batch_count;
|
||||
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor;
|
||||
|
||||
void const * ptr_A;
|
||||
void const * ptr_B;
|
||||
void const * ptr_C;
|
||||
void * ptr_D;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
typename LayoutA::Stride stride_a;
|
||||
typename LayoutB::Stride stride_b;
|
||||
typename LayoutC::Stride stride_c;
|
||||
typename LayoutC::Stride stride_d;
|
||||
|
||||
typename LayoutA::Stride::LongIndex lda;
|
||||
typename LayoutB::Stride::LongIndex ldb;
|
||||
typename LayoutC::Stride::LongIndex ldc;
|
||||
typename LayoutC::Stride::LongIndex ldd;
|
||||
|
||||
int const * ptr_gather_A_indices;
|
||||
int const * ptr_gather_B_indices;
|
||||
int const * ptr_scatter_D_indices;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Arguments():
|
||||
mode(GemmUniversalMode::kGemm),
|
||||
batch_count(1),
|
||||
ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr),
|
||||
ptr_gather_A_indices(nullptr),
|
||||
ptr_gather_B_indices(nullptr),
|
||||
ptr_scatter_D_indices(nullptr) {}
|
||||
|
||||
/// constructs an arguments structure
|
||||
Arguments(
|
||||
GemmUniversalMode mode,
|
||||
GemmCoord problem_size,
|
||||
int batch_count,
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor,
|
||||
void const * ptr_A,
|
||||
void const * ptr_B,
|
||||
void const * ptr_C,
|
||||
void * ptr_D,
|
||||
int64_t batch_stride_A,
|
||||
int64_t batch_stride_B,
|
||||
int64_t batch_stride_C,
|
||||
int64_t batch_stride_D,
|
||||
typename LayoutA::Stride stride_a,
|
||||
typename LayoutB::Stride stride_b,
|
||||
typename LayoutC::Stride stride_c,
|
||||
typename LayoutC::Stride stride_d,
|
||||
int const *ptr_gather_A_indices = nullptr,
|
||||
int const *ptr_gather_B_indices = nullptr,
|
||||
int const *ptr_scatter_D_indices = nullptr
|
||||
):
|
||||
mode(mode),
|
||||
problem_size(problem_size),
|
||||
batch_count(batch_count),
|
||||
epilogue_visitor(epilogue_visitor),
|
||||
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
|
||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
|
||||
stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d),
|
||||
ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),
|
||||
ptr_scatter_D_indices(ptr_scatter_D_indices) {
|
||||
lda = 0;
|
||||
ldb = 0;
|
||||
ldc = 0;
|
||||
ldd = 0;
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
|
||||
}
|
||||
|
||||
/// constructs an arguments structure
|
||||
Arguments(
|
||||
GemmUniversalMode mode,
|
||||
GemmCoord problem_size,
|
||||
int batch_count,
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor,
|
||||
void const * ptr_A,
|
||||
void const * ptr_B,
|
||||
void const * ptr_C,
|
||||
void * ptr_D,
|
||||
int64_t batch_stride_A,
|
||||
int64_t batch_stride_B,
|
||||
int64_t batch_stride_C,
|
||||
int64_t batch_stride_D,
|
||||
typename LayoutA::Stride::LongIndex lda,
|
||||
typename LayoutB::Stride::LongIndex ldb,
|
||||
typename LayoutC::Stride::LongIndex ldc,
|
||||
typename LayoutC::Stride::LongIndex ldd,
|
||||
int const *ptr_gather_A_indices = nullptr,
|
||||
int const *ptr_gather_B_indices = nullptr,
|
||||
int const *ptr_scatter_D_indices = nullptr
|
||||
):
|
||||
mode(mode),
|
||||
problem_size(problem_size),
|
||||
batch_count(batch_count),
|
||||
epilogue_visitor(epilogue_visitor),
|
||||
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
|
||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
|
||||
lda(lda), ldb(ldb), ldc(ldc), ldd(ldd),
|
||||
ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),
|
||||
ptr_scatter_D_indices(ptr_scatter_D_indices) {
|
||||
stride_a = make_Coord(lda);
|
||||
stride_b = make_Coord(ldb);
|
||||
stride_c = make_Coord(ldc);
|
||||
stride_d = make_Coord(ldd);
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
|
||||
}
|
||||
|
||||
/// Returns arguments for the transposed problem
|
||||
Arguments transposed_problem() const {
|
||||
Arguments args(*this);
|
||||
|
||||
std::swap(args.problem_size.m(), args.problem_size.n());
|
||||
std::swap(args.ptr_A, args.ptr_B);
|
||||
std::swap(args.lda, args.ldb);
|
||||
std::swap(args.stride_a, args.stride_b);
|
||||
std::swap(args.batch_stride_A, args.batch_stride_B);
|
||||
std::swap(args.ptr_gather_A_indices, args.ptr_gather_B_indices);
|
||||
|
||||
return args;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
|
||||
typename Mma::IteratorA::Params params_A;
|
||||
typename Mma::IteratorB::Params params_B;
|
||||
typename EpilogueVisitor::OutputTileIterator::Params params_C;
|
||||
typename EpilogueVisitor::OutputTileIterator::Params params_D;
|
||||
|
||||
typename EpilogueVisitor::Params epilogue_visitor;
|
||||
|
||||
GemmUniversalMode mode;
|
||||
int batch_count;
|
||||
int gemm_k_size;
|
||||
|
||||
void * ptr_A;
|
||||
void * ptr_B;
|
||||
void * ptr_C;
|
||||
void * ptr_D;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
int * ptr_gather_A_indices;
|
||||
int * ptr_gather_B_indices;
|
||||
int * ptr_scatter_D_indices;
|
||||
|
||||
int *semaphore;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
swizzle_log_tile(0),
|
||||
params_A(0),
|
||||
params_B(0),
|
||||
params_C(0),
|
||||
params_D(0),
|
||||
batch_count(0),
|
||||
gemm_k_size(0),
|
||||
mode(cutlass::gemm::GemmUniversalMode::kGemm),
|
||||
ptr_A(nullptr),
|
||||
ptr_B(nullptr),
|
||||
ptr_C(nullptr),
|
||||
ptr_D(nullptr),
|
||||
batch_stride_A(0),
|
||||
batch_stride_B(0),
|
||||
batch_stride_C(0),
|
||||
batch_stride_D(0),
|
||||
ptr_gather_A_indices(nullptr),
|
||||
ptr_gather_B_indices(nullptr),
|
||||
ptr_scatter_D_indices(nullptr),
|
||||
semaphore(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Arguments const &args,
|
||||
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
||||
int gemm_k_size,
|
||||
void *workspace = nullptr
|
||||
):
|
||||
problem_size(args.problem_size),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
|
||||
params_A(args.lda ? make_Coord_with_padding<LayoutA::kStrideRank>(args.lda) : args.stride_a),
|
||||
params_B(args.ldb ? make_Coord_with_padding<LayoutB::kStrideRank>(args.ldb) : args.stride_b),
|
||||
params_C(args.ldc ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldc) : args.stride_c),
|
||||
params_D(args.ldd ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldd) : args.stride_d),
|
||||
epilogue_visitor(args.epilogue_visitor),
|
||||
mode(args.mode),
|
||||
batch_count(args.batch_count),
|
||||
gemm_k_size(gemm_k_size),
|
||||
ptr_A(const_cast<void *>(args.ptr_A)),
|
||||
ptr_B(const_cast<void *>(args.ptr_B)),
|
||||
ptr_C(const_cast<void *>(args.ptr_C)),
|
||||
ptr_D(args.ptr_D),
|
||||
batch_stride_A(args.batch_stride_A),
|
||||
batch_stride_B(args.batch_stride_B),
|
||||
batch_stride_C(args.batch_stride_C),
|
||||
batch_stride_D(args.batch_stride_D),
|
||||
ptr_gather_A_indices(const_cast<int *>(args.ptr_gather_A_indices)),
|
||||
ptr_gather_B_indices(const_cast<int *>(args.ptr_gather_B_indices)),
|
||||
ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices)),
|
||||
semaphore(static_cast<int *>(workspace)) {
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void update(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr) {
|
||||
|
||||
ptr_A = const_cast<void *>(args.ptr_A);
|
||||
ptr_B = const_cast<void *>(args.ptr_B);
|
||||
ptr_C = const_cast<void *>(args.ptr_C);
|
||||
ptr_D = args.ptr_D;
|
||||
|
||||
ptr_gather_A_indices = const_cast<int *>(args.ptr_gather_A_indices);
|
||||
ptr_gather_B_indices = const_cast<int *>(args.ptr_gather_B_indices);
|
||||
ptr_scatter_D_indices = const_cast<int *>(args.ptr_scatter_D_indices);
|
||||
|
||||
batch_stride_A = args.batch_stride_A;
|
||||
batch_stride_B = args.batch_stride_B;
|
||||
batch_stride_C = args.batch_stride_C;
|
||||
batch_stride_D = args.batch_stride_D;
|
||||
|
||||
epilogue_visitor = args.epilogue_visitor;
|
||||
|
||||
semaphore = static_cast<int *>(workspace);
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::Params::update()");
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage {
|
||||
typename Mma::SharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
typename EpilogueVisitor::SharedStorage visitor;
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GemmUniversalwithEpilogueVisitor() { }
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(
|
||||
cutlass::gemm::GemmCoord const & problem_size) {
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalwithEpilogueVisitor::can_implement()");
|
||||
|
||||
static int const kAlignmentA = (platform::is_same<LayoutA,
|
||||
layout::ColumnMajorInterleaved<32>>::value)
|
||||
? 32
|
||||
: (platform::is_same<LayoutA,
|
||||
layout::ColumnMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = (platform::is_same<LayoutB,
|
||||
layout::RowMajorInterleaved<32>>::value)
|
||||
? 32
|
||||
: (platform::is_same<LayoutB,
|
||||
layout::RowMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = (platform::is_same<LayoutC,
|
||||
layout::ColumnMajorInterleaved<32>>::value)
|
||||
? 32
|
||||
: (platform::is_same<LayoutC,
|
||||
layout::ColumnMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
bool isAMisaligned = false;
|
||||
bool isBMisaligned = false;
|
||||
bool isCMisaligned = false;
|
||||
|
||||
if (platform::is_same<LayoutA, layout::RowMajor>::value) {
|
||||
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||
} else if (platform::is_same<LayoutA, layout::ColumnMajor>::value) {
|
||||
isAMisaligned = problem_size.m() % kAlignmentA;
|
||||
} else if (platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {
|
||||
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||
}
|
||||
|
||||
if (platform::is_same<LayoutB, layout::RowMajor>::value) {
|
||||
isBMisaligned = problem_size.n() % kAlignmentB;
|
||||
} else if (platform::is_same<LayoutB, layout::ColumnMajor>::value) {
|
||||
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||
} else if (platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {
|
||||
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||
}
|
||||
|
||||
if (platform::is_same<LayoutC, layout::RowMajor>::value) {
|
||||
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||
} else if (platform::is_same<LayoutC, layout::ColumnMajor>::value) {
|
||||
isCMisaligned = problem_size.m() % kAlignmentC;
|
||||
} else if (platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {
|
||||
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||
}
|
||||
|
||||
if (isAMisaligned) {
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (isBMisaligned) {
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (isCMisaligned) {
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" returning kSuccess");
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const &args) {
|
||||
return can_implement(args.problem_size);
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(Arguments const &args,
|
||||
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
||||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
int offset_k = 0;
|
||||
int problem_size_k = params.problem_size.k();
|
||||
|
||||
ElementA *ptr_A = static_cast<ElementA *>(params.ptr_A);
|
||||
ElementB *ptr_B = static_cast<ElementB *>(params.ptr_B);
|
||||
|
||||
//
|
||||
// Fetch pointers based on mode.
|
||||
//
|
||||
if (params.mode == GemmUniversalMode::kGemm ||
|
||||
params.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
|
||||
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) {
|
||||
|
||||
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
|
||||
}
|
||||
|
||||
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kBatched) {
|
||||
ptr_A += threadblock_tile_offset.k() * params.batch_stride_A;
|
||||
ptr_B += threadblock_tile_offset.k() * params.batch_stride_B;
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kArray) {
|
||||
ptr_A = static_cast<ElementA * const *>(params.ptr_A)[threadblock_tile_offset.k()];
|
||||
ptr_B = static_cast<ElementB * const *>(params.ptr_B)[threadblock_tile_offset.k()];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
offset_k,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{
|
||||
offset_k,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN
|
||||
};
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
params.params_A,
|
||||
ptr_A,
|
||||
{params.problem_size.m(), problem_size_k},
|
||||
thread_idx,
|
||||
tb_offset_A,
|
||||
params.ptr_gather_A_indices);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
params.params_B,
|
||||
ptr_B,
|
||||
{problem_size_k, params.problem_size.n()},
|
||||
thread_idx,
|
||||
tb_offset_B,
|
||||
params.ptr_gather_B_indices);
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(
|
||||
gemm_k_iterations,
|
||||
accumulators,
|
||||
iterator_A,
|
||||
iterator_B,
|
||||
accumulators);
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
// EpilogueOutputOp output_op(params.output_op);
|
||||
|
||||
//
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
//assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN
|
||||
);
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
ElementC *ptr_C = static_cast<ElementC *>(params.ptr_C);
|
||||
ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);
|
||||
|
||||
//
|
||||
// Fetch pointers based on mode.
|
||||
//
|
||||
|
||||
// Construct the semaphore.
|
||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||
|
||||
// if (params.mode == GemmUniversalMode::kGemm) {
|
||||
|
||||
// // TODO: fix this order
|
||||
// // If performing a reduction via split-K, fetch the initial synchronization
|
||||
// if (params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// // Fetch the synchronization lock initially but do not block.
|
||||
// semaphore.fetch();
|
||||
|
||||
// // Indicate which position in a serial reduction the output operator is currently updating
|
||||
// output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
// }
|
||||
// }
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
|
||||
EpilogueVisitor epilogue_visitor(
|
||||
params.epilogue_visitor,
|
||||
shared_storage.visitor,
|
||||
threadblock_offset,
|
||||
threadblock_tile_offset,
|
||||
thread_idx,
|
||||
params.problem_size.mn()
|
||||
);
|
||||
|
||||
// if (params.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
// ptr_D += threadblock_tile_offset.k() * params.batch_stride_D;
|
||||
// }
|
||||
if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) {
|
||||
epilogue_visitor.set_batch_index(threadblock_tile_offset.k());
|
||||
}
|
||||
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
// Wait on the semaphore - this latency may have been covered by iterator construction
|
||||
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
||||
// TODO: ???
|
||||
// if (threadblock_tile_offset.k()) {
|
||||
// iterator_C = iterator_D;
|
||||
// }
|
||||
|
||||
semaphore.wait(threadblock_tile_offset.k());
|
||||
}
|
||||
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(epilogue_visitor, accumulators);
|
||||
|
||||
//
|
||||
// Release the semaphore
|
||||
//
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
int lock = 0;
|
||||
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
|
||||
|
||||
// The final threadblock resets the semaphore for subsequent grids.
|
||||
lock = 0;
|
||||
}
|
||||
else {
|
||||
// Otherwise, the semaphore is incremented
|
||||
lock = threadblock_tile_offset.k() + 1;
|
||||
}
|
||||
|
||||
semaphore.release(lock);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -50,7 +50,13 @@ void bind_tensor_coord(py::module &m) {
|
||||
R"pbdoc(Defines a canonical 4D coordinate used by tensor operations)pbdoc")
|
||||
.def(py::init<int, int, int, int>(),
|
||||
py::arg("n"), py::arg("h"), py::arg("w"), py::arg("c"),
|
||||
R"pbdoc(Helper to construct from N, H, W, and C)pbdoc");
|
||||
R"pbdoc(Helper to construct from N, H, W, and C)pbdoc")
|
||||
.def("at", py::overload_cast<int>(&cutlass::Tensor4DCoord::at),
|
||||
py::arg("dim"),
|
||||
R"pbdoc(Gets the index of a given Coord element)pbdoc")
|
||||
.def("size", [](const cutlass::Tensor4DCoord & coord) {
|
||||
return coord.at(0) * coord.at(1) * coord.at(2) * coord.at(3);},
|
||||
R"pbdoc(The size of the tensor coord)pbdoc");
|
||||
|
||||
py::class_<cutlass::Coord<3>>(m, "Tensor3DCoord",
|
||||
R"pbdoc(Defines a canonical 3D coordinate used by tensor operations)pbdoc")
|
||||
|
||||
@ -1,7 +1,24 @@
|
||||
from pycutlass.type import *
|
||||
import re
|
||||
|
||||
|
||||
def SubstituteTemplate(template, values):
|
||||
text = template
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
for key, value in values.items():
|
||||
regex = "\\$\\{%s\\}" % key
|
||||
newtext = re.sub(regex, value, text)
|
||||
if newtext != text:
|
||||
changed = True
|
||||
text = newtext
|
||||
return text
|
||||
|
||||
from pycutlass.type_hint import *
|
||||
from pycutlass.tensor_ref import *
|
||||
from pycutlass.operation import *
|
||||
from pycutlass.epilogue import *
|
||||
from pycutlass.parser import *
|
||||
from pycutlass.compiler import ArtifactManager
|
||||
from pycutlass.memory_manager import *
|
||||
from pycutlass.arguments import *
|
||||
|
||||
@ -60,6 +60,13 @@ class ArgumentBase:
|
||||
C: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]',
|
||||
D: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]',
|
||||
**kwargs) -> None:
|
||||
|
||||
# tensor_C can be interpreted as the bias with bias=True in keyword args
|
||||
if "bias" in kwargs.keys():
|
||||
self.bias = kwargs["bias"]
|
||||
else:
|
||||
# by default, tensor_C is not bias
|
||||
self.bias = False
|
||||
|
||||
# preprocessing input tensors
|
||||
if isinstance(A, np.ndarray):
|
||||
@ -72,21 +79,28 @@ class ArgumentBase:
|
||||
self.ptr_B = self.buffer_B.ptr
|
||||
self.ptr_C = self.buffer_C.ptr
|
||||
self.ptr_D = self.buffer_D.ptr
|
||||
# number of elements in C
|
||||
self.tensor_c_numel = C.size
|
||||
elif torch_available and isinstance(A, torch.Tensor):
|
||||
self.ptr_A = TorchFrontend.argument(A)
|
||||
self.ptr_B = TorchFrontend.argument(B)
|
||||
self.ptr_C = TorchFrontend.argument(C)
|
||||
self.ptr_D = TorchFrontend.argument(D)
|
||||
# number of elements in C
|
||||
self.tensor_c_numel = C.numel()
|
||||
elif isinstance(A, cuda.CUdeviceptr):
|
||||
self.ptr_A = A
|
||||
self.ptr_B = B
|
||||
self.ptr_C = C
|
||||
self.ptr_D = D
|
||||
|
||||
elif cupy_available and isinstance(A, cp.ndarray):
|
||||
self.ptr_A = CupyFrontend.argument(A)
|
||||
self.ptr_B = CupyFrontend.argument(B)
|
||||
self.ptr_C = CupyFrontend.argument(C)
|
||||
self.ptr_D = CupyFrontend.argument(D)
|
||||
# number of elements in C
|
||||
self.tensor_c_numel = C.size
|
||||
else:
|
||||
raise TypeError(
|
||||
"Unsupported Frontend. Only support numpy and torch")
|
||||
|
||||
@ -63,22 +63,9 @@ dtype2ctype = {
|
||||
}
|
||||
|
||||
|
||||
def get_epilogue_output_op(element_compute_):
|
||||
element_compute = dtype2ctype[element_compute_]
|
||||
def get_gemm_arguments(epilogue_functor):
|
||||
|
||||
class _EpilogueOutputOpParams(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("alpha", element_compute),
|
||||
("beta", element_compute),
|
||||
("alpha_ptr", ctypes.c_void_p),
|
||||
("beta_ptr", ctypes.c_void_p)
|
||||
]
|
||||
return _EpilogueOutputOpParams
|
||||
|
||||
|
||||
def get_gemm_arguments(element_compute_):
|
||||
|
||||
_EpilogueOutputOpParams = get_epilogue_output_op(element_compute_)
|
||||
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
|
||||
|
||||
class _GemmArguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
@ -116,8 +103,8 @@ def get_gemm_arguments(element_compute_):
|
||||
|
||||
# include/cutlass/gemm/kernel/gemm_grouped.h
|
||||
|
||||
def get_gemm_grouped_arguments(element_compute_):
|
||||
_EpilogueOutputOpParams = get_epilogue_output_op(element_compute_)
|
||||
def get_gemm_grouped_arguments(epilogue_functor):
|
||||
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
|
||||
|
||||
class _GEMMGroupedArguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
@ -214,8 +201,8 @@ class TensorRef2D_(ctypes.Structure):
|
||||
# include/cutlass/conv/kernel/implicit_gemm_convolution.h
|
||||
# split_k_mode: kNone: 0, kSerial: 1, kParallel: 2, kParallelSerial: 3, kInvalid: 4
|
||||
|
||||
def get_conv2d_arguments(element_compute_):
|
||||
_EpilogueOutputOpParams = get_epilogue_output_op(element_compute_)
|
||||
def get_conv2d_arguments(epilogue_functor):
|
||||
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
|
||||
|
||||
class _Conv2dArguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
@ -236,8 +223,8 @@ def get_conv2d_arguments(element_compute_):
|
||||
############################################################################################
|
||||
|
||||
|
||||
def get_reduction_params(element_compute_):
|
||||
_EpilogueOutputParams = get_epilogue_output_op(element_compute_)
|
||||
def get_reduction_params(epilogue_functor):
|
||||
_EpilogueOutputParams = epilogue_functor.epilogue_type
|
||||
|
||||
class _ReductionParams(ctypes.Structure):
|
||||
_fields_ = [
|
||||
|
||||
@ -1,366 +0,0 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
from pycutlass import *
|
||||
from pycutlass.library import SubstituteTemplate
|
||||
import cutlass
|
||||
from cuda import cuda
|
||||
from cuda import nvrtc
|
||||
import tempfile
|
||||
import os
|
||||
import ctypes
|
||||
|
||||
#
|
||||
import json
|
||||
import sqlite3
|
||||
|
||||
|
||||
IncludeTemplate = r'''#include "${include}"
|
||||
'''
|
||||
|
||||
#
|
||||
class CompilationOptions:
|
||||
'''
|
||||
Compilation options.
|
||||
'''
|
||||
|
||||
#
|
||||
def __init__(self, architectures = [80], include_paths = []):
|
||||
self.includes = []
|
||||
self.include_paths = include_paths
|
||||
self.flags = ['-std=c++11', '-default-device']
|
||||
self.architectures = architectures
|
||||
|
||||
#
|
||||
def get(self):
|
||||
options = []
|
||||
|
||||
for flag in self.flags:
|
||||
options.append(bytes(str.encode(flag)))
|
||||
|
||||
for incl in self.include_paths:
|
||||
options.append(bytes(str.encode('--include-path=%s' % incl)))
|
||||
|
||||
arch_list = "-arch="
|
||||
for idx, arch in enumerate(self.architectures):
|
||||
if idx:
|
||||
arch_list += ","
|
||||
arch_list += "sm_%d" % arch
|
||||
|
||||
options.append(bytes(str.encode(arch_list)))
|
||||
|
||||
return options
|
||||
|
||||
def convertToBinaryData(filename):
|
||||
with open(filename, 'rb') as file:
|
||||
blobData = file.read()
|
||||
return blobData
|
||||
|
||||
def CDLLBin(host_binary):
|
||||
tempfile.tempdir = "./"
|
||||
temp_so = tempfile.NamedTemporaryFile(prefix='host_func', suffix='.so', delete=True)
|
||||
with open(temp_so.name, 'wb') as file:
|
||||
file.write(host_binary)
|
||||
host_lib = ctypes.CDLL(temp_so.name)
|
||||
return host_lib
|
||||
|
||||
|
||||
class ArtifactManager:
|
||||
"""
|
||||
Artifact manager
|
||||
"""
|
||||
def __init__(self) -> None:
|
||||
try:
|
||||
connection = sqlite3.connect("./compiled_cache.db")
|
||||
cursor = connection.cursor()
|
||||
sqlite_create_table_query = """CREATE TABLE compiled_operations(op_key TEXT NOT NULL UNIQUE, cubin BLOB NOT NULL, hostbin BLOB NOT NULL, op_name TEXT NOT NULL, op_attrs TEXT NOT NULL)"""
|
||||
cursor.execute(sqlite_create_table_query)
|
||||
connection.commit()
|
||||
cursor.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
self.compiled_cache_device = cutlass.CompileCache()
|
||||
self.compiled_cache_host = cutlass.CompileCache()
|
||||
|
||||
def insert_operation(self, op_key, cubin, hostfile, op_name, op_attrs):
|
||||
connection = sqlite3.connect("./compiled_cache.db")
|
||||
cursor = connection.cursor()
|
||||
sqlite_insert_blob_query = """ INSERT OR IGNORE INTO compiled_operations (op_key, cubin, hostbin, op_name, op_attrs) VALUES (?, ?, ?, ?, ?)"""
|
||||
|
||||
hostbin = convertToBinaryData(hostfile)
|
||||
|
||||
data_tuple = (op_key, cubin, hostbin, op_name, json.dumps(op_attrs))
|
||||
|
||||
cursor.execute(sqlite_insert_blob_query, data_tuple)
|
||||
connection.commit()
|
||||
cursor.close()
|
||||
|
||||
def load_operation(self, op_key):
|
||||
connection = sqlite3.connect("./compiled_cache.db")
|
||||
cursor = connection.cursor()
|
||||
sqlite_fetch_blob_query = """SELECT * from compiled_operations where op_key = ?"""
|
||||
# try:
|
||||
cursor.execute(sqlite_fetch_blob_query, (op_key, ))
|
||||
record = cursor.fetchall()
|
||||
if len(record) == 0:
|
||||
return False
|
||||
for row in record:
|
||||
key, cubin_image, host_binary, operation_name, op_attr = row
|
||||
op_attr = json.loads(op_attr)
|
||||
err, module = cuda.cuModuleLoadData(cubin_image)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError('Cuda Error: {}'.format(err))
|
||||
|
||||
err, kernel = cuda.cuModuleGetFunction(module, bytes(str.encode(operation_name)))
|
||||
self.compiled_cache_device.insert(key, kernel)
|
||||
|
||||
compiled_host_fns = {}
|
||||
host_lib = CDLLBin(host_binary)
|
||||
|
||||
func_name = operation_name + '_get_params'
|
||||
func = getattr(host_lib, func_name)
|
||||
func.restype = ctypes.POINTER(ctypes.c_char * op_attr[0])
|
||||
compiled_host_fns['get_args'] = func
|
||||
|
||||
func_name = operation_name + '_shared_memory_size'
|
||||
func = getattr(host_lib, func_name)
|
||||
compiled_host_fns['shared_memory_capacity'] = func()
|
||||
|
||||
for attr in op_attr:
|
||||
if isinstance(attr, str):
|
||||
func_name = operation_name + '_' + attr
|
||||
func = getattr(host_lib, func_name)
|
||||
compiled_host_fns[attr] = func
|
||||
|
||||
self.compiled_cache_host.insert(key, compiled_host_fns)
|
||||
return True
|
||||
|
||||
|
||||
def emit_compile_(self, operation_list, compilation_options):
|
||||
"""
|
||||
Compile a list of kernels and store them into database
|
||||
"""
|
||||
source_buffer_device = ""
|
||||
source_buffer_host = ""
|
||||
# 1. include
|
||||
includes = []
|
||||
for operation in operation_list:
|
||||
for incl in operation.emitter.includes:
|
||||
if incl not in includes:
|
||||
includes.append(incl)
|
||||
|
||||
includes_host = [
|
||||
"builtin_types.h", "device_launch_parameters.h", "stddef.h"] + includes
|
||||
for incl in includes:
|
||||
source_buffer_device += SubstituteTemplate(IncludeTemplate, {'include': incl})
|
||||
|
||||
for incl in includes_host:
|
||||
if "/device/" not in incl:
|
||||
source_buffer_host += SubstituteTemplate(IncludeTemplate, { 'include': incl} )
|
||||
|
||||
|
||||
# 2. Operations
|
||||
for operation in operation_list:
|
||||
source_buffer_device += operation.emit()
|
||||
source_buffer_host += operation.emit()
|
||||
values = {
|
||||
'operation_name': operation.name(),
|
||||
'operation_suffix': operation.emitter.operation_suffix
|
||||
}
|
||||
source_buffer_device += SubstituteTemplate(operation.KernelTemplate, values)
|
||||
source_buffer_host += SubstituteTemplate(operation.HostTemplate, values)
|
||||
|
||||
# 3. compile
|
||||
err, program = nvrtc.nvrtcCreateProgram(
|
||||
str.encode(source_buffer_device),
|
||||
bytes(str.encode("module.cu")),
|
||||
0, [], [])
|
||||
|
||||
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise RuntimeError('NVRTC Error: {}'.format(err))
|
||||
|
||||
# Compile program
|
||||
options = compilation_options.get()
|
||||
|
||||
err, = nvrtc.nvrtcCompileProgram(program, len(options), options)
|
||||
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
|
||||
error_string = 'NVRTC Error: {}\n'.format(err)
|
||||
|
||||
# Get log from compilation
|
||||
err, logSize = nvrtc.nvrtcGetProgramLogSize(program)
|
||||
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise RuntimeError('NVRTC Error: {}'.format(err))
|
||||
|
||||
log = b' ' * logSize
|
||||
err, = nvrtc.nvrtcGetProgramLog(program, log)
|
||||
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise RuntimeError('NVRTC Error: {}'.format(err))
|
||||
|
||||
raise RuntimeError(error_string + log.decode() + source_buffer_device)
|
||||
|
||||
# Get data from compilation
|
||||
err, dataSize = nvrtc.nvrtcGetCUBINSize(program)
|
||||
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise RuntimeError('NVRTC Error: {}'.format(err))
|
||||
|
||||
cubin_image = b' ' * dataSize
|
||||
err, = nvrtc.nvrtcGetCUBIN(program, cubin_image)
|
||||
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise RuntimeError('NVRTC Error: {}'.format(err))
|
||||
|
||||
# compile the host code
|
||||
options = compilation_options.get()
|
||||
cmd = "echo '%s'|g++ -x c++ -fpermissive -w -fPIC" % source_buffer_host
|
||||
for opt in options:
|
||||
opt = opt.decode("utf-8")
|
||||
if opt not in ['-default-device', '-std=c++11', '-arch=sm_80']:
|
||||
if '--include-path=' in opt:
|
||||
cmd += " " + opt.replace('--include-path=', '-I')
|
||||
else:
|
||||
cmd += " "+ opt
|
||||
|
||||
tempfile.tempdir = "./"
|
||||
temp = tempfile.NamedTemporaryFile(prefix='host_func', suffix='.so', delete=True)
|
||||
|
||||
cmd += ' - -shared -o %s' % temp.name
|
||||
os.system(cmd)
|
||||
host_lib = ctypes.CDLL(temp.name)
|
||||
|
||||
return cubin_image, host_lib, temp
|
||||
|
||||
|
||||
def add_module(self, operations, compile_options=None):
|
||||
"""
|
||||
Insert a new compiled device module
|
||||
"""
|
||||
if compile_options is None:
|
||||
cutlass_path = os.getenv('CUTLASS_PATH')
|
||||
assert cutlass_path is not None, "Environment variable 'CUTLASS_PATH' is not defined."
|
||||
cuda_install_path = os.getenv('CUDA_INSTALL_PATH')
|
||||
assert cuda_install_path is not None, "Environment variable 'CUDA_INSTALL_PATH' is not defined."
|
||||
architectures = []
|
||||
for operation in operations:
|
||||
if hasattr(operation, "tile_description"):
|
||||
cc = operation.tile_description.minimum_compute_capability
|
||||
if cc not in architectures:
|
||||
architectures.append(cc)
|
||||
include_paths = [
|
||||
cuda_install_path + '/include',
|
||||
cutlass_path + '/include',
|
||||
cutlass_path + '/tools/util/include',
|
||||
]
|
||||
compile_options = CompilationOptions(architectures, include_paths)
|
||||
# save the cubin
|
||||
operation_key = []
|
||||
operation_list = []
|
||||
for operation in operations:
|
||||
# step 1: get kernel string as key
|
||||
key = operation.rt_module.emit() + operation.procedural_name()
|
||||
# step 1: check if the operation is in cache
|
||||
compiled_kernel = self.compiled_cache_device.at(key)
|
||||
|
||||
if compiled_kernel is None:
|
||||
hit = self.load_operation(key)
|
||||
if hit:
|
||||
compiled_kernel = self.compiled_cache_device.at(key)
|
||||
assert compiled_kernel is not None
|
||||
if compiled_kernel is not None:
|
||||
operation.rt_module.kernel = compiled_kernel
|
||||
compiled_host_fns = self.compiled_cache_host.at(key)
|
||||
assert compiled_host_fns is not None
|
||||
for key in compiled_host_fns.keys():
|
||||
setattr(operation.rt_module, key, compiled_host_fns[key])
|
||||
operation.rt_module.initialize()
|
||||
else:
|
||||
operation_list.append(operation.rt_module)
|
||||
operation_key.append(key)
|
||||
if len(operation_list) > 0:
|
||||
cubin_image, host_lib, host_file = self.emit_compile_(operation_list, compile_options)
|
||||
|
||||
err, module = cuda.cuModuleLoadData(cubin_image)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError('Cuda Error: {}'.format(err))
|
||||
|
||||
operation_name = []
|
||||
operation_attr = []
|
||||
for operation, key in zip(operation_list, operation_key):
|
||||
# get device kernels
|
||||
err, operation.kernel = cuda.cuModuleGetFunction(
|
||||
module,
|
||||
bytes(str.encode(operation.name()))
|
||||
)
|
||||
operation_name.append(operation.name())
|
||||
self.compiled_cache_device.insert(key, operation.kernel)
|
||||
# get host functions
|
||||
compiled_host_fns = {}
|
||||
op_attr = []
|
||||
|
||||
# get param size
|
||||
func_name = operation.name() + '_get_param_size'
|
||||
func = getattr(host_lib, func_name)
|
||||
param_size = func()
|
||||
|
||||
func_name = operation.name() + '_get_params'
|
||||
func = getattr(host_lib, func_name)
|
||||
func.argtype = operation.argtype
|
||||
func.restype = ctypes.POINTER(ctypes.c_char * param_size)
|
||||
setattr(operation, 'get_args', func)
|
||||
compiled_host_fns['get_args'] = func
|
||||
|
||||
# set shared memory size
|
||||
func_name = operation.name() + '_shared_memory_size'
|
||||
func = getattr(host_lib, func_name)
|
||||
setattr(operation, 'shared_memory_capacity', func())
|
||||
compiled_host_fns['shared_memory_capacity'] = func()
|
||||
# set the maximum dynamic shared size
|
||||
operation.initialize()
|
||||
|
||||
# get extra functions
|
||||
op_attr.append(param_size)
|
||||
|
||||
if hasattr(operation, "extra_funcs"):
|
||||
for suffix in operation.extra_funcs:
|
||||
func_name = operation.name() + '_' + suffix
|
||||
func = getattr(host_lib, func_name)
|
||||
setattr(operation, suffix, func)
|
||||
compiled_host_fns[suffix] = func
|
||||
op_attr.append(suffix)
|
||||
|
||||
operation_attr.append(op_attr)
|
||||
self.compiled_cache_host.insert(key, compiled_host_fns)
|
||||
|
||||
for key, operation_name, operation_attr in zip(operation_key, operation_name, operation_attr):
|
||||
self.insert_operation(key, cubin_image, host_file.name, operation_name, operation_attr)
|
||||
|
||||
|
||||
artifact_manager = ArtifactManager()
|
||||
@ -30,7 +30,6 @@
|
||||
#
|
||||
#################################################################################################
|
||||
from pycutlass import *
|
||||
from pycutlass.library import SubstituteTemplate
|
||||
import cutlass
|
||||
from cuda import cuda
|
||||
from cuda import nvrtc
|
||||
@ -132,13 +131,15 @@ class ArtifactManager:
|
||||
except:
|
||||
pass
|
||||
|
||||
self.nvcc()
|
||||
self.compiled_cache_device = cutlass.CompileCache()
|
||||
self.compiled_cache_host = cutlass.CompileCache()
|
||||
|
||||
def nvrtc(self):
|
||||
self.backend = "nvrtc"
|
||||
self.default_compile_options = [
|
||||
'-std=c++11', '-default-device',
|
||||
]
|
||||
self.compiled_cache_device = cutlass.CompileCache()
|
||||
self.compiled_cache_host = cutlass.CompileCache()
|
||||
|
||||
def nvcc(self):
|
||||
self.backend = "nvcc"
|
||||
self.default_compile_options = [
|
||||
@ -335,13 +336,14 @@ class ArtifactManager:
|
||||
architectures = []
|
||||
for operation in operations:
|
||||
if hasattr(operation, "tile_description"):
|
||||
cc = operation.tile_description.minimum_compute_capability
|
||||
cc = operation.arch
|
||||
if cc not in architectures:
|
||||
architectures.append(cc)
|
||||
include_paths = [
|
||||
cuda_install_path + '/include',
|
||||
cutlass_path + '/include',
|
||||
cutlass_path + '/tools/util/include',
|
||||
cutlass_path + '/tools/library/scripts/pycutlass/src/cpp/include'
|
||||
]
|
||||
compile_options = CompilationOptions(
|
||||
self.default_compile_options, architectures, include_paths)
|
||||
|
||||
@ -48,6 +48,9 @@ class Conv2dArguments(ArgumentBase):
|
||||
:param operation: the Conv2d operation to take the argument
|
||||
:type operation: :class:`pycutlass.Conv2dOperation`
|
||||
|
||||
:param problem_size: the Conv2d problem size
|
||||
:type problem_size: :class:`cutlass.conv.Conv2dProblemSize`
|
||||
|
||||
:param A: tensor A
|
||||
:type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
||||
|
||||
@ -78,6 +81,7 @@ class Conv2dArguments(ArgumentBase):
|
||||
split_k_mode: 'cutlass.conv.SplitKMode'
|
||||
= cutlass.conv.SplitKMode.Serial, **kwargs) -> None:
|
||||
|
||||
self.operation = operation
|
||||
#: convolution kind
|
||||
self.conv_kind: cutlass.conv.Operator = operation.conv_kind
|
||||
self.layout_A: cutlass.layout = operation.A.layout
|
||||
@ -93,15 +97,12 @@ class Conv2dArguments(ArgumentBase):
|
||||
|
||||
super().__init__(A, B, C, D, **kwargs)
|
||||
# preprocessing output ops
|
||||
if "output_op" in kwargs.keys() and \
|
||||
|
||||
if 'output_op' in kwargs.keys() and \
|
||||
split_k_mode != cutlass.conv.SplitKMode.Parallel:
|
||||
self.alpha = kwargs["output_op"].alpha
|
||||
self.beta = kwargs["output_op"].beta
|
||||
self.output_op = kwargs['output_op']
|
||||
else:
|
||||
self.alpha = 1.0
|
||||
self.beta = 0.0
|
||||
|
||||
self.element_compute = operation.element_epilogue
|
||||
self.output_op = self.operation.epilogue_type(1.0, 0.0)
|
||||
|
||||
if "split_k_slices" in kwargs.keys():
|
||||
self.split_k_mode = split_k_mode
|
||||
@ -114,7 +115,12 @@ class Conv2dArguments(ArgumentBase):
|
||||
self.problem_size: cutlass.conv.Conv2dProblemSize = problem_size
|
||||
self.problem_size.split_k_slices = self.split_k_slices
|
||||
|
||||
self.operation = operation
|
||||
if hasattr(self, "tensor_c_numel"):
|
||||
c_coord = cutlass.conv.implicit_gemm_tensor_c_extent(
|
||||
self.conv_kind, problem_size)
|
||||
if (self.tensor_c_numel == c_coord.at(3) and
|
||||
self.tensor_c_numel < c_coord.size()):
|
||||
self.bias = True
|
||||
|
||||
#
|
||||
# initialize the argument
|
||||
@ -159,6 +165,9 @@ class Conv2dArguments(ArgumentBase):
|
||||
self.conv_kind, problem_size)
|
||||
else:
|
||||
raise ValueError("unknown operand: " + operand)
|
||||
# Zero stride trick
|
||||
if operand == "c" and self.bias:
|
||||
tensor_coord = cutlass.Tensor4DCoord(0, 0, 0, 0)
|
||||
|
||||
layout = tensor_layout.packed(tensor_coord)
|
||||
|
||||
@ -174,24 +183,9 @@ class Conv2dArguments(ArgumentBase):
|
||||
ref_D = TensorRef_(self.get_tensor_ref(
|
||||
self.ptr_D, self.element_C, self.layout_C, self.problem_size, "d"))
|
||||
|
||||
if self.element_compute == cutlass.float16:
|
||||
alpha = cutlass.float16(self.alpha).storage
|
||||
beta = cutlass.float16(self.beta).storage
|
||||
elif self.element_compute == cutlass.int32:
|
||||
alpha = int(self.alpha)
|
||||
beta = int(self.beta)
|
||||
else:
|
||||
alpha = self.alpha
|
||||
beta = self.beta
|
||||
|
||||
argument_type, epilogue_type = get_conv2d_arguments(
|
||||
self.operation.element_epilogue)
|
||||
|
||||
output_op = epilogue_type(alpha, beta, 0, 0)
|
||||
|
||||
self.c_arguments = argument_type(
|
||||
self.c_arguments = self.operation.argument_type(
|
||||
Conv2DProblemSize(self.problem_size),
|
||||
ref_A, ref_B, ref_C, ref_D, output_op, self.split_k_mode
|
||||
ref_A, ref_B, ref_C, ref_D, self.output_op, self.split_k_mode
|
||||
)
|
||||
|
||||
self.semaphore = semaphore
|
||||
@ -296,9 +290,8 @@ extern "C" {
|
||||
|
||||
def __init__(self, operation: 'Conv2dOperation'):
|
||||
super().__init__(operation)
|
||||
|
||||
self.argtype = [ctypes.POINTER(get_conv2d_arguments(
|
||||
operation.element_epilogue)[0]), ctypes.c_void_p]
|
||||
self.argument_type, self.epilogue_type = get_conv2d_arguments(operation.epilogue_functor)
|
||||
self.argtype = [ctypes.POINTER(self.argument_type), ctypes.c_void_p]
|
||||
self.conv_kind = operation.conv_kind
|
||||
|
||||
self.operation: Conv2dOperation = operation
|
||||
@ -410,9 +403,7 @@ class Conv2dOperation:
|
||||
iterator_algorithm: cutlass.conv.IteratorAlgorithm,
|
||||
arch: int, tile_description: TileDescription,
|
||||
A: TensorDescription, B: TensorDescription, C: TensorDescription,
|
||||
element_epilogue: Union[cutlass.int8, cutlass.int32, cutlass.float16,
|
||||
cutlass.bfloat16, cutlass.float32, cutlass.float64],
|
||||
stride_support, epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support, epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1):
|
||||
|
||||
self.operation_kind: OperationKind = OperationKind.Conv2d
|
||||
@ -422,13 +413,14 @@ class Conv2dOperation:
|
||||
self.A: TensorDescription = A
|
||||
self.B: TensorDescription = B
|
||||
self.C: TensorDescription = C
|
||||
self.element_epilogue = element_epilogue
|
||||
self.epilogue_functor = epilogue_functor
|
||||
self.iterator_algorithm = iterator_algorithm
|
||||
self.stride_support = stride_support
|
||||
self.swizzling_functor = swizzling_functor()
|
||||
|
||||
self.rt_module: Conv2dRT = Conv2dRT(self)
|
||||
self.argument_type = self.rt_module.argument_type
|
||||
self.epilogue_type = self.rt_module.epilogue_type
|
||||
|
||||
def run(self, arguments: Conv2dArguments) -> cuda.CUresult:
|
||||
"""
|
||||
@ -577,12 +569,7 @@ typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>,
|
||||
${epilogue_functor},
|
||||
${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
|
||||
${stages},
|
||||
${math_operator},
|
||||
@ -629,8 +616,7 @@ struct ${operation_name}${operation_suffix}:
|
||||
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
||||
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
||||
'epilogue_vector_length': str(epilogue_vector_length),
|
||||
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
||||
'epilogue_functor': operation.epilogue_functor.emit(),
|
||||
'swizzling_functor': operation.swizzling_functor.tag(),
|
||||
'stages': str(operation.tile_description.stages),
|
||||
'iterator_algorithm': IteratorAlgorithmTag[operation.iterator_algorithm],
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -116,6 +116,12 @@ class GemmArguments(ArgumentBase):
|
||||
else:
|
||||
self.problem_size = cutlass.gemm.GemmCoord(
|
||||
problem_size.m(), problem_size.n(), problem_size.k())
|
||||
|
||||
# if the number of elements in C = problem_size.n
|
||||
# C is treated as the bias
|
||||
if hasattr(self, "tensor_c_numel"):
|
||||
if (self.tensor_c_numel == self.problem_size.n() and
|
||||
self.problem_size.m() != 1): self.bias = True
|
||||
|
||||
# get the leading dimension
|
||||
self.lda = operation.A.layout.packed(self.problem_size.mk()).stride()
|
||||
@ -123,27 +129,69 @@ class GemmArguments(ArgumentBase):
|
||||
self.ldc = operation.C.layout.packed(self.problem_size.mn()).stride()
|
||||
self.ldd = self.ldc
|
||||
|
||||
# stride 0 trick
|
||||
if self.bias:
|
||||
self.ldc = 0
|
||||
|
||||
if 'output_op' in kwargs.keys() and \
|
||||
gemm_mode != cutlass.gemm.Mode.GemmSplitKParallel:
|
||||
self.alpha = kwargs['output_op'].alpha
|
||||
self.beta = kwargs['output_op'].beta
|
||||
self.output_op = kwargs['output_op']
|
||||
else:
|
||||
self.alpha = 1.0
|
||||
self.beta = 0.0
|
||||
self.output_op = self.operation.epilogue_type(1.0, 0.0)
|
||||
|
||||
# get number of slices on k dimension
|
||||
self.gemm_mode = gemm_mode
|
||||
if 'split_k_slices' in kwargs.keys():
|
||||
self.split_k_slices = kwargs['split_k_slices']
|
||||
else:
|
||||
self.split_k_slices = 1
|
||||
if gemm_mode in [cutlass.gemm.Mode.Gemm, cutlass.gemm.Mode.GemmSplitKParallel]:
|
||||
if 'split_k_slices' in kwargs.keys():
|
||||
self.batch_count = kwargs['split_k_slices']
|
||||
else:
|
||||
self.batch_count = 1
|
||||
self.split_k_slices = self.batch_count
|
||||
|
||||
self.batch_count = self.split_k_slices
|
||||
if gemm_mode in [cutlass.gemm.Mode.Batched, cutlass.gemm.Mode.Array]:
|
||||
if 'batch' in kwargs.keys():
|
||||
self.batch_count = kwargs['batch']
|
||||
else:
|
||||
self.batch_count = 1
|
||||
|
||||
self.batched_stride_A = self.problem_size.m() * self.problem_size.k()
|
||||
self.batched_stride_B = self.problem_size.n() * self.problem_size.k()
|
||||
self.batched_stride_C = self.problem_size.m() * self.problem_size.n()
|
||||
self.batched_stride_D = self.problem_size.m() * self.problem_size.n()
|
||||
if self.bias:
|
||||
self.batched_stride_C = self.problem_size.n()
|
||||
|
||||
# support GEMM Mode Array
|
||||
if gemm_mode == cutlass.gemm.Mode.Array:
|
||||
self.ptr_A_array = []
|
||||
self.ptr_B_array = []
|
||||
self.ptr_C_array = []
|
||||
self.ptr_D_array = []
|
||||
|
||||
ptr_A_addr = int(self.ptr_A)
|
||||
ptr_B_addr = int(self.ptr_B)
|
||||
ptr_C_addr = int(self.ptr_C)
|
||||
ptr_D_addr = int(self.ptr_D)
|
||||
|
||||
stride_A = self.batched_stride_A * DataTypeSize[self.element_A] // 8
|
||||
stride_B = self.batched_stride_B * DataTypeSize[self.element_B] // 8
|
||||
stride_C = self.batched_stride_C * DataTypeSize[self.element_C] // 8
|
||||
stride_D = self.batched_stride_D * DataTypeSize[self.element_C] // 8
|
||||
for _ in range(self.batch_count):
|
||||
self.ptr_A_array.append(ptr_A_addr)
|
||||
self.ptr_B_array.append(ptr_B_addr)
|
||||
self.ptr_C_array.append(ptr_C_addr)
|
||||
self.ptr_D_array.append(ptr_D_addr)
|
||||
|
||||
ptr_A_addr += stride_A
|
||||
ptr_B_addr += stride_B
|
||||
ptr_C_addr += stride_C
|
||||
ptr_D_addr += stride_D
|
||||
|
||||
self.ptr_A_array_buffer = todevice(self.ptr_A_array, dtype=np.int64)
|
||||
self.ptr_B_array_buffer = todevice(self.ptr_B_array, dtype=np.int64)
|
||||
self.ptr_C_array_buffer = todevice(self.ptr_C_array, dtype=np.int64)
|
||||
self.ptr_D_array_buffer = todevice(self.ptr_D_array, dtype=np.int64)
|
||||
|
||||
if isinstance(self.operation, GemmOperationUniversal):
|
||||
self.initialize()
|
||||
@ -195,28 +243,28 @@ class GemmArguments(ArgumentBase):
|
||||
self.grid_tiled_shape.z
|
||||
)
|
||||
)
|
||||
|
||||
argument_type, epilogue_type = get_gemm_arguments(
|
||||
self.operation.element_epilogue)
|
||||
|
||||
if self.operation.element_epilogue == cutlass.float16:
|
||||
self.alpha = cutlass.float16(self.alpha).storage
|
||||
self.beta = cutlass.float16(self.beta).storage
|
||||
elif self.operation.element_epilogue == cutlass.int32:
|
||||
self.alpha = int(self.alpha)
|
||||
self.beta = int(self.beta)
|
||||
|
||||
output_op = epilogue_type(self.alpha, self.beta, 0, 0)
|
||||
|
||||
arguments = argument_type(
|
||||
self.gemm_mode, problem_size_, self.batch_count, output_op,
|
||||
int(self.ptr_A), int(self.ptr_B), int(self.ptr_C), int(self.ptr_D),
|
||||
self.batched_stride_A, self.batched_stride_B, self.batched_stride_C,
|
||||
self.batched_stride_D,
|
||||
self.lda, self.ldb, self.ldc, self.ldd,
|
||||
self.lda, self.ldb, self.ldc, self.ldd,
|
||||
0, 0, 0
|
||||
)
|
||||
if self.gemm_mode == cutlass.gemm.Mode.Array:
|
||||
arguments = self.operation.argument_type(
|
||||
self.gemm_mode, problem_size_, self.batch_count, self.output_op,
|
||||
int(self.ptr_A_array_buffer.ptr),
|
||||
int(self.ptr_B_array_buffer.ptr),
|
||||
int(self.ptr_C_array_buffer.ptr),
|
||||
int(self.ptr_D_array_buffer.ptr),
|
||||
0, 0, 0, 0,
|
||||
self.lda, self.ldb, self.ldc, self.ldd,
|
||||
self.lda, self.ldb, self.ldc, self.ldd,
|
||||
0, 0, 0
|
||||
)
|
||||
else:
|
||||
arguments = self.operation.argument_type(
|
||||
self.gemm_mode, problem_size_, self.batch_count, self.output_op,
|
||||
int(self.ptr_A), int(self.ptr_B), int(self.ptr_C), int(self.ptr_D),
|
||||
self.batched_stride_A, self.batched_stride_B, self.batched_stride_C,
|
||||
self.batched_stride_D,
|
||||
self.lda, self.ldb, self.ldc, self.ldd,
|
||||
self.lda, self.ldb, self.ldc, self.ldd,
|
||||
0, 0, 0
|
||||
)
|
||||
|
||||
self.arguments = arguments, grid_tiled_shape_, self.gemm_k_size
|
||||
|
||||
@ -381,13 +429,12 @@ class GemmGroupedArguments:
|
||||
else:
|
||||
self.alpha = 1.0
|
||||
self.beta = 0.0
|
||||
|
||||
if 'output_op' in kwargs.keys():
|
||||
self.output_op = kwargs['output_op']
|
||||
else:
|
||||
self.output_op = self.operation.epilogue_type(1.0, 0.0)
|
||||
|
||||
if self.operation.element_epilogue == cutlass.float16:
|
||||
self.alpha = cutlass.float16(self.alpha).storage
|
||||
self.beta = cutlass.float16(self.beta).storage
|
||||
elif self.operation.element_epilogue == cutlass.int32:
|
||||
self.alpha = int(self.alpha)
|
||||
self.beta = int(self.beta)
|
||||
|
||||
# get host problem size
|
||||
self.host_problem_size_ptr = np.array(
|
||||
@ -398,12 +445,7 @@ class GemmGroupedArguments:
|
||||
self.initialize()
|
||||
|
||||
def get_arguments(self):
|
||||
|
||||
argument_type, epilogue_type = get_gemm_grouped_arguments(
|
||||
self.operation.element_epilogue)
|
||||
self.output_op = epilogue_type(self.alpha, self.beta, 0, 0)
|
||||
|
||||
return argument_type(
|
||||
return self.operation.argument_type(
|
||||
self.problem_size_buffer.ptr, self.problem_count, self.total_tiles,
|
||||
self.output_op, self.ptr_A_buffer.ptr, self.ptr_B_buffer.ptr,
|
||||
self.ptr_C_buffer.ptr, self.ptr_D_buffer.ptr, self.lda_buffer.ptr,
|
||||
@ -492,16 +534,6 @@ ${operation_name}(${operation_name}${operation_suffix}::Params params) {
|
||||
#: number of threads per threadblock
|
||||
self.threads: int = operation.tile_description.num_threads
|
||||
|
||||
if (operation.epilogue_functor in
|
||||
[
|
||||
EpilogueFunctor.LinearCombination,
|
||||
EpilogueFunctor.FastLinearCombinationClamp,
|
||||
EpilogueFunctor.LinearCombinationClamp
|
||||
]):
|
||||
self.output_op = LinearCombinationFunctor()
|
||||
else:
|
||||
raise ValueError("unknown epilogue functor")
|
||||
|
||||
#
|
||||
def emit(self):
|
||||
return self.emitter.emit(self.operation)
|
||||
@ -568,9 +600,11 @@ extern "C" {
|
||||
def __init__(self, operation: 'GemmOperation'):
|
||||
super(GemmRTUniversal, self).__init__(operation)
|
||||
self.emitter = EmitGemmUniversalInstance(
|
||||
'_type', operation.direct_store)
|
||||
'_type', operation.direct_store, operation.visitor)
|
||||
|
||||
self.argument_type, self.epilogue_type = get_gemm_arguments(operation.epilogue_functor)
|
||||
self.argtype = [
|
||||
ctypes.POINTER(get_gemm_arguments(operation.element_epilogue)[0]),
|
||||
ctypes.POINTER(self.argument_type),
|
||||
ctypes.POINTER(GemmCoord_), ctypes.c_int, ctypes.c_void_p
|
||||
]
|
||||
|
||||
@ -673,8 +707,8 @@ class GemmRTGrouped(GemmRTbase):
|
||||
self.extra_funcs = ['precompute']
|
||||
|
||||
self.emitter = EmitGemmGroupedInstance('_type')
|
||||
self.argtype = [ctypes.POINTER(get_gemm_grouped_arguments(
|
||||
operation.element_epilogue)[0]), ctypes.c_int, ctypes.c_void_p]
|
||||
self.argument_type, self.epilogue_type = get_gemm_grouped_arguments(operation.epilogue_functor)
|
||||
self.argtype = [ctypes.POINTER(self.argument_type), ctypes.c_int, ctypes.c_void_p]
|
||||
|
||||
def host_precompute(self, arguments, workspace_bytes):
|
||||
self.precompute.argtype = [
|
||||
@ -717,7 +751,7 @@ class GemmOperationBase:
|
||||
def __init__(
|
||||
self, gemm_kind, arch, tile_description: TileDescription,
|
||||
A: TensorDescription, B: TensorDescription, C: TensorDescription,
|
||||
element_epilogue, epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1, **kwargs):
|
||||
|
||||
#: operation kind
|
||||
@ -749,7 +783,7 @@ class GemmOperationBase:
|
||||
#: Operand C
|
||||
self.C: TensorDescription = copy.deepcopy(C)
|
||||
self.switched = False
|
||||
self.element_epilogue = element_epilogue
|
||||
|
||||
self.epilogue_functor = epilogue_functor
|
||||
self.swizzling_functor = swizzling_functor()
|
||||
|
||||
@ -757,6 +791,11 @@ class GemmOperationBase:
|
||||
self.direct_store = kwargs["direct_store"]
|
||||
else:
|
||||
self.direct_store = False
|
||||
|
||||
if "visitor" in kwargs:
|
||||
self.visitor = kwargs["visitor"]
|
||||
else:
|
||||
self.visitor = False
|
||||
|
||||
def run(self, arguments: GemmArguments) -> cuda.CUresult:
|
||||
"""
|
||||
@ -895,22 +934,26 @@ class GemmOperationBase:
|
||||
|
||||
|
||||
class GemmOperationUniversal(GemmOperationBase):
|
||||
def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C, element_epilogue,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs):
|
||||
def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C,
|
||||
epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs):
|
||||
super(GemmOperationUniversal, self).__init__(GemmKind.Universal, arch, tile_description,
|
||||
A, B, C, element_epilogue, epilogue_functor, swizzling_functor, **kwargs)
|
||||
A, B, C, epilogue_functor, swizzling_functor, **kwargs)
|
||||
self.rt_module = GemmRTUniversal(self)
|
||||
self.argument_type = self.rt_module.argument_type
|
||||
self.epilogue_type = self.rt_module.epilogue_type
|
||||
|
||||
|
||||
class GemmOperationGrouped(GemmOperationBase):
|
||||
def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C, element_epilogue,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs):
|
||||
def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C,
|
||||
epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs):
|
||||
super(GemmOperationGrouped, self).__init__(GemmKind.Grouped, arch, tile_description,
|
||||
A, B, C, element_epilogue, epilogue_functor, swizzling_functor, **kwargs)
|
||||
A, B, C, epilogue_functor, swizzling_functor, **kwargs)
|
||||
assert "precompute_mode" in kwargs.keys(
|
||||
), "missing keyword arguement 'precompute_mode'."
|
||||
self.precompute_mode = kwargs["precompute_mode"]
|
||||
self.rt_module = GemmRTGrouped(self)
|
||||
self.argument_type = self.rt_module.argument_type
|
||||
self.epilogue_type = self.rt_module.epilogue_type
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
@ -918,228 +961,14 @@ class GemmOperationGrouped(GemmOperationBase):
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
|
||||
|
||||
class EmitGemmInstance:
|
||||
''' Responsible for emitting a CUTLASS template definition'''
|
||||
|
||||
def __init__(self, operation_suffix=''):
|
||||
self.operation_suffix = operation_suffix
|
||||
self.includes = []
|
||||
self.gemm_template = """
|
||||
// Gemm operator ${operation_name}
|
||||
using Operation_${operation_name} = cutlass::gemm::device::Gemm<
|
||||
${element_a}, ${layout_a},
|
||||
${element_b}, ${layout_b},
|
||||
${element_c}, ${layout_c},
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>,
|
||||
${swizzling_functor},
|
||||
${stages},
|
||||
${align_a},
|
||||
${align_b},
|
||||
false,
|
||||
${math_operation}
|
||||
${residual}
|
||||
>;
|
||||
"""
|
||||
self.gemm_complex_template = """
|
||||
// Gemm operator ${operation_name}
|
||||
using Operation_${operation_name} = cutlass::gemm::device::GemmComplex<
|
||||
${element_a}, ${layout_a},
|
||||
${element_b}, ${layout_b},
|
||||
${element_c}, ${layout_c},
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>,
|
||||
${swizzling_functor},
|
||||
${stages},
|
||||
${transform_a},
|
||||
${transform_b},
|
||||
${math_operation}
|
||||
${residual}
|
||||
>;
|
||||
"""
|
||||
|
||||
#
|
||||
def instance_template(self):
|
||||
return """
|
||||
${compile_guard_start}
|
||||
manifest.append(new ${gemm_kind}<Operation_${operation_name}>("${operation_name}"));
|
||||
${compile_guard_end}
|
||||
"""
|
||||
|
||||
#
|
||||
def emit(self, operation):
|
||||
|
||||
warp_shape = [operation.tile_description.threadblock_shape[idx] //
|
||||
operation.tile_description.warp_count[idx] for idx in range(3)]
|
||||
|
||||
epilogue_vector_length = int(min(
|
||||
operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
|
||||
|
||||
residual = ''
|
||||
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
'element_a': DataTypeTag[operation.A.element],
|
||||
'layout_a': LayoutTag[operation.A.layout],
|
||||
'element_b': DataTypeTag[operation.B.element],
|
||||
'layout_b': LayoutTag[operation.B.layout],
|
||||
'element_c': DataTypeTag[operation.C.element],
|
||||
'layout_c': LayoutTag[operation.C.layout],
|
||||
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
||||
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
||||
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
||||
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
||||
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
||||
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
||||
'warp_shape_m': str(warp_shape[0]),
|
||||
'warp_shape_n': str(warp_shape[1]),
|
||||
'warp_shape_k': str(warp_shape[2]),
|
||||
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
||||
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
||||
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
||||
'epilogue_vector_length': str(epilogue_vector_length),
|
||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
||||
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
||||
'swizzling_functor': operation.swizzling_functor.tag(),
|
||||
'stages': str(operation.tile_description.stages),
|
||||
'align_a': str(operation.A.alignment),
|
||||
'align_b': str(operation.B.alignment),
|
||||
'transform_a': ComplexTransformTag[operation.A.complex_transform],
|
||||
'transform_b': ComplexTransformTag[operation.B.complex_transform],
|
||||
'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
|
||||
'residual': residual
|
||||
}
|
||||
|
||||
template = self.gemm_complex_template if operation.is_complex() else self.gemm_template
|
||||
|
||||
return SubstituteTemplate(template, values)
|
||||
|
||||
###################################################################################################
|
||||
|
||||
|
||||
class EmitSparseGemmInstance:
|
||||
''' Responsible for emitting a CUTLASS template definition'''
|
||||
|
||||
def __init__(self, operation_suffix=''):
|
||||
self.operation_suffix = operation_suffix
|
||||
self.includes = []
|
||||
self.gemm_template = """
|
||||
// Gemm operator ${operation_name}
|
||||
using Operation_${operation_name} = cutlass::gemm::device::SparseGemm<
|
||||
${element_a}, ${layout_a},
|
||||
${element_b}, ${layout_b},
|
||||
${element_c}, ${layout_c},
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>,
|
||||
${swizzling_functor},
|
||||
${stages},
|
||||
${align_a},
|
||||
${align_b},
|
||||
false,
|
||||
${math_operation}
|
||||
${residual}
|
||||
>;
|
||||
"""
|
||||
|
||||
#
|
||||
def instance_template(self):
|
||||
return """
|
||||
${compile_guard_start}
|
||||
manifest.append(new ${gemm_kind}<Operation_${operation_name}>("${operation_name}"));
|
||||
${compile_guard_end}
|
||||
"""
|
||||
|
||||
#
|
||||
def emit(self, operation):
|
||||
|
||||
warp_shape = [operation.tile_description.threadblock_shape[idx] //
|
||||
operation.tile_description.warp_count[idx] for idx in range(3)]
|
||||
|
||||
epilogue_vector_length = int(min(
|
||||
operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
|
||||
|
||||
residual = ''
|
||||
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
'element_a': DataTypeTag[operation.A.element],
|
||||
'layout_a': LayoutTag[operation.A.layout],
|
||||
'element_b': DataTypeTag[operation.B.element],
|
||||
'layout_b': LayoutTag[operation.B.layout],
|
||||
'element_c': DataTypeTag[operation.C.element],
|
||||
'layout_c': LayoutTag[operation.C.layout],
|
||||
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
||||
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
||||
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
||||
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
||||
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
||||
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
||||
'warp_shape_m': str(warp_shape[0]),
|
||||
'warp_shape_n': str(warp_shape[1]),
|
||||
'warp_shape_k': str(warp_shape[2]),
|
||||
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
||||
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
||||
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
||||
'epilogue_vector_length': str(epilogue_vector_length),
|
||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
||||
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
||||
'swizzling_functor': operation.swizzling_functor.tag(),
|
||||
'stages': str(operation.tile_description.stages),
|
||||
'align_a': str(operation.A.alignment),
|
||||
'align_b': str(operation.B.alignment),
|
||||
'transform_a': ComplexTransformTag[operation.A.complex_transform],
|
||||
'transform_b': ComplexTransformTag[operation.B.complex_transform],
|
||||
'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
|
||||
'residual': residual
|
||||
}
|
||||
|
||||
template = self.gemm_template
|
||||
|
||||
return SubstituteTemplate(template, values)
|
||||
|
||||
###################################################################################################
|
||||
|
||||
|
||||
#
|
||||
class EmitGemmUniversalInstance:
|
||||
''' Responsible for emitting a CUTLASS template definition'''
|
||||
|
||||
def __init__(self, operation_suffix='', direct_store=False):
|
||||
def __init__(self, operation_suffix='', direct_store=False, visitor=False):
|
||||
self.operation_suffix = operation_suffix
|
||||
self.direct_store = direct_store
|
||||
self.visitor = visitor
|
||||
self.includes = [
|
||||
"cutlass/cutlass.h",
|
||||
"cutlass/numeric_types.h",
|
||||
@ -1150,46 +979,15 @@ class EmitGemmUniversalInstance:
|
||||
"cutlass/gemm/device/gemm_universal_adapter.h",
|
||||
"cutlass/gemm/kernel/default_gemm_universal.h",
|
||||
]
|
||||
if self.visitor:
|
||||
self.includes += [
|
||||
"gemm/gemm_universal_with_visitor.h",
|
||||
"epilogue/epilogue_visitor_with_layernorm.h",
|
||||
"epilogue/epilogue_visitor_generic.h"
|
||||
]
|
||||
if self.direct_store:
|
||||
self.includes.append(
|
||||
"cutlass/epilogue/threadblock/default_epilogue_direct_store.h")
|
||||
self.builtin_epilogue_functor_template = """
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>
|
||||
"""
|
||||
self.builtin_epilogue_functor_template_clamp = """
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length}
|
||||
>
|
||||
"""
|
||||
self.gemm_template = """
|
||||
// Gemm operator ${operation_name}
|
||||
using ${operation_name}_base =
|
||||
typename cutlass::gemm::kernel::DefaultGemmUniversal<
|
||||
${element_a}, ${layout_a}, ${transform_a}, ${align_a},
|
||||
${element_b}, ${layout_b}, ${transform_b}, ${align_b},
|
||||
${element_c}, ${layout_c},
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor},
|
||||
${swizzling_functor},
|
||||
${stages},
|
||||
${math_operation}
|
||||
>::GemmKernel;
|
||||
|
||||
// Define named type
|
||||
struct ${operation_name}${operation_suffix} :
|
||||
public ${operation_name}_base { };
|
||||
"""
|
||||
self.gemm_template_interleaved = """
|
||||
// Gemm operator ${operation_name}
|
||||
using ${operation_name}_base =
|
||||
@ -1241,6 +1039,42 @@ using ${operation_name}_base =
|
||||
${operation_name}_default::ThreadblockSwizzle
|
||||
>;
|
||||
|
||||
// Define named type
|
||||
struct ${operation_name}${operation_suffix} :
|
||||
public ${operation_name}_base { };
|
||||
"""
|
||||
self.gemm_template_visitor = """
|
||||
// Gemm operator ${operation_name}
|
||||
using ${operation_name}_default =
|
||||
typename cutlass::gemm::kernel::DefaultGemmUniversal<
|
||||
${element_a}, ${layout_a}, ${transform_a}, ${align_a},
|
||||
${element_b}, ${layout_b}, ${transform_b}, ${align_b},
|
||||
${element_c}, ${layout_c},
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${elementwise_epilogue_functor},
|
||||
${swizzling_functor},
|
||||
${stages},
|
||||
${math_operation}
|
||||
>::GemmKernel;
|
||||
|
||||
${epilogue_visitor}
|
||||
|
||||
using ${operation_name}_Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue<
|
||||
${operation_name}_EpilogueVisitor,
|
||||
typename ${operation_name}_default::Epilogue>::Epilogue;
|
||||
|
||||
using ${operation_name}_base =
|
||||
cutlass::gemm::kernel::GemmUniversalwithEpilogueVisitor<
|
||||
${operation_name}_default::Mma,
|
||||
${operation_name}_Epilogue,
|
||||
${operation_name}_default::ThreadblockSwizzle
|
||||
>;
|
||||
|
||||
// Define named type
|
||||
struct ${operation_name}${operation_suffix} :
|
||||
public ${operation_name}_base { };
|
||||
@ -1284,32 +1118,12 @@ ${compile_guard_end}
|
||||
(operation.A.layout, operation.B.layout, operation.C.layout)
|
||||
if self.direct_store:
|
||||
gemm_template = self.gemm_template_direct_store
|
||||
elif self.visitor:
|
||||
gemm_template = self.gemm_template_visitor
|
||||
else:
|
||||
gemm_template = self.gemm_template_interleaved
|
||||
#
|
||||
|
||||
# Support built-in epilogue functors or user-defined functions
|
||||
if isinstance(operation.epilogue_functor, enum.Enum):
|
||||
|
||||
epilogue_vector_length = \
|
||||
min(operation.C.alignment *
|
||||
DataTypeSize[operation.C.element], 128) // DataTypeSize[operation.C.element]
|
||||
|
||||
values = {
|
||||
'epilogue_vector_length': str(epilogue_vector_length),
|
||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
||||
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
||||
}
|
||||
if operation.epilogue_functor == EpilogueFunctor.FastLinearCombinationClamp:
|
||||
epilogue_functor = SubstituteTemplate(
|
||||
self.builtin_epilogue_functor_template_clamp, values)
|
||||
else:
|
||||
epilogue_functor = SubstituteTemplate(
|
||||
self.builtin_epilogue_functor_template, values)
|
||||
else:
|
||||
epilogue_functor = self.epilogue_functor.emit_declaration()
|
||||
#
|
||||
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
'operation_suffix': self.operation_suffix,
|
||||
@ -1331,7 +1145,6 @@ ${compile_guard_end}
|
||||
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
||||
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
||||
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
||||
'epilogue_functor': epilogue_functor,
|
||||
'swizzling_functor': operation.swizzling_functor.tag(),
|
||||
'stages': str(operation.tile_description.stages),
|
||||
'align_a': str(operation.A.alignment),
|
||||
@ -1341,6 +1154,12 @@ ${compile_guard_end}
|
||||
'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation]
|
||||
}
|
||||
|
||||
if self.visitor:
|
||||
values['epilogue_visitor'] = operation.epilogue_functor.emit(operation)
|
||||
values['elementwise_epilogue_functor'] = operation.epilogue_functor.elementwise_functor.emit()
|
||||
else:
|
||||
values['epilogue_functor'] = operation.epilogue_functor.emit()
|
||||
|
||||
return SubstituteTemplate(gemm_template, values)
|
||||
|
||||
###################################################################################################
|
||||
@ -1348,185 +1167,6 @@ ${compile_guard_end}
|
||||
#
|
||||
|
||||
|
||||
class EmitGemmPlanarComplexInstance:
|
||||
''' Responsible for emitting a CUTLASS template definition'''
|
||||
|
||||
def __init__(self, operation_suffix=''):
|
||||
self.operation_suffix = operation_suffix
|
||||
self.includes = []
|
||||
self.template = """
|
||||
// Gemm operator ${operation_name}
|
||||
using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal<
|
||||
${element_a}, ${layout_a}, ${transform_a}, ${alignment_a},
|
||||
${element_b}, ${layout_b}, ${transform_b}, ${alignment_b},
|
||||
${element_c}, cutlass::layout::RowMajor,
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
cutlass::epilogue::thread::LinearCombinationPlanarComplex<
|
||||
${element_c},
|
||||
${alignment_c},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
${stages},
|
||||
${math_operator}
|
||||
>::GemmKernel;
|
||||
|
||||
struct ${operation_name} :
|
||||
public Operation_${operation_name} { };
|
||||
"""
|
||||
|
||||
#
|
||||
def instance_template(self):
|
||||
return """
|
||||
${compile_guard_start}
|
||||
manifest.append(new ${gemm_kind}<
|
||||
cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
|
||||
>("${operation_name}"));
|
||||
${compile_guard_end}
|
||||
"""
|
||||
|
||||
#
|
||||
def emit(self, operation):
|
||||
|
||||
warp_shape = [operation.tile_description.threadblock_shape[idx] //
|
||||
operation.tile_description.warp_count[idx] for idx in range(3)]
|
||||
|
||||
# exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major
|
||||
transposed_layout_A = TransposedLayout[operation.A.layout]
|
||||
transposed_layout_B = TransposedLayout[operation.B.layout]
|
||||
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
'element_a': DataTypeTag[operation.B.element],
|
||||
'layout_a': LayoutTag[transposed_layout_B],
|
||||
'transform_a': ComplexTransformTag[operation.B.complex_transform],
|
||||
'alignment_a': str(operation.B.alignment),
|
||||
'element_b': DataTypeTag[operation.A.element],
|
||||
'layout_b': LayoutTag[transposed_layout_A],
|
||||
'transform_b': ComplexTransformTag[operation.A.complex_transform],
|
||||
'alignment_b': str(operation.A.alignment),
|
||||
'element_c': DataTypeTag[operation.C.element],
|
||||
'layout_c': LayoutTag[operation.C.layout],
|
||||
'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
|
||||
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
||||
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
||||
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
||||
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
||||
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
||||
'warp_shape_m': str(warp_shape[0]),
|
||||
'warp_shape_n': str(warp_shape[1]),
|
||||
'warp_shape_k': str(warp_shape[2]),
|
||||
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
||||
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
||||
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
||||
'alignment_c': str(operation.C.alignment),
|
||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
||||
'stages': str(operation.tile_description.stages),
|
||||
'math_operator': 'cutlass::arch::OpMultiplyAdd'
|
||||
}
|
||||
|
||||
return SubstituteTemplate(self.template, values)
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
|
||||
|
||||
class EmitGemmPlanarComplexArrayInstance:
|
||||
''' Responsible for emitting a CUTLASS template definition'''
|
||||
|
||||
def __init__(self, operation_suffix=''):
|
||||
self.operation_suffix = operation_suffix
|
||||
self.includes = []
|
||||
self.template = """
|
||||
// Gemm operator ${operation_name}
|
||||
using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal<
|
||||
${element_a}, ${layout_a}, ${transform_a}, ${alignment_a},
|
||||
${element_b}, ${layout_b}, ${transform_b}, ${alignment_b},
|
||||
${element_c}, cutlass::layout::RowMajor,
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
cutlass::epilogue::thread::LinearCombinationPlanarComplex<
|
||||
${element_c},
|
||||
${alignment_c},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
${stages},
|
||||
${math_operator}
|
||||
>::GemmArrayKernel;
|
||||
|
||||
struct ${operation_name} : public Operation_${operation_name} { };
|
||||
"""
|
||||
|
||||
#
|
||||
def instance_template(self):
|
||||
return """
|
||||
${compile_guard_start}
|
||||
manifest.append(new ${gemm_kind}<
|
||||
cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
|
||||
>("${operation_name}"));
|
||||
${compile_guard_end}
|
||||
"""
|
||||
|
||||
#
|
||||
def emit(self, operation):
|
||||
|
||||
warp_shape = [operation.tile_description.threadblock_shape[idx] //
|
||||
operation.tile_description.warp_count[idx] for idx in range(3)]
|
||||
|
||||
# exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major
|
||||
transposed_layout_A = TransposedLayout[operation.A.layout]
|
||||
transposed_layout_B = TransposedLayout[operation.B.layout]
|
||||
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
'element_a': DataTypeTag[operation.B.element],
|
||||
'layout_a': LayoutTag[transposed_layout_B],
|
||||
'transform_a': ComplexTransformTag[operation.B.complex_transform],
|
||||
'alignment_a': str(operation.B.alignment),
|
||||
'element_b': DataTypeTag[operation.A.element],
|
||||
'layout_b': LayoutTag[transposed_layout_A],
|
||||
'transform_b': ComplexTransformTag[operation.A.complex_transform],
|
||||
'alignment_b': str(operation.A.alignment),
|
||||
'element_c': DataTypeTag[operation.C.element],
|
||||
'layout_c': LayoutTag[operation.C.layout],
|
||||
'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
|
||||
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
||||
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
||||
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
||||
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
||||
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
||||
'warp_shape_m': str(warp_shape[0]),
|
||||
'warp_shape_n': str(warp_shape[1]),
|
||||
'warp_shape_k': str(warp_shape[2]),
|
||||
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
||||
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
||||
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
||||
'alignment_c': str(operation.C.alignment),
|
||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
||||
'stages': str(operation.tile_description.stages),
|
||||
'math_operator': 'cutlass::arch::OpMultiplyAdd'
|
||||
}
|
||||
|
||||
return SubstituteTemplate(self.template, values)
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
|
||||
|
||||
class EmitGemmGroupedInstance:
|
||||
''' Responsible for emitting a CUTLASS template definition'''
|
||||
|
||||
@ -1541,14 +1181,6 @@ class EmitGemmGroupedInstance:
|
||||
"cutlass/gemm/kernel/gemm_grouped.h",
|
||||
"cutlass/gemm/kernel/default_gemm_grouped.h"
|
||||
]
|
||||
self.builtin_epilogue_functor_template = """
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>
|
||||
"""
|
||||
self.gemm_template = """
|
||||
// Gemm operator ${operation_name}
|
||||
using ${operation_name}_base =
|
||||
@ -1598,23 +1230,8 @@ ${compile_guard_end}
|
||||
#
|
||||
|
||||
# Support built-in epilogue functors or user-defined functions
|
||||
if isinstance(operation.epilogue_functor, enum.Enum):
|
||||
|
||||
epilogue_vector_length = \
|
||||
min(operation.C.alignment *
|
||||
DataTypeSize[operation.C.element], 128) // DataTypeSize[operation.C.element]
|
||||
|
||||
values = {
|
||||
'epilogue_vector_length': str(epilogue_vector_length),
|
||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
||||
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
||||
}
|
||||
epilogue_functor = SubstituteTemplate(
|
||||
self.builtin_epilogue_functor_template, values)
|
||||
else:
|
||||
epilogue_functor = self.epilogue_functor.emit_declaration()
|
||||
#
|
||||
|
||||
epilogue_functor = operation.epilogue_functor.emit()
|
||||
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
'operation_suffix': self.operation_suffix,
|
||||
|
||||
@ -478,27 +478,6 @@ SharedMemPerCC = {
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
|
||||
|
||||
def SubstituteTemplate(template, values):
|
||||
text = template
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
for key, value in values.items():
|
||||
regex = "\\$\\{%s\\}" % key
|
||||
newtext = re.sub(regex, value, text)
|
||||
if newtext != text:
|
||||
changed = True
|
||||
text = newtext
|
||||
return text
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
|
||||
|
||||
class GemmKind(enum.Enum):
|
||||
Gemm = enum_auto()
|
||||
Sparse = enum_auto()
|
||||
@ -557,22 +536,6 @@ SymmKindNames = {
|
||||
#
|
||||
|
||||
|
||||
class EpilogueFunctor(enum.Enum):
|
||||
LinearCombination = enum_auto()
|
||||
LinearCombinationClamp = enum_auto()
|
||||
FastLinearCombinationClamp = enum_auto()
|
||||
|
||||
|
||||
#
|
||||
EpilogueFunctorTag = {
|
||||
EpilogueFunctor.LinearCombination: 'cutlass::epilogue::thread::LinearCombination',
|
||||
EpilogueFunctor.LinearCombinationClamp: 'cutlass::epilogue::thread::LinearCombinationClamp',
|
||||
EpilogueFunctor.FastLinearCombinationClamp: 'cutlass::epilogue::thread::FastLinearCombinationClamp'
|
||||
}
|
||||
|
||||
#
|
||||
|
||||
|
||||
class SwizzlingFunctor(enum.Enum):
|
||||
Identity1 = enum_auto()
|
||||
Identity2 = enum_auto()
|
||||
@ -700,7 +663,7 @@ class MathInstruction:
|
||||
|
||||
class TileDescription:
|
||||
|
||||
def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute):
|
||||
def __init__(self, threadblock_shape, stages, warp_count, math_instruction):
|
||||
self.threadblock_shape = threadblock_shape
|
||||
|
||||
#: number of pipeline stages
|
||||
@ -710,11 +673,6 @@ class TileDescription:
|
||||
self.warp_count: list[int] = warp_count
|
||||
self.math_instruction = math_instruction
|
||||
|
||||
#: minimum compute capability
|
||||
self.minimum_compute_capability: int = min_compute
|
||||
#: maximum compute capability
|
||||
self.maximum_compute_capability: int = max_compute
|
||||
|
||||
#: number threads per threadblock
|
||||
self.num_threads: int = 32
|
||||
for cnt in self.warp_count:
|
||||
|
||||
619
tools/library/scripts/pycutlass/src/pycutlass/parser.py
Normal file
619
tools/library/scripts/pycutlass/src/pycutlass/parser.py
Normal file
@ -0,0 +1,619 @@
|
||||
################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Generic, TypeVar
|
||||
from treelib import Tree
|
||||
import numpy as np
|
||||
|
||||
from pycutlass import *
|
||||
import pycutlass
|
||||
|
||||
import ast
|
||||
import textwrap
|
||||
import inspect
|
||||
|
||||
################################################################################
|
||||
# Type annotation for input arguments
|
||||
################################################################################
|
||||
|
||||
Ttype = TypeVar("Ttype")
|
||||
Dtype = TypeVar("Dtype")
|
||||
|
||||
class NDArray(np.ndarray, Generic[Ttype, Dtype]):
|
||||
pass
|
||||
|
||||
################################################################################
|
||||
# Operations
|
||||
################################################################################
|
||||
|
||||
operators = {
|
||||
ast.Add: "Add",
|
||||
ast.Div: "Div",
|
||||
ast.Eq: "Equal",
|
||||
ast.Mult: "Mult"
|
||||
}
|
||||
|
||||
################################################################################
|
||||
# AST Node abstractions
|
||||
################################################################################
|
||||
class UnaryNode:
|
||||
cnt = 0
|
||||
# Concept: this is created by the BinOp Node in python ast
|
||||
def __init__(self,
|
||||
element_accumulator, element_compute, elements_per_access,
|
||||
node, args) -> None:
|
||||
if isinstance(node, BinOpNode):
|
||||
self.op = node.op
|
||||
elif isinstance(node, ast.Call):
|
||||
if isinstance(node.func, ast.Name):
|
||||
self.op = node.func.id
|
||||
elif isinstance(node.func, ast.Attribute):
|
||||
self.op = node.func.value.id
|
||||
else:
|
||||
raise TypeError
|
||||
else:
|
||||
raise TypeError
|
||||
self.tag = "Unary" + self.op + str(UnaryNode.cnt)
|
||||
self.id = self.op + str(UnaryNode.cnt)
|
||||
self.args = args
|
||||
UnaryNode.cnt += 1
|
||||
|
||||
self.type = "tensor"
|
||||
|
||||
self.epilogue_op = getattr(pycutlass, self.op)(element_compute)
|
||||
|
||||
# data types
|
||||
self.element_accumulator = element_accumulator
|
||||
self.element_compute = element_compute
|
||||
self.elements_per_access = elements_per_access
|
||||
|
||||
def get_epilogue_node(self, visitors):
|
||||
self.epilogue_node = UnaryOp(
|
||||
self.element_accumulator, self.element_compute,
|
||||
self.elements_per_access, *visitors, self.epilogue_op)
|
||||
|
||||
def get_argument(self, visitor_args, kwargs):
|
||||
epilogue_ops = []
|
||||
for arg in self.args:
|
||||
try:
|
||||
epilogue_ops.append(kwargs[arg])
|
||||
except:
|
||||
epilogue_ops.append(arg) # direct arguments like constant
|
||||
self.argument = self.epilogue_node.argument_type(self.epilogue_op.argument_type(*epilogue_ops), *visitor_args)
|
||||
|
||||
|
||||
class BinOpNode:
|
||||
cnt = 0
|
||||
# Concept: this is created by the BinOp Node in python ast
|
||||
def __init__(self,
|
||||
element_accumulator, element_compute, elements_per_access,
|
||||
node) -> None:
|
||||
self.op = operators[type(node.op)]
|
||||
self.tag = "Binary" + self.op + str(BinOpNode.cnt)
|
||||
self.id = self.op + str(BinOpNode.cnt)
|
||||
self.args = None
|
||||
BinOpNode.cnt += 1
|
||||
|
||||
self.type = "tensor"
|
||||
|
||||
self.epilogue_op = getattr(pycutlass, "Vector"+self.op)(element_compute)
|
||||
|
||||
# data types
|
||||
self.element_accumulator = element_accumulator
|
||||
self.element_compute = element_compute
|
||||
self.elements_per_access = elements_per_access
|
||||
|
||||
def get_epilogue_node(self, visitors):
|
||||
self.epilogue_node = BinaryOp(
|
||||
self.element_accumulator, self.element_compute,
|
||||
self.elements_per_access, *visitors, self.epilogue_op)
|
||||
|
||||
def get_argument(self, visitor_args, kwargs):
|
||||
self.argument = self.epilogue_node.argument_type(self.epilogue_op.argument_type(self.args), *visitor_args)
|
||||
|
||||
|
||||
class NameNode:
|
||||
# Concept: this is created by the Name Node in python ast
|
||||
def __init__(self, node) -> None:
|
||||
try:
|
||||
self.id = node.id
|
||||
except:
|
||||
self.id = node.targets[0].id
|
||||
self.tag = self.id
|
||||
|
||||
class ScalarInputNode(NameNode):
|
||||
# Concept: scalar
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.tag = "Scalar:" + self.tag
|
||||
self.type = "scalar"
|
||||
|
||||
class AccumulatorNode(NameNode):
|
||||
# Concept: VisitorOpAccumulator
|
||||
def __init__(self,
|
||||
element_accumulator, elements_per_access, node) -> None:
|
||||
super().__init__(node)
|
||||
self.tag = "Accum:" + self.tag
|
||||
self.type = "tensor"
|
||||
|
||||
self.element_accumulator = element_accumulator
|
||||
self.elements_per_access = elements_per_access
|
||||
|
||||
def get_epilogue_node(self, visitors):
|
||||
self.epilogue_node = AccumulatorOp(
|
||||
self.element_accumulator, self.elements_per_access)
|
||||
|
||||
def get_argument(self, visitor_args, kwargs):
|
||||
self.argument = self.epilogue_node.argument_type()
|
||||
|
||||
class TensorInputNode(NameNode):
|
||||
# Concept: VisitorOpTensorInput
|
||||
def __init__(self, element_accumulator, node) -> None:
|
||||
super().__init__(node)
|
||||
self.tag = "TensorInput:" + self.tag
|
||||
self.type = "tensor"
|
||||
self.element_accumulator = element_accumulator
|
||||
|
||||
def get_epilogue_node(self, *args):
|
||||
self.epilogue_node = TensorInputOp(self.element_accumulator)
|
||||
|
||||
def get_argument(self, visitor_args, kwargs):
|
||||
self.argument = self.epilogue_node.argument_type(
|
||||
kwargs[self.id + "_ptr"], kwargs["problem_size"][1],
|
||||
kwargs["problem_size"][0] * kwargs["problem_size"][1])
|
||||
|
||||
class RowBroadcastNode(NameNode):
|
||||
# Concept: VisitorOpRowBroadcast
|
||||
def __init__(self, element_accumulator, element_fragment, node) -> None:
|
||||
super().__init__(node)
|
||||
#
|
||||
self.tag = "RowBroadcast:" + self.tag
|
||||
self.type = "tensor"
|
||||
self.element_accumulator = element_accumulator
|
||||
self.element_fragment = element_fragment
|
||||
|
||||
def get_epilogue_node(self, *args):
|
||||
self.epilogue_node = RowBroadcastOp(
|
||||
self.element_accumulator, self.element_fragment)
|
||||
|
||||
def get_argument(self, visitor_args, kwargs):
|
||||
self.argument = self.epilogue_node.argument_type(kwargs[self.id + "_ptr"], kwargs["problem_size"][1])
|
||||
|
||||
class ColumnBroadcastNode(NameNode):
|
||||
# Concept: VisitorOpColumnBroadcast
|
||||
def __init__(self, element_accumulator, element_fragment, node) -> None:
|
||||
super().__init__(node)
|
||||
self.tag = "ColumnBroadcast:" + self.tag
|
||||
self.type = "tensor"
|
||||
self.element_accumulator = element_accumulator
|
||||
self.element_fragment = element_fragment
|
||||
|
||||
def get_epilogue_node(self, *args):
|
||||
self.epilogue_node = ColumnBroadcastOp(
|
||||
self.element_accumulator, self.element_fragment)
|
||||
|
||||
def get_argument(self, visitor_args, kwargs):
|
||||
self.argument = self.epilogue_node.argument_type(kwargs[self.id + "_ptr"], kwargs["problem_size"][0])
|
||||
|
||||
class TensorOutputNode(NameNode):
|
||||
# Concept: VisitorOpTensorOutput
|
||||
def __init__(self, element_accumulator, node) -> None:
|
||||
super().__init__(node)
|
||||
self.tag = "TensorOutput:" + self.tag
|
||||
self.type = "tensor"
|
||||
self.element_accumulator = element_accumulator
|
||||
|
||||
def get_epilogue_node(self, visitors):
|
||||
self.epilogue_node = TensorOutputOp(self.element_accumulator, *visitors)
|
||||
|
||||
def get_argument(self, visitor_args, kwargs):
|
||||
self.argument = self.epilogue_node.argument_type(kwargs[self.id + "_ptr"], kwargs["problem_size"][1], *visitor_args, kwargs["problem_size"][0] * kwargs["problem_size"][1])
|
||||
|
||||
class RowReductionNode:
|
||||
# Concept: RowReductionOp
|
||||
def __init__(self, element_accumulator, element_reduction,
|
||||
element_reduction_accumulator, id, factor) -> None:
|
||||
#
|
||||
self.id = id
|
||||
self.tag = "RowReduction:" + self.id
|
||||
self.type = "tensor"
|
||||
self.element_accumulator = element_accumulator
|
||||
self.element_reduction = element_reduction
|
||||
self.element_reduction_accumulator = element_reduction_accumulator
|
||||
self.factor = factor
|
||||
|
||||
def get_epilogue_node(self, visitors):
|
||||
self.epilogue_node = RowReductionOp(
|
||||
self.element_accumulator, self.element_reduction,
|
||||
self.element_reduction_accumulator, *visitors)
|
||||
|
||||
def get_batch_stride(self, problem_size):
|
||||
return problem_size[0] * ((problem_size[1] + self.factor - 1) // self.factor)
|
||||
|
||||
def get_argument(self, visitor_args, kwargs):
|
||||
self.argument = self.epilogue_node.argument_type(kwargs[self.id + "_ptr"], *visitor_args, self.get_batch_stride(kwargs["problem_size"]))
|
||||
|
||||
class ColumnReductionNode:
|
||||
# Concept: ColumnReductionOp
|
||||
def __init__(self, element_accumulator, element_reduction,
|
||||
element_reduction_accumulator, id, factor) -> None:
|
||||
#
|
||||
self.id = id
|
||||
self.tag = "ColumnReduction:" + self.id
|
||||
self.type = "tensor"
|
||||
self.element_accumulator = element_accumulator
|
||||
self.element_reduction = element_reduction
|
||||
self.element_reduction_accumulator = element_reduction_accumulator
|
||||
self.factor = factor
|
||||
|
||||
def get_epilogue_node(self, visitors):
|
||||
self.epilogue_node = ColumnReductionOp(
|
||||
self.element_accumulator, self.element_reduction,
|
||||
self.element_reduction_accumulator, *visitors)
|
||||
|
||||
def get_batch_stride(self, problem_size):
|
||||
return problem_size[1] * ((problem_size[0] + self.factor - 1) // self.factor)
|
||||
|
||||
def get_argument(self, visitor_args, kwargs):
|
||||
self.argument = self.epilogue_node.argument_type(kwargs[self.id + '_ptr'], *visitor_args, self.get_batch_stride(kwargs["problem_size"]))
|
||||
|
||||
################################################################################
|
||||
# Epilogue parser function
|
||||
################################################################################
|
||||
class EpilogueAST(ast.NodeVisitor):
|
||||
def __init__(self, epilogue,
|
||||
tile_description,
|
||||
element_accumulator, elements_per_access,
|
||||
element_compute, element_output) -> None:
|
||||
#
|
||||
|
||||
self.tile_description = tile_description
|
||||
self.element_accumulator = element_accumulator
|
||||
self.elements_per_access = elements_per_access
|
||||
self.element_compute = element_compute
|
||||
self.element_output = element_output
|
||||
self.epilogue = epilogue
|
||||
|
||||
self.source = textwrap.dedent(inspect.getsource(epilogue.__call__))
|
||||
self.ast_tree = ast.parse(self.source)
|
||||
self.epilogue_tree = Tree()
|
||||
|
||||
|
||||
# print(ast.dump(self.ast_tree, indent=4)) # For Debug purpose
|
||||
|
||||
# input arguments
|
||||
self.input_args = {}
|
||||
# return nodes
|
||||
self.returns = []
|
||||
# reduction source nodes
|
||||
self.reduction_source = {}
|
||||
|
||||
# stack used to keep the parent node id
|
||||
self.stack = []
|
||||
|
||||
# visit the AST
|
||||
self.visit(self.ast_tree)
|
||||
|
||||
# visit the name node
|
||||
def visit_Name(self, node):
|
||||
# append the return ids into self.returns
|
||||
if self.stack[-1] == "return":
|
||||
self.returns.append(node.id)
|
||||
else:
|
||||
# accum is produced from accumulator node
|
||||
if node.id == "accum":
|
||||
name_node = AccumulatorNode(
|
||||
self.element_accumulator, self.elements_per_access, node)
|
||||
else:
|
||||
# for input nodes
|
||||
if node.id in self.input_args.keys():
|
||||
type = self.input_args[node.id][0]
|
||||
if type == "tensor":
|
||||
name_node = TensorInputNode(self.element_accumulator, node)
|
||||
elif type == "row":
|
||||
name_node = RowBroadcastNode(self.element_accumulator, self.element_compute, node)
|
||||
elif type == "column":
|
||||
name_node = ColumnBroadcastNode(self.element_accumulator, self.element_compute, node)
|
||||
elif type == "scalar":
|
||||
name_node = ScalarInputNode(node)
|
||||
else:
|
||||
raise ValueError(type)
|
||||
# for output nodes
|
||||
else:
|
||||
name_node = TensorOutputNode(self.element_accumulator, node)
|
||||
self.epilogue_tree.create_node(name_node.tag, name_node.id, data=name_node, parent=self.stack[-1])
|
||||
|
||||
def visit_Assign(self, node):
|
||||
pre_assign_node = self.epilogue_tree.get_node(node.targets[0].id)
|
||||
if pre_assign_node is None:
|
||||
# The assign is to a root node
|
||||
# skip the reduction nodes
|
||||
if isinstance(node.value, ast.Call):
|
||||
if isinstance(node.value.func, ast.Name):
|
||||
func_type = node.value.func.id
|
||||
elif isinstance(node.value.func, ast.Attribute):
|
||||
func_type = node.value.func.value.id
|
||||
else:
|
||||
raise TypeError
|
||||
if func_type == 'reduction_op':
|
||||
self.reduction_source[node.value.args[0].id] = [node.value.args[1].value, node.value.args[2].value, node.targets[0].id]
|
||||
return
|
||||
name_node = TensorOutputNode(self.element_accumulator, node)
|
||||
self.epilogue_tree.create_node(name_node.tag, name_node.id, data=name_node)
|
||||
self.stack.append(name_node.id)
|
||||
else:
|
||||
if node.targets[0].id in self.returns or node.targets[0].id in self.reduction_source.keys():
|
||||
self.stack.append(node.targets[0].id)
|
||||
else:
|
||||
self.stack.append(pre_assign_node.predecessor(self.epilogue_tree.identifier))
|
||||
self.epilogue_tree.remove_node(node.targets[0].id)
|
||||
|
||||
# get child tag
|
||||
self.visit(node.value)
|
||||
self.stack.pop()
|
||||
|
||||
def visit_Call(self, node):
|
||||
if isinstance(node.func, ast.Name):
|
||||
func_type = node.func.id
|
||||
elif isinstance(node.func, ast.Attribute):
|
||||
func_type = node.func.value.id
|
||||
else:
|
||||
raise TypeError
|
||||
if func_type == "reduction_op":
|
||||
self.visit(node.args[0])
|
||||
else:
|
||||
arg_list = []
|
||||
for idx, arg in enumerate(node.args):
|
||||
if idx == 0: continue
|
||||
if isinstance(arg, ast.Constant):
|
||||
arg_list.append(arg.value)
|
||||
elif isinstance(arg, ast.Name):
|
||||
arg_list.append(arg.id)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
unary_node = UnaryNode(self.element_accumulator, self.element_compute, self.elements_per_access, node, arg_list)
|
||||
self.epilogue_tree.create_node(unary_node.tag, unary_node.id, parent=self.stack[-1], data=unary_node)
|
||||
self.stack.append(unary_node.id)
|
||||
self.visit(node.args[0])
|
||||
self.stack.pop()
|
||||
|
||||
def visit_BinOp(self, node):
|
||||
binop = BinOpNode(self.element_accumulator, self.element_compute,
|
||||
self.elements_per_access, node)
|
||||
self.epilogue_tree.create_node(binop.tag, binop.id, data=binop, parent=self.stack[-1])
|
||||
self.stack.append(binop.id)
|
||||
self.visit(node.left)
|
||||
self.visit(node.right)
|
||||
self.stack.pop()
|
||||
|
||||
def visit_Return(self, node):
|
||||
self.stack.append("return")
|
||||
self.visit(node.value)
|
||||
self.stack.pop()
|
||||
|
||||
# # A function definition
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef):
|
||||
# visit args
|
||||
for arg in node.args.args:
|
||||
if arg.arg == "self": continue
|
||||
if isinstance(arg.annotation, ast.Constant):
|
||||
self.input_args[arg.arg] = [arg.annotation.value, ]
|
||||
# visit the assign in the reverse order
|
||||
for idx in range(len(node.body)):
|
||||
self.visit(node.body[-1-idx])
|
||||
|
||||
#
|
||||
# Tree optimization pass
|
||||
#
|
||||
|
||||
# pass 1: lower Binary to Unary
|
||||
def pass_binary_2_unary(self, tree, nid):
|
||||
node = tree.get_node(nid)
|
||||
if isinstance(node.data, BinOpNode):
|
||||
lhs_node = tree.get_node(node.successors(tree.identifier)[0])
|
||||
left_type = lhs_node.data.type
|
||||
rhs_node = tree.get_node(node.successors(tree.identifier)[1])
|
||||
right_type = rhs_node.data.type
|
||||
|
||||
if left_type == "scalar" and right_type == "tensor":
|
||||
node.data = UnaryNode(
|
||||
self.element_accumulator, self.element_compute,
|
||||
self.elements_per_access,
|
||||
node.data, [lhs_node.data.id,])
|
||||
node.tag = node.data.tag
|
||||
tree.remove_node(lhs_node.data.id)
|
||||
self.pass_binary_2_unary(tree, rhs_node.data.id)
|
||||
|
||||
elif left_type == "tensor" and right_type == "scalar":
|
||||
node.data = UnaryNode(
|
||||
self.element_accumulator, self.element_compute,
|
||||
self.elements_per_access,
|
||||
node.data, [rhs_node.id,])
|
||||
node.tag = node.data.tag
|
||||
tree.remove_node(rhs_node.data.id)
|
||||
self.pass_binary_2_unary(tree, lhs_node.data.id)
|
||||
|
||||
else:
|
||||
self.pass_binary_2_unary(tree, lhs_node.data.id)
|
||||
self.pass_binary_2_unary(tree, rhs_node.data.id)
|
||||
else:
|
||||
for child in node.successors(tree.identifier):
|
||||
self.pass_binary_2_unary(tree, child)
|
||||
|
||||
# pass 2: inject reduction nodes
|
||||
def pass_inject_reduction(self, tree, nid):
|
||||
node = tree.get_node(nid)
|
||||
if isinstance(node.data, TensorOutputNode):
|
||||
if node.data.id in self.reduction_source.keys():
|
||||
direction = self.reduction_source[node.data.id][0]
|
||||
target = self.reduction_source[node.data.id][-1]
|
||||
if direction == 'row':
|
||||
reduction_node = RowReductionNode(
|
||||
self.element_accumulator, self.element_output,
|
||||
self.element_accumulator, target, self.tile_description.threadblock_shape[1])
|
||||
elif direction == "column":
|
||||
reduction_node = ColumnReductionNode(
|
||||
self.element_accumulator, self.element_output,
|
||||
self.element_accumulator, target, self.tile_description.threadblock_shape[0])
|
||||
else:
|
||||
raise ValueError(direction)
|
||||
child_nid = node.successors(tree.identifier)[0]
|
||||
# if this output node is injected only for reduction
|
||||
if node.data.id not in self.returns:
|
||||
# get reduction config from disc
|
||||
node.data = reduction_node
|
||||
node.tag = reduction_node.tag
|
||||
self.pass_inject_reduction(tree, child_nid)
|
||||
# if this output node is also a tensor output, inject reduction as its children
|
||||
else:
|
||||
# get child node
|
||||
tree.create_node(reduction_node.tag, reduction_node.id, data=reduction_node, parent=node.data.id)
|
||||
tree.move_node(child_nid, reduction_node.id)
|
||||
child = tree.get_node(child_nid)
|
||||
for grand_child in child.successors(tree.identifier):
|
||||
self.pass_inject_reduction(tree, grand_child)
|
||||
else:
|
||||
for child in node.successors(tree.identifier):
|
||||
self.pass_inject_reduction(tree, child)
|
||||
else:
|
||||
for child in node.successors(tree.identifier):
|
||||
self.pass_inject_reduction(tree, child)
|
||||
|
||||
def pass_inject_epilogue_op(self, tree, nid):
|
||||
node = tree.get_node(nid)
|
||||
visitors = []
|
||||
for child in node.successors(tree.identifier):
|
||||
visitors.append(self.pass_inject_epilogue_op(tree, child))
|
||||
|
||||
node.data.get_epilogue_node(visitors)
|
||||
return node.data.epilogue_node
|
||||
|
||||
def get_arguments(self, tree, nid, kwargs):
|
||||
node = tree.get_node(nid)
|
||||
visitor_args = []
|
||||
for child in node.successors(tree.identifier):
|
||||
visitor_args.append(self.get_arguments(tree, child, kwargs))
|
||||
|
||||
node.data.get_argument(visitor_args, kwargs)
|
||||
return node.data.argument
|
||||
|
||||
class EpilogueVisitTree:
|
||||
KernelTemplate = """
|
||||
${visitor}
|
||||
|
||||
using ${operation_name}_EpilogueVisitor = cutlass::epilogue::threadblock::EpilogueVisitorGeneric<${visitor_name}>;
|
||||
"""
|
||||
def __init__(self, elementwise_functor, tile_description,
|
||||
element_accumulator, elements_per_access,
|
||||
element_compute, element_output) -> None:
|
||||
#
|
||||
# data types
|
||||
self.tile_description = tile_description
|
||||
self.element_accumulator = element_accumulator
|
||||
self.elements_per_access = elements_per_access
|
||||
self.element_compute = element_compute
|
||||
self.element_output = element_output
|
||||
# TODO: deprecate this
|
||||
self.elementwise_functor = elementwise_functor
|
||||
pass
|
||||
|
||||
def initialize(self):
|
||||
function = EpilogueAST(self, self.tile_description,
|
||||
self.element_accumulator, self.elements_per_access,
|
||||
self.element_compute, self.element_output)
|
||||
#
|
||||
tree = function.epilogue_tree
|
||||
self.tree = tree
|
||||
# self.tree.show() # for debug
|
||||
function.pass_binary_2_unary(self.tree, self.tree.root)
|
||||
# self.tree.show() # for debug
|
||||
function.pass_inject_reduction(self.tree, self.tree.root)
|
||||
# self.tree.show() # for debug
|
||||
function.pass_inject_epilogue_op(self.tree,self.tree.root)
|
||||
|
||||
visitor = self.tree.get_node(self.tree.root).data.epilogue_node
|
||||
self.visitor = visitor
|
||||
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("visitor_arg", visitor.argument_type)
|
||||
]
|
||||
def __init__(self, **kwargs) -> None:
|
||||
# process input args
|
||||
_kwargs = {}
|
||||
for input_key in function.input_args.keys():
|
||||
if input_key == "accum":
|
||||
continue
|
||||
if function.input_args[input_key][0] == "scalar":
|
||||
# _kwargs[input_key] = kwargs[input_key]
|
||||
continue
|
||||
# tensor input
|
||||
else:
|
||||
setattr(self, "buffer_tensor_" + input_key, NumpyFrontend.argument(kwargs[input_key], False))
|
||||
setattr(self, input_key + "_ptr", int(getattr(self, "buffer_tensor_" + input_key).ptr))
|
||||
_kwargs[input_key+"_ptr"] = getattr(self, input_key + "_ptr")
|
||||
# process the return args
|
||||
for ret in function.returns:
|
||||
setattr(self, "buffer_tensor_" + ret, NumpyFrontend.argument(kwargs[ret], True))
|
||||
setattr(self, ret + "_ptr", int(getattr(self, "buffer_tensor_" + ret).ptr))
|
||||
_kwargs[ret+"_ptr"] = getattr(self, ret + "_ptr")
|
||||
setattr(self, "host_tensor_" + ret, kwargs[ret])
|
||||
|
||||
_kwargs.update(kwargs)
|
||||
function.get_arguments(tree, tree.root, _kwargs)
|
||||
self.visitor_arg = tree.get_node(tree.root).data.argument
|
||||
|
||||
def sync(self, stream_sync=True):
|
||||
if stream_sync:
|
||||
err, = cudart.cudaDeviceSynchronize()
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
|
||||
for ret in function.returns:
|
||||
err, = cuda.cuMemcpyDtoH(
|
||||
getattr(self, "host_tensor_" + ret), cuda.CUdeviceptr(getattr(self, ret + "_ptr")),
|
||||
getattr(self, "host_tensor_" + ret).size * getattr(self, "host_tensor_" + ret).itemsize
|
||||
)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
pass
|
||||
|
||||
self.epilogue_type = _Argument
|
||||
|
||||
def emit(self, operation):
|
||||
values = {
|
||||
'visitor': self.visitor.emit(operation),
|
||||
'operation_name': operation.procedural_name(),
|
||||
'visitor_name': self.visitor.instance_name
|
||||
}
|
||||
return SubstituteTemplate(self.KernelTemplate, values)
|
||||
@ -58,6 +58,13 @@ class ReductionArguments:
|
||||
destination: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]',
|
||||
source: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]', **kwargs) -> None:
|
||||
|
||||
# tensor_C can be interpreted as the bias with bias=True in keyword args
|
||||
if "bias" in kwargs.keys():
|
||||
self.bias = kwargs["bias"]
|
||||
else:
|
||||
# by default, tensor_C is not bias
|
||||
self.bias = False
|
||||
|
||||
self.operation = operation
|
||||
#: pointer to the workspace
|
||||
self.ptr_workspace = workspace
|
||||
@ -89,11 +96,9 @@ class ReductionArguments:
|
||||
problem_size[1] * DataTypeSize[operation.C.element] // 8
|
||||
|
||||
if "output_op" in kwargs.keys():
|
||||
self.alpha = kwargs["output_op"].alpha
|
||||
self.beta = kwargs["output_op"].beta
|
||||
self.output_op = kwargs['output_op']
|
||||
else:
|
||||
self.alpha = 1.0
|
||||
self.beta = 0.0
|
||||
self.output_op = self.operation.epilogue_type(1.0, 0.0)
|
||||
|
||||
# get arguments
|
||||
self.get_arguments()
|
||||
@ -109,31 +114,25 @@ class ReductionArguments:
|
||||
ref_workspace = ReductionArguments.get_tensor_ref(
|
||||
extent=[self.problem_size.row, self.problem_size.column],
|
||||
device_ptr=self.ptr_workspace, layout=cutlass.RowMajor)
|
||||
|
||||
ref_source = ReductionArguments.get_tensor_ref(
|
||||
extent=[self.problem_size.row, self.problem_size.column],
|
||||
device_ptr=self.ptr_source, layout=cutlass.RowMajor)
|
||||
if self.bias:
|
||||
ref_source = ReductionArguments.get_tensor_ref(
|
||||
extent=[0, 0],
|
||||
device_ptr=self.ptr_source, layout=cutlass.RowMajor)
|
||||
else:
|
||||
ref_source = ReductionArguments.get_tensor_ref(
|
||||
extent=[self.problem_size.row, self.problem_size.column],
|
||||
device_ptr=self.ptr_source, layout=cutlass.RowMajor)
|
||||
|
||||
ref_destination = ReductionArguments.get_tensor_ref(
|
||||
extent=[self.problem_size.row, self.problem_size.column],
|
||||
device_ptr=self.ptr_destination, layout=cutlass.RowMajor)
|
||||
|
||||
argument_type, epilogue_type = get_reduction_params(
|
||||
self.operation.element_compute)
|
||||
|
||||
if self.operation.element_compute == cutlass.float16:
|
||||
self.alpha = cutlass.float16(self.alpha).storage
|
||||
self.beta = cutlass.float16(self.beta).storage
|
||||
elif self.operation.element_compute == cutlass.int32:
|
||||
self.alpha = int(self.alpha)
|
||||
self.beta = int(self.beta)
|
||||
|
||||
output_op = epilogue_type(self.alpha, self.beta, 0, 0)
|
||||
self.c_arguments = argument_type(
|
||||
self.c_arguments = self.operation.argument_type(
|
||||
self.problem_size, self.partitions,
|
||||
self.partition_stride, ref_workspace,
|
||||
ref_destination, ref_source,
|
||||
output_op
|
||||
self.output_op
|
||||
)
|
||||
|
||||
params_ = self.operation.rt_module.get_args(
|
||||
@ -210,8 +209,8 @@ extern "C" {
|
||||
self.emitter = EmitReductionInstance('_type')
|
||||
|
||||
self.elements_per_access = self.operation.count
|
||||
self.argtype = [ctypes.POINTER(
|
||||
get_reduction_params(operation.element_compute)[0])]
|
||||
self.argument_type, self.epilogue_type = get_reduction_params(operation.epilogue_functor)
|
||||
self.argtype = [ctypes.POINTER(self.argument_type)]
|
||||
|
||||
def emit(self):
|
||||
return self.emitter.emit(self.operation)
|
||||
@ -247,14 +246,14 @@ class ReductionOperation:
|
||||
|
||||
def __init__(self, shape: cutlass.MatrixCoord, C: TensorDescription,
|
||||
element_accumulator, element_workspace=None,
|
||||
element_compute=None, epilogue_functor: EpilogueFunctor = EpilogueFunctor.LinearCombination,
|
||||
element_compute=None, epilogue_functor=None,
|
||||
count: int = 1, partitions_per_stage: int = 4) -> None:
|
||||
""" Constructor
|
||||
"""
|
||||
|
||||
self.shape = shape
|
||||
#: epilogue functor (default: LinearCombination)
|
||||
self.epilogue_functor: EpilogueFunctor = epilogue_functor
|
||||
self.epilogue_functor = epilogue_functor
|
||||
#: datatype of accumulator
|
||||
self.element_accumulator = element_accumulator
|
||||
|
||||
@ -285,6 +284,8 @@ class ReductionOperation:
|
||||
self.partitions_per_stage: int = partitions_per_stage
|
||||
|
||||
self.rt_module: ReductionRT = ReductionRT(self)
|
||||
self.argument_type = self.rt_module.argument_type
|
||||
self.epilogue_type = self.rt_module.epilogue_type
|
||||
|
||||
#
|
||||
def extended_name(self):
|
||||
@ -363,12 +364,7 @@ class EmitReductionInstance:
|
||||
using ${operation_name}_base =
|
||||
typename cutlass::reduction::kernel::ReduceSplitK<
|
||||
cutlass::MatrixShape<${shape_row}, ${shape_column}>,
|
||||
${epilogue_functor}<
|
||||
${element_output},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_compute}
|
||||
>,
|
||||
${epilogue_functor},
|
||||
cutlass::reduction::thread::ReduceAdd<
|
||||
${element_accumulator},
|
||||
${element_output},
|
||||
@ -389,7 +385,7 @@ struct ${operation_name}${operation_suffix}:
|
||||
'operation_suffix': self.operation_suffix,
|
||||
'shape_row': str(operation.shape.row()),
|
||||
'shape_column': str(operation.shape.column()),
|
||||
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
||||
'epilogue_functor': operation.epilogue_functor.emit(),
|
||||
'element_output': DataTypeTag[operation.element_output],
|
||||
'epilogue_vector_length': str(epilogue_vector_length),
|
||||
'element_accumulator': DataTypeTag[operation.element_accumulator],
|
||||
|
||||
@ -68,4 +68,3 @@ class TensorRef:
|
||||
# the dtype(0) is used to overload between different data types
|
||||
# with the same layout
|
||||
self.tensor_ref = cutlass.get_tensor_ref(int(ptr), dtype(0), layout)
|
||||
|
||||
|
||||
@ -124,7 +124,7 @@ class Conv2dLauncher:
|
||||
self.reduction_operation = ReductionOperation(
|
||||
shape=cutlass.MatrixCoord(4, 32 * operation.C.alignment),
|
||||
C=operation.C, element_accumulator=operation.tile_description.math_instruction.element_accumulator,
|
||||
element_compute=operation.element_epilogue,
|
||||
element_compute=operation.epilogue_functor.element_epilogue, epilogue_functor=operation.epilogue_functor,
|
||||
count=operation.C.alignment
|
||||
)
|
||||
|
||||
@ -183,7 +183,7 @@ class Conv2dLauncher:
|
||||
# Get the host reference function
|
||||
#
|
||||
|
||||
self.element_compute = operation.element_epilogue
|
||||
self.element_compute = operation.epilogue_functor.element_epilogue
|
||||
|
||||
self.host_conv2d = cutlass.test.conv.host.conv2d
|
||||
|
||||
@ -441,7 +441,7 @@ class Conv2dLauncher:
|
||||
arguments = Conv2dArguments(
|
||||
operation=self.operation, problem_size=problem_size, A=tensor_A,
|
||||
B=tensor_B, C=tensor_C, D=tensor_D,
|
||||
output_op = LinearCombinationFunctorArguments(alpha, beta),
|
||||
output_op = self.operation.epilogue_type(alpha, beta),
|
||||
split_k_slices=problem_size.split_k_slices,
|
||||
split_k_mode=split_k_mode
|
||||
)
|
||||
@ -454,7 +454,7 @@ class Conv2dLauncher:
|
||||
workspace=arguments.ptr_D,
|
||||
destination=tensor_D,
|
||||
source=tensor_C,
|
||||
output_op = LinearCombinationFunctorArguments(alpha, beta)
|
||||
output_op = self.reduction_operation.epilogue_type(alpha, beta)
|
||||
)
|
||||
|
||||
self.operation.run(arguments)
|
||||
|
||||
@ -68,7 +68,7 @@ class TestbedGrouped:
|
||||
self.scope_min = -8
|
||||
|
||||
#: compute type
|
||||
self.compute_type = operation.element_epilogue
|
||||
self.compute_type = operation.epilogue_functor.element_epilogue
|
||||
|
||||
self.accumulator_type = operation.tile_description.math_instruction.element_accumulator
|
||||
|
||||
@ -176,7 +176,7 @@ class TestbedGrouped:
|
||||
arguments = GemmGroupedArguments(
|
||||
operation=self.operation, problem_sizes=problem_sizes,
|
||||
A=tensor_As, B=tensor_Bs, C=tensor_Cs, D=tensor_Ds,
|
||||
output_op=LinearCombinationFunctorArguments(alpha, beta)
|
||||
output_op=self.operation.epilogue_type(alpha, beta)
|
||||
)
|
||||
|
||||
self.operation.run(arguments)
|
||||
|
||||
@ -143,7 +143,7 @@ class GemmUniversalLauncher:
|
||||
self.reduction_operation: ReductionOperation = ReductionOperation(
|
||||
shape=cutlass.MatrixCoord(4, 32 * operation.C.alignment),
|
||||
C=operation.C, element_accumulator=operation.tile_description.math_instruction.element_accumulator,
|
||||
element_compute=operation.element_epilogue,
|
||||
element_compute=operation.epilogue_functor.element_epilogue, epilogue_functor=operation.epilogue_functor,
|
||||
count=operation.C.alignment
|
||||
)
|
||||
|
||||
@ -200,7 +200,7 @@ class GemmUniversalLauncher:
|
||||
self.interleaved = interleaved
|
||||
|
||||
#: compute type
|
||||
self.compute_type = operation.element_epilogue
|
||||
self.compute_type = operation.epilogue_functor.element_epilogue
|
||||
self.accumulator_type = operation.tile_description.math_instruction.element_accumulator
|
||||
|
||||
def print_problem_size(self, p, mode, batch_count):
|
||||
@ -391,7 +391,7 @@ class GemmUniversalLauncher:
|
||||
arguments = GemmArguments(
|
||||
operation=self.operation, problem_size=problem_size,
|
||||
A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D,
|
||||
output_op=LinearCombinationFunctorArguments(alpha, beta),
|
||||
output_op=self.operation.epilogue_type(alpha, beta),
|
||||
gemm_mode=mode, split_k_slices=batch_count
|
||||
)
|
||||
|
||||
@ -403,7 +403,7 @@ class GemmUniversalLauncher:
|
||||
workspace=arguments.ptr_D,
|
||||
destination=tensor_D,
|
||||
source=tensor_C,
|
||||
output_op=LinearCombinationFunctorArguments(alpha, beta)
|
||||
output_op=self.reduction_operation.epilogue_type(alpha, beta)
|
||||
)
|
||||
|
||||
self.operation.run(arguments)
|
||||
|
||||
@ -34,6 +34,7 @@ import numpy as np
|
||||
import cutlass
|
||||
from pycutlass.library import TensorDescription
|
||||
from typing import Union
|
||||
from bfloat16 import bfloat16
|
||||
try:
|
||||
import torch
|
||||
torch_available = True
|
||||
@ -46,7 +47,7 @@ class ReferenceModule:
|
||||
self.layout_B = B.layout
|
||||
self.layout_C = C.layout
|
||||
|
||||
def run(self, A: np.ndarray, B: np.ndarray, C: np.ndarray, problem_size: cutlass.gemm.GemmCoord, alpha: float=1.0, beta: float=0.0):
|
||||
def run(self, A: np.ndarray, B: np.ndarray, C: np.ndarray, problem_size: cutlass.gemm.GemmCoord, alpha: float=1.0, beta: float=0.0, bias=False, batch=1):
|
||||
"""
|
||||
Compute the reference result on CPU
|
||||
Args:
|
||||
@ -57,27 +58,38 @@ class ReferenceModule:
|
||||
M, N, K = problem_size.m(), problem_size.n(), problem_size.k()
|
||||
if isinstance(A, np.ndarray):
|
||||
if self.layout_A == cutlass.RowMajor:
|
||||
A_row = np.reshape(A, newshape=(M, K))
|
||||
A_row = np.reshape(A, newshape=(batch, M, K))
|
||||
else:
|
||||
A_col = np.reshape(A, newshape=(K, M))
|
||||
A_row = np.transpose(A_col, axes=(1, 0))
|
||||
A_col = np.reshape(A, newshape=(batch, K, M))
|
||||
A_row = np.transpose(A_col, axes=(0, 2, 1))
|
||||
|
||||
if self.layout_B == cutlass.RowMajor:
|
||||
B_row = np.reshape(B, newshape=(K, N))
|
||||
B_row = np.reshape(B, newshape=(batch, K, N))
|
||||
else:
|
||||
B_col = np.reshape(B, newshape=(N, K))
|
||||
B_row = np.transpose(B_col, axes=(1, 0))
|
||||
B_col = np.reshape(B, newshape=(batch, N, K))
|
||||
B_row = np.transpose(B_col, axes=(0, 2, 1))
|
||||
|
||||
if self.layout_C == cutlass.RowMajor:
|
||||
C_row = np.reshape(C, newshape=(M, N))
|
||||
if bias:
|
||||
C_row = np.reshape(C, newshape=(batch, 1, N))
|
||||
else:
|
||||
C_row = np.reshape(C, newshape=(batch, M, N))
|
||||
else:
|
||||
C_col = np.reshape(C, newshape=(N, M))
|
||||
C_row = np.transpose(C_col, axes=(1, 0))
|
||||
if bias:
|
||||
C_row = np.reshape(C, newshape=(batch, M, 1))
|
||||
else:
|
||||
C_col = np.reshape(C, newshape=(batch, N, M))
|
||||
C_row = np.transpose(C_col, axes=(0, 2, 1))
|
||||
|
||||
out_row = np.matmul(A_row, B_row) * alpha + C_row * beta
|
||||
if A_row.dtype == bfloat16:
|
||||
# numpy's einsum doesn't support bfloat16
|
||||
out_row = np.einsum("bik,bkj->bij", A_row.astype(np.float32), B_row.astype(np.float32)) * alpha + C_row * beta
|
||||
out_row = out_row.astype(C_row.dtype)
|
||||
else:
|
||||
out_row = np.einsum("bik,bkj->bij", A_row, B_row) * alpha + C_row * beta
|
||||
|
||||
if self.layout_C == cutlass.ColumnMajor:
|
||||
out = np.transpose(out_row, axes=(1, 0))
|
||||
out = np.transpose(out_row, axes=(0, 2, 1))
|
||||
else:
|
||||
out = out_row
|
||||
|
||||
@ -128,7 +140,7 @@ if torch_available:
|
||||
def run(self,
|
||||
A: Union[np.ndarray, torch.Tensor],
|
||||
B: Union[np.ndarray, torch.Tensor],
|
||||
C: Union[np.ndarray, torch.Tensor], problem_size, alpha=1.0, beta=0.0) -> np.ndarray:
|
||||
C: Union[np.ndarray, torch.Tensor], problem_size, alpha=1.0, beta=0.0, bias=False) -> np.ndarray:
|
||||
"""
|
||||
Compute the reference result on CPU
|
||||
"""
|
||||
@ -184,7 +196,10 @@ if torch_available:
|
||||
B_torch_nchw = torch.permute(B_nhwc, (0, 3, 1, 2))
|
||||
|
||||
if self.layout_C == cutlass.TensorNHWC:
|
||||
C_nhwc = C.view((k, r, s, c))
|
||||
if bias:
|
||||
C_nhwc = C.view((1, 1, 1, c))
|
||||
else:
|
||||
C_nhwc = C.view((k, r, s, c))
|
||||
C_torch_nchw = torch.permute(C_nhwc, (0, 3, 1, 2))
|
||||
elif self.kind == cutlass.conv.Operator.dgrad:
|
||||
if self.layout_A == cutlass.TensorNHWC:
|
||||
@ -196,7 +211,10 @@ if torch_available:
|
||||
B_torch_nchw = torch.permute(B_nhwc, (0, 3, 1, 2))
|
||||
|
||||
if self.layout_C == cutlass.TensorNHWC:
|
||||
C_nhwc = C.view((n, h, w, c))
|
||||
if bias:
|
||||
C_nhwc = C.view((1, 1, 1, c))
|
||||
else:
|
||||
C_nhwc = C.view((n, h, w, c))
|
||||
C_torch_nchw = torch.permute(C_nhwc, (0, 3, 1, 2))
|
||||
else:
|
||||
if self.layout_A == cutlass.TensorNHWC:
|
||||
@ -208,7 +226,10 @@ if torch_available:
|
||||
B_torch_nchw = torch.permute(B_nhwc, (0, 3, 1, 2))
|
||||
|
||||
if self.layout_C == cutlass.TensorNHWC:
|
||||
C_nhwc = C.view((n, p, q, k))
|
||||
if bias:
|
||||
C_nhwc = C.view((1, 1, 1, k))
|
||||
else:
|
||||
C_nhwc = C.view((n, p, q, k))
|
||||
C_torch_nchw = torch.permute(C_nhwc, (0, 3, 1, 2))
|
||||
|
||||
if self.kind == cutlass.conv.Operator.fprop:
|
||||
|
||||
@ -30,15 +30,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float16)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float16, stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -68,15 +71,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float16)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float16, stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -106,15 +112,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float16)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float16, stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -156,15 +165,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float16)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float16, stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -29,15 +29,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 32], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -67,15 +70,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 32], stages=4,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -105,15 +111,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -143,15 +152,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=4,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -30,15 +30,18 @@ class Conv2dDgradImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase)
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 8], stages=4,
|
||||
warp_count=[4, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -68,15 +71,18 @@ class Conv2dDgradImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase)
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 8], stages=4,
|
||||
warp_count=[2, 4, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -29,15 +29,18 @@ class Conv2dDgradImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.Te
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 16], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -67,15 +70,18 @@ class Conv2dDgradImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.Te
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 16], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -97,15 +97,18 @@ class Conv2dFpropFewChannelsF16NHWCF16NHWCF16HNWCTensorOpF32SM80(unittest.TestCa
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.few_channels,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -135,15 +138,18 @@ class Conv2dFpropFewChannelsF16NHWCF16NHWCF16HNWCTensorOpF32SM80(unittest.TestCa
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 32], stages=2,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.few_channels,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -79,15 +79,18 @@ class Conv2dFpropFixedChannelsF16NHWCF16NHWCF16HNWCTensorOpF32SM80(unittest.Test
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.fixed_channels,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -117,15 +120,18 @@ class Conv2dFpropFixedChannelsF16NHWCF16NHWCF16HNWCTensorOpF32SM80(unittest.Test
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.fixed_channels,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -155,15 +161,18 @@ class Conv2dFpropFixedChannelsF16NHWCF16NHWCF16HNWCTensorOpF32SM80(unittest.Test
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.fixed_channels,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -29,15 +29,18 @@ class Conv2dFpropImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float16)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -67,15 +70,18 @@ class Conv2dFpropImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float16)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -105,15 +111,18 @@ class Conv2dFpropImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float16)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -173,15 +182,18 @@ class Conv2dFpropImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float16)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -241,15 +253,18 @@ class Conv2dFpropImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float16)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -29,15 +29,18 @@ class Conv2dFpropImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -30,15 +30,18 @@ class Conv2dFpropImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase)
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 8], stages=4,
|
||||
warp_count=[4, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle2
|
||||
)
|
||||
|
||||
@ -68,15 +71,18 @@ class Conv2dFpropImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase)
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 8], stages=4,
|
||||
warp_count=[2, 4, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -29,15 +29,18 @@ class Conv2dFpropImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.Te
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 16], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -67,15 +70,18 @@ class Conv2dFpropImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.Te
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 16], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -29,15 +29,18 @@ class Conv2dStridedDgradImplicitGemmF16NHWCF16NHWCF32NHWCTensorOpF32SM80(unittes
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 32], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.StridedDgradIdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -67,15 +70,18 @@ class Conv2dStridedDgradImplicitGemmF16NHWCF16NHWCF32NHWCTensorOpF32SM80(unittes
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 256, 64], stages=3,
|
||||
warp_count=[2, 4, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.StridedDgradIdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -105,15 +111,18 @@ class Conv2dStridedDgradImplicitGemmF16NHWCF16NHWCF32NHWCTensorOpF32SM80(unittes
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 32], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.StridedDgradIdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -155,15 +164,18 @@ class Conv2dStridedDgradImplicitGemmF16NHWCF16NHWCF32NHWCTensorOpF32SM80(unittes
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 32], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.StridedDgradIdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -193,15 +205,18 @@ class Conv2dStridedDgradImplicitGemmF16NHWCF16NHWCF32NHWCTensorOpF32SM80(unittes
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 32], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.StridedDgradIdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -29,15 +29,19 @@ class Conv2dWgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment, math_inst.element_accumulator,
|
||||
cutlass.float16
|
||||
)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -67,15 +71,19 @@ class Conv2dWgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment, math_inst.element_accumulator,
|
||||
cutlass.float16
|
||||
)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -29,15 +29,18 @@ class Conv2dWgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 16], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -67,15 +70,18 @@ class Conv2dWgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 16], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -105,15 +111,18 @@ class Conv2dWgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[64, 256, 32], stages=3,
|
||||
warp_count=[1, 4, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -143,15 +152,18 @@ class Conv2dWgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 16], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -193,15 +205,18 @@ class Conv2dWgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 16], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -30,15 +30,18 @@ class Conv2dWgradImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase)
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 8], stages=4,
|
||||
warp_count=[2, 4, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -68,15 +71,18 @@ class Conv2dWgradImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase)
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 8], stages=4,
|
||||
warp_count=[2, 4, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -29,15 +29,18 @@ class Conv2dWgradImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.Te
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 16], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -67,15 +70,18 @@ class Conv2dWgradImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.Te
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 32], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -0,0 +1,80 @@
|
||||
pushd $CUTLASS_PATH/examples/40_cutlass_py/
|
||||
|
||||
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
|
||||
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2
|
||||
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2
|
||||
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_f32 -op TensorOp -b 64 64 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 4
|
||||
|
||||
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
|
||||
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2
|
||||
|
||||
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 256 128 64 -s 3 -w 4 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 3
|
||||
|
||||
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 5
|
||||
|
||||
python gemm.py -i 16 8 32 -ta int8 -tb int8 -tc int8 -tacc int32 -m multiply_add -op TensorOp -b 128 128 128 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 16 -lb ColumnMajor -ab 16 -lc RowMajor -ac 16 -te float32 -ep FastLinearCombinationClamp -sw IdentitySwizzle2 -p 512 512 512 -alpha 1.0 -beta 0.0 -gm Gemm -k 1
|
||||
|
||||
python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 0.0 -pm Device
|
||||
|
||||
python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 1.0 -pm Host
|
||||
|
||||
python gemm_grouped.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 64 8 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device
|
||||
|
||||
python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device
|
||||
|
||||
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 13 17 8 -krsc 24 3 3 8 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
|
||||
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 1.0 -beta 1.0
|
||||
|
||||
python conv2d.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 128 8 -s 4 -w 4 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -co fprop -st Strided -ia analytic -sm Parallel -k 3 -nhwc 1 71 80 32 -krsc 64 5 5 32 -pad 2 2 2 2 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 1.0
|
||||
|
||||
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 1 -lb TensorNHWC -ab 1 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co wgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 8 8 1 -krsc 1 3 3 1 -pad 1 1 1 1 -stride 1 1 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
|
||||
python conv2d.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 128 8 -s 4 -w 2 4 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co wgrad -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
|
||||
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
|
||||
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 8 -lb TensorNHWC -ab 8 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
|
||||
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
|
||||
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 8 -lb TensorNHWC -ab 8 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia fixed_channels -sm Serial -k 1 -nhwc 1 8 8 8 -krsc 16 3 3 8 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
|
||||
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 56 56 12 -krsc 8 1 1 12 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
|
||||
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 -bias -activ relu
|
||||
|
||||
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2 -bias -activ leaky_relu -activ_arg 0.2
|
||||
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2 -bias -activ tanh
|
||||
|
||||
python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 0.0 -beta 0.5 -pm Host -bias -activ sigmoid -bias -activ sigmoid
|
||||
|
||||
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ silu
|
||||
|
||||
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ hardswish
|
||||
|
||||
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 0.0 -beta 0.5 -gm GemmSplitKParallel -k 5 -bias -activ gelu
|
||||
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3
|
||||
|
||||
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 2
|
||||
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
|
||||
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -epv ColumnBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
|
||||
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
|
||||
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnReduction -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3
|
||||
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnBroadcast -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 3
|
||||
popd
|
||||
@ -49,7 +49,7 @@ class Test_Frontend(unittest.TestCase):
|
||||
|
||||
tile_description = TileDescription(
|
||||
[128, 128, 8], 4, [2, 4, 1],
|
||||
math_inst, 80, 80
|
||||
math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -64,10 +64,14 @@ class Test_Frontend(unittest.TestCase):
|
||||
cutlass.float32, cutlass.RowMajor, 1
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
self.operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=cutlass.float32,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -89,7 +93,7 @@ class Test_Frontend(unittest.TestCase):
|
||||
arguments = GemmArguments(
|
||||
operation=self.operation, problem_size=problem_size,
|
||||
A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D,
|
||||
output_op=LinearCombinationFunctorArguments(alpha, beta),
|
||||
output_op=self.operation.epilogue_type(alpha, beta),
|
||||
gemm_mode=cutlass.gemm.Mode.Gemm, split_k_splices=1
|
||||
)
|
||||
|
||||
@ -119,7 +123,7 @@ class Test_Frontend(unittest.TestCase):
|
||||
arguments = GemmArguments(
|
||||
operation=self.operation, problem_size=problem_size,
|
||||
A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D,
|
||||
output_op=LinearCombinationFunctorArguments(alpha, beta),
|
||||
output_op=self.operation.epilogue_type(alpha, beta),
|
||||
gemm_mode=cutlass.gemm.Mode.Gemm, split_k_splices=1
|
||||
)
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ class GemmBF16TensorOpSm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[64, 128, 64],
|
||||
stages=4, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -33,15 +33,15 @@ class GemmBF16TensorOpSm80(unittest.TestCase):
|
||||
alignment=4
|
||||
)
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -58,7 +58,7 @@ class GemmBF16TensorOpSm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[64, 128, 32],
|
||||
stages=6, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -74,15 +74,15 @@ class GemmBF16TensorOpSm80(unittest.TestCase):
|
||||
alignment=8
|
||||
)
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 32],
|
||||
stages=3, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -36,13 +36,15 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = cutlass.BatchedIdentitySwizzle
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor,
|
||||
direct_store=True
|
||||
)
|
||||
@ -60,7 +62,7 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64],
|
||||
stages=3, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -78,13 +80,15 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -101,7 +105,7 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 256, 64],
|
||||
stages=3, warp_count=[2, 4, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -119,13 +123,15 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -142,7 +148,7 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[256, 128, 64],
|
||||
stages=3, warp_count=[4, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -160,13 +166,15 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -183,7 +191,7 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 64, 64],
|
||||
stages=3, warp_count=[2, 1, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -201,13 +209,15 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.float16
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -224,7 +234,7 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[64, 64, 32],
|
||||
stages=10, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -242,13 +252,15 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.float16
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -265,7 +277,7 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[256, 128, 64],
|
||||
stages=3, warp_count=[4, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -283,13 +295,15 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -306,7 +320,7 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 64, 64],
|
||||
stages=3, warp_count=[2, 1, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -324,13 +338,15 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -347,7 +363,7 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 256, 64],
|
||||
stages=3, warp_count=[2, 4, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -365,13 +381,15 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -388,7 +406,7 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 256, 64],
|
||||
stages=3, warp_count=[2, 4, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -406,13 +424,15 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ class GemmF32nF32nF32nTensorOpF32Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 32],
|
||||
stages=3, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -37,13 +37,15 @@ class GemmF32nF32nF32nTensorOpF32Sm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -61,7 +63,7 @@ class GemmF32nF32nF32nTensorOpF32Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 32],
|
||||
stages=3, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -79,13 +81,15 @@ class GemmF32nF32nF32nTensorOpF32Sm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -102,7 +106,7 @@ class GemmF32nF32nF32nTensorOpF32Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[64, 64, 32],
|
||||
stages=3, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -120,13 +124,15 @@ class GemmF32nF32nF32nTensorOpF32Sm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ class GemmF64TensorOpSm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[32, 32, 16],
|
||||
stages=4, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
# alignment 1 restricted for double
|
||||
@ -36,13 +36,15 @@ class GemmF64TensorOpSm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.float64
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -59,7 +61,7 @@ class GemmF64TensorOpSm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[64, 64, 16],
|
||||
stages=4, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
# alignment 1 restricted for double
|
||||
@ -78,13 +80,15 @@ class GemmF64TensorOpSm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.float64
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ class GemmGroupedSm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 32],
|
||||
stages=3, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -37,14 +37,15 @@ class GemmGroupedSm80(unittest.TestCase):
|
||||
)
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
swizzling_functor = cutlass.BatchedIdentitySwizzle
|
||||
|
||||
for precompute_mode in [SchedulerMode.Device, SchedulerMode.Host]:
|
||||
operation = GemmOperationGrouped(
|
||||
tile_description.minimum_compute_capability,
|
||||
80,
|
||||
tile_description, A, B, C,
|
||||
element_epilogue,
|
||||
epilogue_functor, swizzling_functor,
|
||||
precompute_mode=precompute_mode
|
||||
)
|
||||
@ -64,7 +65,7 @@ class GemmGroupedSm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[64, 64, 16],
|
||||
stages=4, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -83,14 +84,15 @@ class GemmGroupedSm80(unittest.TestCase):
|
||||
)
|
||||
|
||||
element_epilogue = cutlass.float64
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
swizzling_functor = cutlass.BatchedIdentitySwizzle
|
||||
|
||||
for precompute_mode in [SchedulerMode.Device, SchedulerMode.Host]:
|
||||
operation = GemmOperationGrouped(
|
||||
tile_description.minimum_compute_capability,
|
||||
80,
|
||||
tile_description, A, B, C,
|
||||
element_epilogue,
|
||||
epilogue_functor, swizzling_functor,
|
||||
precompute_mode=precompute_mode
|
||||
)
|
||||
@ -110,7 +112,7 @@ class GemmGroupedSm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 64, 8],
|
||||
stages=4, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -129,14 +131,15 @@ class GemmGroupedSm80(unittest.TestCase):
|
||||
)
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
swizzling_functor = cutlass.BatchedIdentitySwizzle
|
||||
|
||||
for precompute_mode in [SchedulerMode.Device, SchedulerMode.Host]:
|
||||
operation = GemmOperationGrouped(
|
||||
tile_description.minimum_compute_capability,
|
||||
80,
|
||||
tile_description, A, B, C,
|
||||
element_epilogue,
|
||||
epilogue_functor, swizzling_functor,
|
||||
precompute_mode=precompute_mode
|
||||
)
|
||||
@ -156,7 +159,7 @@ class GemmGroupedSm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 32],
|
||||
stages=3, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -175,14 +178,15 @@ class GemmGroupedSm80(unittest.TestCase):
|
||||
)
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
swizzling_functor = cutlass.BatchedIdentitySwizzle
|
||||
|
||||
for precompute_mode in [SchedulerMode.Device, SchedulerMode.Host]:
|
||||
operation = GemmOperationGrouped(
|
||||
tile_description.minimum_compute_capability,
|
||||
80,
|
||||
tile_description, A, B, C,
|
||||
element_epilogue,
|
||||
epilogue_functor, swizzling_functor,
|
||||
precompute_mode=precompute_mode
|
||||
)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
from pycutlass.epilogue import LinearCombinationClamp
|
||||
from pycutlass.test import *
|
||||
import unittest
|
||||
|
||||
@ -17,7 +18,7 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[64, 64, 64],
|
||||
stages=6, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -33,15 +34,15 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase):
|
||||
alignment=8
|
||||
)
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.FastLinearCombinationClamp
|
||||
epilogue_functor = FastLinearCombinationClamp(
|
||||
C.element, C.alignment
|
||||
)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -58,7 +59,7 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 128],
|
||||
stages=3, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -74,15 +75,15 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase):
|
||||
alignment=16
|
||||
)
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.FastLinearCombinationClamp
|
||||
epilogue_functor = FastLinearCombinationClamp(
|
||||
C.element, C.alignment
|
||||
)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -99,7 +100,7 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 128],
|
||||
stages=3, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -115,15 +116,15 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase):
|
||||
alignment=16
|
||||
)
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.FastLinearCombinationClamp
|
||||
epilogue_functor = FastLinearCombinationClamp(
|
||||
C.element, C.alignment
|
||||
)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -140,7 +141,7 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 128],
|
||||
stages=3, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -158,13 +159,16 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.int32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombinationClamp
|
||||
epilogue_functor = LinearCombinationClamp(
|
||||
C.element, C.alignment, math_inst.element_accumulator,
|
||||
element_epilogue
|
||||
)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -181,7 +185,7 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 128],
|
||||
stages=3, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -199,13 +203,16 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.int32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombinationClamp
|
||||
epilogue_functor = LinearCombinationClamp(
|
||||
C.element, C.alignment, math_inst.element_accumulator,
|
||||
element_epilogue
|
||||
)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
|
||||
@ -348,3 +348,16 @@ conv2d fprop_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1
|
||||
conv2d fprop_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 164942943 4259285988 984016853 888753301
|
||||
conv2d fprop_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 2823094147 1681845497 4242738907 3244428635
|
||||
conv2d fprop_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 s8nhwc_s8nhwc_inhwc_i_i 4060010502 2881035321 3927119619 3311661122
|
||||
conv2d dgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4110991321 3464637181 1030377090 3211227145
|
||||
conv2d dgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4110991321 1479940693 2379046159 2482639965
|
||||
conv2d dgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 832653836 1871463331 2718290800 1797658305
|
||||
conv2d dgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3484040069 664160900 3954982568 985899371
|
||||
conv2d dgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1513864544 1924855848 1728786974 3821277575
|
||||
conv2d dgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 868180534 1764715518 3998637379 2782670608
|
||||
conv2d dgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3437976747 666906244 2107859856 831363691
|
||||
conv2d dgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4195072693 1575210381 2486552517 3268706408
|
||||
conv2d dgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3457330201 2316839359 1729888024 2308314800
|
||||
conv2d dgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 754609939 2469024119 464378888 544154978
|
||||
conv2d dgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 754609939 2469024119 464378888 3191247524
|
||||
conv2d dgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1690216859 554790212 956712535 1281779197
|
||||
conv2d dgrad_1x56x56x8_28x28_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3184127693 835105643 4011933753 3207244654
|
||||
|
||||
@ -42,8 +42,7 @@ import unittest
|
||||
#
|
||||
|
||||
def TestGemmOperator(gemm_kind, math_inst, layout, alignment, tiling, arch, mixed=False,
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1, **kwargs):
|
||||
epilogue_functor=None, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs):
|
||||
"""
|
||||
Test GEMM Operation based on configuration
|
||||
"""
|
||||
@ -68,7 +67,7 @@ def TestGemmOperator(gemm_kind, math_inst, layout, alignment, tiling, arch, mixe
|
||||
|
||||
tile_description = TileDescription(
|
||||
tiling[0], tiling[1], tiling[2],
|
||||
math_inst, arch, arch
|
||||
math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -84,11 +83,15 @@ def TestGemmOperator(gemm_kind, math_inst, layout, alignment, tiling, arch, mixe
|
||||
)
|
||||
|
||||
element_epilogue = data_type[3]
|
||||
if epilogue_functor is None:
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
if gemm_kind == GemmKind.Universal:
|
||||
operation = GemmOperationUniversal(
|
||||
arch=arch, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
if A.layout in [cutlass.ColumnMajorInterleaved32, cutlass.RowMajorInterleaved32]:
|
||||
@ -99,7 +102,7 @@ def TestGemmOperator(gemm_kind, math_inst, layout, alignment, tiling, arch, mixe
|
||||
elif gemm_kind == GemmKind.Grouped:
|
||||
operation = GemmOperationGrouped(
|
||||
arch, tile_description, A, B, C,
|
||||
element_epilogue, epilogue_functor, swizzling_functor,
|
||||
epilogue_functor, swizzling_functor,
|
||||
precompute_mode=kwargs["precompute_mode"]
|
||||
)
|
||||
testbed = TestbedGrouped(operation=operation)
|
||||
@ -110,7 +113,7 @@ def TestGemmOperator(gemm_kind, math_inst, layout, alignment, tiling, arch, mixe
|
||||
|
||||
def TestConv2dOperator(math_inst, alignment, tiling, arch,
|
||||
stride_supports=[StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided],
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
epilogue_functor=None,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1, interleaved=False, **kwargs):
|
||||
"""
|
||||
Test Conv2d Operation based on configurations
|
||||
@ -167,20 +170,24 @@ def TestConv2dOperator(math_inst, alignment, tiling, arch,
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=tiling[0], stages=tiling[1],
|
||||
warp_count=tiling[2],
|
||||
math_instruction=math_inst,
|
||||
min_compute=arch, max_compute=arch
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
if conv_kind == cutlass.conv.Operator.dgrad and stride_support == StrideSupport.Strided:
|
||||
swizzling_functor = cutlass.StridedDgradIdentitySwizzle1
|
||||
else:
|
||||
swizzling_functor = default_swizzling_functor
|
||||
|
||||
if epilogue_functor is None:
|
||||
epilogue_functor_ = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, data_type[3])
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=conv_kind, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=arch, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=data_type[3], stride_support=stride_support,
|
||||
epilogue_functor=epilogue_functor,
|
||||
stride_support=stride_support,
|
||||
epilogue_functor=epilogue_functor_,
|
||||
swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -369,7 +376,11 @@ class Test_SM80(unittest.TestCase):
|
||||
tiling = ([256, 64, 64], 4, [4, 1, 1])
|
||||
data_type_mixed = [cutlass.int8, cutlass.int8, cutlass.int8, cutlass.float32]
|
||||
|
||||
self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment_mixed, tiling, 80, False, data_type=data_type_mixed, epilogue_functor=EpilogueFunctor.FastLinearCombinationClamp))
|
||||
epilogue_functor = FastLinearCombinationClamp(
|
||||
data_type_mixed[2], alignment_mixed[2]
|
||||
)
|
||||
|
||||
self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment_mixed, tiling, 80, False, data_type=data_type_mixed, epilogue_functor=epilogue_functor))
|
||||
stride_supports = [StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided]
|
||||
layout = [cutlass.TensorNC32HW32, cutlass.TensorC32RSK32, cutlass.TensorNC32HW32]
|
||||
results = TestConv2dOperator(math_inst, alignment_mixed, tiling, 80, stride_supports=stride_supports, data_type=data_type_mixed, layout=layout, interleaved=True)
|
||||
@ -378,59 +389,59 @@ class Test_SM80(unittest.TestCase):
|
||||
|
||||
def SM80_SparseTensorOp_16832(self):
|
||||
pass
|
||||
def test_SM80_PlanarComplexTensorOp_16816(self):
|
||||
def SM80_PlanarComplexTensorOp_16816(self):
|
||||
pass
|
||||
def test_SM80_SparseTensorOp_16816_fast_math(self):
|
||||
def SM80_SparseTensorOp_16816_fast_math(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_1688_complex(self):
|
||||
def SM80_TensorOp_1688_complex(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_1688_fast_fp32_math_complex(self):
|
||||
def SM80_TensorOp_1688_fast_fp32_math_complex(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_1688_rank_k(self):
|
||||
def SM80_TensorOp_1688_rank_k(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_1688_rank_k_complex(self):
|
||||
def SM80_TensorOp_1688_rank_k_complex(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_1688_trmm(self):
|
||||
def SM80_TensorOp_1688_trmm(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_1688_trmm_complex(self):
|
||||
def SM80_TensorOp_1688_trmm_complex(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_1688_symm(self):
|
||||
def SM80_TensorOp_1688_symm(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_1688_symm_complex(self):
|
||||
def SM80_TensorOp_1688_symm_complex(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_884_complex(self):
|
||||
def SM80_TensorOp_884_complex(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_884_complex_gaussian(self):
|
||||
def SM80_TensorOp_884_complex_gaussian(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_884_rank_k(self):
|
||||
def SM80_TensorOp_884_rank_k(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_884_rank_k_complex(self):
|
||||
def SM80_TensorOp_884_rank_k_complex(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_884_rank_k_complex_gaussian(self):
|
||||
def SM80_TensorOp_884_rank_k_complex_gaussian(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_884_trmm(self):
|
||||
def SM80_TensorOp_884_trmm(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_884_trmm_complex(self):
|
||||
def SM80_TensorOp_884_trmm_complex(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_884_trmm_complex_gaussian(self):
|
||||
def SM80_TensorOp_884_trmm_complex_gaussian(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_884_symm(self):
|
||||
def SM80_TensorOp_884_symm(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_884_symm_complex(self):
|
||||
def SM80_TensorOp_884_symm_complex(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_884_symm_complex_gaussian(self):
|
||||
def SM80_TensorOp_884_symm_complex_gaussian(self):
|
||||
pass
|
||||
def test_SM80_SparseTensorOp_16864_TN(self):
|
||||
def SM80_SparseTensorOp_16864_TN(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_16864_TN(self):
|
||||
def SM80_TensorOp_16864_TN(self):
|
||||
pass
|
||||
def test_SM80_SparseTensorOp_168128_TN(self):
|
||||
def SM80_SparseTensorOp_168128_TN(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_16864_Interleaved(self):
|
||||
def SM80_TensorOp_16864_Interleaved(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_168256(self):
|
||||
def SM80_TensorOp_168256(self):
|
||||
pass
|
||||
def test_SM80_Simt_complex(self):
|
||||
def SM80_Simt_complex(self):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user