CUTLASS 2.10 updates (#622)

Co-authored-by: Aniket Shivam <ashivam@nvidia.com>
This commit is contained in:
ANIKET SHIVAM
2022-09-12 18:26:30 -07:00
committed by GitHub
parent beae168f90
commit e773429f7e
96 changed files with 8365 additions and 1667 deletions

View File

@ -1,11 +1,19 @@
# NVIDIA CUTLASS Changelog
## [2.10.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.10.0) (2022-08-23)
* [Grouped convolution targeting implicit GEMM](test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu)
* [Depthwise separable convolution](test/unit/conv/device/depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu)
* Optimizations for CUTLASS's [Grouped GEMM](examples/24_gemm_grouped/gemm_grouped.cu) kernel
* [Grouped GEMM for Multihead Attention](examples/41_multi_head_attention)
* [GEMM + Layer norm fusion for Ampere](examples/37_gemm_layernorm_gemm_fusion/)
* [CUTLASS Python](/examples/40_cutlass_py) now supports GEMM, CONV, Group GEMM for different data types as well as different epilogue flavours.
* Optimizations for CUTLASS's [Grouped GEMM](examples/24_gemm_grouped/gemm_grouped.cu) kernel. Threadblock scheduling part is improved. Some computation can be moved to the host side if applicable. [Grouped Syr2k](examples/38_syr2k_grouped/syr2k_grouped.cu) kernels are added, too.
* Optimizations for [GEMM+Softmax](examples/35_gemm_softmax). All the reduction computation is fused into the previous GEMM. More template arguments are provided to fine tune the performance.
* [Grouped GEMM for Multihead Attention](examples/41_multi_head_attention). This general group gemm based MHA does not require the sequence length of all GEMMs to be the same which makes it most useful for natural language processing.
* [GEMM + Layer norm fusion for Ampere](examples/37_gemm_layernorm_gemm_fusion/) splits the layernorm into two parts and both of them can be fused into the GEMMs before and after separately. In addition to use square sum to compute variance of layernorm, [Shift-K](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data) is provided if square sum raise numerical issues.
* [GEMM Epilogue Permutation Fusion](examples/39_gemm_permute) can apply user provided permutation layout mapping in the GEMM epilogue.
* [Grouped convolution targeting implicit GEMM](test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) introduces the first group convolution implementation to CUTLASS. It is an Analytical implementation, not an Optimized. The restrictions are: 1) input and output channel number should be multiple of group number. 2) split-K is not supported. The implementation has 2 modes:
* kSingleGroup: output channel per group is multiple of Threadblock tile N.
* kMultipleGroup: Threadblock tile N is multiple of output channel per group.
* [Depthwise separable convolution](test/unit/conv/device/depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) introduces the first depthwise convolution which is also Analytical for now. The restrictions are: 1) SIMT only 2) No split-K 3) input channel equals to output channel equals to group number.
* Standalone [Layernorm](/tools/util/include/cutlass/util/device_layernorm.h) and [Pooling](/tools/util/include/cutlass/util/device_nhwc_pooling.h) kernels.
* [Back-to-back GEMM/CONV](examples/13_two_tensor_op_fusion) relaxes the requirement that the first GEMM K dimension needs to be the multiple of Threadblock Tile K dimension.
* Optimal performance using [**CUDA 11.6u2**](https://developer.nvidia.com/cuda-downloads)
* Updates and bugfixes from the community (thanks!)
* **Deprecation announcement:** CUTLASS plans to deprecate the following:
@ -47,7 +55,7 @@
* New elementwise fusion pattern for [residual block](/include/cutlass/epilogue/thread/linear_combination_residual_block.h).
* [Group GEMM](/examples/24_gemm_grouped) thread block number calculation fix which helps to launch the intended number of threadblocks to fully occupy the GPUs.
* [Parallel GEMM splitk](https://github.com/NVIDIA/cutlass/pull/277) support in the CUTLASS profiler.
* Optimal performance using [**CUDA 11.7**](https://developer.nvidia.com/cuda-downloads)
* Optimal performance using [**CUDA 11.6u2**](https://developer.nvidia.com/cuda-downloads)
* Updates and bugfixes from the community (thanks!)

View File

@ -39,11 +39,16 @@ supported at each level of the execution model hierarchy.
# What's New in CUTLASS 2.10
CUTLASS 2.10 is an update to CUTLASS adding:
- [Grouped convolution targeting implicit GEMM](test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu)
- [Depthwise separable convolution](test/unit/conv/device/depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu)
- Optimizations for CUTLASS's [Grouped GEMM](examples/24_gemm_grouped/gemm_grouped.cu) kernel
- [Grouped GEMM for Multihead Attention](examples/41_multi_head_attention)
- [GEMM + Layer norm fusion for Ampere](examples/37_gemm_layernorm_gemm_fusion/)
- [CUTLASS Python](/examples/40_cutlass_py) now supports GEMM, CONV, Group GEMM for different data types as well as different epilogue flavours.
- Optimizations for CUTLASS's [Grouped GEMM](examples/24_gemm_grouped/gemm_grouped.cu) kernel. It can move some scheduling into the host side if applicable.
- Optimizations for [GEMM+Softmax](examples/35_gemm_softmax).
- [Grouped GEMM for Multihead Attention](examples/41_multi_head_attention) is a general MHA that does not require equal sequence length in every GEMM.
- [GEMM + Layer norm fusion for Ampere](examples/37_gemm_layernorm_gemm_fusion/) can fuse the layernorm into GEMMs before and after.
- [GEMM Epilogue Permutation Fusion](examples/39_gemm_permute) can permute the GEMM output before storing.
- [Grouped convolution targeting implicit GEMM](test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) introduces the first group convolution implementation to CUTLASS. It is an Analytical implementation, not an Optimized.
- [Depthwise separable convolution](test/unit/conv/device/depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) introduces the first depthwise convolution which is also Analytical for now.
- Standalone [Layernorm](/tools/util/include/cutlass/util/device_layernorm.h) and [Pooling](/tools/util/include/cutlass/util/device_nhwc_pooling.h) kernels.
- [Back-to-back GEMM](examples/13_two_tensor_op_fusion) enhancements.
- Updates and bugfixes from the community (thanks!)
- **Deprecation announcement:** CUTLASS plans to deprecate the following:
- Maxwell and Pascal GPU architectures

View File

@ -1528,6 +1528,9 @@ int main(int argc, char const **args) {
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
// NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels.
// This parameter is passed in at present to match the APIs of other kernels. The parameter
// is unused within the kernel.
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
4>::GemmKernel;

View File

@ -187,8 +187,8 @@ struct Options {
// elements in input matrices.
using ElementAccumulator = float; // <- data type of accumulator
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
using ElementInputA = cutlass::half_t;; // <- data type of elements in input matrix A
using ElementInputB = cutlass::half_t;; // <- data type of elements in input matrix B
using ElementInputA = cutlass::half_t; // <- data type of elements in input matrix A
using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B
using ElementOutput = float; // <- data type of elements in output matrix D
// The code section below describes matrix layout of input and output matrices.
@ -252,7 +252,7 @@ using Gemm = cutlass::gemm::device::GemmUniversal<ElementInputA,
SwizzleThreadBlock,
NumStages,
8, /*alignmentA*/
8, /*alignmengB*/
8, /*alignmentB*/
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,

View File

@ -1366,7 +1366,12 @@ int main(int argc, char const **args) {
using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, 1,
ElementAccumulator, ElementAccumulator>;
// NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels.
// This parameter is passed in at present to match the APIs of other kernels. The parameter
// is unused within the kernel.
using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
const int kStages = 4;
const bool kSplitKSerial = false;
using Operator = cutlass::arch::OpMultiplyAdd;

View File

@ -92,25 +92,35 @@ Example 1: SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32_256x128x128_64x64x128
```python
python gemm.py -i 16 8 32 -ta int8 -tb int8 -tc int8 -tacc int32 -m multiply_add -op TensorOp -b 128 128 128 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 16 -lb ColumnMajor -ab 16 -lc RowMajor -ac 16 -te float32 -ep FastLinearCombinationClamp -sw IdentitySwizzle2 -p 512 512 512 -alpha 1.0 -beta 0.0 -gm Gemm -k 1
```
### Batched & Array GEMM
Example 1: Batched GEMM
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3
```
Example 2: Array GEMM
```python
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 2
```
***
## GEMM Grouped Examples
The GEMM Grouped examples use numpy to create input tensors and verify the results.
Example 1: SM80_Device_GemmGrouped_f16t_f16t_f32t_tensor_op_f32_128x128x32_64x64x32, device schedule
```python
python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 0.0 -pm Device
python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 0.0 -pm Device
```
Example 2: SM80_Device_GemmGrouped_f64n_f64n_f64t_tensor_op_f64_64x64x16_32x32x16, host schedule
```python
python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle2 -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 1.0 -pm Host
python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 1.0 -pm Host
```
Example 3: SM80_Device_GemmGrouped_f32n_f32n_f32n_simt_f32_128x64x8_64x32x1, device schedule
```python
python gemm_grouped.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 64 8 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device
python gemm_grouped.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 64 8 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device
```
Example 4: SM80_Device_GemmGrouped_f16t_f16t_f32t_tensor_op_f32_128x128x32_64x64x32, device schedule
```python
python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle8 -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device
python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device
```
***
## Conv2d Example
@ -160,3 +170,61 @@ Example 4: SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nh
```python
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 56 56 12 -krsc 8 1 1 12 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
```
## Epilogue
### Bias
To replace C with a bias vector, add `-bias` flag.
### Activation function
Example 1: ReLU
```python
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 -bias -activ relu
```
Example 2: leaky ReLU
```python
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2 -bias -activ leaky_relu -activ_arg 0.2
```
Example 3: tanh (alpha=0 to avoid saturation)
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2 -bias -activ tanh
```
Example 4: sigmoid
```python
python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 0.0 -beta 0.5 -pm Host -bias -activ sigmoid -bias -activ sigmoid
```
Example 5: SiLU
```python
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ silu
```
Example 6: HardSwish
```python
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ hardswish
```
Example 7: GELU
```python
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 0.0 -beta 0.5 -gm GemmSplitKParallel -k 5 -bias -activ gelu
```
### Epilogue Visitor Tree
Example 1:
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
```
Example 2:
```python
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -epv ColumnBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
```
Example 3:
```python
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
```
Example 4:
```python
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnReduction -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
```
Example 5:
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3
```
Example 6:
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnBroadcast -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 3
```

View File

@ -33,6 +33,7 @@ import pycutlass
from pycutlass import *
from pycutlass.conv2d_operation import *
from pycutlass.utils import reference_model
import torch.nn.functional as F
import argparse
@ -127,6 +128,13 @@ parser.add_argument("-stride", "--stride", nargs=2, type=int, help="stride (stri
parser.add_argument("-dilation", "--dilation", nargs=2, type=int, help="dilation (dilation_h, dilation_w)")
parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha")
parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta")
parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector")
# Activation function
parser.add_argument("-activ", "--activation_function", default="identity",
choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function")
parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float,
help="addition arguments for activation")
parser.add_argument('--print_cuda', action="store_true",
help="print the underlying CUDA kernel")
@ -138,6 +146,8 @@ except:
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
np.random.seed(0)
element_a = getattr(cutlass, args.element_a)
element_b = getattr(cutlass, args.element_b)
element_c = getattr(cutlass, args.element_c)
@ -152,7 +162,7 @@ math_inst = MathInstruction(
tile_description = TileDescription(
args.threadblock_shape, args.stages, args.warp_count,
math_inst, args.compute_capability, args.compute_capability
math_inst
)
layout_a = getattr(cutlass, args.layout_a)
@ -172,7 +182,16 @@ C = TensorDescription(
)
element_epilogue = getattr(cutlass, args.element_epilogue)
epilogue_functor = getattr(EpilogueFunctor, args.epilogue_functor)
if (args.activation_function == "identity"
or (args.split_k_mode == "Parallel" and args.split_k_slices > 1)):
#
epilogue_functor = getattr(pycutlass, args.epilogue_functor)(
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
else:
epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")(
getattr(pycutlass, args.activation_function)(element_epilogue),
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
iterator_algorithm = getattr(cutlass.conv.IteratorAlgorithm, args.iterator_algorithm)
swizzling_functor = getattr(cutlass, args.swizzling_functor)
stride_support = getattr(StrideSupport, args.stride_support)
@ -181,7 +200,7 @@ conv_kind = getattr(cutlass.conv.Operator, args.conv_kind)
operation = Conv2dOperation(
conv_kind=conv_kind, iterator_algorithm=iterator_algorithm,
arch=args.compute_capability, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue, stride_support=stride_support,
A=A, B=B, C=C, stride_support=stride_support,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
)
@ -191,10 +210,18 @@ if args.print_cuda:
operations = [operation,]
if args.split_k_mode == "Parallel" and args.split_k_slices > 1:
if (args.activation_function == "identity"):
epilogue_functor_reduction = getattr(pycutlass, args.epilogue_functor)(
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
else:
epilogue_functor_reduction = getattr(pycutlass, "LinearCombinationGeneric")(
getattr(pycutlass, args.activation_function)(element_epilogue),
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
reduction_operation = ReductionOperation(
shape=cutlass.MatrixCoord(4, 32 * C.alignment),
C=C, element_accumulator=element_acc,
element_compute=element_epilogue,
epilogue_functor=epilogue_functor_reduction,
count=C.alignment
)
operations.append(reduction_operation)
@ -219,9 +246,18 @@ tensor_A_size = cutlass.conv.implicit_gemm_tensor_a_size(
tensor_B_size = cutlass.conv.implicit_gemm_tensor_b_size(
conv_kind, problem_size
)
tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_size(
conv_kind, problem_size
)
if args.bias:
tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_extent(
conv_kind, problem_size
).at(3)
else:
tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_size(
conv_kind, problem_size
)
tensor_D_size = cutlass.conv.implicit_gemm_tensor_c_size(
conv_kind, problem_size
)
if args.element_a != "int8":
tensor_A = torch.ceil(torch.empty(size=(tensor_A_size,), dtype=getattr(torch, args.element_a), device="cuda").uniform_(-8.5, 7.5))
@ -238,12 +274,12 @@ if args.element_c != "int8":
else:
tensor_C = torch.empty(size=(tensor_C_size,), dtype=getattr(torch, args.element_c), device="cuda").uniform_(-2, 2)
tensor_D = torch.ones_like(tensor_C)
tensor_D = torch.ones(size=(tensor_D_size,), dtype=getattr(torch, args.element_c), device="cuda")
arguments = Conv2dArguments(
operation=operation, problem_size=problem_size, A=tensor_A,
B=tensor_B, C=tensor_C, D=tensor_D,
output_op = LinearCombinationFunctorArguments(args.alpha, args.beta),
output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)),
split_k_mode=getattr(cutlass.conv.SplitKMode, args.split_k_mode),
split_k_slices=problem_size.split_k_slices
)
@ -257,7 +293,8 @@ if args.split_k_mode == "Parallel" and args.split_k_slices > 1:
workspace=arguments.ptr_D,
destination=tensor_D,
source=tensor_C,
output_op = LinearCombinationFunctorArguments(args.alpha, args.beta)
output_op = reduction_operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)),
bias = arguments.bias
)
operation.run(arguments)
@ -270,8 +307,12 @@ else:
reference_model = Conv2dReferenceModule(A, B, C, conv_kind)
tensor_D_ref = reference_model.run(tensor_A, tensor_B, tensor_C, arguments.problem_size, args.alpha, args.beta)
assert torch.equal(tensor_D, tensor_D_ref)
tensor_D_ref = reference_model.run(tensor_A, tensor_B, tensor_C, arguments.problem_size, args.alpha, args.beta, args.bias)
if (args.activation_function != "identity"):
tensor_D_ref = getattr(F, args.activation_function)(*([tensor_D_ref,] + args.activation_args))
try:
assert torch.equal(tensor_D, tensor_D_ref)
except:
assert torch.allclose(tensor_D, tensor_D_ref, rtol=1e-2)
print("Passed.")

View File

@ -99,9 +99,11 @@ parser.add_argument("-te", "--element_epilogue", default="float32", type=str,
parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination",
type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'],
help="This option describes the epilogue part of the kernel")
parser.add_argument("-epv", "--epilogue_visitor", default=None,
type=str, choices=['RowReduction', 'ColumnReduction', 'RowBroadcast', 'ColumnBroadcast'], help="epilogue visitor for more complex epilogues")
# swizzling
parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[
"IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle"],
"IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle", "BatchedIdentitySwizzle"],
help="This option describes how thread blocks are scheduled on GPU")
# Argument
@ -113,17 +115,22 @@ parser.add_argument("-alpha", "--alpha", default=1.0, type=float,
parser.add_argument("-beta", "--beta", default=0.0, type=float,
help="Scaling factor of C")
parser.add_argument("-gm", "--gemm_mode", default="Gemm", type=str,
choices=["Gemm", "GemmSplitKParallel"],
choices=["Gemm", "GemmSplitKParallel", "Batched", "Array"],
help="GEMM mode. Gemm is used for non-splitK or serial-splitK. \
GemmSplitKParallel is used for parallel splitK")
parser.add_argument('-k', '--split_k_slices', default=1,
type=int, help="Number of split-k partitions. (default 1)")
parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector")
parser.add_argument('-batch', '--batch', default=1, type=int, help="batch size for batched GEMM")
# Activation function
parser.add_argument("-activ", "--activation_function", default="identity",
choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function")
parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float,
help="addition arguments for activation")
parser.add_argument('--print_cuda', action="store_true",
help="print the underlying CUDA kernel")
# parser.add_argument('-h', '--help', action="store_true",
# help="print help information")
try:
args = parser.parse_args()
@ -131,6 +138,9 @@ except:
sys.exit(0)
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
pycutlass.compiler.nvcc()
np.random.seed(0)
element_a = getattr(cutlass, args.element_a)
element_b = getattr(cutlass, args.element_b)
@ -146,7 +156,7 @@ math_inst = MathInstruction(
tile_description = TileDescription(
args.threadblock_shape, args.stages, args.warp_count,
math_inst, args.compute_capability, args.compute_capability
math_inst
)
layout_a = getattr(cutlass, args.layout_a)
@ -166,13 +176,83 @@ C = TensorDescription(
)
element_epilogue = getattr(cutlass, args.element_epilogue)
epilogue_functor = getattr(EpilogueFunctor, args.epilogue_functor)
if (args.activation_function == "identity"
or (args.gemm_mode == "GemmSplitKParallel" and args.split_k_slices > 1)):
#
epilogue_functor = getattr(pycutlass, args.epilogue_functor)(
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
else:
epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")(
getattr(pycutlass, args.activation_function)(element_epilogue),
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
swizzling_functor = getattr(cutlass, args.swizzling_functor)
visitor = args.epilogue_visitor is not None
if args.epilogue_visitor == "ColumnReduction":
class ColumnReduction_(EpilogueVisitTree):
def __call__(
self, accum: 'tensor', c: 'tensor',
alpha: 'scalar', beta: 'scalar'):
#
D = alpha * accum + beta * c
reduction = reduction_op(D, "column", "Add", args.threadblock_shape[0])
return D, reduction
epilogue_functor = ColumnReduction_(
epilogue_functor, tile_description, math_inst.element_accumulator,
C.alignment, element_epilogue, C.element)
epilogue_functor.initialize()
elif args.epilogue_visitor == "RowReduction":
class RowReduction_(EpilogueVisitTree):
def __call__(
self, accum: 'tensor', c: 'tensor',
alpha: 'scalar', beta: 'scalar'):
#
D = alpha * accum + tanh.numpy(beta * c)
reduction = reduction_op(D, "row", "Add", args.threadblock_shape[1])
return D, reduction
epilogue_functor = RowReduction_(
epilogue_functor, tile_description, math_inst.element_accumulator,
C.alignment, element_epilogue, C.element)
epilogue_functor.initialize()
elif args.epilogue_visitor == "RowBroadcast":
class RowBroadcast_(EpilogueVisitTree):
def __call__(
self, accum: 'tensor', c: 'tensor',
vector: 'row', alpha: 'scalar', beta: 'scalar'):
#
T = accum + vector
scale_T = alpha * T
Z = relu.numpy(scale_T + beta * c)
return Z, T
epilogue_functor = RowBroadcast_(
epilogue_functor, tile_description, math_inst.element_accumulator,
C.alignment, element_epilogue, C.element)
epilogue_functor.initialize()
elif args.epilogue_visitor == "ColumnBroadcast":
class ColumnBroadcast_(EpilogueVisitTree):
def __call__(
self, accum: 'tensor', c: 'tensor',
vector: 'column', alpha: 'scalar', beta: 'scalar'):
#
T = accum + vector
scale_T = leaky_relu.numpy(alpha * T, 0.2)
Z = scale_T + beta * c
return Z, T
epilogue_functor = ColumnBroadcast_(
epilogue_functor, tile_description, math_inst.element_accumulator,
C.alignment, element_epilogue, C.element)
epilogue_functor.initialize()
else:
epilogue_functor = epilogue_functor
operation = GemmOperationUniversal(
arch=args.compute_capability, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor,
visitor=visitor
)
if args.print_cuda:
@ -181,10 +261,19 @@ if args.print_cuda:
operations = [operation, ]
if args.gemm_mode == "GemmSplitKParallel":
if (args.activation_function == "identity"):
epilogue_functor_reduction = getattr(pycutlass, args.epilogue_functor)(
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
else:
epilogue_functor_reduction = getattr(pycutlass, "LinearCombinationGeneric")(
getattr(pycutlass, args.activation_function)(element_epilogue),
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
reduction_operation = ReductionOperation(
shape=cutlass.MatrixCoord(4, 32 * C.alignment),
C=C, element_accumulator=element_acc,
element_compute=element_epilogue,
element_compute=element_epilogue,
epilogue_functor=epilogue_functor_reduction,
count=C.alignment
)
operations.append(reduction_operation)
@ -196,47 +285,102 @@ pycutlass.compiler.add_module(operations)
problem_size = cutlass.gemm.GemmCoord(
args.problem_size[0], args.problem_size[1], args.problem_size[2])
tensor_a_size = args.batch * problem_size.m() * problem_size.k()
if args.element_a != "int8":
if args.element_a == "bfloat16":
tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
* problem_size.k(),))).astype(bfloat16)
tensor_A = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(tensor_a_size,))
).astype(bfloat16)
else:
tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
* problem_size.k(),))).astype(getattr(np, args.element_a))
tensor_A = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(tensor_a_size,))
).astype(getattr(np, args.element_a))
else:
tensor_A = np.random.uniform(low=-2, high=2, size=(problem_size.m()
* problem_size.k(),)).astype(getattr(np, args.element_a))
tensor_A = np.random.uniform(
low=-2, high=2,size=(tensor_a_size,)
).astype(getattr(np, args.element_a))
tensor_b_size = args.batch * problem_size.k() * problem_size.n()
if args.element_b != "int8":
if args.element_b == "bfloat16":
tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k()
* problem_size.n(),))).astype(bfloat16)
tensor_B = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(tensor_b_size,))
).astype(bfloat16)
else:
tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k()
* problem_size.n(),))).astype(getattr(np, args.element_b))
tensor_B = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(tensor_b_size,))
).astype(getattr(np, args.element_b))
else:
tensor_B = np.random.uniform(low=-2, high=2, size=(problem_size.k()
* problem_size.n(),)).astype(getattr(np, args.element_b))
tensor_B = np.random.uniform(
low=-2, high=2, size=(tensor_b_size,)
).astype(getattr(np, args.element_b))
if args.element_c != "int8":
if args.element_c == "bfloat16":
tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
* problem_size.n(),))).astype(bfloat16)
if args.bias:
if args.layout_c == "RowMajor":
tensor_c_size = args.batch * problem_size.n()
elif args.layout_c == "ColumnMajor":
tensor_c_size = args.batch * problem_size.m()
else:
raise ValueError(args.layout_c)
else:
tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
* problem_size.n(),))).astype(getattr(np, args.element_c))
tensor_c_size = args.batch * problem_size.m() * problem_size.n()
if args.element_c == "bfloat16":
tensor_C = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(tensor_c_size,))
).astype(bfloat16)
else:
tensor_C = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(tensor_c_size,))
).astype(getattr(np, args.element_c))
else:
tensor_C = np.random.uniform(low=-2, high=2, size=(problem_size.m()
* problem_size.n(),)).astype(getattr(np, args.element_c))
tensor_C = np.random.uniform(
low=-2, high=2, size=(args.batch * problem_size.m() * problem_size.n(),)
).astype(getattr(np, args.element_c))
tensor_D = np.ones_like(tensor_C)
tensor_D = np.zeros(
shape=(args.batch * problem_size.m() * problem_size.n(),)
).astype(getattr(np, args.element_c))
if args.epilogue_visitor == "RowReduction":
cta_n = args.threadblock_shape[1]
num_cta_n = (problem_size.n() + cta_n - 1) // cta_n
reduction = np.zeros(shape=(args.batch * problem_size.m() * num_cta_n,), dtype=getattr(np, args.element_c))
output_op = operation.epilogue_type(
D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()]
)
elif args.epilogue_visitor == "ColumnReduction":
cta_m = args.threadblock_shape[0]
num_cta_m = (problem_size.m() + cta_m - 1) // cta_m
reduction = np.zeros(shape=(args.batch * problem_size.n() * num_cta_m,), dtype=getattr(np, args.element_c))
output_op = operation.epilogue_type(
D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()]
)
elif args.epilogue_visitor == "RowBroadcast":
vector = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(args.batch, 1, problem_size.n()))
).astype(getattr(np, args.element_c))
tensor_t = np.empty_like(tensor_D)
output_op = operation.epilogue_type(
c=tensor_C, vector=vector, alpha=args.alpha, beta=args.beta, Z=tensor_D, T=tensor_t, problem_size=[problem_size.m(), problem_size.n()]
)
elif args.epilogue_visitor == "ColumnBroadcast":
vector = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(args.batch, problem_size.m(), 1))
).astype(getattr(np, args.element_c))
tensor_t = np.empty_like(tensor_D)
output_op = operation.epilogue_type(
c=tensor_C, vector=vector, alpha=args.alpha, beta=args.beta, Z=tensor_D, T=tensor_t, problem_size=[problem_size.m(), problem_size.n()]
)
else:
output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args))
arguments = GemmArguments(
operation=operation, problem_size=problem_size,
A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D,
output_op=LinearCombinationFunctorArguments(args.alpha, args.beta),
output_op=output_op,
gemm_mode=getattr(cutlass.gemm.Mode, args.gemm_mode),
split_k_slices=args.split_k_slices
split_k_slices=args.split_k_slices, batch=args.batch
)
if args.gemm_mode == "GemmSplitKParallel":
@ -245,7 +389,8 @@ if args.gemm_mode == "GemmSplitKParallel":
problem_size=[problem_size.m(), problem_size.n()],
partitions=args.split_k_slices, workspace=arguments.ptr_D,
destination=tensor_D, source=tensor_C,
output_op=LinearCombinationFunctorArguments(args.alpha, args.beta)
output_op=reduction_operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)),
bias = arguments.bias
)
operation.run(arguments)
@ -259,8 +404,42 @@ else:
# run the host reference module
reference = ReferenceModule(A, B, C)
tensor_D_ref = reference.run(
tensor_A, tensor_B, tensor_C, problem_size, args.alpha, args.beta)
tensor_A, tensor_B, tensor_C, problem_size, args.alpha, args.beta, args.bias, args.batch)
assert np.array_equal(tensor_D, tensor_D_ref)
if args.epilogue_visitor in ["RowBroadcast", "ColumnBroadcast"]:
tensor_D_ref = (tensor_D_ref.reshape((args.batch, problem_size.m(), problem_size.n())) + vector).flatten()
tensor_D_ref = getattr(pycutlass, args.activation_function).numpy(*([tensor_D_ref,] + args.activation_args))
if args.epilogue_visitor in ["RowReduction", "ColumnReduction"]:
output_op.sync()
accum_ref = reference.run(
tensor_A, tensor_B, tensor_C, problem_size, 1.0, 0.0, args.bias, args.batch)
tensor_D_ref, reduction_ref = epilogue_functor(
accum_ref.reshape((args.batch, problem_size.m(), problem_size.n())),
tensor_C.reshape((args.batch, problem_size.m(), problem_size.n())),
args.alpha, args.beta
)
tensor_D_ref = tensor_D_ref.flatten()
reduction_ref = reduction_ref.flatten()
assert np.allclose(reduction_ref, reduction, atol=1e-2)
elif args.epilogue_visitor in ["RowBroadcast", "ColumnBroadcast"]:
output_op.sync()
accum_ref = reference.run(
tensor_A, tensor_B, tensor_C, problem_size, 1.0, 0.0, args.bias, args.batch)
tensor_D_ref, tensor_T_ref = epilogue_functor(
accum_ref.reshape((args.batch, problem_size.m(), problem_size.n())),
tensor_C.reshape((args.batch, problem_size.m(), problem_size.n())),
vector, args.alpha, args.beta)
tensor_D_ref = tensor_D_ref.flatten()
tensor_T_ref = tensor_T_ref.flatten()
assert np.array_equal(tensor_t, tensor_T_ref)
try:
assert np.array_equal(tensor_D, tensor_D_ref)
except:
assert np.allclose(tensor_D, tensor_D_ref, atol=1e-5)
print("Passed.")

View File

@ -99,7 +99,10 @@ parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination",
# swizzling
parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[
"IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle"],
help="This option describes how thread blocks are scheduled on GPU")
help="This option describes how thread blocks are scheduled on GPU. \
NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. \
This parameter is passed in at present to match the APIs of other kernels. The parameter \
is unused within the kernel")
# precompute mode
parser.add_argument("-pm", "--precompute_mode",
default="Device", type=str, choices=["Host", "Device"],
@ -109,7 +112,13 @@ parser.add_argument("-p", "--problem_size_dir", type=str,
help="path to the csv file contains the problem sizes")
parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha")
parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta")
parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector")
# Activation function
parser.add_argument("-activ", "--activation_function", default="identity",
choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function")
parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float,
help="addition arguments for activation")
parser.add_argument('--print_cuda', action="store_true",
help="print the underlying CUDA kernel")
@ -120,6 +129,8 @@ except:
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
np.random.seed(0)
element_a = getattr(cutlass, args.element_a)
element_b = getattr(cutlass, args.element_b)
element_c = getattr(cutlass, args.element_c)
@ -134,7 +145,7 @@ math_inst = MathInstruction(
tile_description = TileDescription(
args.threadblock_shape, args.stages, args.warp_count,
math_inst, args.compute_capability, args.compute_capability
math_inst
)
layout_a = getattr(cutlass, args.layout_a)
@ -154,13 +165,19 @@ C = TensorDescription(
)
element_epilogue = getattr(cutlass, args.element_epilogue)
epilogue_functor = getattr(EpilogueFunctor, args.epilogue_functor)
if args.activation_function == "identity":
epilogue_functor = getattr(pycutlass, args.epilogue_functor)(
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
else:
epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")(
getattr(pycutlass, args.activation_function)(element_epilogue),
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
swizzling_functor = getattr(cutlass, args.swizzling_functor)
precompute_mode = getattr(SchedulerMode, args.precompute_mode)
operation = GemmOperationGrouped(
arch=args.compute_capability, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor,
precompute_mode=precompute_mode
)
@ -214,28 +231,45 @@ for problem_size in problem_sizes:
* problem_size.n(),)).astype(getattr(np, args.element_b))
if args.element_c != "int8":
if args.element_c == "bfloat16":
tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
* problem_size.n(),))).astype(bfloat16)
if args.bias:
if args.layout_c == "RowMajor":
c_size = problem_size.n()
elif args.layout_c == "ColumnMajor":
c_size = problem_size.m()
else:
raise ValueError(args.layout_c)
else:
tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
* problem_size.n(),))).astype(getattr(np, args.element_c))
c_size = problem_size.m() * problem_size.n()
if args.element_c == "bfloat16":
tensor_C = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(c_size,))
).astype(bfloat16)
else:
tensor_C = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(c_size,))
).astype(getattr(np, args.element_c))
else:
tensor_C = np.random.uniform(low=-2, high=2, size=(problem_size.m()
* problem_size.n(),)).astype(getattr(np, args.element_c))
tensor_D = np.zeros_like(tensor_C)
tensor_C = np.random.uniform(
low=-2, high=2, size=(problem_size.m() * problem_size.n(),)
).astype(getattr(np, args.element_c))
tensor_D = np.zeros(
shape=(problem_size.m() * problem_size.n(),)
).astype(getattr(np, args.element_c))
tensor_As.append(tensor_A)
tensor_Bs.append(tensor_B)
tensor_Cs.append(tensor_C)
tensor_Ds.append(tensor_D)
tensor_D_refs.append(reference_module.run(
tensor_A, tensor_B, tensor_C, problem_size, args.alpha, args.beta))
tensor_D_ref = reference_module.run(
tensor_A, tensor_B, tensor_C, problem_size,
args.alpha, args.beta, args.bias)
tensor_D_ref = getattr(pycutlass, args.activation_function).numpy(*([tensor_D_ref,] + args.activation_args))
tensor_D_refs.append(tensor_D_ref)
problem_sizes_coord.append(problem_size)
arguments = GemmGroupedArguments(
operation, problem_sizes_coord, tensor_As, tensor_Bs, tensor_Cs, tensor_Ds,
output_op=LinearCombinationFunctorArguments(args.alpha, args.beta)
output_op=operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args))
)
operation.run(arguments)
@ -243,6 +277,9 @@ operation.run(arguments)
arguments.sync()
for tensor_d, tensor_d_ref in zip(tensor_Ds, tensor_D_refs):
assert np.array_equal(tensor_d, tensor_d_ref)
try:
assert np.array_equal(tensor_d, tensor_d_ref)
except:
assert np.allclose(tensor_d, tensor_d_ref, rtol=1e-5)
print("Passed.")

View File

@ -69,7 +69,7 @@ struct global_load;
/////////////////////////////////////////////////////////////////////////////////////////////////
// The redundant mov PTX instruction is used to enforce the compiler to
// initialize data to zero before ld.global
// keep the initializing code before ld.global
template <typename AccessType>
struct global_load<AccessType,
32

View File

@ -59,6 +59,38 @@ struct Identity {
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
struct LinearCombinationGenericParams {
T alpha; ///< scales accumulators
T beta; ///< scales source tensor
T const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
T const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
//
// Methods
//
CUTLASS_HOST_DEVICE
LinearCombinationGenericParams():
alpha(T(1)),
beta(T(0)),
alpha_ptr(nullptr),
beta_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
LinearCombinationGenericParams(
T alpha,
T beta = T(0)
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
LinearCombinationGenericParams(
T const *alpha_ptr,
T const *beta_ptr = nullptr
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { }
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// ReLu operator - propagates NaNs
@ -79,6 +111,14 @@ struct ReLu {
return mx(value, T(0));
}
/// Host-constructable parameters structure
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
T operator()(T value, Params const &params_) const {
return this->operator()(value);
}
};
template <typename T, int N>
@ -96,20 +136,74 @@ struct ReLu<Array<T, N>> {
maximum<Array<T, N> > mx;
return mx(frag, T(0));
}
/// Host-constructable parameters structure
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &frag, Params const &params_) const {
return this->operator()(frag);
}
};
// Leaky Relu operator
template <typename T>
struct LeakyReLU {
struct Params: LinearCombinationGenericParams<T> {
T leaky_alpha; ///< leaky_alpha
// Methods
using LinearCombinationGenericParams<T>::LinearCombinationGenericParams;
CUTLASS_HOST_DEVICE
Params():
LinearCombinationGenericParams<T>(),
leaky_alpha(T(1)) {}
CUTLASS_HOST_DEVICE
Params(
T alpha,
T beta,
T leaky_alpha = T(1)
): LinearCombinationGenericParams<T>(alpha, beta), leaky_alpha(leaky_alpha) {}
};
CUTLASS_HOST_DEVICE
T operator()(T const &value, T const & alpha_recip) const {
T res = value > T(0) ? value : value * alpha_recip;
return res;
}
CUTLASS_HOST_DEVICE
T operator()(T const &value, Params const &params_) const {
this->operator()(value, params_.leaky_alpha);
}
};
template <typename T, int N>
struct LeakyReLU<Array<T, N> > {
struct Params: LinearCombinationGenericParams<T> {
T leaky_alpha; ///< leaky_alpha
using LinearCombinationGenericParams<T>::LinearCombinationGenericParams;
// Methods
CUTLASS_HOST_DEVICE
Params():
LinearCombinationGenericParams<T>(),
leaky_alpha(T(1)) {}
CUTLASS_HOST_DEVICE
Params(
T alpha,
T beta,
T leaky_alpha = T(1)
): LinearCombinationGenericParams<T>(alpha, beta), leaky_alpha(leaky_alpha) {}
};
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs, T const & alpha_recip) const {
Array<T, N> y;
@ -122,6 +216,11 @@ struct LeakyReLU<Array<T, N> > {
return y;
}
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs, Params const &params_) const {
return this->operator()(rhs, params_.leaky_alpha);
}
};
// Tanh operator
@ -131,6 +230,13 @@ struct Tanh {
T operator()(T const &scalar) const {
return fast_tanh(scalar);
}
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
T operator()(T const &scalar, Params const &params_) const {
return this->operator()(scalar);
}
};
template <typename T, int N>
@ -147,6 +253,13 @@ struct Tanh<Array<T, N> > {
return y;
}
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs, Params const &params_) const {
return this->operator()(rhs);
}
};
template <int N>
@ -159,6 +272,13 @@ struct Tanh<Array<half_t, N>> {
return tanh(z);
}
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs, Params const &params_) const {
return this->operator()(rhs);
}
};
// Sigmoid operator
@ -168,6 +288,13 @@ struct Sigmoid {
T operator()(T const &scalar) const {
return T(1) / (T(1) + fast_exp(-scalar));
}
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
T operator()(T const &scalar, Params const &params_) const {
return this->operator()(scalar);
}
};
template <typename T, int N>
@ -184,6 +311,13 @@ struct Sigmoid<Array<T, N> > {
return y;
}
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs, Params const &params_) const {
return this->operator()(rhs);
}
};
template <int N>
@ -208,6 +342,12 @@ struct Sigmoid<Array<half_t, N>> {
fast_exp(neg(z))));
#endif
}
using Params = LinearCombinationGenericParams<T>;
Array<T, N> operator()(Array<T, N> const &z, Params const &params_) const {
return this->operator()(z);
}
};
// SiLu (swish) operator introduced by Elfwing et al. in the following paper
@ -222,6 +362,13 @@ struct SiLu {
Sigmoid<T> sigmoid;
return scalar * sigmoid(scalar);
}
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
T operator()(T const &scalar, Params const &params_) const {
return this->operator()(scalar);
}
};
template <typename T, int N>
@ -232,6 +379,13 @@ struct SiLu<Array<T, N>> {
multiplies<Array<T, N>> mul;
return mul(rhs, sigmoid_op(rhs));
}
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs, Params const &params_) const {
return this->operator()(rhs);
}
};
// Hardswish operator introduced by Howard et al. in the following paper
@ -248,6 +402,13 @@ struct HardSwish {
T relu6 = mn(mx(x + T(3), T(0)), T(6));
return x * relu6 / T(6);
}
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
T operator()(T const &x, Params const &params_) const {
return this->operator()(x);
}
};
template <>
@ -261,6 +422,13 @@ struct HardSwish<float> {
T relu6 = mn(mx(x + T(3), T(0)), T(6));
return x * relu6 * 0.16666667f;
}
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
T operator()(T const &x, Params const &params_) const {
return this->operator()(x);
}
};
template <typename T, int N>
@ -277,6 +445,13 @@ struct HardSwish<Array<T, N> > {
return y;
}
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &x, Params const &params_) const {
return this->operator()(x);
}
};
template <int N>
@ -292,6 +467,13 @@ struct HardSwish<Array<half_t, N> > {
return mul(mul(mn(mx(add(rhs, T(3)), T(0)), T(6)), rhs), T(0.16666667f));
}
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &x, Params const &params_) const {
return this->operator()(x);
}
};
//
@ -311,6 +493,13 @@ struct GELU {
return T(cutlass::constants::half<T>() * scalar *
(cutlass::constants::one<T>() + (T)erff((float)(scalar / cutlass::constants::root_two<T>()))));
}
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
T operator()(T const &scalar, Params const &params_) const {
return this->operator()(scalar);
}
};
template <>
@ -320,6 +509,13 @@ struct GELU<float> {
return cutlass::constants::half<float>() * scalar *
(cutlass::constants::one<float>() + erff( scalar / cutlass::constants::root_two<float>() ));
}
using Params = LinearCombinationGenericParams<float>;
CUTLASS_HOST_DEVICE
float operator()(float const &scalar, Params const &params_) const {
return this->operator()(scalar);
}
};
template <>
@ -329,6 +525,13 @@ struct GELU<double> {
return cutlass::constants::half<double>() * scalar *
(cutlass::constants::one<double>() + erf( scalar / cutlass::constants::root_two<double>() ));
}
using Params = LinearCombinationGenericParams<double>;
CUTLASS_HOST_DEVICE
double operator()(double const &scalar, Params const &params_) const {
return this->operator()(scalar);
}
};
template <typename T, int N>
@ -345,6 +548,13 @@ struct GELU<Array<T, N> > {
return y;
}
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs, Params const &params_) const {
return this->operator()(rhs);
}
};
// GELU operator implemented using the Taylor series approximation
@ -360,6 +570,9 @@ struct GELU_taylor {
return T(cutlass::constants::half<T>() * z *
(cutlass::constants::one<T>() + fast_tanh(k0 * z * (cutlass::constants::one<T>() + k1 * z * z))));
}
using Params = LinearCombinationGenericParams<T>;
};
template <int N>
@ -386,6 +599,8 @@ struct GELU_taylor<Array<half_t, N> > {
return y;
}
using Params = LinearCombinationGenericParams<half_t>;
};
template <typename T, int N>
@ -403,6 +618,8 @@ struct GELU_taylor<Array<T, N> > {
return y;
}
using Params = LinearCombinationGenericParams<T>;
};
/// Computes backwards pass for GELU operator assuming d_t is the layer gradient and

View File

@ -40,6 +40,7 @@
#include "cutlass/functional.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/epilogue/thread/scale_type.h"
#include "cutlass/epilogue/thread/linear_combination_params.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -76,22 +77,23 @@ public:
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
using ComputeFragment = Array<ElementCompute, kCount>;
using ParamsBase = LinearCombinationParams;
static FloatRoundStyle const kRound = Round;
/// Host-constructable parameters structure
struct Params {
struct Params : ParamsBase{
ElementCompute alpha; ///< scales accumulators
ElementCompute beta; ///< scales source tensor
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
//
// Methods
//
CUTLASS_HOST_DEVICE
Params():
ParamsBase(
ElementCompute(1),
ElementCompute(0)
),
alpha(ElementCompute(1)),
beta(ElementCompute(0)),
alpha_ptr(nullptr),
@ -101,30 +103,43 @@ public:
Params(
ElementCompute alpha,
ElementCompute beta
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
}
):
ParamsBase(alpha, beta),
alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(
ElementCompute alpha
): alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) {
}
):
ParamsBase(alpha, ElementCompute(0)),
alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(
ElementCompute const *alpha_ptr,
ElementCompute const *beta_ptr
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
}
):
ParamsBase(*alpha_ptr, *beta_ptr),
alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { }
CUTLASS_HOST_DEVICE
Params(
ElementCompute const *alpha_ptr
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) {
):
ParamsBase(*alpha_ptr, ElementCompute(0)),
alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(
ParamsBase const& base
): ParamsBase(base), alpha_ptr(nullptr), beta_ptr(nullptr) {
#if defined(__CUDA_ARCH__)
alpha = reinterpret_cast<ElementCompute const&>(base.alpha_data);
beta = reinterpret_cast<ElementCompute const&>(base.beta_data);
#else
memcpy( alpha, base.alpha_data, sizeof(ElementCompute) );
memcpy( beta, base.alpha_data, sizeof(ElementCompute) );
#endif
}
};
@ -142,7 +157,6 @@ public:
/// Constructs the function object, possibly loading from pointers in host memory
CUTLASS_HOST_DEVICE
LinearCombination(Params const &params) {
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
}

View File

@ -83,40 +83,7 @@ public:
static FloatRoundStyle const kRound = Round;
/// Host-constructable parameters structure
struct Params {
ElementCompute alpha; ///< scales accumulators
ElementCompute beta; ///< scales source tensor
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
//
// Methods
//
CUTLASS_HOST_DEVICE
Params():
alpha(ElementCompute(1)),
beta(ElementCompute(0)),
alpha_ptr(nullptr),
beta_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(
ElementCompute alpha,
ElementCompute beta = ElementCompute(0)
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
}
CUTLASS_HOST_DEVICE
Params(
ElementCompute const *alpha_ptr,
ElementCompute const *beta_ptr = nullptr
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
}
};
using Params = typename ActivationFunctor<FragmentCompute>::Params;
private:
@ -124,8 +91,7 @@ private:
// Data members
//
ElementCompute alpha_;
ElementCompute beta_;
Params params_;
bool skip_elementwise_;
public:
@ -133,9 +99,9 @@ public:
/// Constructs the function object, possibly loading from pointers in host memory
CUTLASS_HOST_DEVICE
LinearCombinationGeneric(Params const &params) {
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
params_ = params;
params_.alpha = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
params_.beta = (params.beta_ptr ? *params.beta_ptr : params.beta);
skip_elementwise_ = false;
}
@ -148,14 +114,14 @@ public:
if (Scale == ScaleType::Nothing) return false;
return beta_ != ElementCompute(0);
return params_.beta != ElementCompute(0);
}
/// Functionally required for serial reduction in the epilogue
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition, int k_partition_count) {
if (k_partition) {
beta_ = ElementCompute(1);
params_.beta = ElementCompute(1);
}
if (k_partition != k_partition_count - 1) {
@ -186,15 +152,15 @@ public:
if (Scale == ScaleType::NoBetaScaling) {
intermediate = converted_source;
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
intermediate = mul_add_accumulator(params_.alpha, converted_accumulator, intermediate); // D = alpha * Accum + X
} else if (Scale == ScaleType::Nothing) {
intermediate = converted_accumulator;
} else {
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
intermediate = mul_add_source(params_.beta, converted_source); // X = beta * C + uniform
intermediate = mul_add_accumulator(params_.alpha, converted_accumulator, intermediate); // D = alpha * Accum + X
}
intermediate = skip_elementwise_ ? intermediate : activation(intermediate);
intermediate = skip_elementwise_ ? intermediate : activation(intermediate, params_);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
@ -222,10 +188,10 @@ public:
if (Scale == ScaleType::Nothing) {
intermediate = converted_accumulator;
} else {
intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
intermediate = mul_add_accumulator(params_.alpha, converted_accumulator); // D = alpha * Accum
}
intermediate = skip_elementwise_ ? intermediate : activation(intermediate);
intermediate = skip_elementwise_ ? intermediate : activation(intermediate, params_);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;

View File

@ -0,0 +1,75 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace thread {
/////////////////////////////////////////////////////////////////////////////////////////////////
struct LinearCombinationParams {
uint64_t alpha_data[2];
uint64_t beta_data[2];
CUTLASS_HOST_DEVICE
LinearCombinationParams()
: alpha_data {0lu, 0lu}, beta_data {0lu, 0lu}
{ }
template <typename ElementCompute>
CUTLASS_HOST_DEVICE
LinearCombinationParams(ElementCompute alpha, ElementCompute beta)
: alpha_data {0lu, 0lu}, beta_data {0lu, 0lu}
{
#if defined(__CUDA_ARCH__)
reinterpret_cast<ElementCompute&>(alpha_data) = alpha;
reinterpret_cast<ElementCompute&>(beta_data) = beta;
#else
memcpy( alpha_data, &alpha, sizeof(ElementCompute) );
memcpy( beta_data, &beta, sizeof(ElementCompute) );
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace thread
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,156 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/fast_math.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
int Rank
>
struct PredicatedTileIteratorAffineLayoutRankNParams {
using Layout = layout::AffineRankN<Rank>;
using TensorCoord = typename Layout::TensorCoord;
static bool const kBigEndian = false;
//
// Data members
//
Layout layout;
/// Stride in units of bytes along M modes
Coord<Layout::kRank/2, typename Layout::LongIndex> stride_m;
/// Stride in units of bytes along N modes
Coord<Layout::kRank/2, typename Layout::LongIndex> stride_n;
/// Fast divmod objects divided by tensor extents
FastDivmod divmod_m[(Layout::kRank == 2) ? 1 : (Layout::kRank/2 - 1)];
/// Fast divmod objects divided by tensor extents
FastDivmod divmod_n[(Layout::kRank == 2) ? 1 : (Layout::kRank/2 - 1)];
int64_t rank2_inc_col;
int64_t rank2_inc_row;
//
// Methods
//
CUTLASS_HOST_DEVICE
PredicatedTileIteratorAffineLayoutRankNParams() { }
CUTLASS_HOST_DEVICE
PredicatedTileIteratorAffineLayoutRankNParams(TensorCoord const &extent,
Layout const &layout_,
int64_t element_sizeof_bits)
: layout(layout_)
{
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Layout::kRank / 2; ++i) {
stride_m[i] = OffsetBytes(layout_.stride()[i], element_sizeof_bits);
stride_n[i] = OffsetBytes(layout_.stride()[i + Layout::kRank / 2], element_sizeof_bits);
}
if (kBigEndian) {
// "Big Endian" scheme
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Layout::kRank / 2 - 1; ++i) {
divmod_m[i] = FastDivmod(extent[i + 1]);
divmod_n[i] = FastDivmod(extent[i + Layout::kRank / 2 + 1]);
}
}
else {
// "Little Endian" scheme
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Layout::kRank / 2 - 1; ++i) {
divmod_m[i] = FastDivmod(extent[i]);
divmod_n[i] = FastDivmod(extent[i + Layout::kRank / 2]);
}
}
#if 0
//
// Debug print statements to verify extents and strides are passed correctly.
//
printf("PredicatedTileIteratorAffine::Params() entered\n");
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Layout::kRank; ++i) {
printf(" extent[%d]: %d\n", i, extent[i]);
}
for (int i = 0; i < Layout::kRank; ++i) {
printf(" stride[%d]: %ld\n", i, layout_.stride()[i]);
}
printf("PredicatedTileIteratorAffine::Params() returning\n");
#endif
}
CUTLASS_HOST_DEVICE
PredicatedTileIteratorAffineLayoutRankNParams(Layout const &layout_,
int32_t threadmap_delta_kColumn,
int32_t threadmap_delta_kRow,
int64_t element_sizeof_bits)
: layout(layout_)
{
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Layout::kRank / 2; ++i) {
stride_m[i] = OffsetBytes(layout_.stride()[i], element_sizeof_bits);
stride_n[i] = OffsetBytes(layout_.stride()[i + Layout::kRank / 2], element_sizeof_bits);
}
rank2_inc_col = threadmap_delta_kColumn * stride_n[0];
rank2_inc_row = threadmap_delta_kRow * stride_m[0];
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -561,6 +561,17 @@ CUTLASS_HOST_DEVICE int64_t OffsetBytes(int64_t index) {
}
}
CUTLASS_HOST_DEVICE int64_t OffsetBytes(int64_t index, int64_t element_sizeof_bits) {
if (element_sizeof_bits >= 8) {
return index * (element_sizeof_bits / 8);
}
else {
int64_t const kElementsPerByte = ((8 / element_sizeof_bits) + ((element_sizeof_bits >= 8) ? 1 : 0));
return index / kElementsPerByte;
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
// Min/Max
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -88,12 +88,12 @@ public:
using ElementA = typename MapArguments::ElementA;
using LayoutA = typename MapArguments::LayoutA;
static ComplexTransform const kTransformA = MapArguments::kTransformA;
static int const kAlignmentA = GemmKernel::kAlignmentA;
static int const kAlignmentA = MapArguments::kAlignmentA;
using ElementB = typename MapArguments::ElementB;
using LayoutB = typename MapArguments::LayoutB;
static ComplexTransform const kTransformB = MapArguments::kTransformB;
static int const kAlignmentB = GemmKernel::kAlignmentB;
static int const kAlignmentB = MapArguments::kAlignmentB;
using ElementC = typename GemmKernel::ElementC;
using LayoutC = typename MapArguments::LayoutC;

View File

@ -40,7 +40,7 @@
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/transform/thread/unaryOp.h"
#include "cutlass/transform/thread/unary_op.h"
#include "cutlass/array.h"
#include "cutlass/half.h"
@ -1273,6 +1273,40 @@ struct NumericArrayConverter<uint8_t, int, N, Round> {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Array<int8_t> <= Array<float>
/// Conversion is performed with saturation regardless of setting of
/// the `Round` template parameter.
template <
int N,
FloatRoundStyle Round
>
struct NumericArrayConverter<int8_t, float, N, Round> {
using result_type = Array<int8_t, N>;
using source_type = Array<float, N>;
static FloatRoundStyle const round_style = Round;
CUTLASS_HOST_DEVICE
static result_type convert(source_type const & source) {
// Convert float to int
Array<int32_t, N> temporary;
NumericArrayConverter<int, float, N, Round> compute_converter;
temporary = compute_converter(source);
// Convert to int to int8_t
NumericArrayConverter<int8_t, int32_t, N, Round> destination_converter;
return destination_converter(temporary);
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
return convert(s);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && \
((__CUDACC_VER_MAJOR__ > 10) || \
((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))

View File

@ -189,3 +189,35 @@ TEST(NumericConversion, f16x8_to_f32x8_rn) {
}
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(NumericConversion, f32x8_to_s8x8_rn) {
int const kN = 8;
using Source = float;
using Destination = int8_t;
dim3 grid(1, 1);
dim3 block(1, 1);
cutlass::HostTensor<Destination, cutlass::layout::RowMajor> destination({1, kN});
cutlass::HostTensor<Source, cutlass::layout::RowMajor> source({1, kN});
for (int i = 0; i < kN; ++i) {
source.host_data()[i] = float(i);
}
source.sync_device();
test::core::kernel::convert<Destination, Source, kN><<< grid, block >>>(
reinterpret_cast<cutlass::Array<Destination, kN> *>(destination.device_data()),
reinterpret_cast<cutlass::Array<Source, kN> const *>(source.device_data())
);
destination.sync_host();
for (int i = 0; i < kN; ++i) {
EXPECT_TRUE(float(destination.host_data()[i]) == source.host_data()[i]);
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -16,7 +16,7 @@ math_inst = MathInstruction(
tile_description = TileDescription(
[128, 128, 8], 4, [2, 4, 1],
math_inst, 80, 80
math_inst
)
A = TensorDescription(
@ -31,10 +31,12 @@ C = TensorDescription(
cutlass.float32, cutlass.RowMajor, 1
)
epilogue_functor = LinearCombination(cutlass.float32, 1, cutlass.float32, cutlass.float32)
operation = GemmOperationUniversal(
arch=80, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=cutlass.float32,
epilogue_functor=EpilogueFunctor.LinearCombination,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)
@ -54,7 +56,7 @@ beta = 0.0
arguments = GemmArguments(
operation=operation, problem_size=problem_size,
A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D,
output_op=LinearCombinationFunctorArguments(alpha, beta),
output_op=operation.epilogue_type(alpha, beta),
gemm_mode=cutlass.gemm.Mode.Gemm, split_k_splices=1
)
@ -68,6 +70,14 @@ assert torch.equal(tensor_D, tensor_D_ref)
```
PyCUTLASS also provides infrastructures for profiling, compiled artifact management, and pool memory manager
## Supported Features
PyCUTLASS currently supports following operations:
* GEMM with mode {Serial, Parallel Split K, Batched GEMM, Array GEMM}, op class {SIMT, TensorCore}, data type {int8, f16, bf16, f32, f64}, layout {RowMajor, ColumnMajor, Row/ColumnMajorInterleaved<32> for int8}, math operation {MultiplyAdd, MultiplyAddFastF16, MultiplyAddFastBF16, MultiplyAddFastF32}, swizzling functions {IdentitySwizzle<1,2,4,8>, HorizontalSwizzle, BatchedIdentitySwizzle}, and epilogue {LinearCombination, LinearCombinationClamp}
* GEMM grouped with op class {SIMT, TensorCore}, data type {int8, f16, bf16, f32, f64}, layout {RowMajor, ColumnMajor}, math operation {MultiplyAdd, MultiplyAddFastF16, MultiplyAddFastBF16, MultiplyAddFastF32}, scheduling mode {Host, Device}, and epilogue {LinearCombination, LinearCombinationClamp}.
* Conv2d with {Fprop, Dgrad, Wgrad}, op class {SIMT, TensorCore}, data type {int8, f16, bf16, f32, f64}, layout {Tensor NHWC, TensorNC32HW32 and TensorC32RSK for int8}, math operation {MultiplyAdd, MultiplyAddFastF16, MultiplyAddFastBF16, MultiplyAddFastF32}, split-k mode {Parallel, Serial}, and epilogue {LinearCombination, LinearCombinationClamp}
The tiling size of above operations can also be customized.
## Installation
### Using Docker
@ -94,12 +104,19 @@ cd $CUTLASS_PATH/tools/library/scripts/pycutlass && bash build.sh
```
## Examples
Examples can be found in `$CUTLASS_PATH/examples/40_cutlass_py`
Examples can be found in [$CUTLASS_PATH/examples/40_cutlass_py](examples/40_cutlass_py)
## Test
The test cases are listed in `$CUTLASS_PATH//tools/library/scripts/pycutlass/test`. The unit test can be run with
```shell
cd $CUTLASS_PATH/tools/library/scripts/pycutlass/test/unit && python test_sm80.py
cd $CUTLASS_PATH/tools/library/scripts/pycutlass/test/example && run_all_example.sh
```
## build documentation
Run
```shell
bash build_doc.sh
```

View File

@ -1,2 +1,4 @@
python setup.py develop
pip install enum-tools
pip install sphinx-toolbox
pip install m2r2
sphinx-build -b html docs/source/ docs/build/html

View File

@ -50,7 +50,7 @@
# -- Project information -----------------------------------------------------
project = 'PyCutlass'
copyright = '2022, Andrew Kerr; Zhaodong Chen; Haicheng Wu; Szymon Migacz; Graham Markall'
copyright = '2022, Zhaodong Chen; Andrew Kerr; Haicheng Wu; Szymon Migacz; Graham Markall'
author = 'Zhaodong Chen; Andrew Kerr; Haicheng Wu; Szymon Migacz; Graham Markall'
@ -65,9 +65,12 @@ extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.intersphinx',
'enum_tools.autoenum',
'sphinx.ext.autosummary'
'sphinx.ext.autosummary',
'm2r2'
]
source_suffix = [".rst", ".md"]
autosummary_generate = True
autosummary_imported_members = True
@ -85,7 +88,7 @@ exclude_patterns = []
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'classic'
html_theme = 'bizstyle'
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,

View File

@ -1,2 +1,100 @@
cutlass
=======
.. rubric:: Operator Classification
.. autoclass:: cutlass.OpClass
:members:
.. rubric:: GEMM Layout
.. autoclass:: cutlass.RowMajor
:members:
.. autoclass:: cutlass.ColumnMajor
:members:
.. autoclass:: cutlass.RowMajorInterleaved32
:members:
.. autoclass:: cutlass.ColumnMajorInterleaved32
:members:
.. rubric:: Conv Layout
.. autoclass:: cutlass.TensorNHWC
:members:
.. autoclass:: cutlass.TensorNC32HW32
:members:
.. autoclass:: cutlass.TensorC32RSK32
:members:
.. rubric:: Threadblock Swizzle
.. autoclass:: cutlass.dim3
:special-members:
:members:
.. autoclass:: cutlass.IdentitySwizzle1
:special-members:
:members:
.. autoclass:: cutlass.IdentitySwizzle2
:special-members:
:members:
.. autoclass:: cutlass.IdentitySwizzle4
:special-members:
:members:
.. autoclass:: cutlass.IdentitySwizzle8
:special-members:
:members:
.. autoclass:: cutlass.HorizontalSwizzle
:special-members:
:members:
.. autoclass:: cutlass.BatchedIdentitySwizzle
:special-members:
:members:
.. autoclass:: cutlass.StridedDgradIdentitySwizzle1
:special-members:
:members:
.. autoclass:: cutlass.StridedDgradIdentitySwizzle4
:special-members:
:members:
.. autoclass:: cutlass.StridedDgradHorizontalSwizzle
:special-members:
:members:
.. rubric:: Coordinates
.. autoclass:: cutlass.Tensor4DCoord
:special-members:
:members:
.. autoclass:: cutlass.Tensor3DCoord
:special-members:
:members:
.. autoclass:: cutlass.MatrixCoord
:special-members:
:members:
.. rubric:: Convolution
.. autoclass:: cutlass.conv.Operator
:members:
.. autoclass:: cutlass.conv.IteratorAlgorithm
:members:
.. autoclass:: cutlass.conv.StrideSupport
:members:

View File

@ -1,6 +0,0 @@
Descriptions
==============
.. autoclass:: pycutlass.TileDescription
:special-members:
:members:

View File

@ -1,5 +0,0 @@
Frontend
==============
.. autoclass:: pycutlass.NumpyFrontend
:members:

View File

@ -3,27 +3,29 @@
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
Welcome to PyCutlass's documentation!
CUTLASS Python Project Documentation
=====================================
.. mdinclude:: ../../README.md
.. toctree::
:maxdepth: 2
:caption: Contents:
Indices and tables
.. Indices and tables
.. ==================
.. * :ref:`genindex`
.. * :ref:`modindex`
.. * :ref:`search`
Indices
==================
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`
.. toctree::
types
cutlass
descriptor
frontend
user_guide
visitor_tree
gemm_op
conv2d_op
cutlass

View File

@ -0,0 +1,225 @@
# Epilogue Visitor Tree
The Epilogue Visitor Tree is an experimental feature that directly generates epilogues from user-provide python functions.
## Usage
The Epilogue Visitor tree support many different operations.
### Unary functions
Epilogue Visitor Tree supports unary functions like activation functions. For example,
```python
class UnaryEpilogue_(EpilogueVisitTree):
def __call__(
self, accum: 'tensor', c: 'tensor',
alpha: 'scalar', beta: 'scalar'):
#
T = leaky_relu.numpy(accum, 0.2)
Z = alpha * T + beta * c
return Z
epilogue_functor = UnaryEpilogue_(
epilogue_functor, tile_description, math_inst.element_accumulator,
C.alignment, element_epilogue, C.element)
```
### Broadcast Operation
Epilogue Visitor Tree supports broadcasting row and column vectors to the whole output matrix. To use broadcast, you just need to specify whether the source vector is a `row` vector or a `column` vector. Here is an example.
```python
class ColumnBroadcast_(EpilogueVisitTree):
def __call__(
self, accum: 'tensor', c: 'tensor',
vector: 'column', alpha: 'scalar', beta: 'scalar'):
#
T = accum + vector
scale_T = leaky_relu.numpy(alpha * T, 0.2)
Z = scale_T + beta * c
return Z, T
epilogue_functor = ColumnBroadcast_(
epilogue_functor, tile_description, math_inst.element_accumulator,
C.alignment, element_epilogue, C.element)
```
### Reduction Operation
Epilogue Visitor Tree also supports row and column-wise reduction in each threadblock tile. The syntax for reduction is
```python
{reduction_output} = reduction_op({input_tensor}, {row|column}, {Add}, {threadblock_shape.n|threadblock_shape.m})
```
The `{row|column}` indicates whether the `row` vectors are reduced or the `column` vectors are reduction. The `{Add}` specifies the reduction operation. The `{threadblock_shape.n|threadblock_shape.m}` are the reduction lengths.
**Constraint**
* The `{input_tensor}` can only be the name of source or intermediate result. `reduction_op(A + B, ...)` will not work, please use `C = A + B`, `reduction_op(C, ...)` instead.
* The `{reduction_output}` cannot be used in the epilogue. It will be directly written to global memory after the reduction is done.
```python
class RowReduction_(EpilogueVisitTree):
def __call__(
self, accum: 'tensor', c: 'tensor',
alpha: 'scalar', beta: 'scalar'):
#
D = alpha * accum + tanh.numpy(beta * c)
reduction = reduction_op(D, "row", "Add", args.threadblock_shape[1])
return D, reduction
epilogue_functor = RowReduction_(
epilogue_functor, tile_description, math_inst.element_accumulator,
C.alignment, element_epilogue, C.element)
epilogue_functor.initialize()
```
## Get output_op
As shown in the user guide, an `output_op` is required by the argument wrapper. We will take the `RowReduction_` as an example to show how to get `output_op`.
```python
class RowReduction_(EpilogueVisitTree):
def __call__(
self, accum: 'tensor', c: 'tensor',
alpha: 'scalar', beta: 'scalar'):
#
D = alpha * accum + tanh.numpy(beta * c)
reduction = reduction_op(D, "row", "Add", args.threadblock_shape[1])
return D, reduction
epilogue_functor = RowReduction_(
epilogue_functor, tile_description, math_inst.element_accumulator,
C.alignment, element_epilogue, C.element)
epilogue_functor.initialize()
cta_n = args.threadblock_shape[1]
num_cta_n = (problem_size.n() + cta_n - 1) // cta_n
reduction = np.zeros(shape=(args.batch * problem_size.m() * num_cta_n,), dtype=getattr(np, element_c))
# get output op
output_op = operation.epilogue_type(
D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()]
)
```
Like other epilogue functors such as `LinearCombination`, the output op for EpilogueVisitorTree is also created with `operation.epilogue_type(*)`. However, there are two differences:
* The arguments need to be passed as keyword-arguments. The keywords are the argument names in `def __call__`.
* An additional `problem_size=[problem_size.m(), problem_size.n()]` is required.
## Add new Unary Operation (e.g. Activation Function)
To add additional unary operation into epilogue visitor tree, a new unary op
should be created for `VisitorOpUnary`. We will take `tanh` as an example.
### Step 1: define TanhVisitor
The visitor defines the parameters and computation required by the unary option.
The unary operations are registered in [pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/unary_ops.h](tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/unary_ops.h). But you can define it in any header file and include the header file in [pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_unary.h](tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_unary.h).
* Two template arguments are required:
* `T`: data type used to compute the unary operation
* `N`: compute fragment length
* We also need to provide the `Arguments` and `Params` structures. The `Arguments` will be assembled by [ctypes](https://docs.python.org/3/library/ctypes.html), the `Params` will be generated from `Arguments` automatically. If the unary function takes no argument, an integer like `int tmp` can be provide to ensure the correctness of ctypes.
* The constructor can only take the `params` as the single argument.
* The operation is defined in `Array<T, N> operator()(Array<T, N> const &frag) const `. On common way to do that is first define a scalar computation, and them use it for the fragment computation with an unrolled for-loop.
* A guard function is required. If it returns `true`, it will disable all the children nodes of the unary node and return zeros to parent node. This is very helpful for multiplication with scalar while the scalar is `0`. For general cases, you can just return `true`.
```c++
// T: data type used to compute the unary operation
// N: compute fragment length
template <typename T, int N>
struct TanhVisitor {
/// Argument
struct Arguments {
// a placeholder argument to ensure correctness of ctypes
int tmp;
CUTLASS_HOST_DEVICE
Arguments(): tmp(0) { };
CUTLASS_HOST_DEVICE
Arguments(int tmp): tmp(tmp) { };
};
/// Param
struct Params {
CUTLASS_HOST_DEVICE
Params(){ };
Params(Arguments const &args) { }
};
/// Constructor
CUTLASS_HOST_DEVICE
TanhVisitor(Params const &params) { }
// scalar operator
CUTLASS_HOST_DEVICE
T tanh_op(T const &scalar) const {
return fast_tanh(scalar);
}
/// vector operator
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &frag) const {
Array<T, N> y;
CUTLASS_PRAGMA_UNROLL
for (int i=0; i < N; ++i) {
y[i] = tanh_op(frag[i]);
}
return y;
}
// Guard
CUTLASS_HOST_DEVICE
bool guard() {
return true;
}
};
```
### Step 2: register Tanh function
After defining the function in C++, we need to register it in python. The class below gives an example.
* The init function takes the data type `element_compute`, which will be the `T` in the C++ template.
In the init function, we also generate the `_Arguments` class as a `ctypes.Structure`. It includes all the data members in the `TanhVisitor::Arguments`.
* The `_Arguments` need to be registered as `self.argument_type` of `tanh` class.
* A `emit` function is required to emit the namespace and typename of `TanhVisitor`.
* A staticmethod as numpy reference is required to implement the python code to parse.
The built-in functions are defined in [pycutlass/src/pycutlass/epilogue.py](tools/library/scripts/pycutlass/src/pycutlass/epilogue.py). You can defined yours in any file as long as it can be found by [/pycutlass/src/pycutlass/parser.py](tools/library/scripts/pycutlass/src/pycutlass/parser.py).
```python
class tanh(ActivationFunctor):
def __init__(self, element_compute) -> None:
super().__init__()
class _Arguments(ctypes.Structure):
_fields_ = [
("tmp", ctypes.c_int)
]
def __init__(self, *args) -> None:
self.tmp = 0
self.argument_type = _Arguments
def emit(self):
return "cutlass::TanhVisitor"
@staticmethod
def numpy(x: np.ndarray):
return np.tanh(x)
```
### Step 3: Run the function
Now the new unary op is ready to use. An epilogue visitor tree can be built with
```python
class RowReduction_(EpilogueVisitTree):
def __call__(
self, accum: NDArray['tensor', 'float32'], c: NDArray['tensor', 'float32'],
alpha: 'float32', beta: 'float32'):
#
D = alpha * accum + tanh.numpy(beta * c)
reduction = reduction_op(D, "row", "Add", args.threadblock_shape[1])
return D, reduction
epilogue_functor = RowReduction_(
epilogue_functor, tile_description, math_inst.element_accumulator,
C.alignment, element_epilogue, C.element)
epilogue_functor.initialize()
```
## Limitations and Future work
Although the Epilogue Visitor Tree brings great flexibility to epilogue construction, as the epilogue is formulated as a single tree, there are several limitations.
* [Future Work] Serial and Parallel Split-K GEMM are not supported yet.
* To support serial split-k, additional tree transformation pass is required to inject a `binaryOpNode(Add)` + `TensorInputNode` before each `TensorOutputNode` to fetch the partial sum back. The `semaphore` also needs to be passed into epilogue.
* To support parallel split-k, an Reduction with visitor kernel is required.
* [Future Work] Convolution and GEMM Grouped are not supported yet.
* To support Conv2d and GEMM Grouped, corresponding *_with_visitor kernels are required.
* [Limitation] If the same node is used by two operations (except that one of them is reduction), the node and all its offsprings will be executed twice.
* [Limitation] The result of reduction can only be used as the return value.

View File

@ -0,0 +1,283 @@
# Basics of PyCUTLASS
PyCUTLASS handles the following things when launch the CUTLASS kernels
* Memory management
* Operation Description
* Code emission and compilation
* Arguments preprocessing
* Kernel launching
* Result Synchronization
## Memory management
PyCUTLASS uses [RMM](https://github.com/rapidsai/rmm) to manage device memory. At the begining of the program, call
```python
pycutlass.get_memory_pool({init_pool_size_in_bytes}, {max_pool_size_in_bytes})
```
We also provide functions to query the allocated size.
```python
bytes = get_allocated_size()
```
## Operation Description
PyCUTLASS provides operation description for GEMM, GEMM Grouped and Conv2d operations. These operation descriptions are assembled from four foundamental concepts
* Math Instruction: math instruction executed in GPU cores
* Tile Description: tiling sizes and pipeline stages
* Operand Description: data type, layout, memory alignment
* Epilogue Functor: epilogue function
### Math Instruction
The math instruction is defined as follows:
```python
math_inst = MathInstruction(
{instruction_shape}, {element_a}, {element_b},
{element_acc}, {opclass}, {math_operation}
)
```
The `{instruction_shape}` and `{opclass}` defines the instruction size and type. The table below lists valid combinations. `{element_a}`, `{element_b}` define the source operand data type for each instructions, and `{element_acc}` defines the accumulator type. The `{math_operation}` defines the math operation applied.
|Opclass | element_a/element_b | element_acc | instruction_shape | math_operation |
| -- | -- | -- | -- | -- |
| cutlass.OpClass.TensorOp | cutlass.float64 | cutlass.float64 | [8, 8, 4] | MathOperation.multiply_add|
| | cutass.float32 cutlass.tfloat32, cutlass.float16 cutlass.bfloat16 | cutlass.float32 | [16, 8, 8] | MathOperation.multiply_add MathOperation.multiply_add_fast_f32 MathOperation.multiply_add_fast_f16 MathOperation.multiply_add_fast_bf16 |
| | cutlass.float16 | cutlass.float16/cutlass.float32|[16, 8, 16]| MathOperation.multiply_add |
| | cutlass.bfloat_16 | cutlass.float32 | [16, 8, 16]|MathOperation.multiply_add |
| | cutlass.int8 | cutlass.int32 | [16, 8, 32] | MathOperation.multiply_add_saturate|
|cutlass.OpClass.Simt| cutlass.float64 | cutlass.float64 | [1, 1, 1] | MathOperation.multiply_add |
| | cutlass.float32 | cutlass.float32 | [1, 1, 1] | MathOperation.multiply_add |
The `cutlass.OpClass.TensorOp` indicates that the tensor core is used, while `cutlass.OpClass.Simt` uses the SIMT Core.
The `multiply_add_fast_f32` emulates fast accurate SGEMM kernel which is accelerated
using Ampere Tensor Cores. More details can be found in [examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm](examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm).
### Tile Description
The tile description describes the threadblock and warp tiling sizes, as well as the pipeline stages.
```python
tile_description = TileDescription(
{threadblock_shape}, {stages}, {warp_count},
math_inst
)
```
The `{threadblock_shape}` is a list of 3 integers `[Tile_M, Tile_N, Tile_K]` that defines the threadblock tiling size. `{stages}` defines the number of software pipeline stages ([detail](https://developer.nvidia.com/blog/controlling-data-movement-to-boost-performance-on-ampere-architecture/)). `{warp_count}` defines the number of warps along `M`, `N`, and `K` dimension. I.e., with `{threadblock_shape}=[Tile_M, Tile_N, Tile_K]` and `{warp_count}=[W_M, W_N, W_K]`, the warp tile size would be `[Tile_M / W_M, Tile_N / W_N, Tile_K / W_K]`.
### Operand Description
The Operand Description defines the data type, layout, and memory alignment of input tensor A, B, and C. The output D shares the same attributes with C. The description is as follows:
```python
A = TensorDescription(
{element_a}, {layout_a}, {alignment_a}
)
B = TensorDescription(
{element_b}, {layout_b}, {alignment_b}
)
C = TensorDescription(
{element_c}, {layout_c}, {alignment_c}
)
```
The table below lists the supported layout and data types for each operation
| Operation | data type | layout |
| -- | -- | -- |
| GEMM, GEMM Grouped | cutlass.float64, cutlass.float32, cutlass.float16, cutlass.bfloat16 | cutlass.RowMajor, cutlass.ColumnMajor |
| | cutlass.int8 | cutlass.RowMajor, cutlass.ColumnMajor, cutlass.RowMajorInterleaved32, cutlass.ColumnMajorInterleaved32|
| Conv2d Fprop, Dgrad, Wgrad | cutlass.float64, cutlass.float32, cutlass.float16, cutlass.bfloat16 | cutlass.TensorNHWC |
| Conv2d Fprop | cutlass.int8 | cutlass.TensorNHWC, cutlass.TensorNC32HW32, cutlass.TensorC32RSK32|
### Epilogue Functor
The epilogue functor defines the epilogue executed after mainloop.
We expose the following epilogue functors.
| Epilogue Functor | Remark |
| -- | -- |
| LinearCombination | $D=\alpha \times Accm + \beta \times C$ |
| LinearCombinationClamp | $D=\alpha \times Accm + \beta \times C$, Output is clamped to the maximum value of the data type output |
| FastLinearCombinationClamp | $D=\alpha \times Accm + \beta \times C$, only used for problem size $K\le 256$ for cutlass.int8, with accumulator data type `cutlass.int32` and epilogue compute data type `cutlass.float32` |
| LinearCombinationGeneric | $D = activation(\alpha \times Accm + \beta \times C)$, available activations include `relu`, `leaky_relu`, `tanh`, `sigmoid`, `silu`, `hardswish`, and `gelu` |
The epilogue functors can be created as follows
```python
# LinearCombination
epilogue_functor = LinearCombination(
element_C, alignment_c, element_acc, element_epilogue_compute
)
# LinearCombinationClamp
epilogue_functor = LinearCombinationClamp(
element_C, alignment_c, element_acc, element_epilogue_compute
)
# FastLinearCombinationClamp
epilogue_functor = FastLinearCombinationClamp(
element_C, alignment_c
)
# LinearCombinationGeneric
epilogue_functor = LinearCombinationGeneric(
relu(element_epilogue_compute), element_C, alignment_c,
element_acc, element_epilogue_compute
)
```
We also provides an experimental feature "Epilogue Visitor Tree" for GEMM operation. The details can be found in [EpilogueVisitorTree](tools/library/scripts/pycutlass/docs/source/md/EpilogueVisitorTree.md).
### GEMM Operation
The GEMM Operation description can be created with
```python
operation = GemmOperationUniversal(
{compute_capability}, tile_description,
A, B, C, epilogue_functor,
{swizzling_functor}, {visitor}
)
```
* `{compute_capability}` is an integer indicates the compute capability of the GPU. For A100, it is 80.
* `{swizzling_functor}` describes how threadblocks are scheduled on GPU. This is used to improve the L2 Locality ([detail](https://developer.nvidia.com/blog/optimizing-compute-shaders-for-l2-locality-using-thread-group-id-swizzling/)). Currently we support `cutlass.{IdentitySwizzle1|IdentitySwizzle2|IdentitySwizzle4|IdentitySwizzle8|BatchedIdentitySwizzle}`. The last one is used for batched or array GEMM.
* `{visitor}`: a bool variable indicates whether the epilogue visitor tree is used.
### GEMM Grouped Operation
The GEMM Grouped Operation description can be created with
```python
operation = GemmOperationGrouped(
compute_capability, tile_description,
A, B, C, epilogue_functor,
swizzling_functor, {precompute_mode}
)
```
* `{precompute_mode}`: It could be `SchedulerMode.Host` or `SchedulerMode.Device`. See [examples/24_gemm_grouped](examples/24_gemm_grouped) for more details.
### Conv2d Operation
The Conv2d Operation description can be created with
```python
operation = Conv2dOperation(
{conv_kind}, {iterator_algorithm},
compute_capability, tile_description,
A, B, C, {stride_support},
epilogue_functor, swizzling_functor
)
```
* `{conv_kind}` defines which convolution is executed. Available options include `fprop`, `dgrad`, and `wgrad`.
* `{iterator_algorithm}` specifies the iterator algorithm used by the implicit GEMM in convolution. The options are as follows:
* `analytic`: functionally correct in all cases but lower performance
* `optimized`: optimized for R <= 32, S <= 32 and unity-stride dgrad
* `fixed_channels`: analytic algorithm optimized for fixed channel count (C == AccessSize)
* `few_channels`: Analytic algorithm optimized for few channels (C divisible by AccessSize)
* `{stride_support}`: distinguishes among partial specializations that accelerate certain problems where convolution
stride is unit.
* `strided`: arbitrary convolution stride
* `unity`: unit convolution stride
***
## Code Emission and Compilation
After implementing the operation description, the related host and device code can be compiled with
```python
import pycutlass
pycutlass.compiler.add_module([operation,])
```
Several operations can be compiled togather. The `nvcc` at `$CUDA_INSTALL_PATH/bin` is used by default as the compiler backend. But you can also switch to [CUDA Python](https://nvidia.github.io/cuda-python/overview.html)'s `nvrtc` with
```python
pycutlass.compiler.nvrtc()
```
We also have an internal compiled artifact manager that caches the compiled kernel in both memory and disk. The `compiled_cache.db` at your workspace is the database that contains the binary files. You can delete the file if you want to recompile the kernels.
***
## Argument Processing
We provide argument wrapper to convert python tensors to the kernel parameters. Currently it supports [torch.Tensor](https://pytorch.org/), [numpy.ndarray](https://numpy.org/), and [cupy.ndarray](https://cupy.dev/).
### GEMM Arguments
The Gemm arguments can be created with
```python
arguments = GemmArguments(
operation=operation, problem_size={problem_size},
A={tensor_A}, B={tensor_B}, C={tensor_C}, D={tensor_D},
output_op={output_op},
gemm_mode={gemm_mode},
split_k_slices={split_k_slices}, batch={batch}
)
```
* `problem_size` is a `cutlass.gemm.GemmCoord(M, N, K)` object that defines $M\times N\times K$ matrix multiplication.
* `tensor_X`: user-provide tensors.
* `output_op`: the params for the epilogue functor.
* `gemm_mode`, `split_k_slices`, and `batch`:
|gemm_mode| split_k_slices | batch | remark|
|--|--|--|--|
|cutlass.gemm.Mode.Gemm | number of split-K slices | - | the ordinary GEMM or GEMM with serial split-K|
|cutlass.gemm.Mode.GemmSplitKParallel | number of split-K slices | - | GEMM Split-K Parallel|
|cutlass.gemm.Mode.Batched | - | batch size | Batched GEMM |
|cutlass.gemm.Mode.Array | - | batch size | Array GEMM |
### GEMM Grouped Arguments
The GEMM grouped arguments can be created with
```python
arguments = GemmGroupedArguments(
operation, {problem_sizes_coord}, {tensor_As}, {tensor_Bs}, {tensor_Cs}, {tensor_Ds},
output_op=output_op)
)
```
* `problem_size_coord` is a list of `cutlass.gemm.GemmCoord(M, N, K)` for each problem size.
* `tensor_Xs` is a list of user-provide tensors.
* `output_op`: the params of the epilogue functor
### Conv2d Arguments
The Conv2d arguments can be created with
```python
arguments = Conv2dArguments(
operation, {problem_size}, {tensor_A},
{tensor_B}, {tensor_C}, {tensor_D},
{output_op},
{split_k_mode},
{split_k_slices}
)
```
* `problem_size`: it can be constructed with
```python
problem_size = cutlass.conv.Conv2dProblemSize(
cutlass.Tensor4DCoord(N, H, W, C),
cutlass.Tensor4DCoord(K, R, S, C),
cutlass.Tensor4DCoord(pad[0], pad[1], pad[2], pad[3]),
cutlass.MatrixCoord(stride[0], stride[1]),
cutlass.MatrixCoord(dilation[0], dilation[1]),
cutlass.conv.Mode.cross_correlation,
split_k_slices, 1
)
```
* `tensor_X` are user-provide tensors
* `output_op`: the params of the epilogue functor
* `split_k_mode`: currently we support `cutlass.conv.SplitKMode.Serial` and `cutlass.conv.SplitKMode.Parallel`.
* `split_k_slice`: number of split-k slices
For ordianry conv2d, just use `cutlass.conv.SplitKMode.Serial` with `split_k_slice=1`.
### Getting output_op
The way to create output_op is listed below
```python
output_op = operation.epilogue_type(*([alpha, beta] + args.activation_args)),
```
It is a list of arguments start with the scaling factor `alpha` and `beta`.
The `output_op` of EpilogueVisitorTree is slightly different. Please check [EpilogueVisitorTree](tools/library/scripts/pycutlass/docs/source/md/EpilogueVisitorTree.md) for details.
## Kernel Launching
With the arguments and operations, the kernel can be launched simply with
```python
operation.run(arguments)
```
## Sync results
We also provide function to synchronize the kernel execution. If you use `numpy`, it will also copy the result back to host. To do that, run
```python
arguments.sync()
```
If you use EpilogueVisitorTree, please call
```python
output_op.sync()
```
## Reduction Kernel behind Parallel Split-K
If you use parallel-split-K in GEMM or Conv2d, an additional reduction kernel is required. Please check [examples/40_cutlass_py](examples/40_cutlass_py) for detail.

View File

@ -1,6 +0,0 @@
Types
========
.. autoenum:: pycutlass.OperationKind
:members:

View File

@ -0,0 +1,4 @@
User Guide
=====================================
.. mdinclude:: ./md/basic_idea.md

View File

@ -0,0 +1,4 @@
User Guide
=====================================
.. mdinclude:: ./md/EpilogueVisitorTree.md

View File

@ -32,6 +32,7 @@
from pycutlass import *
import pycutlass
from pycutlass.epilogue import LinearCombination
from pycutlass.test.conv2d_testbed import Conv2dLauncher
@ -62,15 +63,16 @@ if __name__ == "__main__":
tile_description = TileDescription(
threadblock_shape=[128, 128, 64], stages=4,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(cutlass.float32, 4, cutlass.float32, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)

View File

@ -49,7 +49,7 @@ if __name__ == '__main__':
tile_description = TileDescription(
threadblock_shape=[256, 128, 32],
stages=3, warp_count=[4, 2, 1],
math_instruction=math_inst, min_compute=80, max_compute=80
math_instruction=math_inst
)
A = TensorDescription(
@ -67,7 +67,7 @@ if __name__ == '__main__':
element_epilogue = cutlass.float32
epilogue_functor = EpilogueFunctor.LinearCombination
epilogue_functor = LinearCombination(cutlass.float32, 4, cutlass.float32, cutlass.float32)
swizzling_functor = cutlass.IdentitySwizzle1

View File

@ -43,7 +43,7 @@ try:
Pybind11Extension("cutlass",
["src/cpp/cutlass.cpp"],
include_dirs=include_dirs,
extra_compile_args=["-fpermissive"])
extra_compile_args=["-fpermissive", "-w"])
]
except ImportError:
pass
@ -69,7 +69,8 @@ setup(
'typeguard',
'bfloat16',
'typing',
'scikit-build'
'scikit-build',
'treelib'
],
cmdclass={
'rmm': BuildRMM

View File

@ -0,0 +1,225 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor with CTA row-wise broadcast
The epilogue rearranges the result of a matrix product through shared memory to match canonical
tensor layouts in global memory. Epilogues support conversion and reduction operations.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/memory.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/gemm/kernel/default_gemm.h"
#include "cutlass/gemm/kernel/default_gemm_complex.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h"
#include "epilogue_visitor_op/visitor_op_linear_combination.h"
#include "epilogue_visitor_op/visitor_op_tensor_input.h"
#include "epilogue_visitor_op/visitor_op_accumulator.h"
#include "epilogue_visitor_op/visitor_op_row_broadcast.h"
#include "epilogue_visitor_op/visitor_op_tensor_output.h"
#include "epilogue_visitor_op/visitor_op_column_reduction.h"
#include "epilogue_visitor_op/visitor_op_row_reduction.h"
#include "epilogue_visitor_op/visitor_op_column_broadcast.h"
#include "epilogue_visitor_op/visitor_op_unary.h"
#include "epilogue_visitor_op/visitor_op_binary.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Generic Epilogue Visitor.
template <
typename OutputOp_
>
class EpilogueVisitorGeneric {
public:
using OutputOp = OutputOp_;
using AccumulatorAccessType = typename OutputOp::AccumulatorAccessType;
static int const kElementsPerAccess = OutputOp::kElementsPerAccess;
using ElementOutput = typename OutputOp::ElementOutput;
using OutputTileIterator = typename OutputOp::OutputTileIterator;
static int const kIterations = OutputTileIterator::kIterations;
///
/// End Epilogue Tree
///
/// Additional SMEM bufer is not required in the broadcast epilogue visitor
struct SharedStorage {
typename OutputOp::SharedStorage output_smem;
CUTLASS_HOST_DEVICE
SharedStorage() { }
};
public:
/// Argument structure
struct Arguments {
typename OutputOp::Arguments output_op_args;
//
// Methods
//
Arguments() { }
Arguments(
typename OutputOp::Arguments output_op_args
):
output_op_args(output_op_args)
{
}
};
struct Params {
typename OutputOp::Params output_op_params;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params() { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
output_op_params(args.output_op_args)
{
}
};
private:
OutputOp output_op;
public:
/// Constructor
CUTLASS_DEVICE
EpilogueVisitorGeneric(
Params const &params, ///< Parameters routed to the epilogue
SharedStorage &shared_storage, ///< Shared storage needed by the functors here
MatrixCoord threadblock_offset,
gemm::GemmCoord threadblock_tile_offset,
int thread_idx,
MatrixCoord problem_size
):
output_op(params.output_op_params, shared_storage.output_smem, thread_idx, threadblock_offset, problem_size)
{ }
/// Helper to indicate split-K behavior
CUTLASS_DEVICE
void set_k_partition(
int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
int split_k_slices) { ///< Total number of split-K slices
}
/// Called to set the batch index
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
output_op.set_batch_index(batch_idx);
}
/// Called at the start of the epilogue just before iterating over accumulator slices
CUTLASS_DEVICE
void begin_epilogue() {
output_op.begin_epilogue();
}
/// Called at the start of one step before starting accumulator exchange
CUTLASS_DEVICE
void begin_step(int step_idx) {
output_op.begin_step(step_idx);
}
/// Called at the start of a row
CUTLASS_DEVICE
void begin_row(int row_idx) {
output_op.begin_row(row_idx);
}
/// Called after accumulators have been exchanged for each accumulator vector
CUTLASS_DEVICE
void visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum) {
output_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
}
/// Called at the start of a row
CUTLASS_DEVICE
void end_row(int row_idx) {
output_op.end_row(row_idx);
}
/// Called after all accumulator elements have been visited
CUTLASS_DEVICE
void end_step(int step_idx) {
output_op.end_step(step_idx);
}
/// Called after all steps have been completed
CUTLASS_DEVICE
void end_epilogue() {
output_op.end_epilogue();
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,84 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the binary ops
*/
#pragma once
#include "cutlass/cutlass.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Scalar multiplication
template <typename T, int N>
struct VectorAdd {
struct Arguments {
int tmp;
CUTLASS_HOST_DEVICE
Arguments():tmp(0){ }
CUTLASS_HOST_DEVICE
Arguments(int tmp): tmp(tmp) { }
};
struct Params {
CUTLASS_HOST_DEVICE
Params(Arguments const &args) { }
};
CUTLASS_HOST_DEVICE
VectorAdd(
Params const &params
) { }
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
cutlass::plus<Array<T, N>> add_op;
return add_op(lhs, rhs);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,233 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the unary ops
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/thread/activation.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Scalar multiplication
template <typename T, int N>
struct Mult {
struct Arguments {
T alpha;
CUTLASS_HOST_DEVICE
Arguments():alpha(T(1.0)){ }
CUTLASS_HOST_DEVICE
Arguments(T alpha): alpha(alpha) { }
};
struct Params {
T alpha; ///< scales accumulators
CUTLASS_HOST_DEVICE
Params():alpha(T(1.0)){ }
CUTLASS_HOST_DEVICE
Params(Arguments const &args): alpha(args.alpha) { }
};
T alpha_;
CUTLASS_HOST_DEVICE
Mult(
Params const &params
):
alpha_(params.alpha)
{ }
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &source) const {
cutlass::multiplies<Array<T, N>> multiply_op;
return multiply_op(source, alpha_);
}
CUTLASS_HOST_DEVICE
bool guard() {
return alpha_ != T(0);
}
};
/// ReLU
template <typename T, int N>
struct ReLUVisitor {
struct Arguments {
T threshold;
CUTLASS_HOST_DEVICE
Arguments():threshold(T(0.0)) { }
CUTLASS_HOST_DEVICE
Arguments(T threshold): threshold(threshold) { }
};
struct Params {
T threshold;
CUTLASS_HOST_DEVICE
Params():threshold(T(0.0)) { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args): threshold(args.threshold) { }
};
T threshold_;
CUTLASS_HOST_DEVICE
ReLUVisitor(Params const &params):
threshold_(params.threshold) { }
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &frag) const {
maximum<Array<T, N>> mx;
return mx(frag, threshold_);
}
CUTLASS_HOST_DEVICE
bool guard() {
return true;
}
};
/// leakyReLU
template <typename T, int N>
struct LeakyReLUVisitor {
struct Arguments {
T leaky_alpha;
CUTLASS_HOST_DEVICE
Arguments():leaky_alpha(T(0.0)) { }
CUTLASS_HOST_DEVICE
Arguments(T leaky_alpha): leaky_alpha(leaky_alpha) { }
};
struct Params {
T leaky_alpha;
CUTLASS_HOST_DEVICE
Params():leaky_alpha(T(0.0)) { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args): leaky_alpha(args.leaky_alpha) { }
};
T leaky_alpha_;
CUTLASS_HOST_DEVICE
LeakyReLUVisitor(Params const &params):
leaky_alpha_(params.leaky_alpha) { }
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &frag) const {
cutlass::epilogue::thread::LeakyReLU<Array<T, N>> leaky_op;
return leaky_op(frag, leaky_alpha_);
}
CUTLASS_HOST_DEVICE
bool guard() {
return true;
}
};
/// Tanh
template <typename T, int N>
struct TanhVisitor {
/// Argument
struct Arguments {
// a placeholder argument to ensure correctness of ctypes
int tmp;
CUTLASS_HOST_DEVICE
Arguments(): tmp(0) { };
CUTLASS_HOST_DEVICE
Arguments(int tmp): tmp(tmp) { };
};
/// Param
struct Params {
CUTLASS_HOST_DEVICE
Params(){ };
Params(Arguments const &args) { }
};
/// Constructor
CUTLASS_HOST_DEVICE
TanhVisitor(Params const &params) { }
// scalar operator
CUTLASS_HOST_DEVICE
T tanh_op(T const &scalar) const {
return fast_tanh(scalar);
}
/// vector operator
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &frag) const {
Array<T, N> y;
CUTLASS_PRAGMA_UNROLL
for (int i=0; i < N; ++i) {
y[i] = tanh_op(frag[i]);
}
return y;
}
CUTLASS_HOST_DEVICE
bool guard() {
return true;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,148 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor Op with accumulator
*/
#pragma once
#include "cutlass/cutlass.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue Visitor operator for the following Computation
///
/// ElementAccumulator accum;
/// return accum;
///
/// It can only be the leaf node of the epilogue tree
template <
typename ElementAccumulator_, ///< Data type of the Accumulator
int kElementsPerAccess_ ///< Number of elements computed per operation
>
class VisitorOpAccumulator{
public:
using ElementAccumulator = ElementAccumulator_;
static int const kElementsPerAccess = kElementsPerAccess_;
/// Fragment type for Accumulator
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
/// Fragment type returned by this visitor
using VisitAccessType = AccumulatorAccessType;
/// SMEM buffer class required in the epilogue visitor
struct SharedStorage {
CUTLASS_HOST_DEVICE
SharedStorage() {}
};
/// Host-constructable Arguments structure
struct Arguments {
// Note: it is strange that ctypes will return issue with empty arguments
int tmp;
CUTLASS_HOST_DEVICE
Arguments() { }
CUTLASS_HOST_DEVICE
Arguments(int tmp): tmp(tmp) { }
};
/// Parameter structure
struct Params {
CUTLASS_HOST_DEVICE
Params() { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args) { }
};
public:
/// Constructs the function object
CUTLASS_HOST_DEVICE
VisitorOpAccumulator(
Params const &params,
SharedStorage &shared_storage,
int thread_idx,
MatrixCoord threadblock_offset,
MatrixCoord problem_size
) { }
CUTLASS_DEVICE
void set_batch_index(int batch_idx) { }
CUTLASS_DEVICE
void begin_epilogue() { }
CUTLASS_DEVICE
void begin_step(int step_idx) { }
CUTLASS_DEVICE
void begin_row(int row_idx) { }
CUTLASS_DEVICE
VisitAccessType visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum
) {
return accum;
}
CUTLASS_DEVICE
void end_row(int row_idx) { }
CUTLASS_DEVICE
void end_step(int step_idx) { }
CUTLASS_DEVICE
void end_epilogue() { }
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,246 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor Op with Binary op
*/
#pragma once
#include "cutlass/cutlass.h"
#include "binary_ops.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue Visitor operator for the following computation:
///
/// ElementCompute alpha;
/// ElementCompute beta;
/// ElementCompute C = BinaryOp(alpha * ElementCompute(Visitor_A), beta * ElementCompute(Visitor_B)
/// Return C;
///
template <
typename ElementAccumulator_, ///< Data type of the Accumulator
typename ElementCompute_, ///< Data type used to compute linear combination
int kElementsPerAccess_, ///< Number of elements computed per operation
typename VisitorA_, ///< Child node A
typename VisitorB_, ///< Child node B
template<typename T, int N> typename BinaryOp_
>
class VisitorOpBinary{
public:
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
static int const kElementsPerAccess = kElementsPerAccess_;
using VisitorA = VisitorA_;
using VisitorB = VisitorB_;
/// Fragment type returned from VisitorA.visit
using VisitAccessTypeA = typename VisitorA::VisitAccessType;
using ElementA = typename VisitAccessTypeA::Element;
/// Fragment type returned from VisitorB.visit
using VisitAccessTypeB = typename VisitorB::VisitAccessType;
using ElementB = typename VisitAccessTypeB::Element;
/// Fragment type returned by this visitor
using VisitAccessType = Array<ElementCompute, kElementsPerAccess>;
/// Fragment type of accumulator
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
/// Combination Op TODO: generalize this
using BinaryOp = BinaryOp_<ElementCompute, kElementsPerAccess>;
static_assert(kElementsPerAccess==VisitAccessTypeA::kElements, "kElementsPerAccess mismatches with Visitor A");
static_assert(kElementsPerAccess==VisitAccessTypeB::kElements, "kElementsPerAccess misnatches with Visitor B");
/// SMEM buffer class required in the epilogue visitor
struct SharedStorage {
typename VisitorA::SharedStorage storage_a;
typename VisitorB::SharedStorage storage_b;
CUTLASS_HOST_DEVICE
SharedStorage() {}
};
/// Host-constructable Arguments structure
struct Arguments {
typename BinaryOp::Arguments binary_arg;
typename VisitorA::Arguments visitor_a_arg; ///< Argument type for visitor_a
typename VisitorB::Arguments visitor_b_arg; ///< Argument type for visitor_b
//
// Methods
//
CUTLASS_HOST_DEVICE
Arguments():binary_arg() { }
CUTLASS_HOST_DEVICE
Arguments(
typename BinaryOp::Arguments binary_arg,
typename VisitorA::Arguments visitor_a_arg,
typename VisitorB::Arguments visitor_b_arg
):
binary_arg(binary_arg),
visitor_a_arg(visitor_a_arg),
visitor_b_arg(visitor_b_arg)
{ }
};
/// Parameter structure
struct Params {
typename BinaryOp::Params binary_param;
typename VisitorA::Params visitor_a_param; ///< Argument type for visitor_a
typename VisitorB::Params visitor_b_param; ///< Argument type for visitor_b
//
// Methods
//
CUTLASS_HOST_DEVICE
Params() { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
binary_param(args.binary_arg),
visitor_a_param(args.visitor_a_arg),
visitor_b_param(args.visitor_b_arg)
{ }
};
private:
//
// Data members
//
BinaryOp binary_op;
VisitorA visitor_a_op;
VisitorB visitor_b_op;
public:
/// Constructs the function object
CUTLASS_HOST_DEVICE
VisitorOpBinary(
Params const &params,
SharedStorage &shared_storage,
int thread_idx,
MatrixCoord threadblock_offset,
MatrixCoord problem_size
):
binary_op(params.binary_param),
visitor_a_op(params.visitor_a_param, shared_storage.storage_a, thread_idx, threadblock_offset, problem_size),
visitor_b_op(params.visitor_b_param, shared_storage.storage_b, thread_idx, threadblock_offset, problem_size)
{ }
CUTLASS_DEVICE
void begin_epilogue() {
visitor_a_op.begin_epilogue();
visitor_b_op.begin_epilogue();
}
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
visitor_a_op.set_batch_index(batch_idx);
visitor_b_op.set_batch_index(batch_idx);
}
CUTLASS_DEVICE
void begin_step(int step_idx) {
visitor_a_op.begin_step(step_idx);
visitor_b_op.begin_step(step_idx);
}
CUTLASS_DEVICE
void begin_row(int row_idx) {
visitor_a_op.begin_row(row_idx);
visitor_b_op.begin_row(row_idx);
}
CUTLASS_DEVICE
VisitAccessType visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum
) {
/// Get result from visitor A and visitor B
VisitAccessTypeA result_A = visitor_a_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
VisitAccessTypeB result_B = visitor_b_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
/// Type conversion
NumericArrayConverter<ElementCompute, ElementA, kElementsPerAccess> source_converter_A;
NumericArrayConverter<ElementCompute, ElementB, kElementsPerAccess> source_converter_B;
return binary_op(
source_converter_A(result_A),
source_converter_B(result_B)
);
}
CUTLASS_DEVICE
void end_row(int row_idx) {
visitor_a_op.end_row(row_idx);
visitor_b_op.end_row(row_idx);
}
CUTLASS_DEVICE
void end_step(int step_idx) {
visitor_a_op.end_step(step_idx);
visitor_b_op.end_step(step_idx);
}
CUTLASS_DEVICE
void end_epilogue() {
visitor_a_op.end_epilogue();
visitor_b_op.end_epilogue();
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,250 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor Op with broadcasting vector to all columns
*/
#pragma once
#include "cutlass/cutlass.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue Visitor operator for the following computation:
///
/// ElementVector T[i][j] <- device-memory Td[i]
///
/// It can only be a leaf node in the epilogue tree
template <
typename ElementAccumulator_, ///< Data type of the Accumulator
typename ElementFragment_, ///< Data type used to cache vector in register
typename InputTileIterator_ ///< Tile iterator type to read the broadcasted tensor
>
class VisitorOpColumnBroadcast {
public:
using InputTileIterator = InputTileIterator_;
static int const kElementsPerAccess = InputTileIterator::kElementsPerAccess;
using ElementAccumulator = ElementAccumulator_;
using ElementVector = typename InputTileIterator::Element;
using ElementFragment = ElementFragment_;
using VisitAccessType = Array<ElementFragment, kElementsPerAccess>;
/// Thread map used by input tile iterators
using ThreadMap = typename InputTileIterator::ThreadMap;
/// Fragment object used to store the broadcast values
using BroadcastFragment = Array<
ElementFragment, kElementsPerAccess>;
/// Fragment type of accumulator
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
/// Used for the broadcast
struct BroadcastDetail {
/// Number of threads per warp
static int const kWarpSize = 32;
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
/// Number of distinct scalar column indices handled by each thread
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
/// Number of distinct scalar row indices handled by each thread
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
/// Number of threads per threadblock
static int const kThreadCount = ThreadMap::kThreads;
/// Number of distinct threads per row of output tile
static int const kThreadsPerRow = (InputTileIterator::Shape::kN / kColumnsPerThread);
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock.
static int const kThreadRows = kThreadCount / kThreadsPerRow;
// /// Number of iterations (accesses) the threadblock takes to reduce a row
// static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount);
};
// using ComputeFragmentType = Array<ElementVector, BroadcastDetail::kElementsPerAccess>;
struct SharedStorage {
CUTLASS_HOST_DEVICE
SharedStorage() { }
};
/// Host-constructable Argument structure
struct Arguments {
ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand
int64_t batch_stride;
/// Methods
CUTLASS_HOST_DEVICE
Arguments():
broadcast_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Arguments(
ElementVector *broadcast_ptr,
int64_t batch_stride
):
broadcast_ptr(broadcast_ptr),
batch_stride(batch_stride) { }
};
/// Param structure
struct Params {
ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand
int64_t batch_stride;
/// Method
CUTLASS_HOST_DEVICE
Params():
broadcast_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
broadcast_ptr(args.broadcast_ptr),
batch_stride(args.batch_stride) { }
};
private:
ElementVector *broadcast_ptr;
BroadcastFragment broadcast_fragment; ///< Array holds the loaded broadcast fragment
MatrixCoord threadblock_offset_;
int thread_idx_;
MatrixCoord problem_size;
int thread_start_row_;
int state_[3];
int thread_offset_row_;
int64_t batch_stride_;
public:
/// Constructs the function object
CUTLASS_HOST_DEVICE
VisitorOpColumnBroadcast(
Params const &params,
SharedStorage &shared_storage,
int thread_idx,
MatrixCoord threadblock_offset,
MatrixCoord problem_size
):
broadcast_ptr(params.broadcast_ptr),
threadblock_offset_(threadblock_offset),
thread_idx_(thread_idx),
problem_size(problem_size),
thread_start_row_(ThreadMap::initial_offset(thread_idx).row() + threadblock_offset.row()),
batch_stride_(params.batch_stride)
{
state_[0] = state_[1] = state_[2] = 0;
}
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
broadcast_ptr += batch_idx * batch_stride_;
}
CUTLASS_DEVICE
void begin_epilogue() { }
CUTLASS_DEVICE
void begin_step(int step_idx) {}
CUTLASS_DEVICE
void begin_row(int row_idx) {}
CUTLASS_DEVICE
VisitAccessType visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum
) {
// get pointer
thread_offset_row_ = thread_start_row_ + ThreadMap::iteration_offset(frag_idx).row();
ElementFragment broadcast_data = ElementFragment(*(broadcast_ptr + thread_offset_row_));
broadcast_fragment.fill(broadcast_data);
return broadcast_fragment;
}
CUTLASS_DEVICE
void end_row(int row_idx) { }
CUTLASS_DEVICE
void end_step(int step_idx) {
// run operator ++
++state_[0];
thread_start_row_ += ThreadMap::Shape::kRow;
if (state_[0] == ThreadMap::Count::kRow) {
state_[0] = 0;
++state_[1];
thread_start_row_ += (ThreadMap::Shape::kGroup - 1) *
ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
if (state_[1] == ThreadMap::Count::kGroup) {
state_[1] = 0;
++state_[2];
thread_start_row_ += ThreadMap::Count::kGroup *
ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;
if (state_[2] == ThreadMap::Count::kCluster) {
state_[2] = 0;
}
}
}
}
CUTLASS_DEVICE
void end_epilogue() { }
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,342 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor Op with reduction over columns in CTA
*/
#pragma once
#include "cutlass/cutlass.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue Visitor operator for the following computation:
///
/// ElementReductionAccumulator R[j] = \sum_i ElementReductionAccumulator(T[i][j])
/// device memory <- ElementReduction(R[j])
///
template <
typename ThreadblockShape_, /// Threadblock shape
typename ElementAccumulator_, ///< Data type of the Accumulator
typename ElementReduction_, ///< Data type of the output reduction in device memory
typename ElementReductionAccumulator_ , ///< Data type to accumulate reduction in smem and register
typename OutputTileIterator_, ///< Tile Iterator type
typename Visitor_ ///< preceeding visitor op
>
class VisitorOpColumnReduction {
public:
using ElementAccumulator = ElementAccumulator_;
using ElementReductionAccumulator = ElementReductionAccumulator_;
using ElementReduction = ElementReduction_;
using OutputTileIterator = OutputTileIterator_;
using ThreadblockShape = ThreadblockShape_;
using Visitor = Visitor_;
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
// TODO: generalize the reduction op
using ReductionOp = cutlass::plus<Array<ElementReductionAccumulator, kElementsPerAccess>>;
using ReductionOpScalar = cutlass::plus<ElementReductionAccumulator>;
using ElementOutput = typename OutputTileIterator::Element;
/// Fragment type returned from Visitor
using VisitAccessTypeVisitor = typename Visitor::VisitAccessType;
using ElementVisitor = typename VisitAccessTypeVisitor::Element;
using VisitAccessType = VisitAccessTypeVisitor;
/// Fragment type of accumulator
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
/// Fragment type of redcution
using ReductionAccumulatorAccessType = Array<ElementReductionAccumulator, kElementsPerAccess>;
/// Thread map used by output tile iterators
using ThreadMap = typename OutputTileIterator::ThreadMap;
/// Used for the reduction
struct ReductionDetail {
/// Number of threads per warp
static int const kWarpSize = 32;
/// Number of distinct scalar column indices handled by each thread
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
/// Number of distinct scalar row indices handled by each thread
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
/// Number of threads per threadblock
static int const kThreadCount = ThreadMap::kThreads;
/// Number of distinct threads per row of output tile
static int const kThreadsPerRow = ThreadblockShape::kN / kColumnsPerThread;
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock
static int const kThreadRows = kThreadCount / kThreadsPerRow;
/// Number of iterations (accesses) the threadblock takes to reduce a row
static int const kThreadAccessesPerRow = const_max(1, (ThreadblockShape::kN + kThreadCount - 1) / kThreadCount);
using StorageShape = MatrixShape<
kThreadRows,
ThreadblockShape::kN
>;
};
using ReductionFragment = Array<ElementReductionAccumulator, ReductionDetail::kColumnsPerThread>;
/// Shared storage
struct SharedStorage {
typename Visitor::SharedStorage storage_visitor;
AlignedArray<ElementReductionAccumulator, ReductionDetail::StorageShape::kCount, 16> reduction;
CUTLASS_HOST_DEVICE
SharedStorage() {}
};
/// Host-constructable Argument structure
struct Arguments {
ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory
int64_t batch_stride;
typename Visitor::Arguments visitor_arg; ///< Argument type of visitor
/// Method
CUTLASS_HOST_DEVICE
Arguments(): reduction_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Arguments(
ElementReduction *reduction_ptr,
int64_t batch_stride,
typename Visitor::Arguments visitor_arg
):
reduction_ptr(reduction_ptr),
batch_stride(batch_stride),
visitor_arg(visitor_arg)
{ }
};
/// Param structure
struct Params {
ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory
int64_t batch_stride;
typename Visitor::Params visitor_param; ///< Argument type of visitor
/// Method
CUTLASS_HOST_DEVICE
Params(): reduction_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
reduction_ptr(args.reduction_ptr),
batch_stride(args.batch_stride),
visitor_param(args.visitor_arg)
{ }
};
private:
ElementReduction *reduction_output_ptr_; ///< Pointer to the reduction tensor in device memory
ElementReductionAccumulator *reduction_smem_ptr_; ///< Pointer to the partial reductions in shared memory
ReductionFragment reduction_fragment; ///< register fragments that hold the partial reduction
Visitor visitor_; ///< visitor
int thread_idx_;
MatrixCoord threadblock_offset;
MatrixCoord problem_size_;
int64_t batch_stride_;
public:
/// Constructs the function object
CUTLASS_HOST_DEVICE
VisitorOpColumnReduction(
Params const &params,
SharedStorage &shared_storage,
int thread_idx,
MatrixCoord threadblock_offset,
MatrixCoord problem_size
):
visitor_(params.visitor_param, shared_storage.storage_visitor,
thread_idx, threadblock_offset, problem_size),
reduction_smem_ptr_(shared_storage.reduction.data()),
reduction_output_ptr_(params.reduction_ptr),
thread_idx_(thread_idx),
threadblock_offset(threadblock_offset),
problem_size_(problem_size),
batch_stride_(params.batch_stride)
{ }
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
reduction_output_ptr_ += batch_idx * batch_stride_;
visitor_.set_batch_index(batch_idx);
}
CUTLASS_DEVICE
void begin_epilogue() {
visitor_.begin_epilogue();
// clear the reduction fragment
reduction_fragment.clear();
}
CUTLASS_DEVICE
void begin_step(int step_idx) {
visitor_.begin_step(step_idx);
}
CUTLASS_DEVICE
void begin_row(int row_idx) {
visitor_.begin_row(row_idx);
}
CUTLASS_DEVICE
VisitAccessType visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum
) {
/// Get result from visitor
VisitAccessTypeVisitor result = visitor_.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
NumericArrayConverter<ElementReductionAccumulator, ElementVisitor, kElementsPerAccess> reduction_converter;
ReductionOp reduction_op;
ReductionAccumulatorAccessType* reduction_fragment_ = reinterpret_cast<ReductionAccumulatorAccessType*>(&reduction_fragment);
reduction_fragment_[column_idx] = reduction_op(reduction_fragment_[column_idx], reduction_converter(result));
return result;
}
CUTLASS_DEVICE
void end_row(int row_idx) {
visitor_.end_row(row_idx);
}
CUTLASS_DEVICE
void end_step(int step_idx) {
visitor_.end_step(step_idx);
}
CUTLASS_DEVICE
void end_epilogue() {
visitor_.end_epilogue();
//
// Store the partially reduced value to SMEM
//
// Guard against uses of the existing SMEM tile
__syncthreads();
using AccessType = AlignedArray<ElementReductionAccumulator, ThreadMap::kElementsPerAccess>;
//
// Determine a compact thread arrangement to store to SMEM
//
MatrixCoord thread_offset(
thread_idx_ / ReductionDetail::kThreadsPerRow,
(thread_idx_ % ReductionDetail::kThreadsPerRow) * ThreadMap::kElementsPerAccess
);
//
// Each thread store its fragment to a SMEM
//
AccessType *aligned_reduction_ptr = reinterpret_cast<AccessType *>(
&reduction_smem_ptr_[thread_offset.row() * ThreadblockShape::kN + thread_offset.column()]
);
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(
&reduction_fragment
);
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
int col_idx = column * ThreadMap::Delta::kColumn / ThreadMap::kElementsPerAccess;
aligned_reduction_ptr[col_idx] = frag_ptr[column];
}
__syncthreads();
//
// Now, threads are assigned several columns of the output. The fetch over all rows from
// the compacted SMEM tile and perform a reduction.
//
NumericConverter<ElementReduction, ElementReductionAccumulator> output_converter;
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < ReductionDetail::kThreadAccessesPerRow; ++j) {
int column_idx = thread_idx_ + j * ReductionDetail::kThreadCount;
ReductionOpScalar reduction_op;
ElementReductionAccumulator reduction_element = ElementReductionAccumulator();
int output_column_idx = threadblock_offset.column() + column_idx;
if (column_idx < ThreadblockShape::kN && output_column_idx < problem_size_.column()) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ReductionDetail::kThreadRows; ++row) {
if (row) {
auto frag = reduction_smem_ptr_[row * ThreadblockShape::kN + column_idx];
reduction_element = reduction_op(reduction_element, frag);
}
else {
reduction_element = reduction_smem_ptr_[column_idx];
}
}
// Store
reduction_output_ptr_[column_idx + threadblock_offset.column() + threadblock_offset.row() / ThreadblockShape::kM * problem_size_.column()] = output_converter(reduction_element);
}
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,266 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor Op with Linear Combination
*/
#pragma once
#include "cutlass/cutlass.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue Visitor operator for the following computation:
///
/// ElementCompute alpha;
/// ElementCompute beta;
/// ElementCompute C = BinaryOp(alpha * ElementCompute(Visitor_A), beta * ElementCompute(Visitor_B)
/// Return C;
///
template <
typename ElementAccumulator_, ///< Data type of the Accumulator
typename ElementCompute_, ///< Data type used to compute linear combination
int kElementsPerAccess_, ///< Number of elements computed per operation
typename VisitorA_, ///< Child node A
typename VisitorB_ ///< Child node B
>
class VisitorOpLinearCombination{
public:
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
static int const kElementsPerAccess = kElementsPerAccess_;
using VisitorA = VisitorA_;
using VisitorB = VisitorB_;
/// Fragment type returned from VisitorA.visit
using VisitAccessTypeA = typename VisitorA::VisitAccessType;
using ElementA = typename VisitAccessTypeA::Element;
/// Fragment type returned from VisitorB.visit
using VisitAccessTypeB = typename VisitorB::VisitAccessType;
using ElementB = typename VisitAccessTypeB::Element;
/// Fragment type returned by this visitor
using VisitAccessType = Array<ElementCompute, kElementsPerAccess>;
/// Fragment type of accumulator
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
/// Combination Op TODO: generalize this
using CombinationOp = cutlass::plus<VisitAccessType>;
static_assert(kElementsPerAccess==VisitAccessTypeA::kElements, "kElementsPerAccess mismatches with Visitor A");
static_assert(kElementsPerAccess==VisitAccessTypeB::kElements, "kElementsPerAccess misnatches with Visitor B");
/// SMEM buffer class required in the epilogue visitor
struct SharedStorage {
typename VisitorA::SharedStorage storage_a;
typename VisitorB::SharedStorage storage_b;
CUTLASS_HOST_DEVICE
SharedStorage() {}
};
/// Host-constructable Arguments structure
struct Arguments {
ElementCompute alpha; ///< scales accumulators
ElementCompute beta; ///< scales source tensor
typename VisitorA::Arguments visitor_a_arg; ///< Argument type for visitor_a
typename VisitorB::Arguments visitor_b_arg; ///< Argument type for visitor_b
//
// Methods
//
CUTLASS_HOST_DEVICE
Arguments():
alpha(ElementCompute(1)),
beta(ElementCompute(0))
{ }
CUTLASS_HOST_DEVICE
Arguments(
ElementCompute alpha,
ElementCompute beta,
typename VisitorA::Arguments visitor_a_arg,
typename VisitorB::Arguments visitor_b_arg
):
alpha(alpha),
beta(beta),
visitor_a_arg(visitor_a_arg),
visitor_b_arg(visitor_b_arg)
{ }
};
/// Parameter structure
struct Params {
ElementCompute alpha; ///< scales accumulators
ElementCompute beta; ///< scales source tensor
typename VisitorA::Params visitor_a_param; ///< Argument type for visitor_a
typename VisitorB::Params visitor_b_param; ///< Argument type for visitor_b
//
// Methods
//
CUTLASS_HOST_DEVICE
Params() { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
alpha(args.alpha),
beta(args.beta),
visitor_a_param(args.visitor_a_arg),
visitor_b_param(args.visitor_b_arg)
{ }
};
private:
//
// Data members
//
ElementCompute alpha_;
ElementCompute beta_;
VisitorA visitor_a_op;
VisitorB visitor_b_op;
public:
/// Constructs the function object
CUTLASS_HOST_DEVICE
VisitorOpLinearCombination(
Params const &params,
SharedStorage &shared_storage,
int thread_idx,
MatrixCoord threadblock_offset,
MatrixCoord problem_size
):
alpha_(params.alpha),
beta_(params.beta),
visitor_a_op(params.visitor_a_param, shared_storage.storage_a, thread_idx, threadblock_offset, problem_size),
visitor_b_op(params.visitor_b_param, shared_storage.storage_b, thread_idx, threadblock_offset, problem_size)
{ }
CUTLASS_DEVICE
void begin_epilogue() {
if (alpha_ != ElementCompute(0)) visitor_a_op.begin_epilogue();
if (beta_ != ElementCompute(0)) visitor_b_op.begin_epilogue();
}
CUTLASS_DEVICE
void begin_step(int step_idx) {
if (alpha_ != ElementCompute(0)) visitor_a_op.begin_step(step_idx);
if (beta_ != ElementCompute(0)) visitor_b_op.begin_step(step_idx);
}
CUTLASS_DEVICE
void begin_row(int row_idx) {
if (alpha_ != ElementCompute(0)) visitor_a_op.begin_row(row_idx);
if (beta_ != ElementCompute(0)) visitor_b_op.begin_row(row_idx);
}
CUTLASS_DEVICE
VisitAccessType visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum
) {
/// Get result from visitor A and visitor B
VisitAccessTypeA result_A;
VisitAccessTypeB result_B;
if (alpha_ != ElementCompute(0)) {
result_A = visitor_a_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
} else {
// Fill the result A with zeros
result_A.clear();
}
if (beta_ != ElementCompute(0)) {
result_B = visitor_b_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
} else {
// Fill the result B with zeros
result_B.clear();
}
/// Type conversion
NumericArrayConverter<ElementCompute, ElementA, kElementsPerAccess> source_converter_A;
NumericArrayConverter<ElementCompute, ElementB, kElementsPerAccess> source_converter_B;
CombinationOp combination_op;
cutlass::multiplies<VisitAccessType> multiply_op;
return combination_op(
multiply_op(alpha_, source_converter_A(result_A)),
multiply_op(beta_, source_converter_B(result_B))
);
}
CUTLASS_DEVICE
void end_row(int row_idx) {
if (alpha_ != ElementCompute(0)) visitor_a_op.end_row(row_idx);
if (beta_ != ElementCompute(0)) visitor_b_op.end_row(row_idx);
}
CUTLASS_DEVICE
void end_step(int step_idx) {
if (alpha_ != ElementCompute(0)) visitor_a_op.end_step(step_idx);
if (beta_ != ElementCompute(0)) visitor_b_op.end_step(step_idx);
}
CUTLASS_DEVICE
void end_epilogue() {
if (alpha_ != ElementCompute(0)) visitor_a_op.end_epilogue();
if (beta_ != ElementCompute(0)) visitor_b_op.end_epilogue();
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,258 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor Op with broadcasting vector to all rows
*/
#pragma once
#include "cutlass/cutlass.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue Visitor operator for the following computation:
///
/// ElementVector T[i][j] <- device-memory Td[j]
///
/// It can only be a leaf node in the epilogue tree
template <
typename ElementAccumulator_, ///< Data type of the Accumulator
typename ElementFragment_, ///< Data type used to cache vector in register
typename InputTileIterator_ ///< Tile iterator type to read the broadcasted tensor
>
class VisitorOpRowBroadcast {
public:
using InputTileIterator = InputTileIterator_;
static int const kElementsPerAccess = InputTileIterator::kElementsPerAccess;
using ElementAccumulator = ElementAccumulator_;
using ElementVector = typename InputTileIterator::Element;
using ElementFragment = ElementFragment_;
using VisitAccessType = Array<ElementFragment, kElementsPerAccess>;
/// Thread map used by input tile iterators
using ThreadMap = typename InputTileIterator::ThreadMap;
/// Fragment object used to store the broadcast values
using BroadcastFragment = Array<
ElementFragment,
ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>;
/// Fragment type of accumulator
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
/// Used for the broadcast
struct BroadcastDetail {
/// Number of threads per warp
static int const kWarpSize = 32;
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
/// Number of distinct scalar column indices handled by each thread
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
/// Number of distinct scalar row indices handled by each thread
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
/// Number of threads per threadblock
static int const kThreadCount = ThreadMap::kThreads;
/// Number of distinct threads per row of output tile
static int const kThreadsPerRow = (InputTileIterator::Shape::kN / kColumnsPerThread);
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock.
static int const kThreadRows = kThreadCount / kThreadsPerRow;
// /// Number of iterations (accesses) the threadblock takes to reduce a row
// static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount);
};
// using ComputeFragmentType = Array<ElementVector, BroadcastDetail::kElementsPerAccess>;
struct SharedStorage {
CUTLASS_HOST_DEVICE
SharedStorage() { }
};
/// Host-constructable Argument structure
struct Arguments {
ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand
int64_t batch_stride;
/// Methods
CUTLASS_HOST_DEVICE
Arguments():
broadcast_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Arguments(
ElementVector *broadcast_ptr,
int64_t batch_stride
):
broadcast_ptr(broadcast_ptr),
batch_stride(batch_stride) { }
};
/// Param structure
struct Params {
ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand
int64_t batch_stride;
/// Method
CUTLASS_HOST_DEVICE
Params():
broadcast_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
broadcast_ptr(args.broadcast_ptr),
batch_stride(args.batch_stride) { }
};
private:
ElementVector *broadcast_ptr;
BroadcastFragment broadcast_fragment; ///< Array holds the loaded broadcast fragment
MatrixCoord threadblock_offset_;
int thread_idx_;
MatrixCoord problem_size;
int64_t batch_stride_;
public:
/// Constructs the function object
CUTLASS_HOST_DEVICE
VisitorOpRowBroadcast(
Params const &params,
SharedStorage &shared_storage,
int thread_idx,
MatrixCoord threadblock_offset,
MatrixCoord problem_size
):
broadcast_ptr(params.broadcast_ptr + threadblock_offset.column()),
threadblock_offset_(threadblock_offset),
thread_idx_(thread_idx),
problem_size(problem_size),
batch_stride_(params.batch_stride) { }
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
broadcast_ptr += batch_idx * batch_stride_;
}
CUTLASS_DEVICE
void begin_epilogue() {
// load broadcast fragment
load_broadcast_fragment_();
}
CUTLASS_DEVICE
void begin_step(int step_idx) {}
CUTLASS_DEVICE
void begin_row(int row_idx) {}
CUTLASS_DEVICE
VisitAccessType visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum
) {
VisitAccessType* broadcast_fragment_ = reinterpret_cast<VisitAccessType*>(&broadcast_fragment);
return broadcast_fragment_[column_idx];
}
CUTLASS_DEVICE
void end_row(int row_idx) { }
CUTLASS_DEVICE
void end_step(int step_idx) { }
CUTLASS_DEVICE
void end_epilogue() { }
private:
CUTLASS_DEVICE
void load_broadcast_fragment_() {
broadcast_fragment.clear();
// If no pointer is supplied, set with all zeros and avoid memory accesses
if (!broadcast_ptr) {
return;
}
int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column();
int thread_column_idx = threadblock_offset_.column() + thread_initial_column;
broadcast_ptr += thread_initial_column;
NumericArrayConverter<ElementFragment, ElementVector, BroadcastDetail::kElementsPerAccess> converter;
using AccessType = AlignedArray<ElementVector, BroadcastDetail::kElementsPerAccess>;
using AccessFragmentType = Array<ElementFragment, BroadcastDetail::kElementsPerAccess>;
AccessFragmentType *frag_ptr = reinterpret_cast<AccessFragmentType *>(&broadcast_fragment);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) {
AccessType loaded;
loaded.clear();
if (thread_column_idx < problem_size.column()) {
loaded = *reinterpret_cast<AccessType const *>(broadcast_ptr);
}
AccessFragmentType cvt = converter(loaded);
frag_ptr[j] = cvt;
thread_column_idx += ThreadMap::Delta::kColumn;
broadcast_ptr += ThreadMap::Delta::kColumn;
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,320 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor Op with reduction over rows in CTA
*/
#pragma once
#include "cutlass/cutlass.h"
#include "stdio.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue Visitor operator for the following computation:
///
/// ElementReductionAccumulator R[i] = \sum_i ElementReductionAccumulator(T[i][j])
/// device memory <- ElementReduction(R[i])
///
template <
typename ThreadblockShape_, /// Threadblock shape
typename ElementAccumulator_, ///< Data type of the Accumulator
typename ElementReduction_, ///< Data type of the output reduction in device memory
typename ElementReductionAccumulator_ , ///< Data type to accumulate reduction in smem and register
typename OutputTileIterator_, ///< Tile Iterator type
typename Visitor_ ///< preceeding visitor op
>
class VisitorOpRowReduction {
public:
using ElementAccumulator = ElementAccumulator_;
using ElementReductionAccumulator = ElementReductionAccumulator_;
using ElementReduction = ElementReduction_;
using OutputTileIterator = OutputTileIterator_;
using ThreadblockShape = ThreadblockShape_;
using Visitor = Visitor_;
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
// TODO: generalize the reduction op
using ReductionOp = cutlass::plus<Array<ElementReductionAccumulator, kElementsPerAccess>>;
using ReductionOpScalar = cutlass::plus<ElementReductionAccumulator>;
using ElementOutput = typename OutputTileIterator::Element;
/// Fragment type returned from Visitor
using VisitAccessTypeVisitor = typename Visitor::VisitAccessType;
using ElementVisitor = typename VisitAccessTypeVisitor::Element;
using VisitAccessType = VisitAccessTypeVisitor;
/// Fragment type of accumulator
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
/// Fragment type of redcution
using ReductionAccumulatorAccessType = Array<ElementReductionAccumulator, kElementsPerAccess>;
/// Thread map used by output tile iterators
using ThreadMap = typename OutputTileIterator::ThreadMap;
/// Used for the reduction
struct ReductionDetail {
/// Number of threads per warp
static int const kWarpSize = 32;
/// Number of distinct scalar column indices handled by each thread
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
/// Number of distinct scalar row indices handled by each thread
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
/// Number of threads per threadblock
static int const kThreadCount = ThreadMap::kThreads;
/// Number of distinct threads per row of output tile
static int const kThreadsPerRow = ThreadblockShape::kN / kColumnsPerThread;
/// Half number of threads per row used for cross-thread reduction
static int const kHalfThreadsPerRow = (kThreadsPerRow >> 1);
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock
static int const kThreadRows = kThreadCount / kThreadsPerRow;
};
/// Shared storage
struct SharedStorage {
typename Visitor::SharedStorage storage_visitor;
CUTLASS_HOST_DEVICE
SharedStorage() { }
};
/// Host-constructable Argument structure
struct Arguments {
ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory
int64_t batch_stride;
typename Visitor::Arguments visitor_arg; ///< Argument type of visitor
/// Method
CUTLASS_HOST_DEVICE
Arguments(): reduction_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Arguments(
ElementReduction *reduction_ptr,
int64_t batch_stride,
typename Visitor::Arguments visitor_arg
):
reduction_ptr(reduction_ptr),
batch_stride(batch_stride),
visitor_arg(visitor_arg)
{ }
};
/// Param structure
struct Params {
ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory
int64_t batch_stride;
typename Visitor::Params visitor_param; ///< Argument type of visitor
/// Method
CUTLASS_HOST_DEVICE
Params(): reduction_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
reduction_ptr(args.reduction_ptr),
batch_stride(args.batch_stride),
visitor_param(args.visitor_arg)
{ }
};
private:
ElementReduction *reduction_output_ptr_; ///< Pointer to the reduction tensor in device memory
ElementReductionAccumulator reduction_accum;
Visitor visitor_; ///< visitor
int thread_idx_;
MatrixCoord threadblock_offset;
MatrixCoord problem_size_;
int thread_start_row_; /// used to identify
int state_[3]; /// used to track row iterator
int thread_offset_row_;
int64_t batch_stride_;
public:
/// Constructs the function object
CUTLASS_HOST_DEVICE
VisitorOpRowReduction(
Params const &params,
SharedStorage &shared_storage,
int thread_idx,
MatrixCoord threadblock_offset,
MatrixCoord problem_size
):
visitor_(params.visitor_param, shared_storage.storage_visitor,
thread_idx, threadblock_offset, problem_size),
reduction_output_ptr_(params.reduction_ptr),
thread_idx_(thread_idx),
threadblock_offset(threadblock_offset),
problem_size_(problem_size),
thread_start_row_(ThreadMap::initial_offset(thread_idx).row() + threadblock_offset.row()),
batch_stride_(params.batch_stride)
{
state_[0] = state_[1] = state_[2] = 0;
}
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
reduction_output_ptr_ += batch_idx * batch_stride_;
visitor_.set_batch_index(batch_idx);
}
CUTLASS_DEVICE
void begin_epilogue() {
visitor_.begin_epilogue();
}
CUTLASS_DEVICE
void begin_step(int step_idx) {
visitor_.begin_step(step_idx);
}
CUTLASS_DEVICE
void begin_row(int row_idx) {
visitor_.begin_row(row_idx);
reduction_accum = ElementReductionAccumulator(0);
}
CUTLASS_DEVICE
VisitAccessType visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum
) {
/// Get result from visitor
VisitAccessTypeVisitor result = visitor_.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
thread_offset_row_ = thread_start_row_ + ThreadMap::iteration_offset(frag_idx).row();
ReductionOpScalar reduction_op;
ElementReductionAccumulator reduction_accum_ = reduction(result);
// After performing the in-thread reduction, we then perform cross-thread / in-warp reduction
CUTLASS_PRAGMA_UNROLL
for (int i = ReductionDetail::kHalfThreadsPerRow; i > 0; i >>= 1) {
reduction_accum_ = reduction_op(reduction_accum_, __shfl_xor_sync(0xFFFFFFFF, reduction_accum_, i));
}
reduction_accum = reduction_op(reduction_accum, reduction_accum_);
return result;
}
CUTLASS_DEVICE
void end_row(int row_idx) {
visitor_.end_row(row_idx);
NumericConverter<ElementReduction, ElementReductionAccumulator> output_converter;
bool is_write_thread = (thread_offset_row_ < problem_size_.row() && (thread_idx_ % ReductionDetail::kThreadsPerRow) == 0);
int row_offset = thread_offset_row_ + threadblock_offset.column() / ThreadblockShape::kN * problem_size_.row();
ElementReduction *curr_ptr_reduction = reduction_output_ptr_ + row_offset;
arch::global_store<ElementReduction, sizeof(ElementReduction)>(
output_converter(reduction_accum),
(void *)curr_ptr_reduction,
is_write_thread);
}
CUTLASS_DEVICE
void end_step(int step_idx) {
visitor_.end_step(step_idx);
// run operator ++
++state_[0];
thread_start_row_ += ThreadMap::Shape::kRow;
if (state_[0] == ThreadMap::Count::kRow) {
state_[0] = 0;
++state_[1];
thread_start_row_ += (ThreadMap::Shape::kGroup - 1) *
ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
if (state_[1] == ThreadMap::Count::kGroup) {
state_[1] = 0;
++state_[2];
thread_start_row_ += ThreadMap::Count::kGroup *
ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;
if (state_[2] == ThreadMap::Count::kCluster) {
state_[2] = 0;
}
}
}
}
CUTLASS_DEVICE
void end_epilogue() {
visitor_.end_epilogue();
}
private:
CUTLASS_DEVICE
ElementReductionAccumulator reduction(VisitAccessTypeVisitor const& result) {
ElementReductionAccumulator sum_ = ElementReductionAccumulator(0);
ReductionOpScalar reduction_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < VisitAccessTypeVisitor::kElements; ++i) {
sum_ = reduction_op(sum_, result[i]);
}
return sum_;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,188 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor Op with Tensor Output
*/
#pragma once
#include "cutlass/cutlass.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue Visitor operator for the following computation:
///
/// ElementInput C <- device memory
///
/// It can only be a leaf node in the epilogue tree
template <
typename ElementAccumulator_, ///< Data type of the Accumulator
typename InputTileIterator_ ///< Tile iterator type to read the tensor
>
class VisitorOpTensorInput {
public:
using ElementAccumulator = ElementAccumulator_;
using InputTileIterator = InputTileIterator_;
static int const kElementsPerAccess = InputTileIterator::kElementsPerAccess;
using ElementInput = typename InputTileIterator::Element;
using VisitAccessType = Array<ElementInput, kElementsPerAccess>;
/// Fragment type of accumulator
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
struct SharedStorage {
CUTLASS_HOST_DEVICE
SharedStorage() { }
};
/// Host-constructable Argument structure
struct Arguments {
ElementInput *input_ptr; ///< Pointer to the input tensor in device memory
int ldt; ///< Leading dimension of the input tensor operand
int64_t batch_stride; ///< batch stride for batched GEMM
/// Methods
CUTLASS_HOST_DEVICE
Arguments(): input_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Arguments(
ElementInput *input_ptr,
int ldt, int64_t batch_stride
):
input_ptr(input_ptr),
ldt(ldt),
batch_stride(batch_stride)
{ }
};
/// Param structure
struct Params {
typename InputTileIterator::Params params_input;
ElementInput *input_ptr;
int64_t batch_stride;
/// Method
CUTLASS_HOST_DEVICE
Params():
input_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
params_input(args.ldt),
input_ptr(args.input_ptr),
batch_stride(args.batch_stride)
{ }
};
private:
InputTileIterator iterator_T_;
typename InputTileIterator::Fragment fragment_T_;
MatrixCoord problem_size;
int64_t batch_stride_;
public:
/// Constructs the function object
CUTLASS_HOST_DEVICE
VisitorOpTensorInput(
Params const &params,
SharedStorage &shared_storage,
int thread_idx,
MatrixCoord threadblock_offset,
MatrixCoord problem_size
):
iterator_T_(
InputTileIterator(
params.params_input,
params.input_ptr,
problem_size,
thread_idx,
threadblock_offset
)
),
problem_size(problem_size),
batch_stride_(params.batch_stride) { }
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
iterator_T_.add_pointer_offset(batch_idx * batch_stride_);
}
CUTLASS_DEVICE
void begin_epilogue() { }
CUTLASS_DEVICE
void begin_step(int step_idx) {
fragment_T_.clear();
iterator_T_.load(fragment_T_);
++iterator_T_;
}
CUTLASS_DEVICE
void begin_row(int row_idx) { }
CUTLASS_DEVICE
VisitAccessType visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum
) {
VisitAccessType source = reinterpret_cast<VisitAccessType *>(&fragment_T_)[frag_idx];
return source;
}
CUTLASS_DEVICE
void end_row(int row_idx) { }
CUTLASS_DEVICE
void end_step(int step_idx) { }
CUTLASS_DEVICE
void end_epilogue() { }
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,240 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor Op with Tensor Output
*/
#pragma once
#include "cutlass/cutlass.h"
#include "stdio.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue Visitor operator for the following computation:
///
/// ElementOutput T = ElementOutput(Visitor)
/// T-> device memory
///
template <
typename ElementAccumulator_, ///< Data type of the Accumulator
typename OutputTileIterator_, ///< Tile iterator type to write the tensor
typename Visitor_ ///< Child visitor that produces the output tensor
>
class VisitorOpTensorOutput {
public:
using ElementAccumulator = ElementAccumulator_;
using OutputTileIterator = OutputTileIterator_;
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
using ElementOutput = typename OutputTileIterator::Element;
using Visitor = Visitor_;
/// Fragment type returned from Visitor
using VisitAccessTypeVisitor = typename Visitor::VisitAccessType;
using ElementVisitor = typename VisitAccessTypeVisitor::Element;
using VisitAccessType = VisitAccessTypeVisitor;
/// Fragment type of accumulator
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
/// Fragment type of output
using OutputAccessType = Array<ElementOutput, kElementsPerAccess>;
static_assert(kElementsPerAccess==VisitAccessTypeVisitor::kElements, "kElementsPerAccess mismatches with Visitor");
struct SharedStorage {
typename Visitor::SharedStorage storage_visitor;
CUTLASS_HOST_DEVICE
SharedStorage() { }
};
/// Host-constructable Argument structure
struct Arguments {
ElementOutput *output_ptr; ///< Pointer to the output tensor in device memory
int ldt; ///< Leading dimension of the output tensor operand
int64_t batch_stride; ///< batch stride
typename Visitor::Arguments visitor_arg; ///< Argument type of visitor
/// Methods
CUTLASS_HOST_DEVICE
Arguments(): output_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Arguments(
ElementOutput *output_ptr,
int ldt,
int64_t batch_stride,
typename Visitor::Arguments visitor_arg
):
output_ptr(output_ptr),
ldt(ldt),
batch_stride(batch_stride),
visitor_arg(visitor_arg)
{ }
};
/// Param structure
struct Params {
typename OutputTileIterator::Params params_output;
ElementOutput *output_ptr;
int64_t batch_stride;
typename Visitor::Params visitor_param;
/// Method
CUTLASS_HOST_DEVICE
Params():
output_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
params_output(args.ldt),
output_ptr(args.output_ptr),
batch_stride(args.batch_stride),
visitor_param(args.visitor_arg)
{ }
};
private:
OutputTileIterator iterator_T_;
typename OutputTileIterator::Fragment fragment_T_;
MatrixCoord problem_size;
Visitor visitor_;
int64_t batch_stride_;
public:
/// Constructs the function object
CUTLASS_HOST_DEVICE
VisitorOpTensorOutput(
Params const &params,
SharedStorage &shared_storage,
int thread_idx,
MatrixCoord threadblock_offset,
MatrixCoord problem_size
):
visitor_(params.visitor_param, shared_storage.storage_visitor, thread_idx, threadblock_offset, problem_size),
iterator_T_(
OutputTileIterator(
params.params_output,
params.output_ptr,
problem_size,
thread_idx,
threadblock_offset
)
),
problem_size(problem_size),
batch_stride_(params.batch_stride) { }
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
iterator_T_.add_pointer_offset(batch_idx * batch_stride_);
visitor_.set_batch_index(batch_idx);
}
CUTLASS_DEVICE
void begin_epilogue() {
visitor_.begin_epilogue();
}
CUTLASS_DEVICE
void begin_step(int step_idx) {
fragment_T_.clear();
visitor_.begin_step(step_idx);
}
CUTLASS_DEVICE
void begin_row(int row_idx) {
visitor_.begin_row(row_idx);
}
CUTLASS_DEVICE
VisitAccessType visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum
) {
/// Get result from visitor
VisitAccessTypeVisitor result = visitor_.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
// Column guard
MatrixCoord thread_offset_ = iterator_T_.thread_start() + OutputTileIterator::ThreadMap::iteration_offset(frag_idx);
bool column_guard = (thread_offset_.column() < problem_size.column());
if (column_guard) {
NumericArrayConverter<ElementOutput, ElementVisitor, kElementsPerAccess> output_converter;
OutputAccessType &output = reinterpret_cast<OutputAccessType *>(&fragment_T_)[frag_idx];
output = output_converter(result);
}
return result;
}
CUTLASS_DEVICE
void end_row(int row_idx) {
visitor_.end_row(row_idx);
}
CUTLASS_DEVICE
void end_step(int step_idx) {
visitor_.end_step(step_idx);
iterator_T_.store(fragment_T_);
++iterator_T_;
}
CUTLASS_DEVICE
void end_epilogue() {
visitor_.end_epilogue();
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,226 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains the epilogue visitor Op with Unary operation
*/
#pragma once
#include "cutlass/cutlass.h"
#include "unary_ops.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue Visitor operator for the following computation:
///
/// ElementCompute alpha;
/// ElementCompute beta;
/// ElementCompute C = UnaryOp(ElementCompute(Visitor))
/// Return C;
///
template <
typename ElementAccumulator_, ///< Data type of the Accumulator
typename ElementCompute_, ///< Data type used to compute linear combination
int kElementsPerAccess_, ///< Number of elements computed per operation
typename Visitor_, ///< Child node
template<typename T, int N> typename UnaryOp_
>
class VisitorOpUnary{
public:
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
static int const kElementsPerAccess = kElementsPerAccess_;
using Visitor = Visitor_;
/// Fragment type returned from Visitor.visit
using VisitAccessTypeVisitor = typename Visitor::VisitAccessType;
using ElementVisit = typename VisitAccessTypeVisitor::Element;
/// Fragment type returned by this visitor
using VisitAccessType = Array<ElementCompute, kElementsPerAccess>;
/// Fragment type of accumulator
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
/// Combination Op TODO: generalize this
using UnaryOp = UnaryOp_<ElementCompute, kElementsPerAccess>;
static_assert(kElementsPerAccess==VisitAccessTypeVisitor::kElements, "kElementsPerAccess mismatches with Visitor");
/// SMEM buffer class required in the epilogue visitor
struct SharedStorage {
typename Visitor::SharedStorage storage_visitor;
CUTLASS_HOST_DEVICE
SharedStorage() {}
};
/// Host-constructable Arguments structure
struct Arguments {
typename UnaryOp::Arguments unary_arg;
typename Visitor::Arguments visitor_arg; ///< Argument type for visitor
//
// Methods
//
CUTLASS_HOST_DEVICE
Arguments():unary_arg() { }
CUTLASS_HOST_DEVICE
Arguments(
typename UnaryOp::Arguments unary_arg,
typename Visitor::Arguments visitor_arg
):
unary_arg(unary_arg),
visitor_arg(visitor_arg)
{ }
};
/// Parameter structure
struct Params {
typename UnaryOp::Params unary_param;
typename Visitor::Params visitor_param; ///< Argument type for visitor
//
// Methods
//
CUTLASS_HOST_DEVICE
Params():unary_param() { }
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
unary_param(args.unary_arg),
visitor_param(args.visitor_arg)
{ }
};
private:
//
// Data members
//
UnaryOp unary_op;
Visitor visitor_op;
public:
/// Constructs the function object
CUTLASS_HOST_DEVICE
VisitorOpUnary(
Params const &params,
SharedStorage &shared_storage,
int thread_idx,
MatrixCoord threadblock_offset,
MatrixCoord problem_size
):
unary_op(params.unary_param),
visitor_op(params.visitor_param, shared_storage.storage_visitor, thread_idx, threadblock_offset, problem_size)
{ }
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
visitor_op.set_batch_index(batch_idx);
}
CUTLASS_DEVICE
void begin_epilogue() {
if (unary_op.guard()) visitor_op.begin_epilogue();
}
CUTLASS_DEVICE
void begin_step(int step_idx) {
if (unary_op.guard()) visitor_op.begin_step(step_idx);
}
CUTLASS_DEVICE
void begin_row(int row_idx) {
if (unary_op.guard()) visitor_op.begin_row(row_idx);
}
CUTLASS_DEVICE
VisitAccessType visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorAccessType const &accum
) {
/// Get result from visitor A and visitor B
VisitAccessTypeVisitor result;
if (unary_op.guard()){
result = visitor_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
} else {
result.clear();
}
/// Type conversion
NumericArrayConverter<ElementCompute, ElementVisit, kElementsPerAccess> source_converter;
cutlass::multiplies<VisitAccessType> multiply_op;
return unary_op(source_converter(result));
}
CUTLASS_DEVICE
void end_row(int row_idx) {
if (unary_op.guard()) visitor_op.end_row(row_idx);
}
CUTLASS_DEVICE
void end_step(int step_idx) {
if (unary_op.guard()) visitor_op.end_step(step_idx);
}
CUTLASS_DEVICE
void end_epilogue() {
if (unary_op.guard()) visitor_op.end_epilogue();
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,481 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this layernormware without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A file contains all functioning classes needed by GemmLayernorm.
GemmLayernorm example = GEMM0 with partial reduction fused in epilogue (EpilogueVisitorLayerNorm)
+ lightweight full reduction kernel (ApplyFinalReduction)
+ GEMM1 with elemenwise operations fused in mainloop (GemmLayernormMainloopFusion)
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass/cutlass.h"
#include "cutlass/arch/memory.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/gemm/kernel/default_gemm.h"
#include "cutlass/gemm/kernel/default_gemm_complex.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename ThreadblockShape_,
int ThreadCount,
typename OutputTileIterator_,
typename AccumulatorTile_,
typename ElementAccumulator_,
typename ElementVariance_,
typename ElementMean_,
typename ElementLayernormCompute_,
typename ElementwiseFunctor_,
bool IsShiftedVariance_ = false
>
class EpilogueVisitorLayerNorm {
public:
using ElementVariance = ElementVariance_;
using ElementMean = ElementMean_;
using ElementLayernormCompute = ElementLayernormCompute_;
using AccumulatorTile = AccumulatorTile_;
using ThreadblockShape = ThreadblockShape_;
static int const kThreadCount = ThreadCount;
using OutputTileIterator = OutputTileIterator_;
using ElementwiseFunctor = ElementwiseFunctor_;
static int const kIterations = OutputTileIterator::kIterations;
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
static int const kRowIterations = OutputTileIterator::ThreadMap::Iterations::kRow;
static int const kThreads = OutputTileIterator::ThreadMap::kThreads;
static bool const kIsShiftedVariance = IsShiftedVariance_;
using ElementOutput = typename OutputTileIterator::Element;
static int const kDeltaRow = OutputTileIterator::ThreadMap::Delta::kRow;
/// Array type used in Shift-K Layernorm
static int const kRowAccessCount = kIterations * kRowIterations;
using ConvertedShiftFragment = Array<ElementLayernormCompute, kRowAccessCount>;
// Conducts manual transpose externally (already supported) for column major
using LayoutOutput = cutlass::layout::RowMajor;
using ElementAccumulator = ElementAccumulator_;
using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
using LayernormFragment = Array<ElementLayernormCompute, kElementsPerAccess>;
using OutputVector = Array<ElementOutput, kElementsPerAccess>;
using TensorRefD = TensorRef<ElementOutput, LayoutOutput>;
static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth;
static int const kThreadsInColumn = kThreads / kThreadsPerRow;
static int const kHalfThreadsPerRow = (kThreadsPerRow >> 1);
/// Argument structure
struct Arguments {
typename ElementwiseFunctor::Params elementwise;
ElementVariance *ptr_Variance;
ElementMean *ptr_Mean;
ElementOutput *ptr_Shifted_K;
MatrixCoord extent;
//
// Methods
//
Arguments():
ptr_Variance(nullptr),
ptr_Mean(nullptr),
ptr_Shifted_K(nullptr)
{
}
Arguments(
typename ElementwiseFunctor::Params elementwise_,
ElementVariance *ptr_Variance,
ElementMean *ptr_Mean_,
ElementOutput *ptr_Shifted_K_ = nullptr,
MatrixCoord extent = MatrixCoord(0, 0)
):
elementwise(elementwise_),
ptr_Variance(ptr_Variance),
ptr_Mean(ptr_Mean_),
ptr_Shifted_K(ptr_Shifted_K_),
extent(extent)
{
}
};
struct Params {
typename ElementwiseFunctor::Params elementwise;
ElementVariance *ptr_Variance;
ElementMean *ptr_Mean;
ElementOutput *ptr_Shifted_K;
MatrixCoord extent;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params():
ptr_Variance(nullptr),
ptr_Mean(nullptr)
{
}
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
elementwise(args.elementwise),
ptr_Variance(args.ptr_Variance),
ptr_Mean(args.ptr_Mean),
ptr_Shifted_K(args.ptr_Shifted_K),
extent(args.extent)
{
}
};
/// Shared storage
struct SharedStorage {
};
private:
Params const & params_;
SharedStorage & shared_storage_;
MatrixCoord extent_;
ElementwiseFunctor elementwise_;
OutputTileIterator iterator_C_;
OutputTileIterator iterator_D_;
typename OutputTileIterator::Fragment fragment_C_;
typename OutputTileIterator::Fragment fragment_D_;
ElementAccumulator alpha_;
ElementAccumulator beta_;
ConvertedShiftFragment shift_k_frag_;
ElementLayernormCompute accum_sum_square_;
ElementLayernormCompute accum_sum_element_;
int thread_idx_;
MatrixCoord thread_offset_;
gemm::GemmCoord threadblock_tile_offset_;
public:
CUTLASS_DEVICE
EpilogueVisitorLayerNorm(
Params const &params, ///< Parameters routed to the epilogue
SharedStorage &shared_storage, ///< Shared storage needed by the functors here
MatrixCoord threadblock_offset,
gemm::GemmCoord threadblock_tile_offset,
int thread_idx,
OutputTileIterator destination_iterator, ///< Tile iterator for destination
OutputTileIterator source_iterator ///< Threadblock tile coordinate in GEMMM
):
params_(params),
shared_storage_(shared_storage),
elementwise_(params.elementwise),
extent_(params.extent),
iterator_C_(source_iterator),
iterator_D_(destination_iterator),
threadblock_tile_offset_(threadblock_tile_offset),
thread_idx_(thread_idx)
{
alpha_ = (params.elementwise.alpha_ptr ? *params.elementwise.alpha_ptr : params.elementwise.alpha);
beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta);
if (beta_ == ElementAccumulator()) {
iterator_C_.clear_mask();
}
}
/// Helper to indicate split-K behavior
CUTLASS_DEVICE
void set_k_partition(
int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
int split_k_slices) { ///< Total number of split-K slices
}
/// Called to set the batch index
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
}
/// Called at the start of the epilogue just before iterating over accumulator slices
CUTLASS_DEVICE
void begin_epilogue() {
// If shift-K feature is enabled, we load shift-k fragment
// at the very beginning of an epilogue
if (kIsShiftedVariance && params_.ptr_Shifted_K != nullptr) {
shift_k_frag_.clear();
int thread_offset_row_base = iterator_D_.thread_start_row();
CUTLASS_PRAGMA_UNROLL
for (int iter_idx = 0; iter_idx < kIterations; ++iter_idx) {
int step_offset = iter_idx * OutputTileIterator::Shape::kRow;
CUTLASS_PRAGMA_UNROLL
for (int rid = 0; rid < kRowIterations; ++rid) {
int row_step_offset = rid * kDeltaRow;
int row_offset = thread_offset_row_base + step_offset + row_step_offset;
bool is_load = (row_offset < extent_.row());
shift_k_frag_[iter_idx * kRowIterations + rid] = load_shift_k_(row_offset, is_load);
}
}
}
}
/// Called at the start of one step before starting accumulator exchange
CUTLASS_DEVICE
void begin_step(int step_idx) {
fragment_D_.clear();
if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
fragment_C_.clear();
iterator_C_.load(fragment_C_);
++iterator_C_;
}
}
/// Called at the start of a row
CUTLASS_DEVICE
void begin_row(int row_idx) {
/// set the accumulator to 0
accum_sum_element_ = ElementLayernormCompute(0);
accum_sum_square_ = ElementLayernormCompute(0);
}
/// Called after accumulators have been exchanged for each accumulator vector
CUTLASS_DEVICE
void visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorFragment const &accum) {
using Mul = cutlass::multiplies<ElementLayernormCompute>;
using Minus = cutlass::minus<ElementLayernormCompute>;
using Exp = cutlass::fast_exp_op<ElementLayernormCompute>;
Minus minus;
Mul mul;
Exp exponential;
LayernormFragment result;
thread_offset_ =
iterator_D_.thread_start() +
OutputTileIterator::ThreadMap::iteration_offset(frag_idx);
NumericArrayConverter<ElementLayernormCompute, ElementOutput, kElementsPerAccess> source_converter;
OutputVector &source_vector = reinterpret_cast<OutputVector *>(&fragment_C_)[frag_idx];
bool column_guard = (thread_offset_.column() < extent_.column());
if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
result = source_converter(elementwise_(accum));
}else{
result = source_converter(elementwise_(accum, source_vector));
}
ElementLayernormCompute inv_scalar = cutlass::constants::one<ElementLayernormCompute>() / ElementLayernormCompute(extent_.column());
// Fragment is cleared for non-reachable columns so no need to check against column guard
ElementLayernormCompute accum_sum_element_tmp = element_sum_accumulator_(result);
// Square sum is different. Non-reachable columns should've been computed for shift-k
// Otherwise we will incorrectly have some extra k^2 added into square sum.
ElementLayernormCompute accum_sum_square_tmp = ElementLayernormCompute(0);
if (column_guard) {
accum_sum_square_tmp = (kIsShiftedVariance) ? \
square_sum_accumulator_(result, shift_k_frag_[iter_idx * kRowIterations + row_idx]) : \
square_sum_accumulator_(result);
}
accum_sum_element_tmp *= inv_scalar;
accum_sum_square_tmp *= inv_scalar;
// After performing the in-thread reduction, we then perform cross-thread / in-warp reduction
CUTLASS_PRAGMA_UNROLL
for (int i = kHalfThreadsPerRow; i > 0; i >>= 1) {
accum_sum_element_tmp += __shfl_xor_sync(0xFFFFFFFF, accum_sum_element_tmp, i);
accum_sum_square_tmp += __shfl_xor_sync(0xFFFFFFFF, accum_sum_square_tmp, i);
}
accum_sum_element_ += accum_sum_element_tmp;
accum_sum_square_ += accum_sum_square_tmp;
// Convert to the output
NumericArrayConverter<ElementOutput, ElementLayernormCompute, kElementsPerAccess> output_converter;
OutputVector &output = reinterpret_cast<OutputVector *>(&fragment_D_)[frag_idx];
output = output_converter(result);
}
/// Called at the start of a row
CUTLASS_DEVICE
void end_row(int row_idx) {
using ConvertVarianceOutput = cutlass::NumericConverter<ElementVariance, ElementLayernormCompute>;
using ConvertMeanOutput = cutlass::NumericConverter<ElementMean, ElementLayernormCompute>;
ConvertVarianceOutput convert_variance_output;
ConvertMeanOutput convert_mean_output;
bool is_write_thread = (thread_offset_.row() < extent_.row() && (threadIdx.x % kThreadsPerRow) == 0);
int row_offset = thread_offset_.row() + threadblock_tile_offset_.n() * extent_.row();
ElementVariance *curr_ptr_sum_square = params_.ptr_Variance + row_offset;
ElementMean *curr_ptr_element_sum = params_.ptr_Mean + row_offset;
arch::global_store<ElementVariance, sizeof(ElementVariance)>(
convert_variance_output(accum_sum_square_),
(void *)curr_ptr_sum_square,
is_write_thread);
arch::global_store<ElementMean, sizeof(ElementMean)>(
convert_mean_output(accum_sum_element_),
(void *)curr_ptr_element_sum,
is_write_thread);
}
/// Called after all accumulator elements have been visited
CUTLASS_DEVICE
void end_step(int step_idx) {
iterator_D_.store(fragment_D_);
++iterator_D_;
}
/// Called after all steps have been completed
CUTLASS_DEVICE
void end_epilogue() {
}
private:
CUTLASS_DEVICE
ElementLayernormCompute load_shift_k_(int row_offset, bool is_load) {
using ConvertShiftK = cutlass::NumericConverter<ElementLayernormCompute, ElementOutput>;
ConvertShiftK convert_shift_k;
ElementOutput shift_k_val;
// Computes the address to load shift_k element
ElementOutput *curr_ptr_shift_k = params_.ptr_Shifted_K + row_offset;
// Conditionally loads from global memory
arch::global_load<ElementOutput, sizeof(ElementOutput)>(shift_k_val, (void *)curr_ptr_shift_k, is_load);
// Converts data type to return
ElementLayernormCompute converted_shift_k_val = convert_shift_k(shift_k_val);
return converted_shift_k_val;
}
CUTLASS_DEVICE
ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum) {
ElementLayernormCompute sum_ = ElementLayernormCompute(0);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < LayernormFragment::kElements; ++i) {
auto accum_ = accum[i];
sum_ += accum_ * accum_;
}
return sum_;
}
CUTLASS_DEVICE
ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum, ElementLayernormCompute shift_k_val) {
ElementLayernormCompute sum_ = ElementLayernormCompute(0);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < LayernormFragment::kElements; ++i) {
auto accum_ = accum[i] - shift_k_val;
sum_ += accum_ * accum_;
}
return sum_;
}
CUTLASS_DEVICE
ElementLayernormCompute element_sum_accumulator_(LayernormFragment const &accum) {
ElementLayernormCompute sum_ = ElementLayernormCompute(0);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < LayernormFragment::kElements; ++i) {
sum_ += accum[i];
}
return sum_;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,692 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/complex.h"
#include "cutlass/semaphore.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/trace.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
>
struct GemmUniversalwithEpilogueVisitor {
public:
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueVisitor = typename Epilogue::Visitor;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using ElementC = typename EpilogueVisitor::ElementOutput;
using LayoutC = typename EpilogueVisitor::OutputTileIterator::Layout;
static ComplexTransform const kTransformA = Mma::kTransformA;
static ComplexTransform const kTransformB = Mma::kTransformB;
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
static int const kStages = Mma::kStages;
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
/// Split-K preserves splits that are 128b aligned
static int const kSplitKAlignment = const_max(
128 / sizeof_bits<ElementA>::value,
128 / sizeof_bits<ElementB>::value
);
//
// Structures
//
/// Argument structure
struct Arguments {
//
// Data members
//
GemmUniversalMode mode;
GemmCoord problem_size;
int batch_count;
typename EpilogueVisitor::Arguments epilogue_visitor;
void const * ptr_A;
void const * ptr_B;
void const * ptr_C;
void * ptr_D;
int64_t batch_stride_A;
int64_t batch_stride_B;
int64_t batch_stride_C;
int64_t batch_stride_D;
typename LayoutA::Stride stride_a;
typename LayoutB::Stride stride_b;
typename LayoutC::Stride stride_c;
typename LayoutC::Stride stride_d;
typename LayoutA::Stride::LongIndex lda;
typename LayoutB::Stride::LongIndex ldb;
typename LayoutC::Stride::LongIndex ldc;
typename LayoutC::Stride::LongIndex ldd;
int const * ptr_gather_A_indices;
int const * ptr_gather_B_indices;
int const * ptr_scatter_D_indices;
//
// Methods
//
Arguments():
mode(GemmUniversalMode::kGemm),
batch_count(1),
ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr),
ptr_gather_A_indices(nullptr),
ptr_gather_B_indices(nullptr),
ptr_scatter_D_indices(nullptr) {}
/// constructs an arguments structure
Arguments(
GemmUniversalMode mode,
GemmCoord problem_size,
int batch_count,
typename EpilogueVisitor::Arguments epilogue_visitor,
void const * ptr_A,
void const * ptr_B,
void const * ptr_C,
void * ptr_D,
int64_t batch_stride_A,
int64_t batch_stride_B,
int64_t batch_stride_C,
int64_t batch_stride_D,
typename LayoutA::Stride stride_a,
typename LayoutB::Stride stride_b,
typename LayoutC::Stride stride_c,
typename LayoutC::Stride stride_d,
int const *ptr_gather_A_indices = nullptr,
int const *ptr_gather_B_indices = nullptr,
int const *ptr_scatter_D_indices = nullptr
):
mode(mode),
problem_size(problem_size),
batch_count(batch_count),
epilogue_visitor(epilogue_visitor),
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d),
ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),
ptr_scatter_D_indices(ptr_scatter_D_indices) {
lda = 0;
ldb = 0;
ldc = 0;
ldd = 0;
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
}
/// constructs an arguments structure
Arguments(
GemmUniversalMode mode,
GemmCoord problem_size,
int batch_count,
typename EpilogueVisitor::Arguments epilogue_visitor,
void const * ptr_A,
void const * ptr_B,
void const * ptr_C,
void * ptr_D,
int64_t batch_stride_A,
int64_t batch_stride_B,
int64_t batch_stride_C,
int64_t batch_stride_D,
typename LayoutA::Stride::LongIndex lda,
typename LayoutB::Stride::LongIndex ldb,
typename LayoutC::Stride::LongIndex ldc,
typename LayoutC::Stride::LongIndex ldd,
int const *ptr_gather_A_indices = nullptr,
int const *ptr_gather_B_indices = nullptr,
int const *ptr_scatter_D_indices = nullptr
):
mode(mode),
problem_size(problem_size),
batch_count(batch_count),
epilogue_visitor(epilogue_visitor),
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
lda(lda), ldb(ldb), ldc(ldc), ldd(ldd),
ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),
ptr_scatter_D_indices(ptr_scatter_D_indices) {
stride_a = make_Coord(lda);
stride_b = make_Coord(ldb);
stride_c = make_Coord(ldc);
stride_d = make_Coord(ldd);
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
}
/// Returns arguments for the transposed problem
Arguments transposed_problem() const {
Arguments args(*this);
std::swap(args.problem_size.m(), args.problem_size.n());
std::swap(args.ptr_A, args.ptr_B);
std::swap(args.lda, args.ldb);
std::swap(args.stride_a, args.stride_b);
std::swap(args.batch_stride_A, args.batch_stride_B);
std::swap(args.ptr_gather_A_indices, args.ptr_gather_B_indices);
return args;
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params {
cutlass::gemm::GemmCoord problem_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
int swizzle_log_tile;
typename Mma::IteratorA::Params params_A;
typename Mma::IteratorB::Params params_B;
typename EpilogueVisitor::OutputTileIterator::Params params_C;
typename EpilogueVisitor::OutputTileIterator::Params params_D;
typename EpilogueVisitor::Params epilogue_visitor;
GemmUniversalMode mode;
int batch_count;
int gemm_k_size;
void * ptr_A;
void * ptr_B;
void * ptr_C;
void * ptr_D;
int64_t batch_stride_A;
int64_t batch_stride_B;
int64_t batch_stride_C;
int64_t batch_stride_D;
int * ptr_gather_A_indices;
int * ptr_gather_B_indices;
int * ptr_scatter_D_indices;
int *semaphore;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params():
swizzle_log_tile(0),
params_A(0),
params_B(0),
params_C(0),
params_D(0),
batch_count(0),
gemm_k_size(0),
mode(cutlass::gemm::GemmUniversalMode::kGemm),
ptr_A(nullptr),
ptr_B(nullptr),
ptr_C(nullptr),
ptr_D(nullptr),
batch_stride_A(0),
batch_stride_B(0),
batch_stride_C(0),
batch_stride_D(0),
ptr_gather_A_indices(nullptr),
ptr_gather_B_indices(nullptr),
ptr_scatter_D_indices(nullptr),
semaphore(nullptr) { }
CUTLASS_HOST_DEVICE
Params(
Arguments const &args,
cutlass::gemm::GemmCoord const & grid_tiled_shape,
int gemm_k_size,
void *workspace = nullptr
):
problem_size(args.problem_size),
grid_tiled_shape(grid_tiled_shape),
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
params_A(args.lda ? make_Coord_with_padding<LayoutA::kStrideRank>(args.lda) : args.stride_a),
params_B(args.ldb ? make_Coord_with_padding<LayoutB::kStrideRank>(args.ldb) : args.stride_b),
params_C(args.ldc ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldc) : args.stride_c),
params_D(args.ldd ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldd) : args.stride_d),
epilogue_visitor(args.epilogue_visitor),
mode(args.mode),
batch_count(args.batch_count),
gemm_k_size(gemm_k_size),
ptr_A(const_cast<void *>(args.ptr_A)),
ptr_B(const_cast<void *>(args.ptr_B)),
ptr_C(const_cast<void *>(args.ptr_C)),
ptr_D(args.ptr_D),
batch_stride_A(args.batch_stride_A),
batch_stride_B(args.batch_stride_B),
batch_stride_C(args.batch_stride_C),
batch_stride_D(args.batch_stride_D),
ptr_gather_A_indices(const_cast<int *>(args.ptr_gather_A_indices)),
ptr_gather_B_indices(const_cast<int *>(args.ptr_gather_B_indices)),
ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices)),
semaphore(static_cast<int *>(workspace)) {
}
CUTLASS_HOST_DEVICE
void update(
Arguments const &args,
void *workspace = nullptr) {
ptr_A = const_cast<void *>(args.ptr_A);
ptr_B = const_cast<void *>(args.ptr_B);
ptr_C = const_cast<void *>(args.ptr_C);
ptr_D = args.ptr_D;
ptr_gather_A_indices = const_cast<int *>(args.ptr_gather_A_indices);
ptr_gather_B_indices = const_cast<int *>(args.ptr_gather_B_indices);
ptr_scatter_D_indices = const_cast<int *>(args.ptr_scatter_D_indices);
batch_stride_A = args.batch_stride_A;
batch_stride_B = args.batch_stride_B;
batch_stride_C = args.batch_stride_C;
batch_stride_D = args.batch_stride_D;
epilogue_visitor = args.epilogue_visitor;
semaphore = static_cast<int *>(workspace);
CUTLASS_TRACE_HOST("GemmUniversal::Params::update()");
}
};
/// Shared memory storage structure
union SharedStorage {
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
typename EpilogueVisitor::SharedStorage visitor;
};
public:
//
// Methods
//
CUTLASS_DEVICE
GemmUniversalwithEpilogueVisitor() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(
cutlass::gemm::GemmCoord const & problem_size) {
CUTLASS_TRACE_HOST("GemmUniversalwithEpilogueVisitor::can_implement()");
static int const kAlignmentA = (platform::is_same<LayoutA,
layout::ColumnMajorInterleaved<32>>::value)
? 32
: (platform::is_same<LayoutA,
layout::ColumnMajorInterleaved<64>>::value)
? 64
: Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = (platform::is_same<LayoutB,
layout::RowMajorInterleaved<32>>::value)
? 32
: (platform::is_same<LayoutB,
layout::RowMajorInterleaved<64>>::value)
? 64
: Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = (platform::is_same<LayoutC,
layout::ColumnMajorInterleaved<32>>::value)
? 32
: (platform::is_same<LayoutC,
layout::ColumnMajorInterleaved<64>>::value)
? 64
: Epilogue::OutputTileIterator::kElementsPerAccess;
bool isAMisaligned = false;
bool isBMisaligned = false;
bool isCMisaligned = false;
if (platform::is_same<LayoutA, layout::RowMajor>::value) {
isAMisaligned = problem_size.k() % kAlignmentA;
} else if (platform::is_same<LayoutA, layout::ColumnMajor>::value) {
isAMisaligned = problem_size.m() % kAlignmentA;
} else if (platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value
|| platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {
isAMisaligned = problem_size.k() % kAlignmentA;
}
if (platform::is_same<LayoutB, layout::RowMajor>::value) {
isBMisaligned = problem_size.n() % kAlignmentB;
} else if (platform::is_same<LayoutB, layout::ColumnMajor>::value) {
isBMisaligned = problem_size.k() % kAlignmentB;
} else if (platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value
|| platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {
isBMisaligned = problem_size.k() % kAlignmentB;
}
if (platform::is_same<LayoutC, layout::RowMajor>::value) {
isCMisaligned = problem_size.n() % kAlignmentC;
} else if (platform::is_same<LayoutC, layout::ColumnMajor>::value) {
isCMisaligned = problem_size.m() % kAlignmentC;
} else if (platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value
|| platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {
isCMisaligned = problem_size.n() % kAlignmentC;
}
if (isAMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand");
return Status::kErrorMisalignedOperand;
}
if (isBMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand");
return Status::kErrorMisalignedOperand;
}
if (isCMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand");
return Status::kErrorMisalignedOperand;
}
CUTLASS_TRACE_HOST(" returning kSuccess");
return Status::kSuccess;
}
static Status can_implement(Arguments const &args) {
return can_implement(args.problem_size);
}
static size_t get_extra_workspace_size(Arguments const &args,
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
return 0;
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
return;
}
int offset_k = 0;
int problem_size_k = params.problem_size.k();
ElementA *ptr_A = static_cast<ElementA *>(params.ptr_A);
ElementB *ptr_B = static_cast<ElementB *>(params.ptr_B);
//
// Fetch pointers based on mode.
//
if (params.mode == GemmUniversalMode::kGemm ||
params.mode == GemmUniversalMode::kGemmSplitKParallel) {
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) {
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
}
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
}
else if (params.mode == GemmUniversalMode::kBatched) {
ptr_A += threadblock_tile_offset.k() * params.batch_stride_A;
ptr_B += threadblock_tile_offset.k() * params.batch_stride_B;
}
else if (params.mode == GemmUniversalMode::kArray) {
ptr_A = static_cast<ElementA * const *>(params.ptr_A)[threadblock_tile_offset.k()];
ptr_B = static_cast<ElementB * const *>(params.ptr_B)[threadblock_tile_offset.k()];
}
__syncthreads();
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_tile_offset.m() * Mma::Shape::kM,
offset_k,
};
cutlass::MatrixCoord tb_offset_B{
offset_k,
threadblock_tile_offset.n() * Mma::Shape::kN
};
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
params.params_A,
ptr_A,
{params.problem_size.m(), problem_size_k},
thread_idx,
tb_offset_A,
params.ptr_gather_A_indices);
typename Mma::IteratorB iterator_B(
params.params_B,
ptr_B,
{problem_size_k, params.problem_size.n()},
thread_idx,
tb_offset_B,
params.ptr_gather_B_indices);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
mma(
gemm_k_iterations,
accumulators,
iterator_A,
iterator_B,
accumulators);
//
// Epilogue
//
// EpilogueOutputOp output_op(params.output_op);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
//assume identity swizzle
MatrixCoord threadblock_offset(
threadblock_tile_offset.m() * Mma::Shape::kM,
threadblock_tile_offset.n() * Mma::Shape::kN
);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
ElementC *ptr_C = static_cast<ElementC *>(params.ptr_C);
ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);
//
// Fetch pointers based on mode.
//
// Construct the semaphore.
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
// if (params.mode == GemmUniversalMode::kGemm) {
// // TODO: fix this order
// // If performing a reduction via split-K, fetch the initial synchronization
// if (params.grid_tiled_shape.k() > 1) {
// // Fetch the synchronization lock initially but do not block.
// semaphore.fetch();
// // Indicate which position in a serial reduction the output operator is currently updating
// output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
// }
// }
// Tile iterator loading from source tensor.
EpilogueVisitor epilogue_visitor(
params.epilogue_visitor,
shared_storage.visitor,
threadblock_offset,
threadblock_tile_offset,
thread_idx,
params.problem_size.mn()
);
// if (params.mode == GemmUniversalMode::kGemmSplitKParallel) {
// ptr_D += threadblock_tile_offset.k() * params.batch_stride_D;
// }
if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) {
epilogue_visitor.set_batch_index(threadblock_tile_offset.k());
}
Epilogue epilogue(
shared_storage.epilogue,
thread_idx,
warp_idx,
lane_idx);
// Wait on the semaphore - this latency may have been covered by iterator construction
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
// TODO: ???
// if (threadblock_tile_offset.k()) {
// iterator_C = iterator_D;
// }
semaphore.wait(threadblock_tile_offset.k());
}
// Execute the epilogue operator to update the destination tensor.
epilogue(epilogue_visitor, accumulators);
//
// Release the semaphore
//
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
int lock = 0;
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
// The final threadblock resets the semaphore for subsequent grids.
lock = 0;
}
else {
// Otherwise, the semaphore is incremented
lock = threadblock_tile_offset.k() + 1;
}
semaphore.release(lock);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -50,7 +50,13 @@ void bind_tensor_coord(py::module &m) {
R"pbdoc(Defines a canonical 4D coordinate used by tensor operations)pbdoc")
.def(py::init<int, int, int, int>(),
py::arg("n"), py::arg("h"), py::arg("w"), py::arg("c"),
R"pbdoc(Helper to construct from N, H, W, and C)pbdoc");
R"pbdoc(Helper to construct from N, H, W, and C)pbdoc")
.def("at", py::overload_cast<int>(&cutlass::Tensor4DCoord::at),
py::arg("dim"),
R"pbdoc(Gets the index of a given Coord element)pbdoc")
.def("size", [](const cutlass::Tensor4DCoord & coord) {
return coord.at(0) * coord.at(1) * coord.at(2) * coord.at(3);},
R"pbdoc(The size of the tensor coord)pbdoc");
py::class_<cutlass::Coord<3>>(m, "Tensor3DCoord",
R"pbdoc(Defines a canonical 3D coordinate used by tensor operations)pbdoc")

View File

@ -1,7 +1,24 @@
from pycutlass.type import *
import re
def SubstituteTemplate(template, values):
text = template
changed = True
while changed:
changed = False
for key, value in values.items():
regex = "\\$\\{%s\\}" % key
newtext = re.sub(regex, value, text)
if newtext != text:
changed = True
text = newtext
return text
from pycutlass.type_hint import *
from pycutlass.tensor_ref import *
from pycutlass.operation import *
from pycutlass.epilogue import *
from pycutlass.parser import *
from pycutlass.compiler import ArtifactManager
from pycutlass.memory_manager import *
from pycutlass.arguments import *

View File

@ -60,6 +60,13 @@ class ArgumentBase:
C: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]',
D: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]',
**kwargs) -> None:
# tensor_C can be interpreted as the bias with bias=True in keyword args
if "bias" in kwargs.keys():
self.bias = kwargs["bias"]
else:
# by default, tensor_C is not bias
self.bias = False
# preprocessing input tensors
if isinstance(A, np.ndarray):
@ -72,21 +79,28 @@ class ArgumentBase:
self.ptr_B = self.buffer_B.ptr
self.ptr_C = self.buffer_C.ptr
self.ptr_D = self.buffer_D.ptr
# number of elements in C
self.tensor_c_numel = C.size
elif torch_available and isinstance(A, torch.Tensor):
self.ptr_A = TorchFrontend.argument(A)
self.ptr_B = TorchFrontend.argument(B)
self.ptr_C = TorchFrontend.argument(C)
self.ptr_D = TorchFrontend.argument(D)
# number of elements in C
self.tensor_c_numel = C.numel()
elif isinstance(A, cuda.CUdeviceptr):
self.ptr_A = A
self.ptr_B = B
self.ptr_C = C
self.ptr_D = D
elif cupy_available and isinstance(A, cp.ndarray):
self.ptr_A = CupyFrontend.argument(A)
self.ptr_B = CupyFrontend.argument(B)
self.ptr_C = CupyFrontend.argument(C)
self.ptr_D = CupyFrontend.argument(D)
# number of elements in C
self.tensor_c_numel = C.size
else:
raise TypeError(
"Unsupported Frontend. Only support numpy and torch")

View File

@ -63,22 +63,9 @@ dtype2ctype = {
}
def get_epilogue_output_op(element_compute_):
element_compute = dtype2ctype[element_compute_]
def get_gemm_arguments(epilogue_functor):
class _EpilogueOutputOpParams(ctypes.Structure):
_fields_ = [
("alpha", element_compute),
("beta", element_compute),
("alpha_ptr", ctypes.c_void_p),
("beta_ptr", ctypes.c_void_p)
]
return _EpilogueOutputOpParams
def get_gemm_arguments(element_compute_):
_EpilogueOutputOpParams = get_epilogue_output_op(element_compute_)
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
class _GemmArguments(ctypes.Structure):
_fields_ = [
@ -116,8 +103,8 @@ def get_gemm_arguments(element_compute_):
# include/cutlass/gemm/kernel/gemm_grouped.h
def get_gemm_grouped_arguments(element_compute_):
_EpilogueOutputOpParams = get_epilogue_output_op(element_compute_)
def get_gemm_grouped_arguments(epilogue_functor):
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
class _GEMMGroupedArguments(ctypes.Structure):
_fields_ = [
@ -214,8 +201,8 @@ class TensorRef2D_(ctypes.Structure):
# include/cutlass/conv/kernel/implicit_gemm_convolution.h
# split_k_mode: kNone: 0, kSerial: 1, kParallel: 2, kParallelSerial: 3, kInvalid: 4
def get_conv2d_arguments(element_compute_):
_EpilogueOutputOpParams = get_epilogue_output_op(element_compute_)
def get_conv2d_arguments(epilogue_functor):
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
class _Conv2dArguments(ctypes.Structure):
_fields_ = [
@ -236,8 +223,8 @@ def get_conv2d_arguments(element_compute_):
############################################################################################
def get_reduction_params(element_compute_):
_EpilogueOutputParams = get_epilogue_output_op(element_compute_)
def get_reduction_params(epilogue_functor):
_EpilogueOutputParams = epilogue_functor.epilogue_type
class _ReductionParams(ctypes.Structure):
_fields_ = [

View File

@ -1,366 +0,0 @@
#################################################################################################
#
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from pycutlass import *
from pycutlass.library import SubstituteTemplate
import cutlass
from cuda import cuda
from cuda import nvrtc
import tempfile
import os
import ctypes
#
import json
import sqlite3
IncludeTemplate = r'''#include "${include}"
'''
#
class CompilationOptions:
'''
Compilation options.
'''
#
def __init__(self, architectures = [80], include_paths = []):
self.includes = []
self.include_paths = include_paths
self.flags = ['-std=c++11', '-default-device']
self.architectures = architectures
#
def get(self):
options = []
for flag in self.flags:
options.append(bytes(str.encode(flag)))
for incl in self.include_paths:
options.append(bytes(str.encode('--include-path=%s' % incl)))
arch_list = "-arch="
for idx, arch in enumerate(self.architectures):
if idx:
arch_list += ","
arch_list += "sm_%d" % arch
options.append(bytes(str.encode(arch_list)))
return options
def convertToBinaryData(filename):
with open(filename, 'rb') as file:
blobData = file.read()
return blobData
def CDLLBin(host_binary):
tempfile.tempdir = "./"
temp_so = tempfile.NamedTemporaryFile(prefix='host_func', suffix='.so', delete=True)
with open(temp_so.name, 'wb') as file:
file.write(host_binary)
host_lib = ctypes.CDLL(temp_so.name)
return host_lib
class ArtifactManager:
"""
Artifact manager
"""
def __init__(self) -> None:
try:
connection = sqlite3.connect("./compiled_cache.db")
cursor = connection.cursor()
sqlite_create_table_query = """CREATE TABLE compiled_operations(op_key TEXT NOT NULL UNIQUE, cubin BLOB NOT NULL, hostbin BLOB NOT NULL, op_name TEXT NOT NULL, op_attrs TEXT NOT NULL)"""
cursor.execute(sqlite_create_table_query)
connection.commit()
cursor.close()
except:
pass
self.compiled_cache_device = cutlass.CompileCache()
self.compiled_cache_host = cutlass.CompileCache()
def insert_operation(self, op_key, cubin, hostfile, op_name, op_attrs):
connection = sqlite3.connect("./compiled_cache.db")
cursor = connection.cursor()
sqlite_insert_blob_query = """ INSERT OR IGNORE INTO compiled_operations (op_key, cubin, hostbin, op_name, op_attrs) VALUES (?, ?, ?, ?, ?)"""
hostbin = convertToBinaryData(hostfile)
data_tuple = (op_key, cubin, hostbin, op_name, json.dumps(op_attrs))
cursor.execute(sqlite_insert_blob_query, data_tuple)
connection.commit()
cursor.close()
def load_operation(self, op_key):
connection = sqlite3.connect("./compiled_cache.db")
cursor = connection.cursor()
sqlite_fetch_blob_query = """SELECT * from compiled_operations where op_key = ?"""
# try:
cursor.execute(sqlite_fetch_blob_query, (op_key, ))
record = cursor.fetchall()
if len(record) == 0:
return False
for row in record:
key, cubin_image, host_binary, operation_name, op_attr = row
op_attr = json.loads(op_attr)
err, module = cuda.cuModuleLoadData(cubin_image)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError('Cuda Error: {}'.format(err))
err, kernel = cuda.cuModuleGetFunction(module, bytes(str.encode(operation_name)))
self.compiled_cache_device.insert(key, kernel)
compiled_host_fns = {}
host_lib = CDLLBin(host_binary)
func_name = operation_name + '_get_params'
func = getattr(host_lib, func_name)
func.restype = ctypes.POINTER(ctypes.c_char * op_attr[0])
compiled_host_fns['get_args'] = func
func_name = operation_name + '_shared_memory_size'
func = getattr(host_lib, func_name)
compiled_host_fns['shared_memory_capacity'] = func()
for attr in op_attr:
if isinstance(attr, str):
func_name = operation_name + '_' + attr
func = getattr(host_lib, func_name)
compiled_host_fns[attr] = func
self.compiled_cache_host.insert(key, compiled_host_fns)
return True
def emit_compile_(self, operation_list, compilation_options):
"""
Compile a list of kernels and store them into database
"""
source_buffer_device = ""
source_buffer_host = ""
# 1. include
includes = []
for operation in operation_list:
for incl in operation.emitter.includes:
if incl not in includes:
includes.append(incl)
includes_host = [
"builtin_types.h", "device_launch_parameters.h", "stddef.h"] + includes
for incl in includes:
source_buffer_device += SubstituteTemplate(IncludeTemplate, {'include': incl})
for incl in includes_host:
if "/device/" not in incl:
source_buffer_host += SubstituteTemplate(IncludeTemplate, { 'include': incl} )
# 2. Operations
for operation in operation_list:
source_buffer_device += operation.emit()
source_buffer_host += operation.emit()
values = {
'operation_name': operation.name(),
'operation_suffix': operation.emitter.operation_suffix
}
source_buffer_device += SubstituteTemplate(operation.KernelTemplate, values)
source_buffer_host += SubstituteTemplate(operation.HostTemplate, values)
# 3. compile
err, program = nvrtc.nvrtcCreateProgram(
str.encode(source_buffer_device),
bytes(str.encode("module.cu")),
0, [], [])
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError('NVRTC Error: {}'.format(err))
# Compile program
options = compilation_options.get()
err, = nvrtc.nvrtcCompileProgram(program, len(options), options)
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
error_string = 'NVRTC Error: {}\n'.format(err)
# Get log from compilation
err, logSize = nvrtc.nvrtcGetProgramLogSize(program)
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError('NVRTC Error: {}'.format(err))
log = b' ' * logSize
err, = nvrtc.nvrtcGetProgramLog(program, log)
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError('NVRTC Error: {}'.format(err))
raise RuntimeError(error_string + log.decode() + source_buffer_device)
# Get data from compilation
err, dataSize = nvrtc.nvrtcGetCUBINSize(program)
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError('NVRTC Error: {}'.format(err))
cubin_image = b' ' * dataSize
err, = nvrtc.nvrtcGetCUBIN(program, cubin_image)
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError('NVRTC Error: {}'.format(err))
# compile the host code
options = compilation_options.get()
cmd = "echo '%s'|g++ -x c++ -fpermissive -w -fPIC" % source_buffer_host
for opt in options:
opt = opt.decode("utf-8")
if opt not in ['-default-device', '-std=c++11', '-arch=sm_80']:
if '--include-path=' in opt:
cmd += " " + opt.replace('--include-path=', '-I')
else:
cmd += " "+ opt
tempfile.tempdir = "./"
temp = tempfile.NamedTemporaryFile(prefix='host_func', suffix='.so', delete=True)
cmd += ' - -shared -o %s' % temp.name
os.system(cmd)
host_lib = ctypes.CDLL(temp.name)
return cubin_image, host_lib, temp
def add_module(self, operations, compile_options=None):
"""
Insert a new compiled device module
"""
if compile_options is None:
cutlass_path = os.getenv('CUTLASS_PATH')
assert cutlass_path is not None, "Environment variable 'CUTLASS_PATH' is not defined."
cuda_install_path = os.getenv('CUDA_INSTALL_PATH')
assert cuda_install_path is not None, "Environment variable 'CUDA_INSTALL_PATH' is not defined."
architectures = []
for operation in operations:
if hasattr(operation, "tile_description"):
cc = operation.tile_description.minimum_compute_capability
if cc not in architectures:
architectures.append(cc)
include_paths = [
cuda_install_path + '/include',
cutlass_path + '/include',
cutlass_path + '/tools/util/include',
]
compile_options = CompilationOptions(architectures, include_paths)
# save the cubin
operation_key = []
operation_list = []
for operation in operations:
# step 1: get kernel string as key
key = operation.rt_module.emit() + operation.procedural_name()
# step 1: check if the operation is in cache
compiled_kernel = self.compiled_cache_device.at(key)
if compiled_kernel is None:
hit = self.load_operation(key)
if hit:
compiled_kernel = self.compiled_cache_device.at(key)
assert compiled_kernel is not None
if compiled_kernel is not None:
operation.rt_module.kernel = compiled_kernel
compiled_host_fns = self.compiled_cache_host.at(key)
assert compiled_host_fns is not None
for key in compiled_host_fns.keys():
setattr(operation.rt_module, key, compiled_host_fns[key])
operation.rt_module.initialize()
else:
operation_list.append(operation.rt_module)
operation_key.append(key)
if len(operation_list) > 0:
cubin_image, host_lib, host_file = self.emit_compile_(operation_list, compile_options)
err, module = cuda.cuModuleLoadData(cubin_image)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError('Cuda Error: {}'.format(err))
operation_name = []
operation_attr = []
for operation, key in zip(operation_list, operation_key):
# get device kernels
err, operation.kernel = cuda.cuModuleGetFunction(
module,
bytes(str.encode(operation.name()))
)
operation_name.append(operation.name())
self.compiled_cache_device.insert(key, operation.kernel)
# get host functions
compiled_host_fns = {}
op_attr = []
# get param size
func_name = operation.name() + '_get_param_size'
func = getattr(host_lib, func_name)
param_size = func()
func_name = operation.name() + '_get_params'
func = getattr(host_lib, func_name)
func.argtype = operation.argtype
func.restype = ctypes.POINTER(ctypes.c_char * param_size)
setattr(operation, 'get_args', func)
compiled_host_fns['get_args'] = func
# set shared memory size
func_name = operation.name() + '_shared_memory_size'
func = getattr(host_lib, func_name)
setattr(operation, 'shared_memory_capacity', func())
compiled_host_fns['shared_memory_capacity'] = func()
# set the maximum dynamic shared size
operation.initialize()
# get extra functions
op_attr.append(param_size)
if hasattr(operation, "extra_funcs"):
for suffix in operation.extra_funcs:
func_name = operation.name() + '_' + suffix
func = getattr(host_lib, func_name)
setattr(operation, suffix, func)
compiled_host_fns[suffix] = func
op_attr.append(suffix)
operation_attr.append(op_attr)
self.compiled_cache_host.insert(key, compiled_host_fns)
for key, operation_name, operation_attr in zip(operation_key, operation_name, operation_attr):
self.insert_operation(key, cubin_image, host_file.name, operation_name, operation_attr)
artifact_manager = ArtifactManager()

View File

@ -30,7 +30,6 @@
#
#################################################################################################
from pycutlass import *
from pycutlass.library import SubstituteTemplate
import cutlass
from cuda import cuda
from cuda import nvrtc
@ -132,13 +131,15 @@ class ArtifactManager:
except:
pass
self.nvcc()
self.compiled_cache_device = cutlass.CompileCache()
self.compiled_cache_host = cutlass.CompileCache()
def nvrtc(self):
self.backend = "nvrtc"
self.default_compile_options = [
'-std=c++11', '-default-device',
]
self.compiled_cache_device = cutlass.CompileCache()
self.compiled_cache_host = cutlass.CompileCache()
def nvcc(self):
self.backend = "nvcc"
self.default_compile_options = [
@ -335,13 +336,14 @@ class ArtifactManager:
architectures = []
for operation in operations:
if hasattr(operation, "tile_description"):
cc = operation.tile_description.minimum_compute_capability
cc = operation.arch
if cc not in architectures:
architectures.append(cc)
include_paths = [
cuda_install_path + '/include',
cutlass_path + '/include',
cutlass_path + '/tools/util/include',
cutlass_path + '/tools/library/scripts/pycutlass/src/cpp/include'
]
compile_options = CompilationOptions(
self.default_compile_options, architectures, include_paths)

View File

@ -48,6 +48,9 @@ class Conv2dArguments(ArgumentBase):
:param operation: the Conv2d operation to take the argument
:type operation: :class:`pycutlass.Conv2dOperation`
:param problem_size: the Conv2d problem size
:type problem_size: :class:`cutlass.conv.Conv2dProblemSize`
:param A: tensor A
:type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
@ -78,6 +81,7 @@ class Conv2dArguments(ArgumentBase):
split_k_mode: 'cutlass.conv.SplitKMode'
= cutlass.conv.SplitKMode.Serial, **kwargs) -> None:
self.operation = operation
#: convolution kind
self.conv_kind: cutlass.conv.Operator = operation.conv_kind
self.layout_A: cutlass.layout = operation.A.layout
@ -93,15 +97,12 @@ class Conv2dArguments(ArgumentBase):
super().__init__(A, B, C, D, **kwargs)
# preprocessing output ops
if "output_op" in kwargs.keys() and \
if 'output_op' in kwargs.keys() and \
split_k_mode != cutlass.conv.SplitKMode.Parallel:
self.alpha = kwargs["output_op"].alpha
self.beta = kwargs["output_op"].beta
self.output_op = kwargs['output_op']
else:
self.alpha = 1.0
self.beta = 0.0
self.element_compute = operation.element_epilogue
self.output_op = self.operation.epilogue_type(1.0, 0.0)
if "split_k_slices" in kwargs.keys():
self.split_k_mode = split_k_mode
@ -114,7 +115,12 @@ class Conv2dArguments(ArgumentBase):
self.problem_size: cutlass.conv.Conv2dProblemSize = problem_size
self.problem_size.split_k_slices = self.split_k_slices
self.operation = operation
if hasattr(self, "tensor_c_numel"):
c_coord = cutlass.conv.implicit_gemm_tensor_c_extent(
self.conv_kind, problem_size)
if (self.tensor_c_numel == c_coord.at(3) and
self.tensor_c_numel < c_coord.size()):
self.bias = True
#
# initialize the argument
@ -159,6 +165,9 @@ class Conv2dArguments(ArgumentBase):
self.conv_kind, problem_size)
else:
raise ValueError("unknown operand: " + operand)
# Zero stride trick
if operand == "c" and self.bias:
tensor_coord = cutlass.Tensor4DCoord(0, 0, 0, 0)
layout = tensor_layout.packed(tensor_coord)
@ -174,24 +183,9 @@ class Conv2dArguments(ArgumentBase):
ref_D = TensorRef_(self.get_tensor_ref(
self.ptr_D, self.element_C, self.layout_C, self.problem_size, "d"))
if self.element_compute == cutlass.float16:
alpha = cutlass.float16(self.alpha).storage
beta = cutlass.float16(self.beta).storage
elif self.element_compute == cutlass.int32:
alpha = int(self.alpha)
beta = int(self.beta)
else:
alpha = self.alpha
beta = self.beta
argument_type, epilogue_type = get_conv2d_arguments(
self.operation.element_epilogue)
output_op = epilogue_type(alpha, beta, 0, 0)
self.c_arguments = argument_type(
self.c_arguments = self.operation.argument_type(
Conv2DProblemSize(self.problem_size),
ref_A, ref_B, ref_C, ref_D, output_op, self.split_k_mode
ref_A, ref_B, ref_C, ref_D, self.output_op, self.split_k_mode
)
self.semaphore = semaphore
@ -296,9 +290,8 @@ extern "C" {
def __init__(self, operation: 'Conv2dOperation'):
super().__init__(operation)
self.argtype = [ctypes.POINTER(get_conv2d_arguments(
operation.element_epilogue)[0]), ctypes.c_void_p]
self.argument_type, self.epilogue_type = get_conv2d_arguments(operation.epilogue_functor)
self.argtype = [ctypes.POINTER(self.argument_type), ctypes.c_void_p]
self.conv_kind = operation.conv_kind
self.operation: Conv2dOperation = operation
@ -410,9 +403,7 @@ class Conv2dOperation:
iterator_algorithm: cutlass.conv.IteratorAlgorithm,
arch: int, tile_description: TileDescription,
A: TensorDescription, B: TensorDescription, C: TensorDescription,
element_epilogue: Union[cutlass.int8, cutlass.int32, cutlass.float16,
cutlass.bfloat16, cutlass.float32, cutlass.float64],
stride_support, epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support, epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1):
self.operation_kind: OperationKind = OperationKind.Conv2d
@ -422,13 +413,14 @@ class Conv2dOperation:
self.A: TensorDescription = A
self.B: TensorDescription = B
self.C: TensorDescription = C
self.element_epilogue = element_epilogue
self.epilogue_functor = epilogue_functor
self.iterator_algorithm = iterator_algorithm
self.stride_support = stride_support
self.swizzling_functor = swizzling_functor()
self.rt_module: Conv2dRT = Conv2dRT(self)
self.argument_type = self.rt_module.argument_type
self.epilogue_type = self.rt_module.epilogue_type
def run(self, arguments: Conv2dArguments) -> cuda.CUresult:
"""
@ -577,12 +569,7 @@ typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
${epilogue_functor}<
${element_c},
${epilogue_vector_length},
${element_accumulator},
${element_epilogue}
>,
${epilogue_functor},
${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
${stages},
${math_operator},
@ -629,8 +616,7 @@ struct ${operation_name}${operation_suffix}:
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
'epilogue_vector_length': str(epilogue_vector_length),
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
'epilogue_functor': operation.epilogue_functor.emit(),
'swizzling_functor': operation.swizzling_functor.tag(),
'stages': str(operation.tile_description.stages),
'iterator_algorithm': IteratorAlgorithmTag[operation.iterator_algorithm],

File diff suppressed because it is too large Load Diff

View File

@ -116,6 +116,12 @@ class GemmArguments(ArgumentBase):
else:
self.problem_size = cutlass.gemm.GemmCoord(
problem_size.m(), problem_size.n(), problem_size.k())
# if the number of elements in C = problem_size.n
# C is treated as the bias
if hasattr(self, "tensor_c_numel"):
if (self.tensor_c_numel == self.problem_size.n() and
self.problem_size.m() != 1): self.bias = True
# get the leading dimension
self.lda = operation.A.layout.packed(self.problem_size.mk()).stride()
@ -123,27 +129,69 @@ class GemmArguments(ArgumentBase):
self.ldc = operation.C.layout.packed(self.problem_size.mn()).stride()
self.ldd = self.ldc
# stride 0 trick
if self.bias:
self.ldc = 0
if 'output_op' in kwargs.keys() and \
gemm_mode != cutlass.gemm.Mode.GemmSplitKParallel:
self.alpha = kwargs['output_op'].alpha
self.beta = kwargs['output_op'].beta
self.output_op = kwargs['output_op']
else:
self.alpha = 1.0
self.beta = 0.0
self.output_op = self.operation.epilogue_type(1.0, 0.0)
# get number of slices on k dimension
self.gemm_mode = gemm_mode
if 'split_k_slices' in kwargs.keys():
self.split_k_slices = kwargs['split_k_slices']
else:
self.split_k_slices = 1
if gemm_mode in [cutlass.gemm.Mode.Gemm, cutlass.gemm.Mode.GemmSplitKParallel]:
if 'split_k_slices' in kwargs.keys():
self.batch_count = kwargs['split_k_slices']
else:
self.batch_count = 1
self.split_k_slices = self.batch_count
self.batch_count = self.split_k_slices
if gemm_mode in [cutlass.gemm.Mode.Batched, cutlass.gemm.Mode.Array]:
if 'batch' in kwargs.keys():
self.batch_count = kwargs['batch']
else:
self.batch_count = 1
self.batched_stride_A = self.problem_size.m() * self.problem_size.k()
self.batched_stride_B = self.problem_size.n() * self.problem_size.k()
self.batched_stride_C = self.problem_size.m() * self.problem_size.n()
self.batched_stride_D = self.problem_size.m() * self.problem_size.n()
if self.bias:
self.batched_stride_C = self.problem_size.n()
# support GEMM Mode Array
if gemm_mode == cutlass.gemm.Mode.Array:
self.ptr_A_array = []
self.ptr_B_array = []
self.ptr_C_array = []
self.ptr_D_array = []
ptr_A_addr = int(self.ptr_A)
ptr_B_addr = int(self.ptr_B)
ptr_C_addr = int(self.ptr_C)
ptr_D_addr = int(self.ptr_D)
stride_A = self.batched_stride_A * DataTypeSize[self.element_A] // 8
stride_B = self.batched_stride_B * DataTypeSize[self.element_B] // 8
stride_C = self.batched_stride_C * DataTypeSize[self.element_C] // 8
stride_D = self.batched_stride_D * DataTypeSize[self.element_C] // 8
for _ in range(self.batch_count):
self.ptr_A_array.append(ptr_A_addr)
self.ptr_B_array.append(ptr_B_addr)
self.ptr_C_array.append(ptr_C_addr)
self.ptr_D_array.append(ptr_D_addr)
ptr_A_addr += stride_A
ptr_B_addr += stride_B
ptr_C_addr += stride_C
ptr_D_addr += stride_D
self.ptr_A_array_buffer = todevice(self.ptr_A_array, dtype=np.int64)
self.ptr_B_array_buffer = todevice(self.ptr_B_array, dtype=np.int64)
self.ptr_C_array_buffer = todevice(self.ptr_C_array, dtype=np.int64)
self.ptr_D_array_buffer = todevice(self.ptr_D_array, dtype=np.int64)
if isinstance(self.operation, GemmOperationUniversal):
self.initialize()
@ -195,28 +243,28 @@ class GemmArguments(ArgumentBase):
self.grid_tiled_shape.z
)
)
argument_type, epilogue_type = get_gemm_arguments(
self.operation.element_epilogue)
if self.operation.element_epilogue == cutlass.float16:
self.alpha = cutlass.float16(self.alpha).storage
self.beta = cutlass.float16(self.beta).storage
elif self.operation.element_epilogue == cutlass.int32:
self.alpha = int(self.alpha)
self.beta = int(self.beta)
output_op = epilogue_type(self.alpha, self.beta, 0, 0)
arguments = argument_type(
self.gemm_mode, problem_size_, self.batch_count, output_op,
int(self.ptr_A), int(self.ptr_B), int(self.ptr_C), int(self.ptr_D),
self.batched_stride_A, self.batched_stride_B, self.batched_stride_C,
self.batched_stride_D,
self.lda, self.ldb, self.ldc, self.ldd,
self.lda, self.ldb, self.ldc, self.ldd,
0, 0, 0
)
if self.gemm_mode == cutlass.gemm.Mode.Array:
arguments = self.operation.argument_type(
self.gemm_mode, problem_size_, self.batch_count, self.output_op,
int(self.ptr_A_array_buffer.ptr),
int(self.ptr_B_array_buffer.ptr),
int(self.ptr_C_array_buffer.ptr),
int(self.ptr_D_array_buffer.ptr),
0, 0, 0, 0,
self.lda, self.ldb, self.ldc, self.ldd,
self.lda, self.ldb, self.ldc, self.ldd,
0, 0, 0
)
else:
arguments = self.operation.argument_type(
self.gemm_mode, problem_size_, self.batch_count, self.output_op,
int(self.ptr_A), int(self.ptr_B), int(self.ptr_C), int(self.ptr_D),
self.batched_stride_A, self.batched_stride_B, self.batched_stride_C,
self.batched_stride_D,
self.lda, self.ldb, self.ldc, self.ldd,
self.lda, self.ldb, self.ldc, self.ldd,
0, 0, 0
)
self.arguments = arguments, grid_tiled_shape_, self.gemm_k_size
@ -381,13 +429,12 @@ class GemmGroupedArguments:
else:
self.alpha = 1.0
self.beta = 0.0
if 'output_op' in kwargs.keys():
self.output_op = kwargs['output_op']
else:
self.output_op = self.operation.epilogue_type(1.0, 0.0)
if self.operation.element_epilogue == cutlass.float16:
self.alpha = cutlass.float16(self.alpha).storage
self.beta = cutlass.float16(self.beta).storage
elif self.operation.element_epilogue == cutlass.int32:
self.alpha = int(self.alpha)
self.beta = int(self.beta)
# get host problem size
self.host_problem_size_ptr = np.array(
@ -398,12 +445,7 @@ class GemmGroupedArguments:
self.initialize()
def get_arguments(self):
argument_type, epilogue_type = get_gemm_grouped_arguments(
self.operation.element_epilogue)
self.output_op = epilogue_type(self.alpha, self.beta, 0, 0)
return argument_type(
return self.operation.argument_type(
self.problem_size_buffer.ptr, self.problem_count, self.total_tiles,
self.output_op, self.ptr_A_buffer.ptr, self.ptr_B_buffer.ptr,
self.ptr_C_buffer.ptr, self.ptr_D_buffer.ptr, self.lda_buffer.ptr,
@ -492,16 +534,6 @@ ${operation_name}(${operation_name}${operation_suffix}::Params params) {
#: number of threads per threadblock
self.threads: int = operation.tile_description.num_threads
if (operation.epilogue_functor in
[
EpilogueFunctor.LinearCombination,
EpilogueFunctor.FastLinearCombinationClamp,
EpilogueFunctor.LinearCombinationClamp
]):
self.output_op = LinearCombinationFunctor()
else:
raise ValueError("unknown epilogue functor")
#
def emit(self):
return self.emitter.emit(self.operation)
@ -568,9 +600,11 @@ extern "C" {
def __init__(self, operation: 'GemmOperation'):
super(GemmRTUniversal, self).__init__(operation)
self.emitter = EmitGemmUniversalInstance(
'_type', operation.direct_store)
'_type', operation.direct_store, operation.visitor)
self.argument_type, self.epilogue_type = get_gemm_arguments(operation.epilogue_functor)
self.argtype = [
ctypes.POINTER(get_gemm_arguments(operation.element_epilogue)[0]),
ctypes.POINTER(self.argument_type),
ctypes.POINTER(GemmCoord_), ctypes.c_int, ctypes.c_void_p
]
@ -673,8 +707,8 @@ class GemmRTGrouped(GemmRTbase):
self.extra_funcs = ['precompute']
self.emitter = EmitGemmGroupedInstance('_type')
self.argtype = [ctypes.POINTER(get_gemm_grouped_arguments(
operation.element_epilogue)[0]), ctypes.c_int, ctypes.c_void_p]
self.argument_type, self.epilogue_type = get_gemm_grouped_arguments(operation.epilogue_functor)
self.argtype = [ctypes.POINTER(self.argument_type), ctypes.c_int, ctypes.c_void_p]
def host_precompute(self, arguments, workspace_bytes):
self.precompute.argtype = [
@ -717,7 +751,7 @@ class GemmOperationBase:
def __init__(
self, gemm_kind, arch, tile_description: TileDescription,
A: TensorDescription, B: TensorDescription, C: TensorDescription,
element_epilogue, epilogue_functor=EpilogueFunctor.LinearCombination,
epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1, **kwargs):
#: operation kind
@ -749,7 +783,7 @@ class GemmOperationBase:
#: Operand C
self.C: TensorDescription = copy.deepcopy(C)
self.switched = False
self.element_epilogue = element_epilogue
self.epilogue_functor = epilogue_functor
self.swizzling_functor = swizzling_functor()
@ -757,6 +791,11 @@ class GemmOperationBase:
self.direct_store = kwargs["direct_store"]
else:
self.direct_store = False
if "visitor" in kwargs:
self.visitor = kwargs["visitor"]
else:
self.visitor = False
def run(self, arguments: GemmArguments) -> cuda.CUresult:
"""
@ -895,22 +934,26 @@ class GemmOperationBase:
class GemmOperationUniversal(GemmOperationBase):
def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C, element_epilogue,
epilogue_functor=EpilogueFunctor.LinearCombination, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs):
def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C,
epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs):
super(GemmOperationUniversal, self).__init__(GemmKind.Universal, arch, tile_description,
A, B, C, element_epilogue, epilogue_functor, swizzling_functor, **kwargs)
A, B, C, epilogue_functor, swizzling_functor, **kwargs)
self.rt_module = GemmRTUniversal(self)
self.argument_type = self.rt_module.argument_type
self.epilogue_type = self.rt_module.epilogue_type
class GemmOperationGrouped(GemmOperationBase):
def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C, element_epilogue,
epilogue_functor=EpilogueFunctor.LinearCombination, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs):
def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C,
epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs):
super(GemmOperationGrouped, self).__init__(GemmKind.Grouped, arch, tile_description,
A, B, C, element_epilogue, epilogue_functor, swizzling_functor, **kwargs)
A, B, C, epilogue_functor, swizzling_functor, **kwargs)
assert "precompute_mode" in kwargs.keys(
), "missing keyword arguement 'precompute_mode'."
self.precompute_mode = kwargs["precompute_mode"]
self.rt_module = GemmRTGrouped(self)
self.argument_type = self.rt_module.argument_type
self.epilogue_type = self.rt_module.epilogue_type
###################################################################################################
#
@ -918,228 +961,14 @@ class GemmOperationGrouped(GemmOperationBase):
#
###################################################################################################
#
class EmitGemmInstance:
''' Responsible for emitting a CUTLASS template definition'''
def __init__(self, operation_suffix=''):
self.operation_suffix = operation_suffix
self.includes = []
self.gemm_template = """
// Gemm operator ${operation_name}
using Operation_${operation_name} = cutlass::gemm::device::Gemm<
${element_a}, ${layout_a},
${element_b}, ${layout_b},
${element_c}, ${layout_c},
${element_accumulator},
${opcode_class},
${arch},
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
${epilogue_functor}<
${element_c},
${epilogue_vector_length},
${element_accumulator},
${element_epilogue}
>,
${swizzling_functor},
${stages},
${align_a},
${align_b},
false,
${math_operation}
${residual}
>;
"""
self.gemm_complex_template = """
// Gemm operator ${operation_name}
using Operation_${operation_name} = cutlass::gemm::device::GemmComplex<
${element_a}, ${layout_a},
${element_b}, ${layout_b},
${element_c}, ${layout_c},
${element_accumulator},
${opcode_class},
${arch},
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
${epilogue_functor}<
${element_c},
${epilogue_vector_length},
${element_accumulator},
${element_epilogue}
>,
${swizzling_functor},
${stages},
${transform_a},
${transform_b},
${math_operation}
${residual}
>;
"""
#
def instance_template(self):
return """
${compile_guard_start}
manifest.append(new ${gemm_kind}<Operation_${operation_name}>("${operation_name}"));
${compile_guard_end}
"""
#
def emit(self, operation):
warp_shape = [operation.tile_description.threadblock_shape[idx] //
operation.tile_description.warp_count[idx] for idx in range(3)]
epilogue_vector_length = int(min(
operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
residual = ''
values = {
'operation_name': operation.procedural_name(),
'element_a': DataTypeTag[operation.A.element],
'layout_a': LayoutTag[operation.A.layout],
'element_b': DataTypeTag[operation.B.element],
'layout_b': LayoutTag[operation.B.layout],
'element_c': DataTypeTag[operation.C.element],
'layout_c': LayoutTag[operation.C.layout],
'element_accumulator': DataTypeTag[operation.accumulator_type()],
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
'arch': "cutlass::arch::Sm%d" % operation.arch,
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
'warp_shape_m': str(warp_shape[0]),
'warp_shape_n': str(warp_shape[1]),
'warp_shape_k': str(warp_shape[2]),
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
'epilogue_vector_length': str(epilogue_vector_length),
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
'swizzling_functor': operation.swizzling_functor.tag(),
'stages': str(operation.tile_description.stages),
'align_a': str(operation.A.alignment),
'align_b': str(operation.B.alignment),
'transform_a': ComplexTransformTag[operation.A.complex_transform],
'transform_b': ComplexTransformTag[operation.B.complex_transform],
'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
'residual': residual
}
template = self.gemm_complex_template if operation.is_complex() else self.gemm_template
return SubstituteTemplate(template, values)
###################################################################################################
class EmitSparseGemmInstance:
''' Responsible for emitting a CUTLASS template definition'''
def __init__(self, operation_suffix=''):
self.operation_suffix = operation_suffix
self.includes = []
self.gemm_template = """
// Gemm operator ${operation_name}
using Operation_${operation_name} = cutlass::gemm::device::SparseGemm<
${element_a}, ${layout_a},
${element_b}, ${layout_b},
${element_c}, ${layout_c},
${element_accumulator},
${opcode_class},
${arch},
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
${epilogue_functor}<
${element_c},
${epilogue_vector_length},
${element_accumulator},
${element_epilogue}
>,
${swizzling_functor},
${stages},
${align_a},
${align_b},
false,
${math_operation}
${residual}
>;
"""
#
def instance_template(self):
return """
${compile_guard_start}
manifest.append(new ${gemm_kind}<Operation_${operation_name}>("${operation_name}"));
${compile_guard_end}
"""
#
def emit(self, operation):
warp_shape = [operation.tile_description.threadblock_shape[idx] //
operation.tile_description.warp_count[idx] for idx in range(3)]
epilogue_vector_length = int(min(
operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
residual = ''
values = {
'operation_name': operation.procedural_name(),
'element_a': DataTypeTag[operation.A.element],
'layout_a': LayoutTag[operation.A.layout],
'element_b': DataTypeTag[operation.B.element],
'layout_b': LayoutTag[operation.B.layout],
'element_c': DataTypeTag[operation.C.element],
'layout_c': LayoutTag[operation.C.layout],
'element_accumulator': DataTypeTag[operation.accumulator_type()],
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
'arch': "cutlass::arch::Sm%d" % operation.arch,
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
'warp_shape_m': str(warp_shape[0]),
'warp_shape_n': str(warp_shape[1]),
'warp_shape_k': str(warp_shape[2]),
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
'epilogue_vector_length': str(epilogue_vector_length),
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
'swizzling_functor': operation.swizzling_functor.tag(),
'stages': str(operation.tile_description.stages),
'align_a': str(operation.A.alignment),
'align_b': str(operation.B.alignment),
'transform_a': ComplexTransformTag[operation.A.complex_transform],
'transform_b': ComplexTransformTag[operation.B.complex_transform],
'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
'residual': residual
}
template = self.gemm_template
return SubstituteTemplate(template, values)
###################################################################################################
#
class EmitGemmUniversalInstance:
''' Responsible for emitting a CUTLASS template definition'''
def __init__(self, operation_suffix='', direct_store=False):
def __init__(self, operation_suffix='', direct_store=False, visitor=False):
self.operation_suffix = operation_suffix
self.direct_store = direct_store
self.visitor = visitor
self.includes = [
"cutlass/cutlass.h",
"cutlass/numeric_types.h",
@ -1150,46 +979,15 @@ class EmitGemmUniversalInstance:
"cutlass/gemm/device/gemm_universal_adapter.h",
"cutlass/gemm/kernel/default_gemm_universal.h",
]
if self.visitor:
self.includes += [
"gemm/gemm_universal_with_visitor.h",
"epilogue/epilogue_visitor_with_layernorm.h",
"epilogue/epilogue_visitor_generic.h"
]
if self.direct_store:
self.includes.append(
"cutlass/epilogue/threadblock/default_epilogue_direct_store.h")
self.builtin_epilogue_functor_template = """
${epilogue_functor}<
${element_c},
${epilogue_vector_length},
${element_accumulator},
${element_epilogue}
>
"""
self.builtin_epilogue_functor_template_clamp = """
${epilogue_functor}<
${element_c},
${epilogue_vector_length}
>
"""
self.gemm_template = """
// Gemm operator ${operation_name}
using ${operation_name}_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal<
${element_a}, ${layout_a}, ${transform_a}, ${align_a},
${element_b}, ${layout_b}, ${transform_b}, ${align_b},
${element_c}, ${layout_c},
${element_accumulator},
${opcode_class},
${arch},
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
${epilogue_functor},
${swizzling_functor},
${stages},
${math_operation}
>::GemmKernel;
// Define named type
struct ${operation_name}${operation_suffix} :
public ${operation_name}_base { };
"""
self.gemm_template_interleaved = """
// Gemm operator ${operation_name}
using ${operation_name}_base =
@ -1241,6 +1039,42 @@ using ${operation_name}_base =
${operation_name}_default::ThreadblockSwizzle
>;
// Define named type
struct ${operation_name}${operation_suffix} :
public ${operation_name}_base { };
"""
self.gemm_template_visitor = """
// Gemm operator ${operation_name}
using ${operation_name}_default =
typename cutlass::gemm::kernel::DefaultGemmUniversal<
${element_a}, ${layout_a}, ${transform_a}, ${align_a},
${element_b}, ${layout_b}, ${transform_b}, ${align_b},
${element_c}, ${layout_c},
${element_accumulator},
${opcode_class},
${arch},
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
${elementwise_epilogue_functor},
${swizzling_functor},
${stages},
${math_operation}
>::GemmKernel;
${epilogue_visitor}
using ${operation_name}_Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue<
${operation_name}_EpilogueVisitor,
typename ${operation_name}_default::Epilogue>::Epilogue;
using ${operation_name}_base =
cutlass::gemm::kernel::GemmUniversalwithEpilogueVisitor<
${operation_name}_default::Mma,
${operation_name}_Epilogue,
${operation_name}_default::ThreadblockSwizzle
>;
// Define named type
struct ${operation_name}${operation_suffix} :
public ${operation_name}_base { };
@ -1284,32 +1118,12 @@ ${compile_guard_end}
(operation.A.layout, operation.B.layout, operation.C.layout)
if self.direct_store:
gemm_template = self.gemm_template_direct_store
elif self.visitor:
gemm_template = self.gemm_template_visitor
else:
gemm_template = self.gemm_template_interleaved
#
# Support built-in epilogue functors or user-defined functions
if isinstance(operation.epilogue_functor, enum.Enum):
epilogue_vector_length = \
min(operation.C.alignment *
DataTypeSize[operation.C.element], 128) // DataTypeSize[operation.C.element]
values = {
'epilogue_vector_length': str(epilogue_vector_length),
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
}
if operation.epilogue_functor == EpilogueFunctor.FastLinearCombinationClamp:
epilogue_functor = SubstituteTemplate(
self.builtin_epilogue_functor_template_clamp, values)
else:
epilogue_functor = SubstituteTemplate(
self.builtin_epilogue_functor_template, values)
else:
epilogue_functor = self.epilogue_functor.emit_declaration()
#
values = {
'operation_name': operation.procedural_name(),
'operation_suffix': self.operation_suffix,
@ -1331,7 +1145,6 @@ ${compile_guard_end}
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
'epilogue_functor': epilogue_functor,
'swizzling_functor': operation.swizzling_functor.tag(),
'stages': str(operation.tile_description.stages),
'align_a': str(operation.A.alignment),
@ -1341,6 +1154,12 @@ ${compile_guard_end}
'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation]
}
if self.visitor:
values['epilogue_visitor'] = operation.epilogue_functor.emit(operation)
values['elementwise_epilogue_functor'] = operation.epilogue_functor.elementwise_functor.emit()
else:
values['epilogue_functor'] = operation.epilogue_functor.emit()
return SubstituteTemplate(gemm_template, values)
###################################################################################################
@ -1348,185 +1167,6 @@ ${compile_guard_end}
#
class EmitGemmPlanarComplexInstance:
''' Responsible for emitting a CUTLASS template definition'''
def __init__(self, operation_suffix=''):
self.operation_suffix = operation_suffix
self.includes = []
self.template = """
// Gemm operator ${operation_name}
using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal<
${element_a}, ${layout_a}, ${transform_a}, ${alignment_a},
${element_b}, ${layout_b}, ${transform_b}, ${alignment_b},
${element_c}, cutlass::layout::RowMajor,
${element_accumulator},
${opcode_class},
${arch},
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
cutlass::epilogue::thread::LinearCombinationPlanarComplex<
${element_c},
${alignment_c},
${element_accumulator},
${element_epilogue}
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
${stages},
${math_operator}
>::GemmKernel;
struct ${operation_name} :
public Operation_${operation_name} { };
"""
#
def instance_template(self):
return """
${compile_guard_start}
manifest.append(new ${gemm_kind}<
cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
>("${operation_name}"));
${compile_guard_end}
"""
#
def emit(self, operation):
warp_shape = [operation.tile_description.threadblock_shape[idx] //
operation.tile_description.warp_count[idx] for idx in range(3)]
# exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major
transposed_layout_A = TransposedLayout[operation.A.layout]
transposed_layout_B = TransposedLayout[operation.B.layout]
values = {
'operation_name': operation.procedural_name(),
'element_a': DataTypeTag[operation.B.element],
'layout_a': LayoutTag[transposed_layout_B],
'transform_a': ComplexTransformTag[operation.B.complex_transform],
'alignment_a': str(operation.B.alignment),
'element_b': DataTypeTag[operation.A.element],
'layout_b': LayoutTag[transposed_layout_A],
'transform_b': ComplexTransformTag[operation.A.complex_transform],
'alignment_b': str(operation.A.alignment),
'element_c': DataTypeTag[operation.C.element],
'layout_c': LayoutTag[operation.C.layout],
'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
'arch': "cutlass::arch::Sm%d" % operation.arch,
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
'warp_shape_m': str(warp_shape[0]),
'warp_shape_n': str(warp_shape[1]),
'warp_shape_k': str(warp_shape[2]),
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
'alignment_c': str(operation.C.alignment),
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
'stages': str(operation.tile_description.stages),
'math_operator': 'cutlass::arch::OpMultiplyAdd'
}
return SubstituteTemplate(self.template, values)
###################################################################################################
#
class EmitGemmPlanarComplexArrayInstance:
''' Responsible for emitting a CUTLASS template definition'''
def __init__(self, operation_suffix=''):
self.operation_suffix = operation_suffix
self.includes = []
self.template = """
// Gemm operator ${operation_name}
using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal<
${element_a}, ${layout_a}, ${transform_a}, ${alignment_a},
${element_b}, ${layout_b}, ${transform_b}, ${alignment_b},
${element_c}, cutlass::layout::RowMajor,
${element_accumulator},
${opcode_class},
${arch},
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
cutlass::epilogue::thread::LinearCombinationPlanarComplex<
${element_c},
${alignment_c},
${element_accumulator},
${element_epilogue}
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
${stages},
${math_operator}
>::GemmArrayKernel;
struct ${operation_name} : public Operation_${operation_name} { };
"""
#
def instance_template(self):
return """
${compile_guard_start}
manifest.append(new ${gemm_kind}<
cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
>("${operation_name}"));
${compile_guard_end}
"""
#
def emit(self, operation):
warp_shape = [operation.tile_description.threadblock_shape[idx] //
operation.tile_description.warp_count[idx] for idx in range(3)]
# exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major
transposed_layout_A = TransposedLayout[operation.A.layout]
transposed_layout_B = TransposedLayout[operation.B.layout]
values = {
'operation_name': operation.procedural_name(),
'element_a': DataTypeTag[operation.B.element],
'layout_a': LayoutTag[transposed_layout_B],
'transform_a': ComplexTransformTag[operation.B.complex_transform],
'alignment_a': str(operation.B.alignment),
'element_b': DataTypeTag[operation.A.element],
'layout_b': LayoutTag[transposed_layout_A],
'transform_b': ComplexTransformTag[operation.A.complex_transform],
'alignment_b': str(operation.A.alignment),
'element_c': DataTypeTag[operation.C.element],
'layout_c': LayoutTag[operation.C.layout],
'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
'arch': "cutlass::arch::Sm%d" % operation.arch,
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
'warp_shape_m': str(warp_shape[0]),
'warp_shape_n': str(warp_shape[1]),
'warp_shape_k': str(warp_shape[2]),
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
'alignment_c': str(operation.C.alignment),
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
'stages': str(operation.tile_description.stages),
'math_operator': 'cutlass::arch::OpMultiplyAdd'
}
return SubstituteTemplate(self.template, values)
###################################################################################################
#
class EmitGemmGroupedInstance:
''' Responsible for emitting a CUTLASS template definition'''
@ -1541,14 +1181,6 @@ class EmitGemmGroupedInstance:
"cutlass/gemm/kernel/gemm_grouped.h",
"cutlass/gemm/kernel/default_gemm_grouped.h"
]
self.builtin_epilogue_functor_template = """
${epilogue_functor}<
${element_c},
${epilogue_vector_length},
${element_accumulator},
${element_epilogue}
>
"""
self.gemm_template = """
// Gemm operator ${operation_name}
using ${operation_name}_base =
@ -1598,23 +1230,8 @@ ${compile_guard_end}
#
# Support built-in epilogue functors or user-defined functions
if isinstance(operation.epilogue_functor, enum.Enum):
epilogue_vector_length = \
min(operation.C.alignment *
DataTypeSize[operation.C.element], 128) // DataTypeSize[operation.C.element]
values = {
'epilogue_vector_length': str(epilogue_vector_length),
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
}
epilogue_functor = SubstituteTemplate(
self.builtin_epilogue_functor_template, values)
else:
epilogue_functor = self.epilogue_functor.emit_declaration()
#
epilogue_functor = operation.epilogue_functor.emit()
values = {
'operation_name': operation.procedural_name(),
'operation_suffix': self.operation_suffix,

View File

@ -478,27 +478,6 @@ SharedMemPerCC = {
###################################################################################################
#
def SubstituteTemplate(template, values):
text = template
changed = True
while changed:
changed = False
for key, value in values.items():
regex = "\\$\\{%s\\}" % key
newtext = re.sub(regex, value, text)
if newtext != text:
changed = True
text = newtext
return text
###################################################################################################
#
class GemmKind(enum.Enum):
Gemm = enum_auto()
Sparse = enum_auto()
@ -557,22 +536,6 @@ SymmKindNames = {
#
class EpilogueFunctor(enum.Enum):
LinearCombination = enum_auto()
LinearCombinationClamp = enum_auto()
FastLinearCombinationClamp = enum_auto()
#
EpilogueFunctorTag = {
EpilogueFunctor.LinearCombination: 'cutlass::epilogue::thread::LinearCombination',
EpilogueFunctor.LinearCombinationClamp: 'cutlass::epilogue::thread::LinearCombinationClamp',
EpilogueFunctor.FastLinearCombinationClamp: 'cutlass::epilogue::thread::FastLinearCombinationClamp'
}
#
class SwizzlingFunctor(enum.Enum):
Identity1 = enum_auto()
Identity2 = enum_auto()
@ -700,7 +663,7 @@ class MathInstruction:
class TileDescription:
def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute):
def __init__(self, threadblock_shape, stages, warp_count, math_instruction):
self.threadblock_shape = threadblock_shape
#: number of pipeline stages
@ -710,11 +673,6 @@ class TileDescription:
self.warp_count: list[int] = warp_count
self.math_instruction = math_instruction
#: minimum compute capability
self.minimum_compute_capability: int = min_compute
#: maximum compute capability
self.maximum_compute_capability: int = max_compute
#: number threads per threadblock
self.num_threads: int = 32
for cnt in self.warp_count:

View File

@ -0,0 +1,619 @@
################################################################################
#
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
from typing import Generic, TypeVar
from treelib import Tree
import numpy as np
from pycutlass import *
import pycutlass
import ast
import textwrap
import inspect
################################################################################
# Type annotation for input arguments
################################################################################
Ttype = TypeVar("Ttype")
Dtype = TypeVar("Dtype")
class NDArray(np.ndarray, Generic[Ttype, Dtype]):
pass
################################################################################
# Operations
################################################################################
operators = {
ast.Add: "Add",
ast.Div: "Div",
ast.Eq: "Equal",
ast.Mult: "Mult"
}
################################################################################
# AST Node abstractions
################################################################################
class UnaryNode:
cnt = 0
# Concept: this is created by the BinOp Node in python ast
def __init__(self,
element_accumulator, element_compute, elements_per_access,
node, args) -> None:
if isinstance(node, BinOpNode):
self.op = node.op
elif isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
self.op = node.func.id
elif isinstance(node.func, ast.Attribute):
self.op = node.func.value.id
else:
raise TypeError
else:
raise TypeError
self.tag = "Unary" + self.op + str(UnaryNode.cnt)
self.id = self.op + str(UnaryNode.cnt)
self.args = args
UnaryNode.cnt += 1
self.type = "tensor"
self.epilogue_op = getattr(pycutlass, self.op)(element_compute)
# data types
self.element_accumulator = element_accumulator
self.element_compute = element_compute
self.elements_per_access = elements_per_access
def get_epilogue_node(self, visitors):
self.epilogue_node = UnaryOp(
self.element_accumulator, self.element_compute,
self.elements_per_access, *visitors, self.epilogue_op)
def get_argument(self, visitor_args, kwargs):
epilogue_ops = []
for arg in self.args:
try:
epilogue_ops.append(kwargs[arg])
except:
epilogue_ops.append(arg) # direct arguments like constant
self.argument = self.epilogue_node.argument_type(self.epilogue_op.argument_type(*epilogue_ops), *visitor_args)
class BinOpNode:
cnt = 0
# Concept: this is created by the BinOp Node in python ast
def __init__(self,
element_accumulator, element_compute, elements_per_access,
node) -> None:
self.op = operators[type(node.op)]
self.tag = "Binary" + self.op + str(BinOpNode.cnt)
self.id = self.op + str(BinOpNode.cnt)
self.args = None
BinOpNode.cnt += 1
self.type = "tensor"
self.epilogue_op = getattr(pycutlass, "Vector"+self.op)(element_compute)
# data types
self.element_accumulator = element_accumulator
self.element_compute = element_compute
self.elements_per_access = elements_per_access
def get_epilogue_node(self, visitors):
self.epilogue_node = BinaryOp(
self.element_accumulator, self.element_compute,
self.elements_per_access, *visitors, self.epilogue_op)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type(self.epilogue_op.argument_type(self.args), *visitor_args)
class NameNode:
# Concept: this is created by the Name Node in python ast
def __init__(self, node) -> None:
try:
self.id = node.id
except:
self.id = node.targets[0].id
self.tag = self.id
class ScalarInputNode(NameNode):
# Concept: scalar
def __init__(self, node) -> None:
super().__init__(node)
self.tag = "Scalar:" + self.tag
self.type = "scalar"
class AccumulatorNode(NameNode):
# Concept: VisitorOpAccumulator
def __init__(self,
element_accumulator, elements_per_access, node) -> None:
super().__init__(node)
self.tag = "Accum:" + self.tag
self.type = "tensor"
self.element_accumulator = element_accumulator
self.elements_per_access = elements_per_access
def get_epilogue_node(self, visitors):
self.epilogue_node = AccumulatorOp(
self.element_accumulator, self.elements_per_access)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type()
class TensorInputNode(NameNode):
# Concept: VisitorOpTensorInput
def __init__(self, element_accumulator, node) -> None:
super().__init__(node)
self.tag = "TensorInput:" + self.tag
self.type = "tensor"
self.element_accumulator = element_accumulator
def get_epilogue_node(self, *args):
self.epilogue_node = TensorInputOp(self.element_accumulator)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type(
kwargs[self.id + "_ptr"], kwargs["problem_size"][1],
kwargs["problem_size"][0] * kwargs["problem_size"][1])
class RowBroadcastNode(NameNode):
# Concept: VisitorOpRowBroadcast
def __init__(self, element_accumulator, element_fragment, node) -> None:
super().__init__(node)
#
self.tag = "RowBroadcast:" + self.tag
self.type = "tensor"
self.element_accumulator = element_accumulator
self.element_fragment = element_fragment
def get_epilogue_node(self, *args):
self.epilogue_node = RowBroadcastOp(
self.element_accumulator, self.element_fragment)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type(kwargs[self.id + "_ptr"], kwargs["problem_size"][1])
class ColumnBroadcastNode(NameNode):
# Concept: VisitorOpColumnBroadcast
def __init__(self, element_accumulator, element_fragment, node) -> None:
super().__init__(node)
self.tag = "ColumnBroadcast:" + self.tag
self.type = "tensor"
self.element_accumulator = element_accumulator
self.element_fragment = element_fragment
def get_epilogue_node(self, *args):
self.epilogue_node = ColumnBroadcastOp(
self.element_accumulator, self.element_fragment)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type(kwargs[self.id + "_ptr"], kwargs["problem_size"][0])
class TensorOutputNode(NameNode):
# Concept: VisitorOpTensorOutput
def __init__(self, element_accumulator, node) -> None:
super().__init__(node)
self.tag = "TensorOutput:" + self.tag
self.type = "tensor"
self.element_accumulator = element_accumulator
def get_epilogue_node(self, visitors):
self.epilogue_node = TensorOutputOp(self.element_accumulator, *visitors)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type(kwargs[self.id + "_ptr"], kwargs["problem_size"][1], *visitor_args, kwargs["problem_size"][0] * kwargs["problem_size"][1])
class RowReductionNode:
# Concept: RowReductionOp
def __init__(self, element_accumulator, element_reduction,
element_reduction_accumulator, id, factor) -> None:
#
self.id = id
self.tag = "RowReduction:" + self.id
self.type = "tensor"
self.element_accumulator = element_accumulator
self.element_reduction = element_reduction
self.element_reduction_accumulator = element_reduction_accumulator
self.factor = factor
def get_epilogue_node(self, visitors):
self.epilogue_node = RowReductionOp(
self.element_accumulator, self.element_reduction,
self.element_reduction_accumulator, *visitors)
def get_batch_stride(self, problem_size):
return problem_size[0] * ((problem_size[1] + self.factor - 1) // self.factor)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type(kwargs[self.id + "_ptr"], *visitor_args, self.get_batch_stride(kwargs["problem_size"]))
class ColumnReductionNode:
# Concept: ColumnReductionOp
def __init__(self, element_accumulator, element_reduction,
element_reduction_accumulator, id, factor) -> None:
#
self.id = id
self.tag = "ColumnReduction:" + self.id
self.type = "tensor"
self.element_accumulator = element_accumulator
self.element_reduction = element_reduction
self.element_reduction_accumulator = element_reduction_accumulator
self.factor = factor
def get_epilogue_node(self, visitors):
self.epilogue_node = ColumnReductionOp(
self.element_accumulator, self.element_reduction,
self.element_reduction_accumulator, *visitors)
def get_batch_stride(self, problem_size):
return problem_size[1] * ((problem_size[0] + self.factor - 1) // self.factor)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type(kwargs[self.id + '_ptr'], *visitor_args, self.get_batch_stride(kwargs["problem_size"]))
################################################################################
# Epilogue parser function
################################################################################
class EpilogueAST(ast.NodeVisitor):
def __init__(self, epilogue,
tile_description,
element_accumulator, elements_per_access,
element_compute, element_output) -> None:
#
self.tile_description = tile_description
self.element_accumulator = element_accumulator
self.elements_per_access = elements_per_access
self.element_compute = element_compute
self.element_output = element_output
self.epilogue = epilogue
self.source = textwrap.dedent(inspect.getsource(epilogue.__call__))
self.ast_tree = ast.parse(self.source)
self.epilogue_tree = Tree()
# print(ast.dump(self.ast_tree, indent=4)) # For Debug purpose
# input arguments
self.input_args = {}
# return nodes
self.returns = []
# reduction source nodes
self.reduction_source = {}
# stack used to keep the parent node id
self.stack = []
# visit the AST
self.visit(self.ast_tree)
# visit the name node
def visit_Name(self, node):
# append the return ids into self.returns
if self.stack[-1] == "return":
self.returns.append(node.id)
else:
# accum is produced from accumulator node
if node.id == "accum":
name_node = AccumulatorNode(
self.element_accumulator, self.elements_per_access, node)
else:
# for input nodes
if node.id in self.input_args.keys():
type = self.input_args[node.id][0]
if type == "tensor":
name_node = TensorInputNode(self.element_accumulator, node)
elif type == "row":
name_node = RowBroadcastNode(self.element_accumulator, self.element_compute, node)
elif type == "column":
name_node = ColumnBroadcastNode(self.element_accumulator, self.element_compute, node)
elif type == "scalar":
name_node = ScalarInputNode(node)
else:
raise ValueError(type)
# for output nodes
else:
name_node = TensorOutputNode(self.element_accumulator, node)
self.epilogue_tree.create_node(name_node.tag, name_node.id, data=name_node, parent=self.stack[-1])
def visit_Assign(self, node):
pre_assign_node = self.epilogue_tree.get_node(node.targets[0].id)
if pre_assign_node is None:
# The assign is to a root node
# skip the reduction nodes
if isinstance(node.value, ast.Call):
if isinstance(node.value.func, ast.Name):
func_type = node.value.func.id
elif isinstance(node.value.func, ast.Attribute):
func_type = node.value.func.value.id
else:
raise TypeError
if func_type == 'reduction_op':
self.reduction_source[node.value.args[0].id] = [node.value.args[1].value, node.value.args[2].value, node.targets[0].id]
return
name_node = TensorOutputNode(self.element_accumulator, node)
self.epilogue_tree.create_node(name_node.tag, name_node.id, data=name_node)
self.stack.append(name_node.id)
else:
if node.targets[0].id in self.returns or node.targets[0].id in self.reduction_source.keys():
self.stack.append(node.targets[0].id)
else:
self.stack.append(pre_assign_node.predecessor(self.epilogue_tree.identifier))
self.epilogue_tree.remove_node(node.targets[0].id)
# get child tag
self.visit(node.value)
self.stack.pop()
def visit_Call(self, node):
if isinstance(node.func, ast.Name):
func_type = node.func.id
elif isinstance(node.func, ast.Attribute):
func_type = node.func.value.id
else:
raise TypeError
if func_type == "reduction_op":
self.visit(node.args[0])
else:
arg_list = []
for idx, arg in enumerate(node.args):
if idx == 0: continue
if isinstance(arg, ast.Constant):
arg_list.append(arg.value)
elif isinstance(arg, ast.Name):
arg_list.append(arg.id)
else:
raise TypeError
unary_node = UnaryNode(self.element_accumulator, self.element_compute, self.elements_per_access, node, arg_list)
self.epilogue_tree.create_node(unary_node.tag, unary_node.id, parent=self.stack[-1], data=unary_node)
self.stack.append(unary_node.id)
self.visit(node.args[0])
self.stack.pop()
def visit_BinOp(self, node):
binop = BinOpNode(self.element_accumulator, self.element_compute,
self.elements_per_access, node)
self.epilogue_tree.create_node(binop.tag, binop.id, data=binop, parent=self.stack[-1])
self.stack.append(binop.id)
self.visit(node.left)
self.visit(node.right)
self.stack.pop()
def visit_Return(self, node):
self.stack.append("return")
self.visit(node.value)
self.stack.pop()
# # A function definition
def visit_FunctionDef(self, node: ast.FunctionDef):
# visit args
for arg in node.args.args:
if arg.arg == "self": continue
if isinstance(arg.annotation, ast.Constant):
self.input_args[arg.arg] = [arg.annotation.value, ]
# visit the assign in the reverse order
for idx in range(len(node.body)):
self.visit(node.body[-1-idx])
#
# Tree optimization pass
#
# pass 1: lower Binary to Unary
def pass_binary_2_unary(self, tree, nid):
node = tree.get_node(nid)
if isinstance(node.data, BinOpNode):
lhs_node = tree.get_node(node.successors(tree.identifier)[0])
left_type = lhs_node.data.type
rhs_node = tree.get_node(node.successors(tree.identifier)[1])
right_type = rhs_node.data.type
if left_type == "scalar" and right_type == "tensor":
node.data = UnaryNode(
self.element_accumulator, self.element_compute,
self.elements_per_access,
node.data, [lhs_node.data.id,])
node.tag = node.data.tag
tree.remove_node(lhs_node.data.id)
self.pass_binary_2_unary(tree, rhs_node.data.id)
elif left_type == "tensor" and right_type == "scalar":
node.data = UnaryNode(
self.element_accumulator, self.element_compute,
self.elements_per_access,
node.data, [rhs_node.id,])
node.tag = node.data.tag
tree.remove_node(rhs_node.data.id)
self.pass_binary_2_unary(tree, lhs_node.data.id)
else:
self.pass_binary_2_unary(tree, lhs_node.data.id)
self.pass_binary_2_unary(tree, rhs_node.data.id)
else:
for child in node.successors(tree.identifier):
self.pass_binary_2_unary(tree, child)
# pass 2: inject reduction nodes
def pass_inject_reduction(self, tree, nid):
node = tree.get_node(nid)
if isinstance(node.data, TensorOutputNode):
if node.data.id in self.reduction_source.keys():
direction = self.reduction_source[node.data.id][0]
target = self.reduction_source[node.data.id][-1]
if direction == 'row':
reduction_node = RowReductionNode(
self.element_accumulator, self.element_output,
self.element_accumulator, target, self.tile_description.threadblock_shape[1])
elif direction == "column":
reduction_node = ColumnReductionNode(
self.element_accumulator, self.element_output,
self.element_accumulator, target, self.tile_description.threadblock_shape[0])
else:
raise ValueError(direction)
child_nid = node.successors(tree.identifier)[0]
# if this output node is injected only for reduction
if node.data.id not in self.returns:
# get reduction config from disc
node.data = reduction_node
node.tag = reduction_node.tag
self.pass_inject_reduction(tree, child_nid)
# if this output node is also a tensor output, inject reduction as its children
else:
# get child node
tree.create_node(reduction_node.tag, reduction_node.id, data=reduction_node, parent=node.data.id)
tree.move_node(child_nid, reduction_node.id)
child = tree.get_node(child_nid)
for grand_child in child.successors(tree.identifier):
self.pass_inject_reduction(tree, grand_child)
else:
for child in node.successors(tree.identifier):
self.pass_inject_reduction(tree, child)
else:
for child in node.successors(tree.identifier):
self.pass_inject_reduction(tree, child)
def pass_inject_epilogue_op(self, tree, nid):
node = tree.get_node(nid)
visitors = []
for child in node.successors(tree.identifier):
visitors.append(self.pass_inject_epilogue_op(tree, child))
node.data.get_epilogue_node(visitors)
return node.data.epilogue_node
def get_arguments(self, tree, nid, kwargs):
node = tree.get_node(nid)
visitor_args = []
for child in node.successors(tree.identifier):
visitor_args.append(self.get_arguments(tree, child, kwargs))
node.data.get_argument(visitor_args, kwargs)
return node.data.argument
class EpilogueVisitTree:
KernelTemplate = """
${visitor}
using ${operation_name}_EpilogueVisitor = cutlass::epilogue::threadblock::EpilogueVisitorGeneric<${visitor_name}>;
"""
def __init__(self, elementwise_functor, tile_description,
element_accumulator, elements_per_access,
element_compute, element_output) -> None:
#
# data types
self.tile_description = tile_description
self.element_accumulator = element_accumulator
self.elements_per_access = elements_per_access
self.element_compute = element_compute
self.element_output = element_output
# TODO: deprecate this
self.elementwise_functor = elementwise_functor
pass
def initialize(self):
function = EpilogueAST(self, self.tile_description,
self.element_accumulator, self.elements_per_access,
self.element_compute, self.element_output)
#
tree = function.epilogue_tree
self.tree = tree
# self.tree.show() # for debug
function.pass_binary_2_unary(self.tree, self.tree.root)
# self.tree.show() # for debug
function.pass_inject_reduction(self.tree, self.tree.root)
# self.tree.show() # for debug
function.pass_inject_epilogue_op(self.tree,self.tree.root)
visitor = self.tree.get_node(self.tree.root).data.epilogue_node
self.visitor = visitor
class _Argument(ctypes.Structure):
_fields_ = [
("visitor_arg", visitor.argument_type)
]
def __init__(self, **kwargs) -> None:
# process input args
_kwargs = {}
for input_key in function.input_args.keys():
if input_key == "accum":
continue
if function.input_args[input_key][0] == "scalar":
# _kwargs[input_key] = kwargs[input_key]
continue
# tensor input
else:
setattr(self, "buffer_tensor_" + input_key, NumpyFrontend.argument(kwargs[input_key], False))
setattr(self, input_key + "_ptr", int(getattr(self, "buffer_tensor_" + input_key).ptr))
_kwargs[input_key+"_ptr"] = getattr(self, input_key + "_ptr")
# process the return args
for ret in function.returns:
setattr(self, "buffer_tensor_" + ret, NumpyFrontend.argument(kwargs[ret], True))
setattr(self, ret + "_ptr", int(getattr(self, "buffer_tensor_" + ret).ptr))
_kwargs[ret+"_ptr"] = getattr(self, ret + "_ptr")
setattr(self, "host_tensor_" + ret, kwargs[ret])
_kwargs.update(kwargs)
function.get_arguments(tree, tree.root, _kwargs)
self.visitor_arg = tree.get_node(tree.root).data.argument
def sync(self, stream_sync=True):
if stream_sync:
err, = cudart.cudaDeviceSynchronize()
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
for ret in function.returns:
err, = cuda.cuMemcpyDtoH(
getattr(self, "host_tensor_" + ret), cuda.CUdeviceptr(getattr(self, ret + "_ptr")),
getattr(self, "host_tensor_" + ret).size * getattr(self, "host_tensor_" + ret).itemsize
)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
pass
self.epilogue_type = _Argument
def emit(self, operation):
values = {
'visitor': self.visitor.emit(operation),
'operation_name': operation.procedural_name(),
'visitor_name': self.visitor.instance_name
}
return SubstituteTemplate(self.KernelTemplate, values)

View File

@ -58,6 +58,13 @@ class ReductionArguments:
destination: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]',
source: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]', **kwargs) -> None:
# tensor_C can be interpreted as the bias with bias=True in keyword args
if "bias" in kwargs.keys():
self.bias = kwargs["bias"]
else:
# by default, tensor_C is not bias
self.bias = False
self.operation = operation
#: pointer to the workspace
self.ptr_workspace = workspace
@ -89,11 +96,9 @@ class ReductionArguments:
problem_size[1] * DataTypeSize[operation.C.element] // 8
if "output_op" in kwargs.keys():
self.alpha = kwargs["output_op"].alpha
self.beta = kwargs["output_op"].beta
self.output_op = kwargs['output_op']
else:
self.alpha = 1.0
self.beta = 0.0
self.output_op = self.operation.epilogue_type(1.0, 0.0)
# get arguments
self.get_arguments()
@ -109,31 +114,25 @@ class ReductionArguments:
ref_workspace = ReductionArguments.get_tensor_ref(
extent=[self.problem_size.row, self.problem_size.column],
device_ptr=self.ptr_workspace, layout=cutlass.RowMajor)
ref_source = ReductionArguments.get_tensor_ref(
extent=[self.problem_size.row, self.problem_size.column],
device_ptr=self.ptr_source, layout=cutlass.RowMajor)
if self.bias:
ref_source = ReductionArguments.get_tensor_ref(
extent=[0, 0],
device_ptr=self.ptr_source, layout=cutlass.RowMajor)
else:
ref_source = ReductionArguments.get_tensor_ref(
extent=[self.problem_size.row, self.problem_size.column],
device_ptr=self.ptr_source, layout=cutlass.RowMajor)
ref_destination = ReductionArguments.get_tensor_ref(
extent=[self.problem_size.row, self.problem_size.column],
device_ptr=self.ptr_destination, layout=cutlass.RowMajor)
argument_type, epilogue_type = get_reduction_params(
self.operation.element_compute)
if self.operation.element_compute == cutlass.float16:
self.alpha = cutlass.float16(self.alpha).storage
self.beta = cutlass.float16(self.beta).storage
elif self.operation.element_compute == cutlass.int32:
self.alpha = int(self.alpha)
self.beta = int(self.beta)
output_op = epilogue_type(self.alpha, self.beta, 0, 0)
self.c_arguments = argument_type(
self.c_arguments = self.operation.argument_type(
self.problem_size, self.partitions,
self.partition_stride, ref_workspace,
ref_destination, ref_source,
output_op
self.output_op
)
params_ = self.operation.rt_module.get_args(
@ -210,8 +209,8 @@ extern "C" {
self.emitter = EmitReductionInstance('_type')
self.elements_per_access = self.operation.count
self.argtype = [ctypes.POINTER(
get_reduction_params(operation.element_compute)[0])]
self.argument_type, self.epilogue_type = get_reduction_params(operation.epilogue_functor)
self.argtype = [ctypes.POINTER(self.argument_type)]
def emit(self):
return self.emitter.emit(self.operation)
@ -247,14 +246,14 @@ class ReductionOperation:
def __init__(self, shape: cutlass.MatrixCoord, C: TensorDescription,
element_accumulator, element_workspace=None,
element_compute=None, epilogue_functor: EpilogueFunctor = EpilogueFunctor.LinearCombination,
element_compute=None, epilogue_functor=None,
count: int = 1, partitions_per_stage: int = 4) -> None:
""" Constructor
"""
self.shape = shape
#: epilogue functor (default: LinearCombination)
self.epilogue_functor: EpilogueFunctor = epilogue_functor
self.epilogue_functor = epilogue_functor
#: datatype of accumulator
self.element_accumulator = element_accumulator
@ -285,6 +284,8 @@ class ReductionOperation:
self.partitions_per_stage: int = partitions_per_stage
self.rt_module: ReductionRT = ReductionRT(self)
self.argument_type = self.rt_module.argument_type
self.epilogue_type = self.rt_module.epilogue_type
#
def extended_name(self):
@ -363,12 +364,7 @@ class EmitReductionInstance:
using ${operation_name}_base =
typename cutlass::reduction::kernel::ReduceSplitK<
cutlass::MatrixShape<${shape_row}, ${shape_column}>,
${epilogue_functor}<
${element_output},
${epilogue_vector_length},
${element_accumulator},
${element_compute}
>,
${epilogue_functor},
cutlass::reduction::thread::ReduceAdd<
${element_accumulator},
${element_output},
@ -389,7 +385,7 @@ struct ${operation_name}${operation_suffix}:
'operation_suffix': self.operation_suffix,
'shape_row': str(operation.shape.row()),
'shape_column': str(operation.shape.column()),
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
'epilogue_functor': operation.epilogue_functor.emit(),
'element_output': DataTypeTag[operation.element_output],
'epilogue_vector_length': str(epilogue_vector_length),
'element_accumulator': DataTypeTag[operation.element_accumulator],

View File

@ -68,4 +68,3 @@ class TensorRef:
# the dtype(0) is used to overload between different data types
# with the same layout
self.tensor_ref = cutlass.get_tensor_ref(int(ptr), dtype(0), layout)

View File

@ -124,7 +124,7 @@ class Conv2dLauncher:
self.reduction_operation = ReductionOperation(
shape=cutlass.MatrixCoord(4, 32 * operation.C.alignment),
C=operation.C, element_accumulator=operation.tile_description.math_instruction.element_accumulator,
element_compute=operation.element_epilogue,
element_compute=operation.epilogue_functor.element_epilogue, epilogue_functor=operation.epilogue_functor,
count=operation.C.alignment
)
@ -183,7 +183,7 @@ class Conv2dLauncher:
# Get the host reference function
#
self.element_compute = operation.element_epilogue
self.element_compute = operation.epilogue_functor.element_epilogue
self.host_conv2d = cutlass.test.conv.host.conv2d
@ -441,7 +441,7 @@ class Conv2dLauncher:
arguments = Conv2dArguments(
operation=self.operation, problem_size=problem_size, A=tensor_A,
B=tensor_B, C=tensor_C, D=tensor_D,
output_op = LinearCombinationFunctorArguments(alpha, beta),
output_op = self.operation.epilogue_type(alpha, beta),
split_k_slices=problem_size.split_k_slices,
split_k_mode=split_k_mode
)
@ -454,7 +454,7 @@ class Conv2dLauncher:
workspace=arguments.ptr_D,
destination=tensor_D,
source=tensor_C,
output_op = LinearCombinationFunctorArguments(alpha, beta)
output_op = self.reduction_operation.epilogue_type(alpha, beta)
)
self.operation.run(arguments)

View File

@ -68,7 +68,7 @@ class TestbedGrouped:
self.scope_min = -8
#: compute type
self.compute_type = operation.element_epilogue
self.compute_type = operation.epilogue_functor.element_epilogue
self.accumulator_type = operation.tile_description.math_instruction.element_accumulator
@ -176,7 +176,7 @@ class TestbedGrouped:
arguments = GemmGroupedArguments(
operation=self.operation, problem_sizes=problem_sizes,
A=tensor_As, B=tensor_Bs, C=tensor_Cs, D=tensor_Ds,
output_op=LinearCombinationFunctorArguments(alpha, beta)
output_op=self.operation.epilogue_type(alpha, beta)
)
self.operation.run(arguments)

View File

@ -143,7 +143,7 @@ class GemmUniversalLauncher:
self.reduction_operation: ReductionOperation = ReductionOperation(
shape=cutlass.MatrixCoord(4, 32 * operation.C.alignment),
C=operation.C, element_accumulator=operation.tile_description.math_instruction.element_accumulator,
element_compute=operation.element_epilogue,
element_compute=operation.epilogue_functor.element_epilogue, epilogue_functor=operation.epilogue_functor,
count=operation.C.alignment
)
@ -200,7 +200,7 @@ class GemmUniversalLauncher:
self.interleaved = interleaved
#: compute type
self.compute_type = operation.element_epilogue
self.compute_type = operation.epilogue_functor.element_epilogue
self.accumulator_type = operation.tile_description.math_instruction.element_accumulator
def print_problem_size(self, p, mode, batch_count):
@ -391,7 +391,7 @@ class GemmUniversalLauncher:
arguments = GemmArguments(
operation=self.operation, problem_size=problem_size,
A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D,
output_op=LinearCombinationFunctorArguments(alpha, beta),
output_op=self.operation.epilogue_type(alpha, beta),
gemm_mode=mode, split_k_slices=batch_count
)
@ -403,7 +403,7 @@ class GemmUniversalLauncher:
workspace=arguments.ptr_D,
destination=tensor_D,
source=tensor_C,
output_op=LinearCombinationFunctorArguments(alpha, beta)
output_op=self.reduction_operation.epilogue_type(alpha, beta)
)
self.operation.run(arguments)

View File

@ -34,6 +34,7 @@ import numpy as np
import cutlass
from pycutlass.library import TensorDescription
from typing import Union
from bfloat16 import bfloat16
try:
import torch
torch_available = True
@ -46,7 +47,7 @@ class ReferenceModule:
self.layout_B = B.layout
self.layout_C = C.layout
def run(self, A: np.ndarray, B: np.ndarray, C: np.ndarray, problem_size: cutlass.gemm.GemmCoord, alpha: float=1.0, beta: float=0.0):
def run(self, A: np.ndarray, B: np.ndarray, C: np.ndarray, problem_size: cutlass.gemm.GemmCoord, alpha: float=1.0, beta: float=0.0, bias=False, batch=1):
"""
Compute the reference result on CPU
Args:
@ -57,27 +58,38 @@ class ReferenceModule:
M, N, K = problem_size.m(), problem_size.n(), problem_size.k()
if isinstance(A, np.ndarray):
if self.layout_A == cutlass.RowMajor:
A_row = np.reshape(A, newshape=(M, K))
A_row = np.reshape(A, newshape=(batch, M, K))
else:
A_col = np.reshape(A, newshape=(K, M))
A_row = np.transpose(A_col, axes=(1, 0))
A_col = np.reshape(A, newshape=(batch, K, M))
A_row = np.transpose(A_col, axes=(0, 2, 1))
if self.layout_B == cutlass.RowMajor:
B_row = np.reshape(B, newshape=(K, N))
B_row = np.reshape(B, newshape=(batch, K, N))
else:
B_col = np.reshape(B, newshape=(N, K))
B_row = np.transpose(B_col, axes=(1, 0))
B_col = np.reshape(B, newshape=(batch, N, K))
B_row = np.transpose(B_col, axes=(0, 2, 1))
if self.layout_C == cutlass.RowMajor:
C_row = np.reshape(C, newshape=(M, N))
if bias:
C_row = np.reshape(C, newshape=(batch, 1, N))
else:
C_row = np.reshape(C, newshape=(batch, M, N))
else:
C_col = np.reshape(C, newshape=(N, M))
C_row = np.transpose(C_col, axes=(1, 0))
if bias:
C_row = np.reshape(C, newshape=(batch, M, 1))
else:
C_col = np.reshape(C, newshape=(batch, N, M))
C_row = np.transpose(C_col, axes=(0, 2, 1))
out_row = np.matmul(A_row, B_row) * alpha + C_row * beta
if A_row.dtype == bfloat16:
# numpy's einsum doesn't support bfloat16
out_row = np.einsum("bik,bkj->bij", A_row.astype(np.float32), B_row.astype(np.float32)) * alpha + C_row * beta
out_row = out_row.astype(C_row.dtype)
else:
out_row = np.einsum("bik,bkj->bij", A_row, B_row) * alpha + C_row * beta
if self.layout_C == cutlass.ColumnMajor:
out = np.transpose(out_row, axes=(1, 0))
out = np.transpose(out_row, axes=(0, 2, 1))
else:
out = out_row
@ -128,7 +140,7 @@ if torch_available:
def run(self,
A: Union[np.ndarray, torch.Tensor],
B: Union[np.ndarray, torch.Tensor],
C: Union[np.ndarray, torch.Tensor], problem_size, alpha=1.0, beta=0.0) -> np.ndarray:
C: Union[np.ndarray, torch.Tensor], problem_size, alpha=1.0, beta=0.0, bias=False) -> np.ndarray:
"""
Compute the reference result on CPU
"""
@ -184,7 +196,10 @@ if torch_available:
B_torch_nchw = torch.permute(B_nhwc, (0, 3, 1, 2))
if self.layout_C == cutlass.TensorNHWC:
C_nhwc = C.view((k, r, s, c))
if bias:
C_nhwc = C.view((1, 1, 1, c))
else:
C_nhwc = C.view((k, r, s, c))
C_torch_nchw = torch.permute(C_nhwc, (0, 3, 1, 2))
elif self.kind == cutlass.conv.Operator.dgrad:
if self.layout_A == cutlass.TensorNHWC:
@ -196,7 +211,10 @@ if torch_available:
B_torch_nchw = torch.permute(B_nhwc, (0, 3, 1, 2))
if self.layout_C == cutlass.TensorNHWC:
C_nhwc = C.view((n, h, w, c))
if bias:
C_nhwc = C.view((1, 1, 1, c))
else:
C_nhwc = C.view((n, h, w, c))
C_torch_nchw = torch.permute(C_nhwc, (0, 3, 1, 2))
else:
if self.layout_A == cutlass.TensorNHWC:
@ -208,7 +226,10 @@ if torch_available:
B_torch_nchw = torch.permute(B_nhwc, (0, 3, 1, 2))
if self.layout_C == cutlass.TensorNHWC:
C_nhwc = C.view((n, p, q, k))
if bias:
C_nhwc = C.view((1, 1, 1, k))
else:
C_nhwc = C.view((n, p, q, k))
C_torch_nchw = torch.permute(C_nhwc, (0, 3, 1, 2))
if self.kind == cutlass.conv.Operator.fprop:

View File

@ -30,15 +30,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
tile_description = TileDescription(
threadblock_shape=[128, 128, 64], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float16)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float16, stride_support=StrideSupport.Unity,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Unity,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)
@ -68,15 +71,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
tile_description = TileDescription(
threadblock_shape=[128, 128, 64], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float16)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float16, stride_support=StrideSupport.Unity,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Unity,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)
@ -106,15 +112,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
tile_description = TileDescription(
threadblock_shape=[128, 128, 64], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float16)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float16, stride_support=StrideSupport.Unity,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Unity,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)
@ -156,15 +165,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
tile_description = TileDescription(
threadblock_shape=[128, 128, 64], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float16)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float16, stride_support=StrideSupport.Unity,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Unity,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)

View File

@ -29,15 +29,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC
tile_description = TileDescription(
threadblock_shape=[128, 128, 32], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Unity,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)
@ -67,15 +70,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC
tile_description = TileDescription(
threadblock_shape=[128, 128, 32], stages=4,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Unity,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)
@ -105,15 +111,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC
tile_description = TileDescription(
threadblock_shape=[128, 128, 64], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Unity,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)
@ -143,15 +152,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC
tile_description = TileDescription(
threadblock_shape=[128, 128, 64], stages=4,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Unity,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)

View File

@ -30,15 +30,18 @@ class Conv2dDgradImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase)
tile_description = TileDescription(
threadblock_shape=[128, 128, 8], stages=4,
warp_count=[4, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Unity,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)
@ -68,15 +71,18 @@ class Conv2dDgradImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase)
tile_description = TileDescription(
threadblock_shape=[128, 128, 8], stages=4,
warp_count=[2, 4, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Unity,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)

View File

@ -29,15 +29,18 @@ class Conv2dDgradImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.Te
tile_description = TileDescription(
threadblock_shape=[128, 128, 16], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Unity,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)
@ -67,15 +70,18 @@ class Conv2dDgradImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.Te
tile_description = TileDescription(
threadblock_shape=[128, 128, 16], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Unity,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)

View File

@ -97,15 +97,18 @@ class Conv2dFpropFewChannelsF16NHWCF16NHWCF16HNWCTensorOpF32SM80(unittest.TestCa
tile_description = TileDescription(
threadblock_shape=[128, 128, 64], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.few_channels,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)
@ -135,15 +138,18 @@ class Conv2dFpropFewChannelsF16NHWCF16NHWCF16HNWCTensorOpF32SM80(unittest.TestCa
tile_description = TileDescription(
threadblock_shape=[128, 128, 32], stages=2,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.few_channels,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)

View File

@ -79,15 +79,18 @@ class Conv2dFpropFixedChannelsF16NHWCF16NHWCF16HNWCTensorOpF32SM80(unittest.Test
tile_description = TileDescription(
threadblock_shape=[128, 128, 64], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.fixed_channels,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)
@ -117,15 +120,18 @@ class Conv2dFpropFixedChannelsF16NHWCF16NHWCF16HNWCTensorOpF32SM80(unittest.Test
tile_description = TileDescription(
threadblock_shape=[128, 128, 64], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.fixed_channels,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)
@ -155,15 +161,18 @@ class Conv2dFpropFixedChannelsF16NHWCF16NHWCF16HNWCTensorOpF32SM80(unittest.Test
tile_description = TileDescription(
threadblock_shape=[128, 128, 64], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.fixed_channels,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)

View File

@ -29,15 +29,18 @@ class Conv2dFpropImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
tile_description = TileDescription(
threadblock_shape=[128, 128, 64], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float16)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)
@ -67,15 +70,18 @@ class Conv2dFpropImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
tile_description = TileDescription(
threadblock_shape=[128, 128, 64], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float16)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)
@ -105,15 +111,18 @@ class Conv2dFpropImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
tile_description = TileDescription(
threadblock_shape=[128, 128, 64], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float16)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)
@ -173,15 +182,18 @@ class Conv2dFpropImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
tile_description = TileDescription(
threadblock_shape=[128, 128, 64], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float16)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)
@ -241,15 +253,18 @@ class Conv2dFpropImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
tile_description = TileDescription(
threadblock_shape=[128, 128, 64], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float16)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)

View File

@ -29,15 +29,18 @@ class Conv2dFpropImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC
tile_description = TileDescription(
threadblock_shape=[128, 128, 64], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)

View File

@ -30,15 +30,18 @@ class Conv2dFpropImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase)
tile_description = TileDescription(
threadblock_shape=[128, 128, 8], stages=4,
warp_count=[4, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle2
)
@ -68,15 +71,18 @@ class Conv2dFpropImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase)
tile_description = TileDescription(
threadblock_shape=[128, 128, 8], stages=4,
warp_count=[2, 4, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)

View File

@ -29,15 +29,18 @@ class Conv2dFpropImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.Te
tile_description = TileDescription(
threadblock_shape=[128, 128, 16], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)
@ -67,15 +70,18 @@ class Conv2dFpropImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.Te
tile_description = TileDescription(
threadblock_shape=[128, 128, 16], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)

View File

@ -29,15 +29,18 @@ class Conv2dStridedDgradImplicitGemmF16NHWCF16NHWCF32NHWCTensorOpF32SM80(unittes
tile_description = TileDescription(
threadblock_shape=[128, 128, 32], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.StridedDgradIdentitySwizzle1
)
@ -67,15 +70,18 @@ class Conv2dStridedDgradImplicitGemmF16NHWCF16NHWCF32NHWCTensorOpF32SM80(unittes
tile_description = TileDescription(
threadblock_shape=[128, 256, 64], stages=3,
warp_count=[2, 4, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.StridedDgradIdentitySwizzle1
)
@ -105,15 +111,18 @@ class Conv2dStridedDgradImplicitGemmF16NHWCF16NHWCF32NHWCTensorOpF32SM80(unittes
tile_description = TileDescription(
threadblock_shape=[128, 128, 32], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.StridedDgradIdentitySwizzle1
)
@ -155,15 +164,18 @@ class Conv2dStridedDgradImplicitGemmF16NHWCF16NHWCF32NHWCTensorOpF32SM80(unittes
tile_description = TileDescription(
threadblock_shape=[128, 128, 32], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.StridedDgradIdentitySwizzle1
)
@ -193,15 +205,18 @@ class Conv2dStridedDgradImplicitGemmF16NHWCF16NHWCF32NHWCTensorOpF32SM80(unittes
tile_description = TileDescription(
threadblock_shape=[128, 128, 32], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.StridedDgradIdentitySwizzle1
)

View File

@ -29,15 +29,19 @@ class Conv2dWgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
tile_description = TileDescription(
threadblock_shape=[128, 128, 64], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment, math_inst.element_accumulator,
cutlass.float16
)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)
@ -67,15 +71,19 @@ class Conv2dWgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
tile_description = TileDescription(
threadblock_shape=[128, 128, 64], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment, math_inst.element_accumulator,
cutlass.float16
)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)

View File

@ -29,15 +29,18 @@ class Conv2dWgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC
tile_description = TileDescription(
threadblock_shape=[128, 128, 16], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)
@ -67,15 +70,18 @@ class Conv2dWgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC
tile_description = TileDescription(
threadblock_shape=[128, 128, 16], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)
@ -105,15 +111,18 @@ class Conv2dWgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC
tile_description = TileDescription(
threadblock_shape=[64, 256, 32], stages=3,
warp_count=[1, 4, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)
@ -143,15 +152,18 @@ class Conv2dWgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC
tile_description = TileDescription(
threadblock_shape=[128, 128, 16], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)
@ -193,15 +205,18 @@ class Conv2dWgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC
tile_description = TileDescription(
threadblock_shape=[128, 128, 16], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)

View File

@ -30,15 +30,18 @@ class Conv2dWgradImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase)
tile_description = TileDescription(
threadblock_shape=[128, 128, 8], stages=4,
warp_count=[2, 4, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)
@ -68,15 +71,18 @@ class Conv2dWgradImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase)
tile_description = TileDescription(
threadblock_shape=[128, 128, 8], stages=4,
warp_count=[2, 4, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)

View File

@ -29,15 +29,18 @@ class Conv2dWgradImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.Te
tile_description = TileDescription(
threadblock_shape=[128, 128, 16], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)
@ -67,15 +70,18 @@ class Conv2dWgradImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.Te
tile_description = TileDescription(
threadblock_shape=[128, 128, 32], stages=3,
warp_count=[2, 2, 1],
math_instruction=math_inst,
min_compute=80, max_compute=80
math_instruction=math_inst
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
arch=80, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
epilogue_functor=EpilogueFunctor.LinearCombination,
stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)

View File

@ -0,0 +1,80 @@
pushd $CUTLASS_PATH/examples/40_cutlass_py/
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_f32 -op TensorOp -b 64 64 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 4
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 256 128 64 -s 3 -w 4 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 3
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 5
python gemm.py -i 16 8 32 -ta int8 -tb int8 -tc int8 -tacc int32 -m multiply_add -op TensorOp -b 128 128 128 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 16 -lb ColumnMajor -ab 16 -lc RowMajor -ac 16 -te float32 -ep FastLinearCombinationClamp -sw IdentitySwizzle2 -p 512 512 512 -alpha 1.0 -beta 0.0 -gm Gemm -k 1
python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 0.0 -pm Device
python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 1.0 -pm Host
python gemm_grouped.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 64 8 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device
python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 13 17 8 -krsc 24 3 3 8 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 1.0 -beta 1.0
python conv2d.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 128 8 -s 4 -w 4 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -co fprop -st Strided -ia analytic -sm Parallel -k 3 -nhwc 1 71 80 32 -krsc 64 5 5 32 -pad 2 2 2 2 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 1.0
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 1 -lb TensorNHWC -ab 1 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co wgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 8 8 1 -krsc 1 3 3 1 -pad 1 1 1 1 -stride 1 1 -dilation 1 1 -alpha 1.0 -beta 0.0
python conv2d.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 128 8 -s 4 -w 2 4 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co wgrad -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 8 -lb TensorNHWC -ab 8 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 8 -lb TensorNHWC -ab 8 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia fixed_channels -sm Serial -k 1 -nhwc 1 8 8 8 -krsc 16 3 3 8 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 56 56 12 -krsc 8 1 1 12 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 -bias -activ relu
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2 -bias -activ leaky_relu -activ_arg 0.2
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2 -bias -activ tanh
python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 0.0 -beta 0.5 -pm Host -bias -activ sigmoid -bias -activ sigmoid
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ silu
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ hardswish
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 0.0 -beta 0.5 -gm GemmSplitKParallel -k 5 -bias -activ gelu
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 2
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -epv ColumnBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnReduction -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnBroadcast -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 3
popd

View File

@ -49,7 +49,7 @@ class Test_Frontend(unittest.TestCase):
tile_description = TileDescription(
[128, 128, 8], 4, [2, 4, 1],
math_inst, 80, 80
math_inst
)
A = TensorDescription(
@ -64,10 +64,14 @@ class Test_Frontend(unittest.TestCase):
cutlass.float32, cutlass.RowMajor, 1
)
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
self.operation = GemmOperationUniversal(
arch=80, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=cutlass.float32,
epilogue_functor=EpilogueFunctor.LinearCombination,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor,
swizzling_functor=cutlass.IdentitySwizzle1
)
@ -89,7 +93,7 @@ class Test_Frontend(unittest.TestCase):
arguments = GemmArguments(
operation=self.operation, problem_size=problem_size,
A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D,
output_op=LinearCombinationFunctorArguments(alpha, beta),
output_op=self.operation.epilogue_type(alpha, beta),
gemm_mode=cutlass.gemm.Mode.Gemm, split_k_splices=1
)
@ -119,7 +123,7 @@ class Test_Frontend(unittest.TestCase):
arguments = GemmArguments(
operation=self.operation, problem_size=problem_size,
A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D,
output_op=LinearCombinationFunctorArguments(alpha, beta),
output_op=self.operation.epilogue_type(alpha, beta),
gemm_mode=cutlass.gemm.Mode.Gemm, split_k_splices=1
)

View File

@ -17,7 +17,7 @@ class GemmBF16TensorOpSm80(unittest.TestCase):
tile_description = TileDescription(
threadblock_shape=[64, 128, 64],
stages=4, warp_count=[2, 2, 1],
math_instruction=math_inst, min_compute=80, max_compute=80
math_instruction=math_inst
)
A = TensorDescription(
@ -33,15 +33,15 @@ class GemmBF16TensorOpSm80(unittest.TestCase):
alignment=4
)
element_epilogue = cutlass.float32
epilogue_functor = EpilogueFunctor.LinearCombination
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
swizzling_functor = cutlass.IdentitySwizzle1
operation = GemmOperationUniversal(
arch=80, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
)
@ -58,7 +58,7 @@ class GemmBF16TensorOpSm80(unittest.TestCase):
tile_description = TileDescription(
threadblock_shape=[64, 128, 32],
stages=6, warp_count=[2, 2, 1],
math_instruction=math_inst, min_compute=80, max_compute=80
math_instruction=math_inst
)
A = TensorDescription(
@ -74,15 +74,15 @@ class GemmBF16TensorOpSm80(unittest.TestCase):
alignment=8
)
element_epilogue = cutlass.float32
epilogue_functor = EpilogueFunctor.LinearCombination
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, cutlass.float32)
swizzling_functor = cutlass.IdentitySwizzle1
operation = GemmOperationUniversal(
arch=80, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
)

View File

@ -18,7 +18,7 @@ class GemmF16Sm80(unittest.TestCase):
tile_description = TileDescription(
threadblock_shape=[128, 128, 32],
stages=3, warp_count=[2, 2, 1],
math_instruction=math_inst, min_compute=80, max_compute=80
math_instruction=math_inst
)
A = TensorDescription(
@ -36,13 +36,15 @@ class GemmF16Sm80(unittest.TestCase):
element_epilogue = cutlass.float32
epilogue_functor = EpilogueFunctor.LinearCombination
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, element_epilogue)
swizzling_functor = cutlass.BatchedIdentitySwizzle
operation = GemmOperationUniversal(
arch=80, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor,
direct_store=True
)
@ -60,7 +62,7 @@ class GemmF16Sm80(unittest.TestCase):
tile_description = TileDescription(
threadblock_shape=[128, 128, 64],
stages=3, warp_count=[2, 2, 1],
math_instruction=math_inst, min_compute=80, max_compute=80
math_instruction=math_inst
)
A = TensorDescription(
@ -78,13 +80,15 @@ class GemmF16Sm80(unittest.TestCase):
element_epilogue = cutlass.float32
epilogue_functor = EpilogueFunctor.LinearCombination
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, element_epilogue)
swizzling_functor = cutlass.IdentitySwizzle1
operation = GemmOperationUniversal(
arch=80, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
)
@ -101,7 +105,7 @@ class GemmF16Sm80(unittest.TestCase):
tile_description = TileDescription(
threadblock_shape=[128, 256, 64],
stages=3, warp_count=[2, 4, 1],
math_instruction=math_inst, min_compute=80, max_compute=80
math_instruction=math_inst
)
A = TensorDescription(
@ -119,13 +123,15 @@ class GemmF16Sm80(unittest.TestCase):
element_epilogue = cutlass.float32
epilogue_functor = EpilogueFunctor.LinearCombination
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, element_epilogue)
swizzling_functor = cutlass.IdentitySwizzle1
operation = GemmOperationUniversal(
arch=80, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
)
@ -142,7 +148,7 @@ class GemmF16Sm80(unittest.TestCase):
tile_description = TileDescription(
threadblock_shape=[256, 128, 64],
stages=3, warp_count=[4, 2, 1],
math_instruction=math_inst, min_compute=80, max_compute=80
math_instruction=math_inst
)
A = TensorDescription(
@ -160,13 +166,15 @@ class GemmF16Sm80(unittest.TestCase):
element_epilogue = cutlass.float32
epilogue_functor = EpilogueFunctor.LinearCombination
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, element_epilogue)
swizzling_functor = cutlass.IdentitySwizzle1
operation = GemmOperationUniversal(
arch=80, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
)
@ -183,7 +191,7 @@ class GemmF16Sm80(unittest.TestCase):
tile_description = TileDescription(
threadblock_shape=[128, 64, 64],
stages=3, warp_count=[2, 1, 1],
math_instruction=math_inst, min_compute=80, max_compute=80
math_instruction=math_inst
)
A = TensorDescription(
@ -201,13 +209,15 @@ class GemmF16Sm80(unittest.TestCase):
element_epilogue = cutlass.float16
epilogue_functor = EpilogueFunctor.LinearCombination
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, element_epilogue)
swizzling_functor = cutlass.IdentitySwizzle1
operation = GemmOperationUniversal(
arch=80, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
)
@ -224,7 +234,7 @@ class GemmF16Sm80(unittest.TestCase):
tile_description = TileDescription(
threadblock_shape=[64, 64, 32],
stages=10, warp_count=[2, 2, 1],
math_instruction=math_inst, min_compute=80, max_compute=80
math_instruction=math_inst
)
A = TensorDescription(
@ -242,13 +252,15 @@ class GemmF16Sm80(unittest.TestCase):
element_epilogue = cutlass.float16
epilogue_functor = EpilogueFunctor.LinearCombination
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, element_epilogue)
swizzling_functor = cutlass.IdentitySwizzle1
operation = GemmOperationUniversal(
arch=80, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
)
@ -265,7 +277,7 @@ class GemmF16Sm80(unittest.TestCase):
tile_description = TileDescription(
threadblock_shape=[256, 128, 64],
stages=3, warp_count=[4, 2, 1],
math_instruction=math_inst, min_compute=80, max_compute=80
math_instruction=math_inst
)
A = TensorDescription(
@ -283,13 +295,15 @@ class GemmF16Sm80(unittest.TestCase):
element_epilogue = cutlass.float32
epilogue_functor = EpilogueFunctor.LinearCombination
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, element_epilogue)
swizzling_functor = cutlass.IdentitySwizzle1
operation = GemmOperationUniversal(
arch=80, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
)
@ -306,7 +320,7 @@ class GemmF16Sm80(unittest.TestCase):
tile_description = TileDescription(
threadblock_shape=[128, 64, 64],
stages=3, warp_count=[2, 1, 1],
math_instruction=math_inst, min_compute=80, max_compute=80
math_instruction=math_inst
)
A = TensorDescription(
@ -324,13 +338,15 @@ class GemmF16Sm80(unittest.TestCase):
element_epilogue = cutlass.float32
epilogue_functor = EpilogueFunctor.LinearCombination
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, element_epilogue)
swizzling_functor = cutlass.IdentitySwizzle1
operation = GemmOperationUniversal(
arch=80, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
)
@ -347,7 +363,7 @@ class GemmF16Sm80(unittest.TestCase):
tile_description = TileDescription(
threadblock_shape=[128, 256, 64],
stages=3, warp_count=[2, 4, 1],
math_instruction=math_inst, min_compute=80, max_compute=80
math_instruction=math_inst
)
A = TensorDescription(
@ -365,13 +381,15 @@ class GemmF16Sm80(unittest.TestCase):
element_epilogue = cutlass.float32
epilogue_functor = EpilogueFunctor.LinearCombination
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, element_epilogue)
swizzling_functor = cutlass.IdentitySwizzle1
operation = GemmOperationUniversal(
arch=80, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
)
@ -388,7 +406,7 @@ class GemmF16Sm80(unittest.TestCase):
tile_description = TileDescription(
threadblock_shape=[128, 256, 64],
stages=3, warp_count=[2, 4, 1],
math_instruction=math_inst, min_compute=80, max_compute=80
math_instruction=math_inst
)
A = TensorDescription(
@ -406,13 +424,15 @@ class GemmF16Sm80(unittest.TestCase):
element_epilogue = cutlass.float32
epilogue_functor = EpilogueFunctor.LinearCombination
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, element_epilogue)
swizzling_functor = cutlass.IdentitySwizzle1
operation = GemmOperationUniversal(
arch=80, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
)

View File

@ -19,7 +19,7 @@ class GemmF32nF32nF32nTensorOpF32Sm80(unittest.TestCase):
tile_description = TileDescription(
threadblock_shape=[128, 128, 32],
stages=3, warp_count=[2, 2, 1],
math_instruction=math_inst, min_compute=80, max_compute=80
math_instruction=math_inst
)
A = TensorDescription(
@ -37,13 +37,15 @@ class GemmF32nF32nF32nTensorOpF32Sm80(unittest.TestCase):
element_epilogue = cutlass.float32
epilogue_functor = EpilogueFunctor.LinearCombination
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, element_epilogue)
swizzling_functor = cutlass.IdentitySwizzle1
operation = GemmOperationUniversal(
arch=80, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
)
@ -61,7 +63,7 @@ class GemmF32nF32nF32nTensorOpF32Sm80(unittest.TestCase):
tile_description = TileDescription(
threadblock_shape=[128, 128, 32],
stages=3, warp_count=[2, 2, 1],
math_instruction=math_inst, min_compute=80, max_compute=80
math_instruction=math_inst
)
A = TensorDescription(
@ -79,13 +81,15 @@ class GemmF32nF32nF32nTensorOpF32Sm80(unittest.TestCase):
element_epilogue = cutlass.float32
epilogue_functor = EpilogueFunctor.LinearCombination
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, element_epilogue)
swizzling_functor = cutlass.IdentitySwizzle1
operation = GemmOperationUniversal(
arch=80, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
)
@ -102,7 +106,7 @@ class GemmF32nF32nF32nTensorOpF32Sm80(unittest.TestCase):
tile_description = TileDescription(
threadblock_shape=[64, 64, 32],
stages=3, warp_count=[2, 2, 1],
math_instruction=math_inst, min_compute=80, max_compute=80
math_instruction=math_inst
)
A = TensorDescription(
@ -120,13 +124,15 @@ class GemmF32nF32nF32nTensorOpF32Sm80(unittest.TestCase):
element_epilogue = cutlass.float32
epilogue_functor = EpilogueFunctor.LinearCombination
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, element_epilogue)
swizzling_functor = cutlass.IdentitySwizzle1
operation = GemmOperationUniversal(
arch=80, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
)

View File

@ -17,7 +17,7 @@ class GemmF64TensorOpSm80(unittest.TestCase):
tile_description = TileDescription(
threadblock_shape=[32, 32, 16],
stages=4, warp_count=[2, 2, 1],
math_instruction=math_inst, min_compute=80, max_compute=80
math_instruction=math_inst
)
# alignment 1 restricted for double
@ -36,13 +36,15 @@ class GemmF64TensorOpSm80(unittest.TestCase):
element_epilogue = cutlass.float64
epilogue_functor = EpilogueFunctor.LinearCombination
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, element_epilogue)
swizzling_functor = cutlass.IdentitySwizzle1
operation = GemmOperationUniversal(
arch=80, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
)
@ -59,7 +61,7 @@ class GemmF64TensorOpSm80(unittest.TestCase):
tile_description = TileDescription(
threadblock_shape=[64, 64, 16],
stages=4, warp_count=[2, 2, 1],
math_instruction=math_inst, min_compute=80, max_compute=80
math_instruction=math_inst
)
# alignment 1 restricted for double
@ -78,13 +80,15 @@ class GemmF64TensorOpSm80(unittest.TestCase):
element_epilogue = cutlass.float64
epilogue_functor = EpilogueFunctor.LinearCombination
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, element_epilogue)
swizzling_functor = cutlass.IdentitySwizzle1
operation = GemmOperationUniversal(
arch=80, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
)

View File

@ -18,7 +18,7 @@ class GemmGroupedSm80(unittest.TestCase):
tile_description = TileDescription(
threadblock_shape=[128, 128, 32],
stages=3, warp_count=[2, 2, 1],
math_instruction=math_inst, min_compute=80, max_compute=80
math_instruction=math_inst
)
A = TensorDescription(
@ -37,14 +37,15 @@ class GemmGroupedSm80(unittest.TestCase):
)
element_epilogue = cutlass.float32
epilogue_functor = EpilogueFunctor.LinearCombination
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, element_epilogue)
swizzling_functor = cutlass.BatchedIdentitySwizzle
for precompute_mode in [SchedulerMode.Device, SchedulerMode.Host]:
operation = GemmOperationGrouped(
tile_description.minimum_compute_capability,
80,
tile_description, A, B, C,
element_epilogue,
epilogue_functor, swizzling_functor,
precompute_mode=precompute_mode
)
@ -64,7 +65,7 @@ class GemmGroupedSm80(unittest.TestCase):
tile_description = TileDescription(
threadblock_shape=[64, 64, 16],
stages=4, warp_count=[2, 2, 1],
math_instruction=math_inst, min_compute=80, max_compute=80
math_instruction=math_inst
)
A = TensorDescription(
@ -83,14 +84,15 @@ class GemmGroupedSm80(unittest.TestCase):
)
element_epilogue = cutlass.float64
epilogue_functor = EpilogueFunctor.LinearCombination
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, element_epilogue)
swizzling_functor = cutlass.BatchedIdentitySwizzle
for precompute_mode in [SchedulerMode.Device, SchedulerMode.Host]:
operation = GemmOperationGrouped(
tile_description.minimum_compute_capability,
80,
tile_description, A, B, C,
element_epilogue,
epilogue_functor, swizzling_functor,
precompute_mode=precompute_mode
)
@ -110,7 +112,7 @@ class GemmGroupedSm80(unittest.TestCase):
tile_description = TileDescription(
threadblock_shape=[128, 64, 8],
stages=4, warp_count=[2, 2, 1],
math_instruction=math_inst, min_compute=80, max_compute=80
math_instruction=math_inst
)
A = TensorDescription(
@ -129,14 +131,15 @@ class GemmGroupedSm80(unittest.TestCase):
)
element_epilogue = cutlass.float32
epilogue_functor = EpilogueFunctor.LinearCombination
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, element_epilogue)
swizzling_functor = cutlass.BatchedIdentitySwizzle
for precompute_mode in [SchedulerMode.Device, SchedulerMode.Host]:
operation = GemmOperationGrouped(
tile_description.minimum_compute_capability,
80,
tile_description, A, B, C,
element_epilogue,
epilogue_functor, swizzling_functor,
precompute_mode=precompute_mode
)
@ -156,7 +159,7 @@ class GemmGroupedSm80(unittest.TestCase):
tile_description = TileDescription(
threadblock_shape=[128, 128, 32],
stages=3, warp_count=[2, 2, 1],
math_instruction=math_inst, min_compute=80, max_compute=80
math_instruction=math_inst
)
A = TensorDescription(
@ -175,14 +178,15 @@ class GemmGroupedSm80(unittest.TestCase):
)
element_epilogue = cutlass.float32
epilogue_functor = EpilogueFunctor.LinearCombination
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, element_epilogue)
swizzling_functor = cutlass.BatchedIdentitySwizzle
for precompute_mode in [SchedulerMode.Device, SchedulerMode.Host]:
operation = GemmOperationGrouped(
tile_description.minimum_compute_capability,
80,
tile_description, A, B, C,
element_epilogue,
epilogue_functor, swizzling_functor,
precompute_mode=precompute_mode
)

View File

@ -1,5 +1,6 @@
import pycutlass
from pycutlass import *
from pycutlass.epilogue import LinearCombinationClamp
from pycutlass.test import *
import unittest
@ -17,7 +18,7 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase):
tile_description = TileDescription(
threadblock_shape=[64, 64, 64],
stages=6, warp_count=[2, 2, 1],
math_instruction=math_inst, min_compute=80, max_compute=80
math_instruction=math_inst
)
A = TensorDescription(
@ -33,15 +34,15 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase):
alignment=8
)
element_epilogue = cutlass.float32
epilogue_functor = EpilogueFunctor.FastLinearCombinationClamp
epilogue_functor = FastLinearCombinationClamp(
C.element, C.alignment
)
swizzling_functor = cutlass.IdentitySwizzle1
operation = GemmOperationUniversal(
arch=80, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
)
@ -58,7 +59,7 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase):
tile_description = TileDescription(
threadblock_shape=[128, 128, 128],
stages=3, warp_count=[2, 2, 1],
math_instruction=math_inst, min_compute=80, max_compute=80
math_instruction=math_inst
)
A = TensorDescription(
@ -74,15 +75,15 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase):
alignment=16
)
element_epilogue = cutlass.float32
epilogue_functor = EpilogueFunctor.FastLinearCombinationClamp
epilogue_functor = FastLinearCombinationClamp(
C.element, C.alignment
)
swizzling_functor = cutlass.IdentitySwizzle1
operation = GemmOperationUniversal(
arch=80, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
)
@ -99,7 +100,7 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase):
tile_description = TileDescription(
threadblock_shape=[128, 128, 128],
stages=3, warp_count=[2, 2, 1],
math_instruction=math_inst, min_compute=80, max_compute=80
math_instruction=math_inst
)
A = TensorDescription(
@ -115,15 +116,15 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase):
alignment=16
)
element_epilogue = cutlass.float32
epilogue_functor = EpilogueFunctor.FastLinearCombinationClamp
epilogue_functor = FastLinearCombinationClamp(
C.element, C.alignment
)
swizzling_functor = cutlass.IdentitySwizzle1
operation = GemmOperationUniversal(
arch=80, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
)
@ -140,7 +141,7 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase):
tile_description = TileDescription(
threadblock_shape=[128, 128, 128],
stages=3, warp_count=[2, 2, 1],
math_instruction=math_inst, min_compute=80, max_compute=80
math_instruction=math_inst
)
A = TensorDescription(
@ -158,13 +159,16 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase):
element_epilogue = cutlass.int32
epilogue_functor = EpilogueFunctor.LinearCombinationClamp
epilogue_functor = LinearCombinationClamp(
C.element, C.alignment, math_inst.element_accumulator,
element_epilogue
)
swizzling_functor = cutlass.IdentitySwizzle1
operation = GemmOperationUniversal(
arch=80, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
)
@ -181,7 +185,7 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase):
tile_description = TileDescription(
threadblock_shape=[128, 128, 128],
stages=3, warp_count=[2, 2, 1],
math_instruction=math_inst, min_compute=80, max_compute=80
math_instruction=math_inst
)
A = TensorDescription(
@ -199,13 +203,16 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase):
element_epilogue = cutlass.int32
epilogue_functor = EpilogueFunctor.LinearCombinationClamp
epilogue_functor = LinearCombinationClamp(
C.element, C.alignment, math_inst.element_accumulator,
element_epilogue
)
swizzling_functor = cutlass.IdentitySwizzle1
operation = GemmOperationUniversal(
arch=80, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
)

View File

@ -348,3 +348,16 @@ conv2d fprop_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1
conv2d fprop_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 164942943 4259285988 984016853 888753301
conv2d fprop_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 2823094147 1681845497 4242738907 3244428635
conv2d fprop_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 s8nhwc_s8nhwc_inhwc_i_i 4060010502 2881035321 3927119619 3311661122
conv2d dgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4110991321 3464637181 1030377090 3211227145
conv2d dgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4110991321 1479940693 2379046159 2482639965
conv2d dgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 832653836 1871463331 2718290800 1797658305
conv2d dgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3484040069 664160900 3954982568 985899371
conv2d dgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1513864544 1924855848 1728786974 3821277575
conv2d dgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 868180534 1764715518 3998637379 2782670608
conv2d dgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3437976747 666906244 2107859856 831363691
conv2d dgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4195072693 1575210381 2486552517 3268706408
conv2d dgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3457330201 2316839359 1729888024 2308314800
conv2d dgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 754609939 2469024119 464378888 544154978
conv2d dgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 754609939 2469024119 464378888 3191247524
conv2d dgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1690216859 554790212 956712535 1281779197
conv2d dgrad_1x56x56x8_28x28_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3184127693 835105643 4011933753 3207244654

View File

@ -42,8 +42,7 @@ import unittest
#
def TestGemmOperator(gemm_kind, math_inst, layout, alignment, tiling, arch, mixed=False,
epilogue_functor = EpilogueFunctor.LinearCombination,
swizzling_functor=cutlass.IdentitySwizzle1, **kwargs):
epilogue_functor=None, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs):
"""
Test GEMM Operation based on configuration
"""
@ -68,7 +67,7 @@ def TestGemmOperator(gemm_kind, math_inst, layout, alignment, tiling, arch, mixe
tile_description = TileDescription(
tiling[0], tiling[1], tiling[2],
math_inst, arch, arch
math_inst
)
A = TensorDescription(
@ -84,11 +83,15 @@ def TestGemmOperator(gemm_kind, math_inst, layout, alignment, tiling, arch, mixe
)
element_epilogue = data_type[3]
if epilogue_functor is None:
epilogue_functor = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, element_epilogue)
if gemm_kind == GemmKind.Universal:
operation = GemmOperationUniversal(
arch=arch, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
)
if A.layout in [cutlass.ColumnMajorInterleaved32, cutlass.RowMajorInterleaved32]:
@ -99,7 +102,7 @@ def TestGemmOperator(gemm_kind, math_inst, layout, alignment, tiling, arch, mixe
elif gemm_kind == GemmKind.Grouped:
operation = GemmOperationGrouped(
arch, tile_description, A, B, C,
element_epilogue, epilogue_functor, swizzling_functor,
epilogue_functor, swizzling_functor,
precompute_mode=kwargs["precompute_mode"]
)
testbed = TestbedGrouped(operation=operation)
@ -110,7 +113,7 @@ def TestGemmOperator(gemm_kind, math_inst, layout, alignment, tiling, arch, mixe
def TestConv2dOperator(math_inst, alignment, tiling, arch,
stride_supports=[StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided],
epilogue_functor=EpilogueFunctor.LinearCombination,
epilogue_functor=None,
swizzling_functor=cutlass.IdentitySwizzle1, interleaved=False, **kwargs):
"""
Test Conv2d Operation based on configurations
@ -167,20 +170,24 @@ def TestConv2dOperator(math_inst, alignment, tiling, arch,
tile_description = TileDescription(
threadblock_shape=tiling[0], stages=tiling[1],
warp_count=tiling[2],
math_instruction=math_inst,
min_compute=arch, max_compute=arch
math_instruction=math_inst
)
if conv_kind == cutlass.conv.Operator.dgrad and stride_support == StrideSupport.Strided:
swizzling_functor = cutlass.StridedDgradIdentitySwizzle1
else:
swizzling_functor = default_swizzling_functor
if epilogue_functor is None:
epilogue_functor_ = LinearCombination(
C.element, C.alignment,
math_inst.element_accumulator, data_type[3])
operation = Conv2dOperation(
conv_kind=conv_kind, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
arch=arch, tile_description=tile_description, A=A, B=B, C=C,
element_epilogue=data_type[3], stride_support=stride_support,
epilogue_functor=epilogue_functor,
stride_support=stride_support,
epilogue_functor=epilogue_functor_,
swizzling_functor=swizzling_functor
)
@ -369,7 +376,11 @@ class Test_SM80(unittest.TestCase):
tiling = ([256, 64, 64], 4, [4, 1, 1])
data_type_mixed = [cutlass.int8, cutlass.int8, cutlass.int8, cutlass.float32]
self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment_mixed, tiling, 80, False, data_type=data_type_mixed, epilogue_functor=EpilogueFunctor.FastLinearCombinationClamp))
epilogue_functor = FastLinearCombinationClamp(
data_type_mixed[2], alignment_mixed[2]
)
self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment_mixed, tiling, 80, False, data_type=data_type_mixed, epilogue_functor=epilogue_functor))
stride_supports = [StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided]
layout = [cutlass.TensorNC32HW32, cutlass.TensorC32RSK32, cutlass.TensorNC32HW32]
results = TestConv2dOperator(math_inst, alignment_mixed, tiling, 80, stride_supports=stride_supports, data_type=data_type_mixed, layout=layout, interleaved=True)
@ -378,59 +389,59 @@ class Test_SM80(unittest.TestCase):
def SM80_SparseTensorOp_16832(self):
pass
def test_SM80_PlanarComplexTensorOp_16816(self):
def SM80_PlanarComplexTensorOp_16816(self):
pass
def test_SM80_SparseTensorOp_16816_fast_math(self):
def SM80_SparseTensorOp_16816_fast_math(self):
pass
def test_SM80_TensorOp_1688_complex(self):
def SM80_TensorOp_1688_complex(self):
pass
def test_SM80_TensorOp_1688_fast_fp32_math_complex(self):
def SM80_TensorOp_1688_fast_fp32_math_complex(self):
pass
def test_SM80_TensorOp_1688_rank_k(self):
def SM80_TensorOp_1688_rank_k(self):
pass
def test_SM80_TensorOp_1688_rank_k_complex(self):
def SM80_TensorOp_1688_rank_k_complex(self):
pass
def test_SM80_TensorOp_1688_trmm(self):
def SM80_TensorOp_1688_trmm(self):
pass
def test_SM80_TensorOp_1688_trmm_complex(self):
def SM80_TensorOp_1688_trmm_complex(self):
pass
def test_SM80_TensorOp_1688_symm(self):
def SM80_TensorOp_1688_symm(self):
pass
def test_SM80_TensorOp_1688_symm_complex(self):
def SM80_TensorOp_1688_symm_complex(self):
pass
def test_SM80_TensorOp_884_complex(self):
def SM80_TensorOp_884_complex(self):
pass
def test_SM80_TensorOp_884_complex_gaussian(self):
def SM80_TensorOp_884_complex_gaussian(self):
pass
def test_SM80_TensorOp_884_rank_k(self):
def SM80_TensorOp_884_rank_k(self):
pass
def test_SM80_TensorOp_884_rank_k_complex(self):
def SM80_TensorOp_884_rank_k_complex(self):
pass
def test_SM80_TensorOp_884_rank_k_complex_gaussian(self):
def SM80_TensorOp_884_rank_k_complex_gaussian(self):
pass
def test_SM80_TensorOp_884_trmm(self):
def SM80_TensorOp_884_trmm(self):
pass
def test_SM80_TensorOp_884_trmm_complex(self):
def SM80_TensorOp_884_trmm_complex(self):
pass
def test_SM80_TensorOp_884_trmm_complex_gaussian(self):
def SM80_TensorOp_884_trmm_complex_gaussian(self):
pass
def test_SM80_TensorOp_884_symm(self):
def SM80_TensorOp_884_symm(self):
pass
def test_SM80_TensorOp_884_symm_complex(self):
def SM80_TensorOp_884_symm_complex(self):
pass
def test_SM80_TensorOp_884_symm_complex_gaussian(self):
def SM80_TensorOp_884_symm_complex_gaussian(self):
pass
def test_SM80_SparseTensorOp_16864_TN(self):
def SM80_SparseTensorOp_16864_TN(self):
pass
def test_SM80_TensorOp_16864_TN(self):
def SM80_TensorOp_16864_TN(self):
pass
def test_SM80_SparseTensorOp_168128_TN(self):
def SM80_SparseTensorOp_168128_TN(self):
pass
def test_SM80_TensorOp_16864_Interleaved(self):
def SM80_TensorOp_16864_Interleaved(self):
pass
def test_SM80_TensorOp_168256(self):
def SM80_TensorOp_168256(self):
pass
def test_SM80_Simt_complex(self):
def SM80_Simt_complex(self):
pass

View File

@ -1195,13 +1195,13 @@ std::string lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type) {
break;
case NumericTypeID::kBF16:
{
float tmp = *reinterpret_cast<bfloat16_t *>(bytes.data());;
float tmp = *reinterpret_cast<bfloat16_t *>(bytes.data());
ss << tmp;
}
break;
case NumericTypeID::kTF32:
{
float tmp = *reinterpret_cast<tfloat32_t *>(bytes.data());;
float tmp = *reinterpret_cast<tfloat32_t *>(bytes.data());
ss << tmp;
}
break;

View File

@ -183,7 +183,7 @@ __global__ void GemmPlanarComplex(
ComplexC d_ij;
d_ij.real() = convert_op(result.real());
d_ij.imag() = convert_op(result.imag());;
d_ij.imag() = convert_op(result.imag());
tensor_d.at(coord) = d_ij;
}

View File

@ -172,7 +172,7 @@ void GemmPlanarComplex(
complex<ScalarType> result = alpha * acc + beta * src;
d_ij.real() = convert_op(result.real());
d_ij.imag() = convert_op(result.imag());;
d_ij.imag() = convert_op(result.imag());
tensor_d.at(coord) = d_ij;
}