CUTLASS 2.10 updates (#622)
Co-authored-by: Aniket Shivam <ashivam@nvidia.com>
This commit is contained in:
20
CHANGELOG.md
20
CHANGELOG.md
@ -1,11 +1,19 @@
|
||||
# NVIDIA CUTLASS Changelog
|
||||
|
||||
## [2.10.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.10.0) (2022-08-23)
|
||||
* [Grouped convolution targeting implicit GEMM](test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu)
|
||||
* [Depthwise separable convolution](test/unit/conv/device/depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu)
|
||||
* Optimizations for CUTLASS's [Grouped GEMM](examples/24_gemm_grouped/gemm_grouped.cu) kernel
|
||||
* [Grouped GEMM for Multihead Attention](examples/41_multi_head_attention)
|
||||
* [GEMM + Layer norm fusion for Ampere](examples/37_gemm_layernorm_gemm_fusion/)
|
||||
* [CUTLASS Python](/examples/40_cutlass_py) now supports GEMM, CONV, Group GEMM for different data types as well as different epilogue flavours.
|
||||
* Optimizations for CUTLASS's [Grouped GEMM](examples/24_gemm_grouped/gemm_grouped.cu) kernel. Threadblock scheduling part is improved. Some computation can be moved to the host side if applicable. [Grouped Syr2k](examples/38_syr2k_grouped/syr2k_grouped.cu) kernels are added, too.
|
||||
* Optimizations for [GEMM+Softmax](examples/35_gemm_softmax). All the reduction computation is fused into the previous GEMM. More template arguments are provided to fine tune the performance.
|
||||
* [Grouped GEMM for Multihead Attention](examples/41_multi_head_attention). This general group gemm based MHA does not require the sequence length of all GEMMs to be the same which makes it most useful for natural language processing.
|
||||
* [GEMM + Layer norm fusion for Ampere](examples/37_gemm_layernorm_gemm_fusion/) splits the layernorm into two parts and both of them can be fused into the GEMMs before and after separately. In addition to use square sum to compute variance of layernorm, [Shift-K](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data) is provided if square sum raise numerical issues.
|
||||
* [GEMM Epilogue Permutation Fusion](examples/39_gemm_permute) can apply user provided permutation layout mapping in the GEMM epilogue.
|
||||
* [Grouped convolution targeting implicit GEMM](test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) introduces the first group convolution implementation to CUTLASS. It is an Analytical implementation, not an Optimized. The restrictions are: 1) input and output channel number should be multiple of group number. 2) split-K is not supported. The implementation has 2 modes:
|
||||
* kSingleGroup: output channel per group is multiple of Threadblock tile N.
|
||||
* kMultipleGroup: Threadblock tile N is multiple of output channel per group.
|
||||
* [Depthwise separable convolution](test/unit/conv/device/depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) introduces the first depthwise convolution which is also Analytical for now. The restrictions are: 1) SIMT only 2) No split-K 3) input channel equals to output channel equals to group number.
|
||||
* Standalone [Layernorm](/tools/util/include/cutlass/util/device_layernorm.h) and [Pooling](/tools/util/include/cutlass/util/device_nhwc_pooling.h) kernels.
|
||||
* [Back-to-back GEMM/CONV](examples/13_two_tensor_op_fusion) relaxes the requirement that the first GEMM K dimension needs to be the multiple of Threadblock Tile K dimension.
|
||||
* Optimal performance using [**CUDA 11.6u2**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates and bugfixes from the community (thanks!)
|
||||
|
||||
* **Deprecation announcement:** CUTLASS plans to deprecate the following:
|
||||
@ -47,7 +55,7 @@
|
||||
* New elementwise fusion pattern for [residual block](/include/cutlass/epilogue/thread/linear_combination_residual_block.h).
|
||||
* [Group GEMM](/examples/24_gemm_grouped) thread block number calculation fix which helps to launch the intended number of threadblocks to fully occupy the GPUs.
|
||||
* [Parallel GEMM splitk](https://github.com/NVIDIA/cutlass/pull/277) support in the CUTLASS profiler.
|
||||
* Optimal performance using [**CUDA 11.7**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Optimal performance using [**CUDA 11.6u2**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates and bugfixes from the community (thanks!)
|
||||
|
||||
|
||||
|
||||
15
README.md
15
README.md
@ -39,11 +39,16 @@ supported at each level of the execution model hierarchy.
|
||||
# What's New in CUTLASS 2.10
|
||||
|
||||
CUTLASS 2.10 is an update to CUTLASS adding:
|
||||
- [Grouped convolution targeting implicit GEMM](test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu)
|
||||
- [Depthwise separable convolution](test/unit/conv/device/depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu)
|
||||
- Optimizations for CUTLASS's [Grouped GEMM](examples/24_gemm_grouped/gemm_grouped.cu) kernel
|
||||
- [Grouped GEMM for Multihead Attention](examples/41_multi_head_attention)
|
||||
- [GEMM + Layer norm fusion for Ampere](examples/37_gemm_layernorm_gemm_fusion/)
|
||||
- [CUTLASS Python](/examples/40_cutlass_py) now supports GEMM, CONV, Group GEMM for different data types as well as different epilogue flavours.
|
||||
- Optimizations for CUTLASS's [Grouped GEMM](examples/24_gemm_grouped/gemm_grouped.cu) kernel. It can move some scheduling into the host side if applicable.
|
||||
- Optimizations for [GEMM+Softmax](examples/35_gemm_softmax).
|
||||
- [Grouped GEMM for Multihead Attention](examples/41_multi_head_attention) is a general MHA that does not require equal sequence length in every GEMM.
|
||||
- [GEMM + Layer norm fusion for Ampere](examples/37_gemm_layernorm_gemm_fusion/) can fuse the layernorm into GEMMs before and after.
|
||||
- [GEMM Epilogue Permutation Fusion](examples/39_gemm_permute) can permute the GEMM output before storing.
|
||||
- [Grouped convolution targeting implicit GEMM](test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) introduces the first group convolution implementation to CUTLASS. It is an Analytical implementation, not an Optimized.
|
||||
- [Depthwise separable convolution](test/unit/conv/device/depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) introduces the first depthwise convolution which is also Analytical for now.
|
||||
- Standalone [Layernorm](/tools/util/include/cutlass/util/device_layernorm.h) and [Pooling](/tools/util/include/cutlass/util/device_nhwc_pooling.h) kernels.
|
||||
- [Back-to-back GEMM](examples/13_two_tensor_op_fusion) enhancements.
|
||||
- Updates and bugfixes from the community (thanks!)
|
||||
- **Deprecation announcement:** CUTLASS plans to deprecate the following:
|
||||
- Maxwell and Pascal GPU architectures
|
||||
|
||||
@ -1528,6 +1528,9 @@ int main(int argc, char const **args) {
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
// NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels.
|
||||
// This parameter is passed in at present to match the APIs of other kernels. The parameter
|
||||
// is unused within the kernel.
|
||||
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
|
||||
4>::GemmKernel;
|
||||
|
||||
|
||||
@ -187,8 +187,8 @@ struct Options {
|
||||
// elements in input matrices.
|
||||
using ElementAccumulator = float; // <- data type of accumulator
|
||||
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
|
||||
using ElementInputA = cutlass::half_t;; // <- data type of elements in input matrix A
|
||||
using ElementInputB = cutlass::half_t;; // <- data type of elements in input matrix B
|
||||
using ElementInputA = cutlass::half_t; // <- data type of elements in input matrix A
|
||||
using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B
|
||||
using ElementOutput = float; // <- data type of elements in output matrix D
|
||||
|
||||
// The code section below describes matrix layout of input and output matrices.
|
||||
@ -252,7 +252,7 @@ using Gemm = cutlass::gemm::device::GemmUniversal<ElementInputA,
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
8, /*alignmentA*/
|
||||
8, /*alignmengB*/
|
||||
8, /*alignmentB*/
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
|
||||
@ -1366,7 +1366,12 @@ int main(int argc, char const **args) {
|
||||
using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 1,
|
||||
ElementAccumulator, ElementAccumulator>;
|
||||
|
||||
// NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels.
|
||||
// This parameter is passed in at present to match the APIs of other kernels. The parameter
|
||||
// is unused within the kernel.
|
||||
using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
const int kStages = 4;
|
||||
const bool kSplitKSerial = false;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
|
||||
@ -92,25 +92,35 @@ Example 1: SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32_256x128x128_64x64x128
|
||||
```python
|
||||
python gemm.py -i 16 8 32 -ta int8 -tb int8 -tc int8 -tacc int32 -m multiply_add -op TensorOp -b 128 128 128 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 16 -lb ColumnMajor -ab 16 -lc RowMajor -ac 16 -te float32 -ep FastLinearCombinationClamp -sw IdentitySwizzle2 -p 512 512 512 -alpha 1.0 -beta 0.0 -gm Gemm -k 1
|
||||
```
|
||||
|
||||
### Batched & Array GEMM
|
||||
Example 1: Batched GEMM
|
||||
```python
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3
|
||||
```
|
||||
Example 2: Array GEMM
|
||||
```python
|
||||
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 2
|
||||
```
|
||||
***
|
||||
## GEMM Grouped Examples
|
||||
The GEMM Grouped examples use numpy to create input tensors and verify the results.
|
||||
|
||||
Example 1: SM80_Device_GemmGrouped_f16t_f16t_f32t_tensor_op_f32_128x128x32_64x64x32, device schedule
|
||||
```python
|
||||
python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 0.0 -pm Device
|
||||
python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 0.0 -pm Device
|
||||
```
|
||||
Example 2: SM80_Device_GemmGrouped_f64n_f64n_f64t_tensor_op_f64_64x64x16_32x32x16, host schedule
|
||||
```python
|
||||
python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle2 -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 1.0 -pm Host
|
||||
python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 1.0 -pm Host
|
||||
```
|
||||
Example 3: SM80_Device_GemmGrouped_f32n_f32n_f32n_simt_f32_128x64x8_64x32x1, device schedule
|
||||
```python
|
||||
python gemm_grouped.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 64 8 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device
|
||||
python gemm_grouped.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 64 8 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device
|
||||
```
|
||||
Example 4: SM80_Device_GemmGrouped_f16t_f16t_f32t_tensor_op_f32_128x128x32_64x64x32, device schedule
|
||||
```python
|
||||
python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle8 -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device
|
||||
python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device
|
||||
```
|
||||
***
|
||||
## Conv2d Example
|
||||
@ -160,3 +170,61 @@ Example 4: SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nh
|
||||
```python
|
||||
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 56 56 12 -krsc 8 1 1 12 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
```
|
||||
|
||||
## Epilogue
|
||||
### Bias
|
||||
To replace C with a bias vector, add `-bias` flag.
|
||||
### Activation function
|
||||
Example 1: ReLU
|
||||
```python
|
||||
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 -bias -activ relu
|
||||
```
|
||||
Example 2: leaky ReLU
|
||||
```python
|
||||
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2 -bias -activ leaky_relu -activ_arg 0.2
|
||||
```
|
||||
Example 3: tanh (alpha=0 to avoid saturation)
|
||||
```python
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2 -bias -activ tanh
|
||||
```
|
||||
Example 4: sigmoid
|
||||
```python
|
||||
python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 0.0 -beta 0.5 -pm Host -bias -activ sigmoid -bias -activ sigmoid
|
||||
```
|
||||
Example 5: SiLU
|
||||
```python
|
||||
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ silu
|
||||
```
|
||||
Example 6: HardSwish
|
||||
```python
|
||||
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ hardswish
|
||||
```
|
||||
Example 7: GELU
|
||||
```python
|
||||
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 0.0 -beta 0.5 -gm GemmSplitKParallel -k 5 -bias -activ gelu
|
||||
```
|
||||
### Epilogue Visitor Tree
|
||||
Example 1:
|
||||
```python
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
```
|
||||
Example 2:
|
||||
```python
|
||||
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -epv ColumnBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
```
|
||||
Example 3:
|
||||
```python
|
||||
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
```
|
||||
Example 4:
|
||||
```python
|
||||
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnReduction -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
```
|
||||
Example 5:
|
||||
```python
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3
|
||||
```
|
||||
Example 6:
|
||||
```python
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnBroadcast -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 3
|
||||
```
|
||||
|
||||
@ -33,6 +33,7 @@ import pycutlass
|
||||
from pycutlass import *
|
||||
from pycutlass.conv2d_operation import *
|
||||
from pycutlass.utils import reference_model
|
||||
import torch.nn.functional as F
|
||||
|
||||
import argparse
|
||||
|
||||
@ -127,6 +128,13 @@ parser.add_argument("-stride", "--stride", nargs=2, type=int, help="stride (stri
|
||||
parser.add_argument("-dilation", "--dilation", nargs=2, type=int, help="dilation (dilation_h, dilation_w)")
|
||||
parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha")
|
||||
parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta")
|
||||
parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector")
|
||||
# Activation function
|
||||
parser.add_argument("-activ", "--activation_function", default="identity",
|
||||
choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function")
|
||||
parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float,
|
||||
help="addition arguments for activation")
|
||||
|
||||
|
||||
parser.add_argument('--print_cuda', action="store_true",
|
||||
help="print the underlying CUDA kernel")
|
||||
@ -138,6 +146,8 @@ except:
|
||||
|
||||
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
element_a = getattr(cutlass, args.element_a)
|
||||
element_b = getattr(cutlass, args.element_b)
|
||||
element_c = getattr(cutlass, args.element_c)
|
||||
@ -152,7 +162,7 @@ math_inst = MathInstruction(
|
||||
|
||||
tile_description = TileDescription(
|
||||
args.threadblock_shape, args.stages, args.warp_count,
|
||||
math_inst, args.compute_capability, args.compute_capability
|
||||
math_inst
|
||||
)
|
||||
|
||||
layout_a = getattr(cutlass, args.layout_a)
|
||||
@ -172,7 +182,16 @@ C = TensorDescription(
|
||||
)
|
||||
|
||||
element_epilogue = getattr(cutlass, args.element_epilogue)
|
||||
epilogue_functor = getattr(EpilogueFunctor, args.epilogue_functor)
|
||||
if (args.activation_function == "identity"
|
||||
or (args.split_k_mode == "Parallel" and args.split_k_slices > 1)):
|
||||
#
|
||||
epilogue_functor = getattr(pycutlass, args.epilogue_functor)(
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
else:
|
||||
epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")(
|
||||
getattr(pycutlass, args.activation_function)(element_epilogue),
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
iterator_algorithm = getattr(cutlass.conv.IteratorAlgorithm, args.iterator_algorithm)
|
||||
swizzling_functor = getattr(cutlass, args.swizzling_functor)
|
||||
stride_support = getattr(StrideSupport, args.stride_support)
|
||||
@ -181,7 +200,7 @@ conv_kind = getattr(cutlass.conv.Operator, args.conv_kind)
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=conv_kind, iterator_algorithm=iterator_algorithm,
|
||||
arch=args.compute_capability, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue, stride_support=stride_support,
|
||||
A=A, B=B, C=C, stride_support=stride_support,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -191,10 +210,18 @@ if args.print_cuda:
|
||||
operations = [operation,]
|
||||
|
||||
if args.split_k_mode == "Parallel" and args.split_k_slices > 1:
|
||||
if (args.activation_function == "identity"):
|
||||
epilogue_functor_reduction = getattr(pycutlass, args.epilogue_functor)(
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
else:
|
||||
epilogue_functor_reduction = getattr(pycutlass, "LinearCombinationGeneric")(
|
||||
getattr(pycutlass, args.activation_function)(element_epilogue),
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
reduction_operation = ReductionOperation(
|
||||
shape=cutlass.MatrixCoord(4, 32 * C.alignment),
|
||||
C=C, element_accumulator=element_acc,
|
||||
element_compute=element_epilogue,
|
||||
epilogue_functor=epilogue_functor_reduction,
|
||||
count=C.alignment
|
||||
)
|
||||
operations.append(reduction_operation)
|
||||
@ -219,9 +246,18 @@ tensor_A_size = cutlass.conv.implicit_gemm_tensor_a_size(
|
||||
tensor_B_size = cutlass.conv.implicit_gemm_tensor_b_size(
|
||||
conv_kind, problem_size
|
||||
)
|
||||
tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_size(
|
||||
conv_kind, problem_size
|
||||
)
|
||||
if args.bias:
|
||||
tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_extent(
|
||||
conv_kind, problem_size
|
||||
).at(3)
|
||||
else:
|
||||
tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_size(
|
||||
conv_kind, problem_size
|
||||
)
|
||||
|
||||
tensor_D_size = cutlass.conv.implicit_gemm_tensor_c_size(
|
||||
conv_kind, problem_size
|
||||
)
|
||||
|
||||
if args.element_a != "int8":
|
||||
tensor_A = torch.ceil(torch.empty(size=(tensor_A_size,), dtype=getattr(torch, args.element_a), device="cuda").uniform_(-8.5, 7.5))
|
||||
@ -238,12 +274,12 @@ if args.element_c != "int8":
|
||||
else:
|
||||
tensor_C = torch.empty(size=(tensor_C_size,), dtype=getattr(torch, args.element_c), device="cuda").uniform_(-2, 2)
|
||||
|
||||
tensor_D = torch.ones_like(tensor_C)
|
||||
tensor_D = torch.ones(size=(tensor_D_size,), dtype=getattr(torch, args.element_c), device="cuda")
|
||||
|
||||
arguments = Conv2dArguments(
|
||||
operation=operation, problem_size=problem_size, A=tensor_A,
|
||||
B=tensor_B, C=tensor_C, D=tensor_D,
|
||||
output_op = LinearCombinationFunctorArguments(args.alpha, args.beta),
|
||||
output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)),
|
||||
split_k_mode=getattr(cutlass.conv.SplitKMode, args.split_k_mode),
|
||||
split_k_slices=problem_size.split_k_slices
|
||||
)
|
||||
@ -257,7 +293,8 @@ if args.split_k_mode == "Parallel" and args.split_k_slices > 1:
|
||||
workspace=arguments.ptr_D,
|
||||
destination=tensor_D,
|
||||
source=tensor_C,
|
||||
output_op = LinearCombinationFunctorArguments(args.alpha, args.beta)
|
||||
output_op = reduction_operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)),
|
||||
bias = arguments.bias
|
||||
)
|
||||
|
||||
operation.run(arguments)
|
||||
@ -270,8 +307,12 @@ else:
|
||||
|
||||
reference_model = Conv2dReferenceModule(A, B, C, conv_kind)
|
||||
|
||||
tensor_D_ref = reference_model.run(tensor_A, tensor_B, tensor_C, arguments.problem_size, args.alpha, args.beta)
|
||||
|
||||
assert torch.equal(tensor_D, tensor_D_ref)
|
||||
tensor_D_ref = reference_model.run(tensor_A, tensor_B, tensor_C, arguments.problem_size, args.alpha, args.beta, args.bias)
|
||||
if (args.activation_function != "identity"):
|
||||
tensor_D_ref = getattr(F, args.activation_function)(*([tensor_D_ref,] + args.activation_args))
|
||||
|
||||
try:
|
||||
assert torch.equal(tensor_D, tensor_D_ref)
|
||||
except:
|
||||
assert torch.allclose(tensor_D, tensor_D_ref, rtol=1e-2)
|
||||
print("Passed.")
|
||||
|
||||
@ -99,9 +99,11 @@ parser.add_argument("-te", "--element_epilogue", default="float32", type=str,
|
||||
parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination",
|
||||
type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'],
|
||||
help="This option describes the epilogue part of the kernel")
|
||||
parser.add_argument("-epv", "--epilogue_visitor", default=None,
|
||||
type=str, choices=['RowReduction', 'ColumnReduction', 'RowBroadcast', 'ColumnBroadcast'], help="epilogue visitor for more complex epilogues")
|
||||
# swizzling
|
||||
parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[
|
||||
"IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle"],
|
||||
"IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle", "BatchedIdentitySwizzle"],
|
||||
help="This option describes how thread blocks are scheduled on GPU")
|
||||
|
||||
# Argument
|
||||
@ -113,17 +115,22 @@ parser.add_argument("-alpha", "--alpha", default=1.0, type=float,
|
||||
parser.add_argument("-beta", "--beta", default=0.0, type=float,
|
||||
help="Scaling factor of C")
|
||||
parser.add_argument("-gm", "--gemm_mode", default="Gemm", type=str,
|
||||
choices=["Gemm", "GemmSplitKParallel"],
|
||||
choices=["Gemm", "GemmSplitKParallel", "Batched", "Array"],
|
||||
help="GEMM mode. Gemm is used for non-splitK or serial-splitK. \
|
||||
GemmSplitKParallel is used for parallel splitK")
|
||||
parser.add_argument('-k', '--split_k_slices', default=1,
|
||||
type=int, help="Number of split-k partitions. (default 1)")
|
||||
parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector")
|
||||
parser.add_argument('-batch', '--batch', default=1, type=int, help="batch size for batched GEMM")
|
||||
|
||||
# Activation function
|
||||
parser.add_argument("-activ", "--activation_function", default="identity",
|
||||
choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function")
|
||||
parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float,
|
||||
help="addition arguments for activation")
|
||||
parser.add_argument('--print_cuda', action="store_true",
|
||||
help="print the underlying CUDA kernel")
|
||||
|
||||
# parser.add_argument('-h', '--help', action="store_true",
|
||||
# help="print help information")
|
||||
|
||||
try:
|
||||
args = parser.parse_args()
|
||||
@ -131,6 +138,9 @@ except:
|
||||
sys.exit(0)
|
||||
|
||||
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
|
||||
pycutlass.compiler.nvcc()
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
element_a = getattr(cutlass, args.element_a)
|
||||
element_b = getattr(cutlass, args.element_b)
|
||||
@ -146,7 +156,7 @@ math_inst = MathInstruction(
|
||||
|
||||
tile_description = TileDescription(
|
||||
args.threadblock_shape, args.stages, args.warp_count,
|
||||
math_inst, args.compute_capability, args.compute_capability
|
||||
math_inst
|
||||
)
|
||||
|
||||
layout_a = getattr(cutlass, args.layout_a)
|
||||
@ -166,13 +176,83 @@ C = TensorDescription(
|
||||
)
|
||||
|
||||
element_epilogue = getattr(cutlass, args.element_epilogue)
|
||||
epilogue_functor = getattr(EpilogueFunctor, args.epilogue_functor)
|
||||
if (args.activation_function == "identity"
|
||||
or (args.gemm_mode == "GemmSplitKParallel" and args.split_k_slices > 1)):
|
||||
#
|
||||
epilogue_functor = getattr(pycutlass, args.epilogue_functor)(
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
else:
|
||||
epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")(
|
||||
getattr(pycutlass, args.activation_function)(element_epilogue),
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = getattr(cutlass, args.swizzling_functor)
|
||||
|
||||
visitor = args.epilogue_visitor is not None
|
||||
|
||||
if args.epilogue_visitor == "ColumnReduction":
|
||||
class ColumnReduction_(EpilogueVisitTree):
|
||||
def __call__(
|
||||
self, accum: 'tensor', c: 'tensor',
|
||||
alpha: 'scalar', beta: 'scalar'):
|
||||
#
|
||||
D = alpha * accum + beta * c
|
||||
reduction = reduction_op(D, "column", "Add", args.threadblock_shape[0])
|
||||
return D, reduction
|
||||
epilogue_functor = ColumnReduction_(
|
||||
epilogue_functor, tile_description, math_inst.element_accumulator,
|
||||
C.alignment, element_epilogue, C.element)
|
||||
epilogue_functor.initialize()
|
||||
elif args.epilogue_visitor == "RowReduction":
|
||||
class RowReduction_(EpilogueVisitTree):
|
||||
def __call__(
|
||||
self, accum: 'tensor', c: 'tensor',
|
||||
alpha: 'scalar', beta: 'scalar'):
|
||||
#
|
||||
D = alpha * accum + tanh.numpy(beta * c)
|
||||
reduction = reduction_op(D, "row", "Add", args.threadblock_shape[1])
|
||||
return D, reduction
|
||||
epilogue_functor = RowReduction_(
|
||||
epilogue_functor, tile_description, math_inst.element_accumulator,
|
||||
C.alignment, element_epilogue, C.element)
|
||||
epilogue_functor.initialize()
|
||||
|
||||
elif args.epilogue_visitor == "RowBroadcast":
|
||||
class RowBroadcast_(EpilogueVisitTree):
|
||||
def __call__(
|
||||
self, accum: 'tensor', c: 'tensor',
|
||||
vector: 'row', alpha: 'scalar', beta: 'scalar'):
|
||||
#
|
||||
T = accum + vector
|
||||
scale_T = alpha * T
|
||||
Z = relu.numpy(scale_T + beta * c)
|
||||
return Z, T
|
||||
epilogue_functor = RowBroadcast_(
|
||||
epilogue_functor, tile_description, math_inst.element_accumulator,
|
||||
C.alignment, element_epilogue, C.element)
|
||||
epilogue_functor.initialize()
|
||||
elif args.epilogue_visitor == "ColumnBroadcast":
|
||||
class ColumnBroadcast_(EpilogueVisitTree):
|
||||
def __call__(
|
||||
self, accum: 'tensor', c: 'tensor',
|
||||
vector: 'column', alpha: 'scalar', beta: 'scalar'):
|
||||
#
|
||||
T = accum + vector
|
||||
scale_T = leaky_relu.numpy(alpha * T, 0.2)
|
||||
Z = scale_T + beta * c
|
||||
return Z, T
|
||||
epilogue_functor = ColumnBroadcast_(
|
||||
epilogue_functor, tile_description, math_inst.element_accumulator,
|
||||
C.alignment, element_epilogue, C.element)
|
||||
epilogue_functor.initialize()
|
||||
else:
|
||||
epilogue_functor = epilogue_functor
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=args.compute_capability, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor,
|
||||
visitor=visitor
|
||||
)
|
||||
|
||||
if args.print_cuda:
|
||||
@ -181,10 +261,19 @@ if args.print_cuda:
|
||||
operations = [operation, ]
|
||||
|
||||
if args.gemm_mode == "GemmSplitKParallel":
|
||||
if (args.activation_function == "identity"):
|
||||
epilogue_functor_reduction = getattr(pycutlass, args.epilogue_functor)(
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
else:
|
||||
epilogue_functor_reduction = getattr(pycutlass, "LinearCombinationGeneric")(
|
||||
getattr(pycutlass, args.activation_function)(element_epilogue),
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
reduction_operation = ReductionOperation(
|
||||
shape=cutlass.MatrixCoord(4, 32 * C.alignment),
|
||||
C=C, element_accumulator=element_acc,
|
||||
element_compute=element_epilogue,
|
||||
element_compute=element_epilogue,
|
||||
epilogue_functor=epilogue_functor_reduction,
|
||||
count=C.alignment
|
||||
)
|
||||
operations.append(reduction_operation)
|
||||
@ -196,47 +285,102 @@ pycutlass.compiler.add_module(operations)
|
||||
problem_size = cutlass.gemm.GemmCoord(
|
||||
args.problem_size[0], args.problem_size[1], args.problem_size[2])
|
||||
|
||||
tensor_a_size = args.batch * problem_size.m() * problem_size.k()
|
||||
if args.element_a != "int8":
|
||||
if args.element_a == "bfloat16":
|
||||
tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
|
||||
* problem_size.k(),))).astype(bfloat16)
|
||||
tensor_A = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(tensor_a_size,))
|
||||
).astype(bfloat16)
|
||||
else:
|
||||
tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
|
||||
* problem_size.k(),))).astype(getattr(np, args.element_a))
|
||||
tensor_A = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(tensor_a_size,))
|
||||
).astype(getattr(np, args.element_a))
|
||||
else:
|
||||
tensor_A = np.random.uniform(low=-2, high=2, size=(problem_size.m()
|
||||
* problem_size.k(),)).astype(getattr(np, args.element_a))
|
||||
tensor_A = np.random.uniform(
|
||||
low=-2, high=2,size=(tensor_a_size,)
|
||||
).astype(getattr(np, args.element_a))
|
||||
|
||||
tensor_b_size = args.batch * problem_size.k() * problem_size.n()
|
||||
if args.element_b != "int8":
|
||||
if args.element_b == "bfloat16":
|
||||
tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k()
|
||||
* problem_size.n(),))).astype(bfloat16)
|
||||
tensor_B = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(tensor_b_size,))
|
||||
).astype(bfloat16)
|
||||
else:
|
||||
tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k()
|
||||
* problem_size.n(),))).astype(getattr(np, args.element_b))
|
||||
tensor_B = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(tensor_b_size,))
|
||||
).astype(getattr(np, args.element_b))
|
||||
else:
|
||||
tensor_B = np.random.uniform(low=-2, high=2, size=(problem_size.k()
|
||||
* problem_size.n(),)).astype(getattr(np, args.element_b))
|
||||
tensor_B = np.random.uniform(
|
||||
low=-2, high=2, size=(tensor_b_size,)
|
||||
).astype(getattr(np, args.element_b))
|
||||
|
||||
if args.element_c != "int8":
|
||||
if args.element_c == "bfloat16":
|
||||
tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
|
||||
* problem_size.n(),))).astype(bfloat16)
|
||||
if args.bias:
|
||||
if args.layout_c == "RowMajor":
|
||||
tensor_c_size = args.batch * problem_size.n()
|
||||
elif args.layout_c == "ColumnMajor":
|
||||
tensor_c_size = args.batch * problem_size.m()
|
||||
else:
|
||||
raise ValueError(args.layout_c)
|
||||
else:
|
||||
tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
|
||||
* problem_size.n(),))).astype(getattr(np, args.element_c))
|
||||
tensor_c_size = args.batch * problem_size.m() * problem_size.n()
|
||||
if args.element_c == "bfloat16":
|
||||
tensor_C = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(tensor_c_size,))
|
||||
).astype(bfloat16)
|
||||
else:
|
||||
tensor_C = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(tensor_c_size,))
|
||||
).astype(getattr(np, args.element_c))
|
||||
else:
|
||||
tensor_C = np.random.uniform(low=-2, high=2, size=(problem_size.m()
|
||||
* problem_size.n(),)).astype(getattr(np, args.element_c))
|
||||
tensor_C = np.random.uniform(
|
||||
low=-2, high=2, size=(args.batch * problem_size.m() * problem_size.n(),)
|
||||
).astype(getattr(np, args.element_c))
|
||||
|
||||
tensor_D = np.ones_like(tensor_C)
|
||||
tensor_D = np.zeros(
|
||||
shape=(args.batch * problem_size.m() * problem_size.n(),)
|
||||
).astype(getattr(np, args.element_c))
|
||||
|
||||
if args.epilogue_visitor == "RowReduction":
|
||||
cta_n = args.threadblock_shape[1]
|
||||
num_cta_n = (problem_size.n() + cta_n - 1) // cta_n
|
||||
reduction = np.zeros(shape=(args.batch * problem_size.m() * num_cta_n,), dtype=getattr(np, args.element_c))
|
||||
output_op = operation.epilogue_type(
|
||||
D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()]
|
||||
)
|
||||
elif args.epilogue_visitor == "ColumnReduction":
|
||||
cta_m = args.threadblock_shape[0]
|
||||
num_cta_m = (problem_size.m() + cta_m - 1) // cta_m
|
||||
reduction = np.zeros(shape=(args.batch * problem_size.n() * num_cta_m,), dtype=getattr(np, args.element_c))
|
||||
output_op = operation.epilogue_type(
|
||||
D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()]
|
||||
)
|
||||
elif args.epilogue_visitor == "RowBroadcast":
|
||||
vector = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(args.batch, 1, problem_size.n()))
|
||||
).astype(getattr(np, args.element_c))
|
||||
tensor_t = np.empty_like(tensor_D)
|
||||
output_op = operation.epilogue_type(
|
||||
c=tensor_C, vector=vector, alpha=args.alpha, beta=args.beta, Z=tensor_D, T=tensor_t, problem_size=[problem_size.m(), problem_size.n()]
|
||||
)
|
||||
elif args.epilogue_visitor == "ColumnBroadcast":
|
||||
vector = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(args.batch, problem_size.m(), 1))
|
||||
).astype(getattr(np, args.element_c))
|
||||
tensor_t = np.empty_like(tensor_D)
|
||||
output_op = operation.epilogue_type(
|
||||
c=tensor_C, vector=vector, alpha=args.alpha, beta=args.beta, Z=tensor_D, T=tensor_t, problem_size=[problem_size.m(), problem_size.n()]
|
||||
)
|
||||
else:
|
||||
output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args))
|
||||
|
||||
arguments = GemmArguments(
|
||||
operation=operation, problem_size=problem_size,
|
||||
A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D,
|
||||
output_op=LinearCombinationFunctorArguments(args.alpha, args.beta),
|
||||
output_op=output_op,
|
||||
gemm_mode=getattr(cutlass.gemm.Mode, args.gemm_mode),
|
||||
split_k_slices=args.split_k_slices
|
||||
split_k_slices=args.split_k_slices, batch=args.batch
|
||||
)
|
||||
|
||||
if args.gemm_mode == "GemmSplitKParallel":
|
||||
@ -245,7 +389,8 @@ if args.gemm_mode == "GemmSplitKParallel":
|
||||
problem_size=[problem_size.m(), problem_size.n()],
|
||||
partitions=args.split_k_slices, workspace=arguments.ptr_D,
|
||||
destination=tensor_D, source=tensor_C,
|
||||
output_op=LinearCombinationFunctorArguments(args.alpha, args.beta)
|
||||
output_op=reduction_operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)),
|
||||
bias = arguments.bias
|
||||
)
|
||||
|
||||
operation.run(arguments)
|
||||
@ -259,8 +404,42 @@ else:
|
||||
# run the host reference module
|
||||
reference = ReferenceModule(A, B, C)
|
||||
tensor_D_ref = reference.run(
|
||||
tensor_A, tensor_B, tensor_C, problem_size, args.alpha, args.beta)
|
||||
tensor_A, tensor_B, tensor_C, problem_size, args.alpha, args.beta, args.bias, args.batch)
|
||||
|
||||
assert np.array_equal(tensor_D, tensor_D_ref)
|
||||
if args.epilogue_visitor in ["RowBroadcast", "ColumnBroadcast"]:
|
||||
tensor_D_ref = (tensor_D_ref.reshape((args.batch, problem_size.m(), problem_size.n())) + vector).flatten()
|
||||
tensor_D_ref = getattr(pycutlass, args.activation_function).numpy(*([tensor_D_ref,] + args.activation_args))
|
||||
|
||||
if args.epilogue_visitor in ["RowReduction", "ColumnReduction"]:
|
||||
output_op.sync()
|
||||
accum_ref = reference.run(
|
||||
tensor_A, tensor_B, tensor_C, problem_size, 1.0, 0.0, args.bias, args.batch)
|
||||
tensor_D_ref, reduction_ref = epilogue_functor(
|
||||
accum_ref.reshape((args.batch, problem_size.m(), problem_size.n())),
|
||||
tensor_C.reshape((args.batch, problem_size.m(), problem_size.n())),
|
||||
args.alpha, args.beta
|
||||
)
|
||||
tensor_D_ref = tensor_D_ref.flatten()
|
||||
reduction_ref = reduction_ref.flatten()
|
||||
assert np.allclose(reduction_ref, reduction, atol=1e-2)
|
||||
|
||||
elif args.epilogue_visitor in ["RowBroadcast", "ColumnBroadcast"]:
|
||||
output_op.sync()
|
||||
accum_ref = reference.run(
|
||||
tensor_A, tensor_B, tensor_C, problem_size, 1.0, 0.0, args.bias, args.batch)
|
||||
|
||||
tensor_D_ref, tensor_T_ref = epilogue_functor(
|
||||
accum_ref.reshape((args.batch, problem_size.m(), problem_size.n())),
|
||||
tensor_C.reshape((args.batch, problem_size.m(), problem_size.n())),
|
||||
vector, args.alpha, args.beta)
|
||||
|
||||
tensor_D_ref = tensor_D_ref.flatten()
|
||||
tensor_T_ref = tensor_T_ref.flatten()
|
||||
|
||||
assert np.array_equal(tensor_t, tensor_T_ref)
|
||||
|
||||
try:
|
||||
assert np.array_equal(tensor_D, tensor_D_ref)
|
||||
except:
|
||||
assert np.allclose(tensor_D, tensor_D_ref, atol=1e-5)
|
||||
print("Passed.")
|
||||
|
||||
@ -99,7 +99,10 @@ parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination",
|
||||
# swizzling
|
||||
parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[
|
||||
"IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle"],
|
||||
help="This option describes how thread blocks are scheduled on GPU")
|
||||
help="This option describes how thread blocks are scheduled on GPU. \
|
||||
NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. \
|
||||
This parameter is passed in at present to match the APIs of other kernels. The parameter \
|
||||
is unused within the kernel")
|
||||
# precompute mode
|
||||
parser.add_argument("-pm", "--precompute_mode",
|
||||
default="Device", type=str, choices=["Host", "Device"],
|
||||
@ -109,7 +112,13 @@ parser.add_argument("-p", "--problem_size_dir", type=str,
|
||||
help="path to the csv file contains the problem sizes")
|
||||
parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha")
|
||||
parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta")
|
||||
parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector")
|
||||
|
||||
# Activation function
|
||||
parser.add_argument("-activ", "--activation_function", default="identity",
|
||||
choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function")
|
||||
parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float,
|
||||
help="addition arguments for activation")
|
||||
parser.add_argument('--print_cuda', action="store_true",
|
||||
help="print the underlying CUDA kernel")
|
||||
|
||||
@ -120,6 +129,8 @@ except:
|
||||
|
||||
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
element_a = getattr(cutlass, args.element_a)
|
||||
element_b = getattr(cutlass, args.element_b)
|
||||
element_c = getattr(cutlass, args.element_c)
|
||||
@ -134,7 +145,7 @@ math_inst = MathInstruction(
|
||||
|
||||
tile_description = TileDescription(
|
||||
args.threadblock_shape, args.stages, args.warp_count,
|
||||
math_inst, args.compute_capability, args.compute_capability
|
||||
math_inst
|
||||
)
|
||||
|
||||
layout_a = getattr(cutlass, args.layout_a)
|
||||
@ -154,13 +165,19 @@ C = TensorDescription(
|
||||
)
|
||||
|
||||
element_epilogue = getattr(cutlass, args.element_epilogue)
|
||||
epilogue_functor = getattr(EpilogueFunctor, args.epilogue_functor)
|
||||
if args.activation_function == "identity":
|
||||
epilogue_functor = getattr(pycutlass, args.epilogue_functor)(
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
else:
|
||||
epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")(
|
||||
getattr(pycutlass, args.activation_function)(element_epilogue),
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
swizzling_functor = getattr(cutlass, args.swizzling_functor)
|
||||
precompute_mode = getattr(SchedulerMode, args.precompute_mode)
|
||||
|
||||
operation = GemmOperationGrouped(
|
||||
arch=args.compute_capability, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor,
|
||||
precompute_mode=precompute_mode
|
||||
)
|
||||
@ -214,28 +231,45 @@ for problem_size in problem_sizes:
|
||||
* problem_size.n(),)).astype(getattr(np, args.element_b))
|
||||
|
||||
if args.element_c != "int8":
|
||||
if args.element_c == "bfloat16":
|
||||
tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
|
||||
* problem_size.n(),))).astype(bfloat16)
|
||||
if args.bias:
|
||||
if args.layout_c == "RowMajor":
|
||||
c_size = problem_size.n()
|
||||
elif args.layout_c == "ColumnMajor":
|
||||
c_size = problem_size.m()
|
||||
else:
|
||||
raise ValueError(args.layout_c)
|
||||
else:
|
||||
tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
|
||||
* problem_size.n(),))).astype(getattr(np, args.element_c))
|
||||
c_size = problem_size.m() * problem_size.n()
|
||||
if args.element_c == "bfloat16":
|
||||
tensor_C = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(c_size,))
|
||||
).astype(bfloat16)
|
||||
else:
|
||||
tensor_C = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(c_size,))
|
||||
).astype(getattr(np, args.element_c))
|
||||
else:
|
||||
tensor_C = np.random.uniform(low=-2, high=2, size=(problem_size.m()
|
||||
* problem_size.n(),)).astype(getattr(np, args.element_c))
|
||||
tensor_D = np.zeros_like(tensor_C)
|
||||
tensor_C = np.random.uniform(
|
||||
low=-2, high=2, size=(problem_size.m() * problem_size.n(),)
|
||||
).astype(getattr(np, args.element_c))
|
||||
tensor_D = np.zeros(
|
||||
shape=(problem_size.m() * problem_size.n(),)
|
||||
).astype(getattr(np, args.element_c))
|
||||
|
||||
tensor_As.append(tensor_A)
|
||||
tensor_Bs.append(tensor_B)
|
||||
tensor_Cs.append(tensor_C)
|
||||
tensor_Ds.append(tensor_D)
|
||||
tensor_D_refs.append(reference_module.run(
|
||||
tensor_A, tensor_B, tensor_C, problem_size, args.alpha, args.beta))
|
||||
tensor_D_ref = reference_module.run(
|
||||
tensor_A, tensor_B, tensor_C, problem_size,
|
||||
args.alpha, args.beta, args.bias)
|
||||
tensor_D_ref = getattr(pycutlass, args.activation_function).numpy(*([tensor_D_ref,] + args.activation_args))
|
||||
tensor_D_refs.append(tensor_D_ref)
|
||||
problem_sizes_coord.append(problem_size)
|
||||
|
||||
arguments = GemmGroupedArguments(
|
||||
operation, problem_sizes_coord, tensor_As, tensor_Bs, tensor_Cs, tensor_Ds,
|
||||
output_op=LinearCombinationFunctorArguments(args.alpha, args.beta)
|
||||
output_op=operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args))
|
||||
)
|
||||
|
||||
operation.run(arguments)
|
||||
@ -243,6 +277,9 @@ operation.run(arguments)
|
||||
arguments.sync()
|
||||
|
||||
for tensor_d, tensor_d_ref in zip(tensor_Ds, tensor_D_refs):
|
||||
assert np.array_equal(tensor_d, tensor_d_ref)
|
||||
try:
|
||||
assert np.array_equal(tensor_d, tensor_d_ref)
|
||||
except:
|
||||
assert np.allclose(tensor_d, tensor_d_ref, rtol=1e-5)
|
||||
|
||||
print("Passed.")
|
||||
|
||||
@ -69,7 +69,7 @@ struct global_load;
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// The redundant mov PTX instruction is used to enforce the compiler to
|
||||
// initialize data to zero before ld.global
|
||||
// keep the initializing code before ld.global
|
||||
template <typename AccessType>
|
||||
struct global_load<AccessType,
|
||||
32
|
||||
|
||||
@ -59,6 +59,38 @@ struct Identity {
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
struct LinearCombinationGenericParams {
|
||||
T alpha; ///< scales accumulators
|
||||
T beta; ///< scales source tensor
|
||||
T const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
||||
T const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
LinearCombinationGenericParams():
|
||||
alpha(T(1)),
|
||||
beta(T(0)),
|
||||
alpha_ptr(nullptr),
|
||||
beta_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
LinearCombinationGenericParams(
|
||||
T alpha,
|
||||
T beta = T(0)
|
||||
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
LinearCombinationGenericParams(
|
||||
T const *alpha_ptr,
|
||||
T const *beta_ptr = nullptr
|
||||
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { }
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// ReLu operator - propagates NaNs
|
||||
@ -79,6 +111,14 @@ struct ReLu {
|
||||
|
||||
return mx(value, T(0));
|
||||
}
|
||||
|
||||
/// Host-constructable parameters structure
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T value, Params const ¶ms_) const {
|
||||
return this->operator()(value);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
@ -96,20 +136,74 @@ struct ReLu<Array<T, N>> {
|
||||
maximum<Array<T, N> > mx;
|
||||
return mx(frag, T(0));
|
||||
}
|
||||
|
||||
/// Host-constructable parameters structure
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &frag, Params const ¶ms_) const {
|
||||
return this->operator()(frag);
|
||||
}
|
||||
};
|
||||
|
||||
// Leaky Relu operator
|
||||
template <typename T>
|
||||
struct LeakyReLU {
|
||||
|
||||
struct Params: LinearCombinationGenericParams<T> {
|
||||
T leaky_alpha; ///< leaky_alpha
|
||||
|
||||
// Methods
|
||||
using LinearCombinationGenericParams<T>::LinearCombinationGenericParams;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
LinearCombinationGenericParams<T>(),
|
||||
leaky_alpha(T(1)) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
T alpha,
|
||||
T beta,
|
||||
T leaky_alpha = T(1)
|
||||
): LinearCombinationGenericParams<T>(alpha, beta), leaky_alpha(leaky_alpha) {}
|
||||
};
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &value, T const & alpha_recip) const {
|
||||
T res = value > T(0) ? value : value * alpha_recip;
|
||||
return res;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &value, Params const ¶ms_) const {
|
||||
this->operator()(value, params_.leaky_alpha);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
struct LeakyReLU<Array<T, N> > {
|
||||
|
||||
struct Params: LinearCombinationGenericParams<T> {
|
||||
T leaky_alpha; ///< leaky_alpha
|
||||
using LinearCombinationGenericParams<T>::LinearCombinationGenericParams;
|
||||
|
||||
// Methods
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
LinearCombinationGenericParams<T>(),
|
||||
leaky_alpha(T(1)) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
T alpha,
|
||||
T beta,
|
||||
T leaky_alpha = T(1)
|
||||
): LinearCombinationGenericParams<T>(alpha, beta), leaky_alpha(leaky_alpha) {}
|
||||
};
|
||||
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs, T const & alpha_recip) const {
|
||||
Array<T, N> y;
|
||||
@ -122,6 +216,11 @@ struct LeakyReLU<Array<T, N> > {
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs, Params const ¶ms_) const {
|
||||
return this->operator()(rhs, params_.leaky_alpha);
|
||||
}
|
||||
};
|
||||
|
||||
// Tanh operator
|
||||
@ -131,6 +230,13 @@ struct Tanh {
|
||||
T operator()(T const &scalar) const {
|
||||
return fast_tanh(scalar);
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &scalar, Params const ¶ms_) const {
|
||||
return this->operator()(scalar);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
@ -147,6 +253,13 @@ struct Tanh<Array<T, N> > {
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs, Params const ¶ms_) const {
|
||||
return this->operator()(rhs);
|
||||
}
|
||||
};
|
||||
|
||||
template <int N>
|
||||
@ -159,6 +272,13 @@ struct Tanh<Array<half_t, N>> {
|
||||
return tanh(z);
|
||||
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs, Params const ¶ms_) const {
|
||||
return this->operator()(rhs);
|
||||
}
|
||||
};
|
||||
|
||||
// Sigmoid operator
|
||||
@ -168,6 +288,13 @@ struct Sigmoid {
|
||||
T operator()(T const &scalar) const {
|
||||
return T(1) / (T(1) + fast_exp(-scalar));
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &scalar, Params const ¶ms_) const {
|
||||
return this->operator()(scalar);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
@ -184,6 +311,13 @@ struct Sigmoid<Array<T, N> > {
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs, Params const ¶ms_) const {
|
||||
return this->operator()(rhs);
|
||||
}
|
||||
};
|
||||
|
||||
template <int N>
|
||||
@ -208,6 +342,12 @@ struct Sigmoid<Array<half_t, N>> {
|
||||
fast_exp(neg(z))));
|
||||
#endif
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
Array<T, N> operator()(Array<T, N> const &z, Params const ¶ms_) const {
|
||||
return this->operator()(z);
|
||||
}
|
||||
};
|
||||
|
||||
// SiLu (swish) operator introduced by Elfwing et al. in the following paper
|
||||
@ -222,6 +362,13 @@ struct SiLu {
|
||||
Sigmoid<T> sigmoid;
|
||||
return scalar * sigmoid(scalar);
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &scalar, Params const ¶ms_) const {
|
||||
return this->operator()(scalar);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
@ -232,6 +379,13 @@ struct SiLu<Array<T, N>> {
|
||||
multiplies<Array<T, N>> mul;
|
||||
return mul(rhs, sigmoid_op(rhs));
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs, Params const ¶ms_) const {
|
||||
return this->operator()(rhs);
|
||||
}
|
||||
};
|
||||
|
||||
// Hardswish operator introduced by Howard et al. in the following paper
|
||||
@ -248,6 +402,13 @@ struct HardSwish {
|
||||
T relu6 = mn(mx(x + T(3), T(0)), T(6));
|
||||
return x * relu6 / T(6);
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &x, Params const ¶ms_) const {
|
||||
return this->operator()(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -261,6 +422,13 @@ struct HardSwish<float> {
|
||||
T relu6 = mn(mx(x + T(3), T(0)), T(6));
|
||||
return x * relu6 * 0.16666667f;
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &x, Params const ¶ms_) const {
|
||||
return this->operator()(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
@ -277,6 +445,13 @@ struct HardSwish<Array<T, N> > {
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &x, Params const ¶ms_) const {
|
||||
return this->operator()(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <int N>
|
||||
@ -292,6 +467,13 @@ struct HardSwish<Array<half_t, N> > {
|
||||
|
||||
return mul(mul(mn(mx(add(rhs, T(3)), T(0)), T(6)), rhs), T(0.16666667f));
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &x, Params const ¶ms_) const {
|
||||
return this->operator()(x);
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
@ -311,6 +493,13 @@ struct GELU {
|
||||
return T(cutlass::constants::half<T>() * scalar *
|
||||
(cutlass::constants::one<T>() + (T)erff((float)(scalar / cutlass::constants::root_two<T>()))));
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &scalar, Params const ¶ms_) const {
|
||||
return this->operator()(scalar);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -320,6 +509,13 @@ struct GELU<float> {
|
||||
return cutlass::constants::half<float>() * scalar *
|
||||
(cutlass::constants::one<float>() + erff( scalar / cutlass::constants::root_two<float>() ));
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<float>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
float operator()(float const &scalar, Params const ¶ms_) const {
|
||||
return this->operator()(scalar);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -329,6 +525,13 @@ struct GELU<double> {
|
||||
return cutlass::constants::half<double>() * scalar *
|
||||
(cutlass::constants::one<double>() + erf( scalar / cutlass::constants::root_two<double>() ));
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<double>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
double operator()(double const &scalar, Params const ¶ms_) const {
|
||||
return this->operator()(scalar);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
@ -345,6 +548,13 @@ struct GELU<Array<T, N> > {
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs, Params const ¶ms_) const {
|
||||
return this->operator()(rhs);
|
||||
}
|
||||
};
|
||||
|
||||
// GELU operator implemented using the Taylor series approximation
|
||||
@ -360,6 +570,9 @@ struct GELU_taylor {
|
||||
return T(cutlass::constants::half<T>() * z *
|
||||
(cutlass::constants::one<T>() + fast_tanh(k0 * z * (cutlass::constants::one<T>() + k1 * z * z))));
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
};
|
||||
|
||||
template <int N>
|
||||
@ -386,6 +599,8 @@ struct GELU_taylor<Array<half_t, N> > {
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<half_t>;
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
@ -403,6 +618,8 @@ struct GELU_taylor<Array<T, N> > {
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
};
|
||||
|
||||
/// Computes backwards pass for GELU operator assuming d_t is the layer gradient and
|
||||
|
||||
@ -40,6 +40,7 @@
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/epilogue/thread/scale_type.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_params.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -76,22 +77,23 @@ public:
|
||||
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
||||
using ComputeFragment = Array<ElementCompute, kCount>;
|
||||
|
||||
using ParamsBase = LinearCombinationParams;
|
||||
|
||||
static FloatRoundStyle const kRound = Round;
|
||||
|
||||
/// Host-constructable parameters structure
|
||||
struct Params {
|
||||
|
||||
struct Params : ParamsBase{
|
||||
ElementCompute alpha; ///< scales accumulators
|
||||
ElementCompute beta; ///< scales source tensor
|
||||
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
||||
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
ParamsBase(
|
||||
ElementCompute(1),
|
||||
ElementCompute(0)
|
||||
),
|
||||
alpha(ElementCompute(1)),
|
||||
beta(ElementCompute(0)),
|
||||
alpha_ptr(nullptr),
|
||||
@ -101,30 +103,43 @@ public:
|
||||
Params(
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta
|
||||
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
||||
|
||||
}
|
||||
):
|
||||
ParamsBase(alpha, beta),
|
||||
alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute alpha
|
||||
): alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
||||
|
||||
}
|
||||
):
|
||||
ParamsBase(alpha, ElementCompute(0)),
|
||||
alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute const *alpha_ptr,
|
||||
ElementCompute const *beta_ptr
|
||||
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
|
||||
|
||||
}
|
||||
):
|
||||
ParamsBase(*alpha_ptr, *beta_ptr),
|
||||
alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute const *alpha_ptr
|
||||
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) {
|
||||
):
|
||||
ParamsBase(*alpha_ptr, ElementCompute(0)),
|
||||
alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ParamsBase const& base
|
||||
): ParamsBase(base), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
alpha = reinterpret_cast<ElementCompute const&>(base.alpha_data);
|
||||
beta = reinterpret_cast<ElementCompute const&>(base.beta_data);
|
||||
#else
|
||||
memcpy( alpha, base.alpha_data, sizeof(ElementCompute) );
|
||||
memcpy( beta, base.alpha_data, sizeof(ElementCompute) );
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@ -142,7 +157,6 @@ public:
|
||||
/// Constructs the function object, possibly loading from pointers in host memory
|
||||
CUTLASS_HOST_DEVICE
|
||||
LinearCombination(Params const ¶ms) {
|
||||
|
||||
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
||||
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
||||
}
|
||||
|
||||
@ -83,40 +83,7 @@ public:
|
||||
static FloatRoundStyle const kRound = Round;
|
||||
|
||||
/// Host-constructable parameters structure
|
||||
struct Params {
|
||||
|
||||
ElementCompute alpha; ///< scales accumulators
|
||||
ElementCompute beta; ///< scales source tensor
|
||||
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
||||
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
alpha(ElementCompute(1)),
|
||||
beta(ElementCompute(0)),
|
||||
alpha_ptr(nullptr),
|
||||
beta_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta = ElementCompute(0)
|
||||
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute const *alpha_ptr,
|
||||
ElementCompute const *beta_ptr = nullptr
|
||||
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
|
||||
|
||||
}
|
||||
};
|
||||
using Params = typename ActivationFunctor<FragmentCompute>::Params;
|
||||
|
||||
private:
|
||||
|
||||
@ -124,8 +91,7 @@ private:
|
||||
// Data members
|
||||
//
|
||||
|
||||
ElementCompute alpha_;
|
||||
ElementCompute beta_;
|
||||
Params params_;
|
||||
bool skip_elementwise_;
|
||||
|
||||
public:
|
||||
@ -133,9 +99,9 @@ public:
|
||||
/// Constructs the function object, possibly loading from pointers in host memory
|
||||
CUTLASS_HOST_DEVICE
|
||||
LinearCombinationGeneric(Params const ¶ms) {
|
||||
|
||||
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
||||
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
||||
params_ = params;
|
||||
params_.alpha = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
||||
params_.beta = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
||||
skip_elementwise_ = false;
|
||||
}
|
||||
|
||||
@ -148,14 +114,14 @@ public:
|
||||
|
||||
if (Scale == ScaleType::Nothing) return false;
|
||||
|
||||
return beta_ != ElementCompute(0);
|
||||
return params_.beta != ElementCompute(0);
|
||||
}
|
||||
|
||||
/// Functionally required for serial reduction in the epilogue
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition, int k_partition_count) {
|
||||
if (k_partition) {
|
||||
beta_ = ElementCompute(1);
|
||||
params_.beta = ElementCompute(1);
|
||||
}
|
||||
|
||||
if (k_partition != k_partition_count - 1) {
|
||||
@ -186,15 +152,15 @@ public:
|
||||
|
||||
if (Scale == ScaleType::NoBetaScaling) {
|
||||
intermediate = converted_source;
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
intermediate = mul_add_accumulator(params_.alpha, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
} else if (Scale == ScaleType::Nothing) {
|
||||
intermediate = converted_accumulator;
|
||||
} else {
|
||||
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
intermediate = mul_add_source(params_.beta, converted_source); // X = beta * C + uniform
|
||||
intermediate = mul_add_accumulator(params_.alpha, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
}
|
||||
|
||||
intermediate = skip_elementwise_ ? intermediate : activation(intermediate);
|
||||
intermediate = skip_elementwise_ ? intermediate : activation(intermediate, params_);
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
||||
@ -222,10 +188,10 @@ public:
|
||||
if (Scale == ScaleType::Nothing) {
|
||||
intermediate = converted_accumulator;
|
||||
} else {
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
||||
intermediate = mul_add_accumulator(params_.alpha, converted_accumulator); // D = alpha * Accum
|
||||
}
|
||||
|
||||
intermediate = skip_elementwise_ ? intermediate : activation(intermediate);
|
||||
intermediate = skip_elementwise_ ? intermediate : activation(intermediate, params_);
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
||||
|
||||
75
include/cutlass/epilogue/thread/linear_combination_params.h
Normal file
75
include/cutlass/epilogue/thread/linear_combination_params.h
Normal file
@ -0,0 +1,75 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace thread {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct LinearCombinationParams {
|
||||
uint64_t alpha_data[2];
|
||||
uint64_t beta_data[2];
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
LinearCombinationParams()
|
||||
: alpha_data {0lu, 0lu}, beta_data {0lu, 0lu}
|
||||
{ }
|
||||
|
||||
template <typename ElementCompute>
|
||||
CUTLASS_HOST_DEVICE
|
||||
LinearCombinationParams(ElementCompute alpha, ElementCompute beta)
|
||||
: alpha_data {0lu, 0lu}, beta_data {0lu, 0lu}
|
||||
{
|
||||
#if defined(__CUDA_ARCH__)
|
||||
reinterpret_cast<ElementCompute&>(alpha_data) = alpha;
|
||||
reinterpret_cast<ElementCompute&>(beta_data) = beta;
|
||||
#else
|
||||
memcpy( alpha_data, &alpha, sizeof(ElementCompute) );
|
||||
memcpy( beta_data, &beta, sizeof(ElementCompute) );
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace thread
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,156 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
int Rank
|
||||
>
|
||||
struct PredicatedTileIteratorAffineLayoutRankNParams {
|
||||
using Layout = layout::AffineRankN<Rank>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
|
||||
static bool const kBigEndian = false;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
Layout layout;
|
||||
|
||||
/// Stride in units of bytes along M modes
|
||||
Coord<Layout::kRank/2, typename Layout::LongIndex> stride_m;
|
||||
|
||||
/// Stride in units of bytes along N modes
|
||||
Coord<Layout::kRank/2, typename Layout::LongIndex> stride_n;
|
||||
|
||||
/// Fast divmod objects divided by tensor extents
|
||||
FastDivmod divmod_m[(Layout::kRank == 2) ? 1 : (Layout::kRank/2 - 1)];
|
||||
|
||||
/// Fast divmod objects divided by tensor extents
|
||||
FastDivmod divmod_n[(Layout::kRank == 2) ? 1 : (Layout::kRank/2 - 1)];
|
||||
|
||||
int64_t rank2_inc_col;
|
||||
int64_t rank2_inc_row;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedTileIteratorAffineLayoutRankNParams() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedTileIteratorAffineLayoutRankNParams(TensorCoord const &extent,
|
||||
Layout const &layout_,
|
||||
int64_t element_sizeof_bits)
|
||||
: layout(layout_)
|
||||
{
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Layout::kRank / 2; ++i) {
|
||||
stride_m[i] = OffsetBytes(layout_.stride()[i], element_sizeof_bits);
|
||||
stride_n[i] = OffsetBytes(layout_.stride()[i + Layout::kRank / 2], element_sizeof_bits);
|
||||
}
|
||||
|
||||
if (kBigEndian) {
|
||||
// "Big Endian" scheme
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Layout::kRank / 2 - 1; ++i) {
|
||||
divmod_m[i] = FastDivmod(extent[i + 1]);
|
||||
divmod_n[i] = FastDivmod(extent[i + Layout::kRank / 2 + 1]);
|
||||
}
|
||||
}
|
||||
else {
|
||||
// "Little Endian" scheme
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Layout::kRank / 2 - 1; ++i) {
|
||||
divmod_m[i] = FastDivmod(extent[i]);
|
||||
divmod_n[i] = FastDivmod(extent[i + Layout::kRank / 2]);
|
||||
}
|
||||
}
|
||||
|
||||
#if 0
|
||||
//
|
||||
// Debug print statements to verify extents and strides are passed correctly.
|
||||
//
|
||||
printf("PredicatedTileIteratorAffine::Params() entered\n");
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Layout::kRank; ++i) {
|
||||
printf(" extent[%d]: %d\n", i, extent[i]);
|
||||
}
|
||||
for (int i = 0; i < Layout::kRank; ++i) {
|
||||
printf(" stride[%d]: %ld\n", i, layout_.stride()[i]);
|
||||
}
|
||||
printf("PredicatedTileIteratorAffine::Params() returning\n");
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedTileIteratorAffineLayoutRankNParams(Layout const &layout_,
|
||||
int32_t threadmap_delta_kColumn,
|
||||
int32_t threadmap_delta_kRow,
|
||||
int64_t element_sizeof_bits)
|
||||
: layout(layout_)
|
||||
{
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Layout::kRank / 2; ++i) {
|
||||
stride_m[i] = OffsetBytes(layout_.stride()[i], element_sizeof_bits);
|
||||
stride_n[i] = OffsetBytes(layout_.stride()[i + Layout::kRank / 2], element_sizeof_bits);
|
||||
}
|
||||
|
||||
rank2_inc_col = threadmap_delta_kColumn * stride_n[0];
|
||||
rank2_inc_row = threadmap_delta_kRow * stride_m[0];
|
||||
}
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -561,6 +561,17 @@ CUTLASS_HOST_DEVICE int64_t OffsetBytes(int64_t index) {
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE int64_t OffsetBytes(int64_t index, int64_t element_sizeof_bits) {
|
||||
if (element_sizeof_bits >= 8) {
|
||||
return index * (element_sizeof_bits / 8);
|
||||
}
|
||||
else {
|
||||
int64_t const kElementsPerByte = ((8 / element_sizeof_bits) + ((element_sizeof_bits >= 8) ? 1 : 0));
|
||||
return index / kElementsPerByte;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Min/Max
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -88,12 +88,12 @@ public:
|
||||
using ElementA = typename MapArguments::ElementA;
|
||||
using LayoutA = typename MapArguments::LayoutA;
|
||||
static ComplexTransform const kTransformA = MapArguments::kTransformA;
|
||||
static int const kAlignmentA = GemmKernel::kAlignmentA;
|
||||
static int const kAlignmentA = MapArguments::kAlignmentA;
|
||||
|
||||
using ElementB = typename MapArguments::ElementB;
|
||||
using LayoutB = typename MapArguments::LayoutB;
|
||||
static ComplexTransform const kTransformB = MapArguments::kTransformB;
|
||||
static int const kAlignmentB = GemmKernel::kAlignmentB;
|
||||
static int const kAlignmentB = MapArguments::kAlignmentB;
|
||||
|
||||
using ElementC = typename GemmKernel::ElementC;
|
||||
using LayoutC = typename MapArguments::LayoutC;
|
||||
|
||||
@ -40,7 +40,7 @@
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/transform/thread/unaryOp.h"
|
||||
#include "cutlass/transform/thread/unary_op.h"
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/half.h"
|
||||
@ -1273,6 +1273,40 @@ struct NumericArrayConverter<uint8_t, int, N, Round> {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Array<int8_t> <= Array<float>
|
||||
/// Conversion is performed with saturation regardless of setting of
|
||||
/// the `Round` template parameter.
|
||||
template <
|
||||
int N,
|
||||
FloatRoundStyle Round
|
||||
>
|
||||
struct NumericArrayConverter<int8_t, float, N, Round> {
|
||||
|
||||
using result_type = Array<int8_t, N>;
|
||||
using source_type = Array<float, N>;
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static result_type convert(source_type const & source) {
|
||||
// Convert float to int
|
||||
Array<int32_t, N> temporary;
|
||||
|
||||
NumericArrayConverter<int, float, N, Round> compute_converter;
|
||||
temporary = compute_converter(source);
|
||||
|
||||
// Convert to int to int8_t
|
||||
NumericArrayConverter<int8_t, int32_t, N, Round> destination_converter;
|
||||
return destination_converter(temporary);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && \
|
||||
((__CUDACC_VER_MAJOR__ > 10) || \
|
||||
((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
|
||||
|
||||
@ -189,3 +189,35 @@ TEST(NumericConversion, f16x8_to_f32x8_rn) {
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(NumericConversion, f32x8_to_s8x8_rn) {
|
||||
|
||||
int const kN = 8;
|
||||
using Source = float;
|
||||
using Destination = int8_t;
|
||||
|
||||
dim3 grid(1, 1);
|
||||
dim3 block(1, 1);
|
||||
|
||||
cutlass::HostTensor<Destination, cutlass::layout::RowMajor> destination({1, kN});
|
||||
cutlass::HostTensor<Source, cutlass::layout::RowMajor> source({1, kN});
|
||||
|
||||
for (int i = 0; i < kN; ++i) {
|
||||
source.host_data()[i] = float(i);
|
||||
}
|
||||
|
||||
source.sync_device();
|
||||
|
||||
test::core::kernel::convert<Destination, Source, kN><<< grid, block >>>(
|
||||
reinterpret_cast<cutlass::Array<Destination, kN> *>(destination.device_data()),
|
||||
reinterpret_cast<cutlass::Array<Source, kN> const *>(source.device_data())
|
||||
);
|
||||
|
||||
destination.sync_host();
|
||||
|
||||
for (int i = 0; i < kN; ++i) {
|
||||
EXPECT_TRUE(float(destination.host_data()[i]) == source.host_data()[i]);
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -16,7 +16,7 @@ math_inst = MathInstruction(
|
||||
|
||||
tile_description = TileDescription(
|
||||
[128, 128, 8], 4, [2, 4, 1],
|
||||
math_inst, 80, 80
|
||||
math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -31,10 +31,12 @@ C = TensorDescription(
|
||||
cutlass.float32, cutlass.RowMajor, 1
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(cutlass.float32, 1, cutlass.float32, cutlass.float32)
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=cutlass.float32,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -54,7 +56,7 @@ beta = 0.0
|
||||
arguments = GemmArguments(
|
||||
operation=operation, problem_size=problem_size,
|
||||
A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D,
|
||||
output_op=LinearCombinationFunctorArguments(alpha, beta),
|
||||
output_op=operation.epilogue_type(alpha, beta),
|
||||
gemm_mode=cutlass.gemm.Mode.Gemm, split_k_splices=1
|
||||
)
|
||||
|
||||
@ -68,6 +70,14 @@ assert torch.equal(tensor_D, tensor_D_ref)
|
||||
```
|
||||
PyCUTLASS also provides infrastructures for profiling, compiled artifact management, and pool memory manager
|
||||
|
||||
## Supported Features
|
||||
PyCUTLASS currently supports following operations:
|
||||
* GEMM with mode {Serial, Parallel Split K, Batched GEMM, Array GEMM}, op class {SIMT, TensorCore}, data type {int8, f16, bf16, f32, f64}, layout {RowMajor, ColumnMajor, Row/ColumnMajorInterleaved<32> for int8}, math operation {MultiplyAdd, MultiplyAddFastF16, MultiplyAddFastBF16, MultiplyAddFastF32}, swizzling functions {IdentitySwizzle<1,2,4,8>, HorizontalSwizzle, BatchedIdentitySwizzle}, and epilogue {LinearCombination, LinearCombinationClamp}
|
||||
* GEMM grouped with op class {SIMT, TensorCore}, data type {int8, f16, bf16, f32, f64}, layout {RowMajor, ColumnMajor}, math operation {MultiplyAdd, MultiplyAddFastF16, MultiplyAddFastBF16, MultiplyAddFastF32}, scheduling mode {Host, Device}, and epilogue {LinearCombination, LinearCombinationClamp}.
|
||||
* Conv2d with {Fprop, Dgrad, Wgrad}, op class {SIMT, TensorCore}, data type {int8, f16, bf16, f32, f64}, layout {Tensor NHWC, TensorNC32HW32 and TensorC32RSK for int8}, math operation {MultiplyAdd, MultiplyAddFastF16, MultiplyAddFastBF16, MultiplyAddFastF32}, split-k mode {Parallel, Serial}, and epilogue {LinearCombination, LinearCombinationClamp}
|
||||
|
||||
The tiling size of above operations can also be customized.
|
||||
|
||||
## Installation
|
||||
|
||||
### Using Docker
|
||||
@ -94,12 +104,19 @@ cd $CUTLASS_PATH/tools/library/scripts/pycutlass && bash build.sh
|
||||
```
|
||||
|
||||
## Examples
|
||||
Examples can be found in `$CUTLASS_PATH/examples/40_cutlass_py`
|
||||
Examples can be found in [$CUTLASS_PATH/examples/40_cutlass_py](examples/40_cutlass_py)
|
||||
|
||||
## Test
|
||||
The test cases are listed in `$CUTLASS_PATH//tools/library/scripts/pycutlass/test`. The unit test can be run with
|
||||
```shell
|
||||
cd $CUTLASS_PATH/tools/library/scripts/pycutlass/test/unit && python test_sm80.py
|
||||
cd $CUTLASS_PATH/tools/library/scripts/pycutlass/test/example && run_all_example.sh
|
||||
```
|
||||
|
||||
## build documentation
|
||||
Run
|
||||
```shell
|
||||
bash build_doc.sh
|
||||
```
|
||||
|
||||
|
||||
|
||||
@ -1,2 +1,4 @@
|
||||
python setup.py develop
|
||||
pip install enum-tools
|
||||
pip install sphinx-toolbox
|
||||
pip install m2r2
|
||||
sphinx-build -b html docs/source/ docs/build/html
|
||||
|
||||
@ -50,7 +50,7 @@
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = 'PyCutlass'
|
||||
copyright = '2022, Andrew Kerr; Zhaodong Chen; Haicheng Wu; Szymon Migacz; Graham Markall'
|
||||
copyright = '2022, Zhaodong Chen; Andrew Kerr; Haicheng Wu; Szymon Migacz; Graham Markall'
|
||||
author = 'Zhaodong Chen; Andrew Kerr; Haicheng Wu; Szymon Migacz; Graham Markall'
|
||||
|
||||
|
||||
@ -65,9 +65,12 @@ extensions = [
|
||||
'sphinx.ext.autodoc',
|
||||
'sphinx.ext.intersphinx',
|
||||
'enum_tools.autoenum',
|
||||
'sphinx.ext.autosummary'
|
||||
'sphinx.ext.autosummary',
|
||||
'm2r2'
|
||||
]
|
||||
|
||||
source_suffix = [".rst", ".md"]
|
||||
|
||||
autosummary_generate = True
|
||||
autosummary_imported_members = True
|
||||
|
||||
@ -85,7 +88,7 @@ exclude_patterns = []
|
||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||
# a list of builtin themes.
|
||||
#
|
||||
html_theme = 'classic'
|
||||
html_theme = 'bizstyle'
|
||||
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
|
||||
@ -1,2 +1,100 @@
|
||||
cutlass
|
||||
=======
|
||||
|
||||
.. rubric:: Operator Classification
|
||||
|
||||
.. autoclass:: cutlass.OpClass
|
||||
:members:
|
||||
|
||||
.. rubric:: GEMM Layout
|
||||
|
||||
.. autoclass:: cutlass.RowMajor
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.ColumnMajor
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.RowMajorInterleaved32
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.ColumnMajorInterleaved32
|
||||
:members:
|
||||
|
||||
.. rubric:: Conv Layout
|
||||
|
||||
.. autoclass:: cutlass.TensorNHWC
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.TensorNC32HW32
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.TensorC32RSK32
|
||||
:members:
|
||||
|
||||
.. rubric:: Threadblock Swizzle
|
||||
|
||||
.. autoclass:: cutlass.dim3
|
||||
:special-members:
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.IdentitySwizzle1
|
||||
:special-members:
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.IdentitySwizzle2
|
||||
:special-members:
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.IdentitySwizzle4
|
||||
:special-members:
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.IdentitySwizzle8
|
||||
:special-members:
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.HorizontalSwizzle
|
||||
:special-members:
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.BatchedIdentitySwizzle
|
||||
:special-members:
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.StridedDgradIdentitySwizzle1
|
||||
:special-members:
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.StridedDgradIdentitySwizzle4
|
||||
:special-members:
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.StridedDgradHorizontalSwizzle
|
||||
:special-members:
|
||||
:members:
|
||||
|
||||
.. rubric:: Coordinates
|
||||
|
||||
.. autoclass:: cutlass.Tensor4DCoord
|
||||
:special-members:
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.Tensor3DCoord
|
||||
:special-members:
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.MatrixCoord
|
||||
:special-members:
|
||||
:members:
|
||||
|
||||
|
||||
.. rubric:: Convolution
|
||||
|
||||
.. autoclass:: cutlass.conv.Operator
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.conv.IteratorAlgorithm
|
||||
:members:
|
||||
|
||||
.. autoclass:: cutlass.conv.StrideSupport
|
||||
:members:
|
||||
|
||||
@ -1,6 +0,0 @@
|
||||
Descriptions
|
||||
==============
|
||||
|
||||
.. autoclass:: pycutlass.TileDescription
|
||||
:special-members:
|
||||
:members:
|
||||
@ -1,5 +0,0 @@
|
||||
Frontend
|
||||
==============
|
||||
|
||||
.. autoclass:: pycutlass.NumpyFrontend
|
||||
:members:
|
||||
@ -3,27 +3,29 @@
|
||||
You can adapt this file completely to your liking, but it should at least
|
||||
contain the root `toctree` directive.
|
||||
|
||||
Welcome to PyCutlass's documentation!
|
||||
CUTLASS Python Project Documentation
|
||||
=====================================
|
||||
|
||||
.. mdinclude:: ../../README.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: Contents:
|
||||
|
||||
|
||||
|
||||
Indices and tables
|
||||
.. Indices and tables
|
||||
.. ==================
|
||||
|
||||
.. * :ref:`genindex`
|
||||
.. * :ref:`modindex`
|
||||
.. * :ref:`search`
|
||||
|
||||
|
||||
Indices
|
||||
==================
|
||||
|
||||
* :ref:`genindex`
|
||||
* :ref:`modindex`
|
||||
* :ref:`search`
|
||||
|
||||
|
||||
.. toctree::
|
||||
types
|
||||
cutlass
|
||||
descriptor
|
||||
frontend
|
||||
user_guide
|
||||
visitor_tree
|
||||
gemm_op
|
||||
conv2d_op
|
||||
cutlass
|
||||
|
||||
@ -0,0 +1,225 @@
|
||||
# Epilogue Visitor Tree
|
||||
The Epilogue Visitor Tree is an experimental feature that directly generates epilogues from user-provide python functions.
|
||||
|
||||
## Usage
|
||||
|
||||
The Epilogue Visitor tree support many different operations.
|
||||
|
||||
### Unary functions
|
||||
Epilogue Visitor Tree supports unary functions like activation functions. For example,
|
||||
```python
|
||||
class UnaryEpilogue_(EpilogueVisitTree):
|
||||
def __call__(
|
||||
self, accum: 'tensor', c: 'tensor',
|
||||
alpha: 'scalar', beta: 'scalar'):
|
||||
#
|
||||
T = leaky_relu.numpy(accum, 0.2)
|
||||
Z = alpha * T + beta * c
|
||||
return Z
|
||||
epilogue_functor = UnaryEpilogue_(
|
||||
epilogue_functor, tile_description, math_inst.element_accumulator,
|
||||
C.alignment, element_epilogue, C.element)
|
||||
```
|
||||
|
||||
### Broadcast Operation
|
||||
Epilogue Visitor Tree supports broadcasting row and column vectors to the whole output matrix. To use broadcast, you just need to specify whether the source vector is a `row` vector or a `column` vector. Here is an example.
|
||||
```python
|
||||
class ColumnBroadcast_(EpilogueVisitTree):
|
||||
def __call__(
|
||||
self, accum: 'tensor', c: 'tensor',
|
||||
vector: 'column', alpha: 'scalar', beta: 'scalar'):
|
||||
#
|
||||
T = accum + vector
|
||||
scale_T = leaky_relu.numpy(alpha * T, 0.2)
|
||||
Z = scale_T + beta * c
|
||||
return Z, T
|
||||
epilogue_functor = ColumnBroadcast_(
|
||||
epilogue_functor, tile_description, math_inst.element_accumulator,
|
||||
C.alignment, element_epilogue, C.element)
|
||||
```
|
||||
|
||||
### Reduction Operation
|
||||
|
||||
Epilogue Visitor Tree also supports row and column-wise reduction in each threadblock tile. The syntax for reduction is
|
||||
```python
|
||||
{reduction_output} = reduction_op({input_tensor}, {row|column}, {Add}, {threadblock_shape.n|threadblock_shape.m})
|
||||
```
|
||||
The `{row|column}` indicates whether the `row` vectors are reduced or the `column` vectors are reduction. The `{Add}` specifies the reduction operation. The `{threadblock_shape.n|threadblock_shape.m}` are the reduction lengths.
|
||||
|
||||
**Constraint**
|
||||
* The `{input_tensor}` can only be the name of source or intermediate result. `reduction_op(A + B, ...)` will not work, please use `C = A + B`, `reduction_op(C, ...)` instead.
|
||||
* The `{reduction_output}` cannot be used in the epilogue. It will be directly written to global memory after the reduction is done.
|
||||
```python
|
||||
class RowReduction_(EpilogueVisitTree):
|
||||
def __call__(
|
||||
self, accum: 'tensor', c: 'tensor',
|
||||
alpha: 'scalar', beta: 'scalar'):
|
||||
#
|
||||
D = alpha * accum + tanh.numpy(beta * c)
|
||||
reduction = reduction_op(D, "row", "Add", args.threadblock_shape[1])
|
||||
return D, reduction
|
||||
epilogue_functor = RowReduction_(
|
||||
epilogue_functor, tile_description, math_inst.element_accumulator,
|
||||
C.alignment, element_epilogue, C.element)
|
||||
epilogue_functor.initialize()
|
||||
```
|
||||
|
||||
## Get output_op
|
||||
|
||||
As shown in the user guide, an `output_op` is required by the argument wrapper. We will take the `RowReduction_` as an example to show how to get `output_op`.
|
||||
```python
|
||||
class RowReduction_(EpilogueVisitTree):
|
||||
def __call__(
|
||||
self, accum: 'tensor', c: 'tensor',
|
||||
alpha: 'scalar', beta: 'scalar'):
|
||||
#
|
||||
D = alpha * accum + tanh.numpy(beta * c)
|
||||
reduction = reduction_op(D, "row", "Add", args.threadblock_shape[1])
|
||||
return D, reduction
|
||||
epilogue_functor = RowReduction_(
|
||||
epilogue_functor, tile_description, math_inst.element_accumulator,
|
||||
C.alignment, element_epilogue, C.element)
|
||||
epilogue_functor.initialize()
|
||||
|
||||
cta_n = args.threadblock_shape[1]
|
||||
num_cta_n = (problem_size.n() + cta_n - 1) // cta_n
|
||||
reduction = np.zeros(shape=(args.batch * problem_size.m() * num_cta_n,), dtype=getattr(np, element_c))
|
||||
# get output op
|
||||
output_op = operation.epilogue_type(
|
||||
D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()]
|
||||
)
|
||||
```
|
||||
Like other epilogue functors such as `LinearCombination`, the output op for EpilogueVisitorTree is also created with `operation.epilogue_type(*)`. However, there are two differences:
|
||||
* The arguments need to be passed as keyword-arguments. The keywords are the argument names in `def __call__`.
|
||||
* An additional `problem_size=[problem_size.m(), problem_size.n()]` is required.
|
||||
|
||||
|
||||
## Add new Unary Operation (e.g. Activation Function)
|
||||
To add additional unary operation into epilogue visitor tree, a new unary op
|
||||
should be created for `VisitorOpUnary`. We will take `tanh` as an example.
|
||||
|
||||
### Step 1: define TanhVisitor
|
||||
|
||||
The visitor defines the parameters and computation required by the unary option.
|
||||
The unary operations are registered in [pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/unary_ops.h](tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/unary_ops.h). But you can define it in any header file and include the header file in [pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_unary.h](tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_unary.h).
|
||||
|
||||
|
||||
* Two template arguments are required:
|
||||
* `T`: data type used to compute the unary operation
|
||||
* `N`: compute fragment length
|
||||
* We also need to provide the `Arguments` and `Params` structures. The `Arguments` will be assembled by [ctypes](https://docs.python.org/3/library/ctypes.html), the `Params` will be generated from `Arguments` automatically. If the unary function takes no argument, an integer like `int tmp` can be provide to ensure the correctness of ctypes.
|
||||
* The constructor can only take the `params` as the single argument.
|
||||
* The operation is defined in `Array<T, N> operator()(Array<T, N> const &frag) const `. On common way to do that is first define a scalar computation, and them use it for the fragment computation with an unrolled for-loop.
|
||||
* A guard function is required. If it returns `true`, it will disable all the children nodes of the unary node and return zeros to parent node. This is very helpful for multiplication with scalar while the scalar is `0`. For general cases, you can just return `true`.
|
||||
```c++
|
||||
// T: data type used to compute the unary operation
|
||||
// N: compute fragment length
|
||||
template <typename T, int N>
|
||||
struct TanhVisitor {
|
||||
/// Argument
|
||||
struct Arguments {
|
||||
// a placeholder argument to ensure correctness of ctypes
|
||||
int tmp;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): tmp(0) { };
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(int tmp): tmp(tmp) { };
|
||||
};
|
||||
|
||||
/// Param
|
||||
struct Params {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(){ };
|
||||
Params(Arguments const &args) { }
|
||||
};
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TanhVisitor(Params const ¶ms) { }
|
||||
|
||||
// scalar operator
|
||||
CUTLASS_HOST_DEVICE
|
||||
T tanh_op(T const &scalar) const {
|
||||
return fast_tanh(scalar);
|
||||
}
|
||||
|
||||
/// vector operator
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &frag) const {
|
||||
Array<T, N> y;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i=0; i < N; ++i) {
|
||||
y[i] = tanh_op(frag[i]);
|
||||
}
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
// Guard
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool guard() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
### Step 2: register Tanh function
|
||||
After defining the function in C++, we need to register it in python. The class below gives an example.
|
||||
* The init function takes the data type `element_compute`, which will be the `T` in the C++ template.
|
||||
In the init function, we also generate the `_Arguments` class as a `ctypes.Structure`. It includes all the data members in the `TanhVisitor::Arguments`.
|
||||
* The `_Arguments` need to be registered as `self.argument_type` of `tanh` class.
|
||||
* A `emit` function is required to emit the namespace and typename of `TanhVisitor`.
|
||||
* A staticmethod as numpy reference is required to implement the python code to parse.
|
||||
|
||||
The built-in functions are defined in [pycutlass/src/pycutlass/epilogue.py](tools/library/scripts/pycutlass/src/pycutlass/epilogue.py). You can defined yours in any file as long as it can be found by [/pycutlass/src/pycutlass/parser.py](tools/library/scripts/pycutlass/src/pycutlass/parser.py).
|
||||
```python
|
||||
class tanh(ActivationFunctor):
|
||||
def __init__(self, element_compute) -> None:
|
||||
super().__init__()
|
||||
class _Arguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("tmp", ctypes.c_int)
|
||||
]
|
||||
def __init__(self, *args) -> None:
|
||||
self.tmp = 0
|
||||
self.argument_type = _Arguments
|
||||
|
||||
def emit(self):
|
||||
return "cutlass::TanhVisitor"
|
||||
|
||||
@staticmethod
|
||||
def numpy(x: np.ndarray):
|
||||
return np.tanh(x)
|
||||
```
|
||||
|
||||
### Step 3: Run the function
|
||||
Now the new unary op is ready to use. An epilogue visitor tree can be built with
|
||||
```python
|
||||
class RowReduction_(EpilogueVisitTree):
|
||||
def __call__(
|
||||
self, accum: NDArray['tensor', 'float32'], c: NDArray['tensor', 'float32'],
|
||||
alpha: 'float32', beta: 'float32'):
|
||||
#
|
||||
D = alpha * accum + tanh.numpy(beta * c)
|
||||
reduction = reduction_op(D, "row", "Add", args.threadblock_shape[1])
|
||||
return D, reduction
|
||||
epilogue_functor = RowReduction_(
|
||||
epilogue_functor, tile_description, math_inst.element_accumulator,
|
||||
C.alignment, element_epilogue, C.element)
|
||||
epilogue_functor.initialize()
|
||||
```
|
||||
|
||||
## Limitations and Future work
|
||||
|
||||
Although the Epilogue Visitor Tree brings great flexibility to epilogue construction, as the epilogue is formulated as a single tree, there are several limitations.
|
||||
* [Future Work] Serial and Parallel Split-K GEMM are not supported yet.
|
||||
* To support serial split-k, additional tree transformation pass is required to inject a `binaryOpNode(Add)` + `TensorInputNode` before each `TensorOutputNode` to fetch the partial sum back. The `semaphore` also needs to be passed into epilogue.
|
||||
* To support parallel split-k, an Reduction with visitor kernel is required.
|
||||
* [Future Work] Convolution and GEMM Grouped are not supported yet.
|
||||
* To support Conv2d and GEMM Grouped, corresponding *_with_visitor kernels are required.
|
||||
|
||||
* [Limitation] If the same node is used by two operations (except that one of them is reduction), the node and all its offsprings will be executed twice.
|
||||
* [Limitation] The result of reduction can only be used as the return value.
|
||||
283
tools/library/scripts/pycutlass/docs/source/md/basic_idea.md
Normal file
283
tools/library/scripts/pycutlass/docs/source/md/basic_idea.md
Normal file
@ -0,0 +1,283 @@
|
||||
# Basics of PyCUTLASS
|
||||
|
||||
PyCUTLASS handles the following things when launch the CUTLASS kernels
|
||||
* Memory management
|
||||
* Operation Description
|
||||
* Code emission and compilation
|
||||
* Arguments preprocessing
|
||||
* Kernel launching
|
||||
* Result Synchronization
|
||||
|
||||
## Memory management
|
||||
|
||||
PyCUTLASS uses [RMM](https://github.com/rapidsai/rmm) to manage device memory. At the begining of the program, call
|
||||
```python
|
||||
pycutlass.get_memory_pool({init_pool_size_in_bytes}, {max_pool_size_in_bytes})
|
||||
```
|
||||
We also provide functions to query the allocated size.
|
||||
```python
|
||||
bytes = get_allocated_size()
|
||||
```
|
||||
|
||||
|
||||
## Operation Description
|
||||
PyCUTLASS provides operation description for GEMM, GEMM Grouped and Conv2d operations. These operation descriptions are assembled from four foundamental concepts
|
||||
* Math Instruction: math instruction executed in GPU cores
|
||||
* Tile Description: tiling sizes and pipeline stages
|
||||
* Operand Description: data type, layout, memory alignment
|
||||
* Epilogue Functor: epilogue function
|
||||
|
||||
### Math Instruction
|
||||
|
||||
The math instruction is defined as follows:
|
||||
```python
|
||||
math_inst = MathInstruction(
|
||||
{instruction_shape}, {element_a}, {element_b},
|
||||
{element_acc}, {opclass}, {math_operation}
|
||||
)
|
||||
```
|
||||
The `{instruction_shape}` and `{opclass}` defines the instruction size and type. The table below lists valid combinations. `{element_a}`, `{element_b}` define the source operand data type for each instructions, and `{element_acc}` defines the accumulator type. The `{math_operation}` defines the math operation applied.
|
||||
|
||||
|Opclass | element_a/element_b | element_acc | instruction_shape | math_operation |
|
||||
| -- | -- | -- | -- | -- |
|
||||
| cutlass.OpClass.TensorOp | cutlass.float64 | cutlass.float64 | [8, 8, 4] | MathOperation.multiply_add|
|
||||
| | cutass.float32 cutlass.tfloat32, cutlass.float16 cutlass.bfloat16 | cutlass.float32 | [16, 8, 8] | MathOperation.multiply_add MathOperation.multiply_add_fast_f32 MathOperation.multiply_add_fast_f16 MathOperation.multiply_add_fast_bf16 |
|
||||
| | cutlass.float16 | cutlass.float16/cutlass.float32|[16, 8, 16]| MathOperation.multiply_add |
|
||||
| | cutlass.bfloat_16 | cutlass.float32 | [16, 8, 16]|MathOperation.multiply_add |
|
||||
| | cutlass.int8 | cutlass.int32 | [16, 8, 32] | MathOperation.multiply_add_saturate|
|
||||
|cutlass.OpClass.Simt| cutlass.float64 | cutlass.float64 | [1, 1, 1] | MathOperation.multiply_add |
|
||||
| | cutlass.float32 | cutlass.float32 | [1, 1, 1] | MathOperation.multiply_add |
|
||||
|
||||
The `cutlass.OpClass.TensorOp` indicates that the tensor core is used, while `cutlass.OpClass.Simt` uses the SIMT Core.
|
||||
|
||||
The `multiply_add_fast_f32` emulates fast accurate SGEMM kernel which is accelerated
|
||||
using Ampere Tensor Cores. More details can be found in [examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm](examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm).
|
||||
|
||||
### Tile Description
|
||||
The tile description describes the threadblock and warp tiling sizes, as well as the pipeline stages.
|
||||
```python
|
||||
tile_description = TileDescription(
|
||||
{threadblock_shape}, {stages}, {warp_count},
|
||||
math_inst
|
||||
)
|
||||
```
|
||||
The `{threadblock_shape}` is a list of 3 integers `[Tile_M, Tile_N, Tile_K]` that defines the threadblock tiling size. `{stages}` defines the number of software pipeline stages ([detail](https://developer.nvidia.com/blog/controlling-data-movement-to-boost-performance-on-ampere-architecture/)). `{warp_count}` defines the number of warps along `M`, `N`, and `K` dimension. I.e., with `{threadblock_shape}=[Tile_M, Tile_N, Tile_K]` and `{warp_count}=[W_M, W_N, W_K]`, the warp tile size would be `[Tile_M / W_M, Tile_N / W_N, Tile_K / W_K]`.
|
||||
|
||||
### Operand Description
|
||||
The Operand Description defines the data type, layout, and memory alignment of input tensor A, B, and C. The output D shares the same attributes with C. The description is as follows:
|
||||
```python
|
||||
A = TensorDescription(
|
||||
{element_a}, {layout_a}, {alignment_a}
|
||||
)
|
||||
|
||||
B = TensorDescription(
|
||||
{element_b}, {layout_b}, {alignment_b}
|
||||
)
|
||||
|
||||
C = TensorDescription(
|
||||
{element_c}, {layout_c}, {alignment_c}
|
||||
)
|
||||
```
|
||||
The table below lists the supported layout and data types for each operation
|
||||
| Operation | data type | layout |
|
||||
| -- | -- | -- |
|
||||
| GEMM, GEMM Grouped | cutlass.float64, cutlass.float32, cutlass.float16, cutlass.bfloat16 | cutlass.RowMajor, cutlass.ColumnMajor |
|
||||
| | cutlass.int8 | cutlass.RowMajor, cutlass.ColumnMajor, cutlass.RowMajorInterleaved32, cutlass.ColumnMajorInterleaved32|
|
||||
| Conv2d Fprop, Dgrad, Wgrad | cutlass.float64, cutlass.float32, cutlass.float16, cutlass.bfloat16 | cutlass.TensorNHWC |
|
||||
| Conv2d Fprop | cutlass.int8 | cutlass.TensorNHWC, cutlass.TensorNC32HW32, cutlass.TensorC32RSK32|
|
||||
|
||||
### Epilogue Functor
|
||||
The epilogue functor defines the epilogue executed after mainloop.
|
||||
We expose the following epilogue functors.
|
||||
| Epilogue Functor | Remark |
|
||||
| -- | -- |
|
||||
| LinearCombination | $D=\alpha \times Accm + \beta \times C$ |
|
||||
| LinearCombinationClamp | $D=\alpha \times Accm + \beta \times C$, Output is clamped to the maximum value of the data type output |
|
||||
| FastLinearCombinationClamp | $D=\alpha \times Accm + \beta \times C$, only used for problem size $K\le 256$ for cutlass.int8, with accumulator data type `cutlass.int32` and epilogue compute data type `cutlass.float32` |
|
||||
| LinearCombinationGeneric | $D = activation(\alpha \times Accm + \beta \times C)$, available activations include `relu`, `leaky_relu`, `tanh`, `sigmoid`, `silu`, `hardswish`, and `gelu` |
|
||||
|
||||
The epilogue functors can be created as follows
|
||||
```python
|
||||
# LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
element_C, alignment_c, element_acc, element_epilogue_compute
|
||||
)
|
||||
|
||||
# LinearCombinationClamp
|
||||
epilogue_functor = LinearCombinationClamp(
|
||||
element_C, alignment_c, element_acc, element_epilogue_compute
|
||||
)
|
||||
|
||||
# FastLinearCombinationClamp
|
||||
epilogue_functor = FastLinearCombinationClamp(
|
||||
element_C, alignment_c
|
||||
)
|
||||
|
||||
# LinearCombinationGeneric
|
||||
epilogue_functor = LinearCombinationGeneric(
|
||||
relu(element_epilogue_compute), element_C, alignment_c,
|
||||
element_acc, element_epilogue_compute
|
||||
)
|
||||
```
|
||||
|
||||
We also provides an experimental feature "Epilogue Visitor Tree" for GEMM operation. The details can be found in [EpilogueVisitorTree](tools/library/scripts/pycutlass/docs/source/md/EpilogueVisitorTree.md).
|
||||
|
||||
|
||||
### GEMM Operation
|
||||
|
||||
The GEMM Operation description can be created with
|
||||
```python
|
||||
operation = GemmOperationUniversal(
|
||||
{compute_capability}, tile_description,
|
||||
A, B, C, epilogue_functor,
|
||||
{swizzling_functor}, {visitor}
|
||||
)
|
||||
```
|
||||
* `{compute_capability}` is an integer indicates the compute capability of the GPU. For A100, it is 80.
|
||||
* `{swizzling_functor}` describes how threadblocks are scheduled on GPU. This is used to improve the L2 Locality ([detail](https://developer.nvidia.com/blog/optimizing-compute-shaders-for-l2-locality-using-thread-group-id-swizzling/)). Currently we support `cutlass.{IdentitySwizzle1|IdentitySwizzle2|IdentitySwizzle4|IdentitySwizzle8|BatchedIdentitySwizzle}`. The last one is used for batched or array GEMM.
|
||||
* `{visitor}`: a bool variable indicates whether the epilogue visitor tree is used.
|
||||
|
||||
### GEMM Grouped Operation
|
||||
The GEMM Grouped Operation description can be created with
|
||||
```python
|
||||
operation = GemmOperationGrouped(
|
||||
compute_capability, tile_description,
|
||||
A, B, C, epilogue_functor,
|
||||
swizzling_functor, {precompute_mode}
|
||||
)
|
||||
```
|
||||
* `{precompute_mode}`: It could be `SchedulerMode.Host` or `SchedulerMode.Device`. See [examples/24_gemm_grouped](examples/24_gemm_grouped) for more details.
|
||||
|
||||
|
||||
### Conv2d Operation
|
||||
The Conv2d Operation description can be created with
|
||||
```python
|
||||
operation = Conv2dOperation(
|
||||
{conv_kind}, {iterator_algorithm},
|
||||
compute_capability, tile_description,
|
||||
A, B, C, {stride_support},
|
||||
epilogue_functor, swizzling_functor
|
||||
)
|
||||
```
|
||||
* `{conv_kind}` defines which convolution is executed. Available options include `fprop`, `dgrad`, and `wgrad`.
|
||||
* `{iterator_algorithm}` specifies the iterator algorithm used by the implicit GEMM in convolution. The options are as follows:
|
||||
* `analytic`: functionally correct in all cases but lower performance
|
||||
* `optimized`: optimized for R <= 32, S <= 32 and unity-stride dgrad
|
||||
* `fixed_channels`: analytic algorithm optimized for fixed channel count (C == AccessSize)
|
||||
* `few_channels`: Analytic algorithm optimized for few channels (C divisible by AccessSize)
|
||||
* `{stride_support}`: distinguishes among partial specializations that accelerate certain problems where convolution
|
||||
stride is unit.
|
||||
* `strided`: arbitrary convolution stride
|
||||
* `unity`: unit convolution stride
|
||||
|
||||
***
|
||||
## Code Emission and Compilation
|
||||
After implementing the operation description, the related host and device code can be compiled with
|
||||
```python
|
||||
import pycutlass
|
||||
|
||||
pycutlass.compiler.add_module([operation,])
|
||||
```
|
||||
Several operations can be compiled togather. The `nvcc` at `$CUDA_INSTALL_PATH/bin` is used by default as the compiler backend. But you can also switch to [CUDA Python](https://nvidia.github.io/cuda-python/overview.html)'s `nvrtc` with
|
||||
```python
|
||||
pycutlass.compiler.nvrtc()
|
||||
```
|
||||
We also have an internal compiled artifact manager that caches the compiled kernel in both memory and disk. The `compiled_cache.db` at your workspace is the database that contains the binary files. You can delete the file if you want to recompile the kernels.
|
||||
***
|
||||
## Argument Processing
|
||||
We provide argument wrapper to convert python tensors to the kernel parameters. Currently it supports [torch.Tensor](https://pytorch.org/), [numpy.ndarray](https://numpy.org/), and [cupy.ndarray](https://cupy.dev/).
|
||||
### GEMM Arguments
|
||||
The Gemm arguments can be created with
|
||||
```python
|
||||
arguments = GemmArguments(
|
||||
operation=operation, problem_size={problem_size},
|
||||
A={tensor_A}, B={tensor_B}, C={tensor_C}, D={tensor_D},
|
||||
output_op={output_op},
|
||||
gemm_mode={gemm_mode},
|
||||
split_k_slices={split_k_slices}, batch={batch}
|
||||
)
|
||||
```
|
||||
* `problem_size` is a `cutlass.gemm.GemmCoord(M, N, K)` object that defines $M\times N\times K$ matrix multiplication.
|
||||
* `tensor_X`: user-provide tensors.
|
||||
* `output_op`: the params for the epilogue functor.
|
||||
* `gemm_mode`, `split_k_slices`, and `batch`:
|
||||
|
||||
|gemm_mode| split_k_slices | batch | remark|
|
||||
|--|--|--|--|
|
||||
|cutlass.gemm.Mode.Gemm | number of split-K slices | - | the ordinary GEMM or GEMM with serial split-K|
|
||||
|cutlass.gemm.Mode.GemmSplitKParallel | number of split-K slices | - | GEMM Split-K Parallel|
|
||||
|cutlass.gemm.Mode.Batched | - | batch size | Batched GEMM |
|
||||
|cutlass.gemm.Mode.Array | - | batch size | Array GEMM |
|
||||
|
||||
### GEMM Grouped Arguments
|
||||
The GEMM grouped arguments can be created with
|
||||
```python
|
||||
arguments = GemmGroupedArguments(
|
||||
operation, {problem_sizes_coord}, {tensor_As}, {tensor_Bs}, {tensor_Cs}, {tensor_Ds},
|
||||
output_op=output_op)
|
||||
)
|
||||
```
|
||||
* `problem_size_coord` is a list of `cutlass.gemm.GemmCoord(M, N, K)` for each problem size.
|
||||
* `tensor_Xs` is a list of user-provide tensors.
|
||||
* `output_op`: the params of the epilogue functor
|
||||
|
||||
### Conv2d Arguments
|
||||
The Conv2d arguments can be created with
|
||||
```python
|
||||
arguments = Conv2dArguments(
|
||||
operation, {problem_size}, {tensor_A},
|
||||
{tensor_B}, {tensor_C}, {tensor_D},
|
||||
{output_op},
|
||||
{split_k_mode},
|
||||
{split_k_slices}
|
||||
)
|
||||
```
|
||||
* `problem_size`: it can be constructed with
|
||||
```python
|
||||
problem_size = cutlass.conv.Conv2dProblemSize(
|
||||
cutlass.Tensor4DCoord(N, H, W, C),
|
||||
cutlass.Tensor4DCoord(K, R, S, C),
|
||||
cutlass.Tensor4DCoord(pad[0], pad[1], pad[2], pad[3]),
|
||||
cutlass.MatrixCoord(stride[0], stride[1]),
|
||||
cutlass.MatrixCoord(dilation[0], dilation[1]),
|
||||
cutlass.conv.Mode.cross_correlation,
|
||||
split_k_slices, 1
|
||||
)
|
||||
```
|
||||
* `tensor_X` are user-provide tensors
|
||||
* `output_op`: the params of the epilogue functor
|
||||
* `split_k_mode`: currently we support `cutlass.conv.SplitKMode.Serial` and `cutlass.conv.SplitKMode.Parallel`.
|
||||
* `split_k_slice`: number of split-k slices
|
||||
|
||||
For ordianry conv2d, just use `cutlass.conv.SplitKMode.Serial` with `split_k_slice=1`.
|
||||
|
||||
### Getting output_op
|
||||
The way to create output_op is listed below
|
||||
```python
|
||||
output_op = operation.epilogue_type(*([alpha, beta] + args.activation_args)),
|
||||
```
|
||||
It is a list of arguments start with the scaling factor `alpha` and `beta`.
|
||||
The `output_op` of EpilogueVisitorTree is slightly different. Please check [EpilogueVisitorTree](tools/library/scripts/pycutlass/docs/source/md/EpilogueVisitorTree.md) for details.
|
||||
|
||||
|
||||
## Kernel Launching
|
||||
|
||||
With the arguments and operations, the kernel can be launched simply with
|
||||
```python
|
||||
operation.run(arguments)
|
||||
```
|
||||
|
||||
## Sync results
|
||||
|
||||
We also provide function to synchronize the kernel execution. If you use `numpy`, it will also copy the result back to host. To do that, run
|
||||
```python
|
||||
arguments.sync()
|
||||
```
|
||||
If you use EpilogueVisitorTree, please call
|
||||
```python
|
||||
output_op.sync()
|
||||
```
|
||||
|
||||
## Reduction Kernel behind Parallel Split-K
|
||||
|
||||
If you use parallel-split-K in GEMM or Conv2d, an additional reduction kernel is required. Please check [examples/40_cutlass_py](examples/40_cutlass_py) for detail.
|
||||
@ -1,6 +0,0 @@
|
||||
Types
|
||||
========
|
||||
|
||||
|
||||
.. autoenum:: pycutlass.OperationKind
|
||||
:members:
|
||||
@ -0,0 +1,4 @@
|
||||
User Guide
|
||||
=====================================
|
||||
|
||||
.. mdinclude:: ./md/basic_idea.md
|
||||
@ -0,0 +1,4 @@
|
||||
User Guide
|
||||
=====================================
|
||||
|
||||
.. mdinclude:: ./md/EpilogueVisitorTree.md
|
||||
@ -32,6 +32,7 @@
|
||||
|
||||
from pycutlass import *
|
||||
import pycutlass
|
||||
from pycutlass.epilogue import LinearCombination
|
||||
from pycutlass.test.conv2d_testbed import Conv2dLauncher
|
||||
|
||||
|
||||
@ -62,15 +63,16 @@ if __name__ == "__main__":
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=4,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(cutlass.float32, 4, cutlass.float32, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -49,7 +49,7 @@ if __name__ == '__main__':
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[256, 128, 32],
|
||||
stages=3, warp_count=[4, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -67,7 +67,7 @@ if __name__ == '__main__':
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(cutlass.float32, 4, cutlass.float32, cutlass.float32)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
|
||||
@ -43,7 +43,7 @@ try:
|
||||
Pybind11Extension("cutlass",
|
||||
["src/cpp/cutlass.cpp"],
|
||||
include_dirs=include_dirs,
|
||||
extra_compile_args=["-fpermissive"])
|
||||
extra_compile_args=["-fpermissive", "-w"])
|
||||
]
|
||||
except ImportError:
|
||||
pass
|
||||
@ -69,7 +69,8 @@ setup(
|
||||
'typeguard',
|
||||
'bfloat16',
|
||||
'typing',
|
||||
'scikit-build'
|
||||
'scikit-build',
|
||||
'treelib'
|
||||
],
|
||||
cmdclass={
|
||||
'rmm': BuildRMM
|
||||
|
||||
@ -0,0 +1,225 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor with CTA row-wise broadcast
|
||||
|
||||
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
||||
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_complex.h"
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h"
|
||||
|
||||
#include "epilogue_visitor_op/visitor_op_linear_combination.h"
|
||||
#include "epilogue_visitor_op/visitor_op_tensor_input.h"
|
||||
#include "epilogue_visitor_op/visitor_op_accumulator.h"
|
||||
#include "epilogue_visitor_op/visitor_op_row_broadcast.h"
|
||||
#include "epilogue_visitor_op/visitor_op_tensor_output.h"
|
||||
#include "epilogue_visitor_op/visitor_op_column_reduction.h"
|
||||
#include "epilogue_visitor_op/visitor_op_row_reduction.h"
|
||||
#include "epilogue_visitor_op/visitor_op_column_broadcast.h"
|
||||
#include "epilogue_visitor_op/visitor_op_unary.h"
|
||||
#include "epilogue_visitor_op/visitor_op_binary.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Generic Epilogue Visitor.
|
||||
template <
|
||||
typename OutputOp_
|
||||
>
|
||||
class EpilogueVisitorGeneric {
|
||||
public:
|
||||
|
||||
using OutputOp = OutputOp_;
|
||||
using AccumulatorAccessType = typename OutputOp::AccumulatorAccessType;
|
||||
static int const kElementsPerAccess = OutputOp::kElementsPerAccess;
|
||||
using ElementOutput = typename OutputOp::ElementOutput;
|
||||
using OutputTileIterator = typename OutputOp::OutputTileIterator;
|
||||
|
||||
static int const kIterations = OutputTileIterator::kIterations;
|
||||
|
||||
///
|
||||
/// End Epilogue Tree
|
||||
///
|
||||
|
||||
/// Additional SMEM bufer is not required in the broadcast epilogue visitor
|
||||
struct SharedStorage {
|
||||
|
||||
typename OutputOp::SharedStorage output_smem;
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() { }
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
typename OutputOp::Arguments output_op_args;
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
Arguments() { }
|
||||
|
||||
Arguments(
|
||||
typename OutputOp::Arguments output_op_args
|
||||
):
|
||||
output_op_args(output_op_args)
|
||||
{
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
struct Params {
|
||||
typename OutputOp::Params output_op_params;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
output_op_params(args.output_op_args)
|
||||
{
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
private:
|
||||
|
||||
OutputOp output_op;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
EpilogueVisitorGeneric(
|
||||
Params const ¶ms, ///< Parameters routed to the epilogue
|
||||
SharedStorage &shared_storage, ///< Shared storage needed by the functors here
|
||||
MatrixCoord threadblock_offset,
|
||||
gemm::GemmCoord threadblock_tile_offset,
|
||||
int thread_idx,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
output_op(params.output_op_params, shared_storage.output_smem, thread_idx, threadblock_offset, problem_size)
|
||||
{ }
|
||||
|
||||
/// Helper to indicate split-K behavior
|
||||
CUTLASS_DEVICE
|
||||
void set_k_partition(
|
||||
int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
|
||||
int split_k_slices) { ///< Total number of split-K slices
|
||||
|
||||
}
|
||||
|
||||
/// Called to set the batch index
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
output_op.set_batch_index(batch_idx);
|
||||
}
|
||||
|
||||
/// Called at the start of the epilogue just before iterating over accumulator slices
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
output_op.begin_epilogue();
|
||||
}
|
||||
|
||||
/// Called at the start of one step before starting accumulator exchange
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
output_op.begin_step(step_idx);
|
||||
}
|
||||
|
||||
/// Called at the start of a row
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
output_op.begin_row(row_idx);
|
||||
}
|
||||
|
||||
/// Called after accumulators have been exchanged for each accumulator vector
|
||||
CUTLASS_DEVICE
|
||||
void visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum) {
|
||||
output_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
}
|
||||
|
||||
/// Called at the start of a row
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
output_op.end_row(row_idx);
|
||||
|
||||
}
|
||||
|
||||
/// Called after all accumulator elements have been visited
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
output_op.end_step(step_idx);
|
||||
}
|
||||
|
||||
/// Called after all steps have been completed
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
output_op.end_epilogue();
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,84 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the binary ops
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
/// Scalar multiplication
|
||||
template <typename T, int N>
|
||||
struct VectorAdd {
|
||||
|
||||
struct Arguments {
|
||||
int tmp;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():tmp(0){ }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(int tmp): tmp(tmp) { }
|
||||
};
|
||||
|
||||
struct Params {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args) { }
|
||||
};
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VectorAdd(
|
||||
Params const ¶ms
|
||||
) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
|
||||
cutlass::plus<Array<T, N>> add_op;
|
||||
return add_op(lhs, rhs);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,233 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the unary ops
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
/// Scalar multiplication
|
||||
template <typename T, int N>
|
||||
struct Mult {
|
||||
|
||||
struct Arguments {
|
||||
T alpha;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():alpha(T(1.0)){ }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(T alpha): alpha(alpha) { }
|
||||
};
|
||||
|
||||
struct Params {
|
||||
T alpha; ///< scales accumulators
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():alpha(T(1.0)){ }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args): alpha(args.alpha) { }
|
||||
};
|
||||
|
||||
T alpha_;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Mult(
|
||||
Params const ¶ms
|
||||
):
|
||||
alpha_(params.alpha)
|
||||
{ }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &source) const {
|
||||
cutlass::multiplies<Array<T, N>> multiply_op;
|
||||
return multiply_op(source, alpha_);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool guard() {
|
||||
return alpha_ != T(0);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
/// ReLU
|
||||
template <typename T, int N>
|
||||
struct ReLUVisitor {
|
||||
struct Arguments {
|
||||
T threshold;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():threshold(T(0.0)) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(T threshold): threshold(threshold) { }
|
||||
};
|
||||
|
||||
struct Params {
|
||||
T threshold;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():threshold(T(0.0)) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args): threshold(args.threshold) { }
|
||||
};
|
||||
|
||||
T threshold_;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
ReLUVisitor(Params const ¶ms):
|
||||
threshold_(params.threshold) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &frag) const {
|
||||
maximum<Array<T, N>> mx;
|
||||
return mx(frag, threshold_);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool guard() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
/// leakyReLU
|
||||
template <typename T, int N>
|
||||
struct LeakyReLUVisitor {
|
||||
struct Arguments {
|
||||
T leaky_alpha;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():leaky_alpha(T(0.0)) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(T leaky_alpha): leaky_alpha(leaky_alpha) { }
|
||||
};
|
||||
|
||||
struct Params {
|
||||
T leaky_alpha;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():leaky_alpha(T(0.0)) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args): leaky_alpha(args.leaky_alpha) { }
|
||||
};
|
||||
|
||||
T leaky_alpha_;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
LeakyReLUVisitor(Params const ¶ms):
|
||||
leaky_alpha_(params.leaky_alpha) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &frag) const {
|
||||
cutlass::epilogue::thread::LeakyReLU<Array<T, N>> leaky_op;
|
||||
return leaky_op(frag, leaky_alpha_);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool guard() {
|
||||
return true;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/// Tanh
|
||||
template <typename T, int N>
|
||||
struct TanhVisitor {
|
||||
/// Argument
|
||||
struct Arguments {
|
||||
// a placeholder argument to ensure correctness of ctypes
|
||||
int tmp;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): tmp(0) { };
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(int tmp): tmp(tmp) { };
|
||||
};
|
||||
|
||||
/// Param
|
||||
struct Params {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(){ };
|
||||
Params(Arguments const &args) { }
|
||||
};
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TanhVisitor(Params const ¶ms) { }
|
||||
|
||||
// scalar operator
|
||||
CUTLASS_HOST_DEVICE
|
||||
T tanh_op(T const &scalar) const {
|
||||
return fast_tanh(scalar);
|
||||
}
|
||||
|
||||
/// vector operator
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &frag) const {
|
||||
Array<T, N> y;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i=0; i < N; ++i) {
|
||||
y[i] = tanh_op(frag[i]);
|
||||
}
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool guard() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,148 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with accumulator
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue Visitor operator for the following Computation
|
||||
///
|
||||
/// ElementAccumulator accum;
|
||||
/// return accum;
|
||||
///
|
||||
/// It can only be the leaf node of the epilogue tree
|
||||
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
int kElementsPerAccess_ ///< Number of elements computed per operation
|
||||
>
|
||||
class VisitorOpAccumulator{
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
static int const kElementsPerAccess = kElementsPerAccess_;
|
||||
|
||||
/// Fragment type for Accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type returned by this visitor
|
||||
using VisitAccessType = AccumulatorAccessType;
|
||||
|
||||
/// SMEM buffer class required in the epilogue visitor
|
||||
struct SharedStorage {
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() {}
|
||||
};
|
||||
|
||||
/// Host-constructable Arguments structure
|
||||
struct Arguments {
|
||||
// Note: it is strange that ctypes will return issue with empty arguments
|
||||
int tmp;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(int tmp): tmp(tmp) { }
|
||||
};
|
||||
|
||||
/// Parameter structure
|
||||
struct Params {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args) { }
|
||||
};
|
||||
|
||||
public:
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpAccumulator(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
return accum;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() { }
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,246 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with Binary op
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "binary_ops.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementCompute alpha;
|
||||
/// ElementCompute beta;
|
||||
/// ElementCompute C = BinaryOp(alpha * ElementCompute(Visitor_A), beta * ElementCompute(Visitor_B)
|
||||
/// Return C;
|
||||
///
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename ElementCompute_, ///< Data type used to compute linear combination
|
||||
int kElementsPerAccess_, ///< Number of elements computed per operation
|
||||
typename VisitorA_, ///< Child node A
|
||||
typename VisitorB_, ///< Child node B
|
||||
template<typename T, int N> typename BinaryOp_
|
||||
>
|
||||
class VisitorOpBinary{
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
static int const kElementsPerAccess = kElementsPerAccess_;
|
||||
|
||||
using VisitorA = VisitorA_;
|
||||
using VisitorB = VisitorB_;
|
||||
|
||||
/// Fragment type returned from VisitorA.visit
|
||||
using VisitAccessTypeA = typename VisitorA::VisitAccessType;
|
||||
using ElementA = typename VisitAccessTypeA::Element;
|
||||
|
||||
/// Fragment type returned from VisitorB.visit
|
||||
using VisitAccessTypeB = typename VisitorB::VisitAccessType;
|
||||
using ElementB = typename VisitAccessTypeB::Element;
|
||||
|
||||
/// Fragment type returned by this visitor
|
||||
using VisitAccessType = Array<ElementCompute, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Combination Op TODO: generalize this
|
||||
using BinaryOp = BinaryOp_<ElementCompute, kElementsPerAccess>;
|
||||
|
||||
static_assert(kElementsPerAccess==VisitAccessTypeA::kElements, "kElementsPerAccess mismatches with Visitor A");
|
||||
static_assert(kElementsPerAccess==VisitAccessTypeB::kElements, "kElementsPerAccess misnatches with Visitor B");
|
||||
|
||||
/// SMEM buffer class required in the epilogue visitor
|
||||
struct SharedStorage {
|
||||
typename VisitorA::SharedStorage storage_a;
|
||||
typename VisitorB::SharedStorage storage_b;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() {}
|
||||
};
|
||||
|
||||
|
||||
/// Host-constructable Arguments structure
|
||||
struct Arguments {
|
||||
typename BinaryOp::Arguments binary_arg;
|
||||
typename VisitorA::Arguments visitor_a_arg; ///< Argument type for visitor_a
|
||||
typename VisitorB::Arguments visitor_b_arg; ///< Argument type for visitor_b
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():binary_arg() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
typename BinaryOp::Arguments binary_arg,
|
||||
typename VisitorA::Arguments visitor_a_arg,
|
||||
typename VisitorB::Arguments visitor_b_arg
|
||||
):
|
||||
binary_arg(binary_arg),
|
||||
visitor_a_arg(visitor_a_arg),
|
||||
visitor_b_arg(visitor_b_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
/// Parameter structure
|
||||
struct Params {
|
||||
typename BinaryOp::Params binary_param;
|
||||
typename VisitorA::Params visitor_a_param; ///< Argument type for visitor_a
|
||||
typename VisitorB::Params visitor_b_param; ///< Argument type for visitor_b
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
binary_param(args.binary_arg),
|
||||
visitor_a_param(args.visitor_a_arg),
|
||||
visitor_b_param(args.visitor_b_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
BinaryOp binary_op;
|
||||
|
||||
VisitorA visitor_a_op;
|
||||
VisitorB visitor_b_op;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpBinary(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
binary_op(params.binary_param),
|
||||
visitor_a_op(params.visitor_a_param, shared_storage.storage_a, thread_idx, threadblock_offset, problem_size),
|
||||
visitor_b_op(params.visitor_b_param, shared_storage.storage_b, thread_idx, threadblock_offset, problem_size)
|
||||
{ }
|
||||
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
visitor_a_op.begin_epilogue();
|
||||
visitor_b_op.begin_epilogue();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
visitor_a_op.set_batch_index(batch_idx);
|
||||
visitor_b_op.set_batch_index(batch_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
visitor_a_op.begin_step(step_idx);
|
||||
visitor_b_op.begin_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
visitor_a_op.begin_row(row_idx);
|
||||
visitor_b_op.begin_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
/// Get result from visitor A and visitor B
|
||||
VisitAccessTypeA result_A = visitor_a_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
VisitAccessTypeB result_B = visitor_b_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
|
||||
/// Type conversion
|
||||
NumericArrayConverter<ElementCompute, ElementA, kElementsPerAccess> source_converter_A;
|
||||
NumericArrayConverter<ElementCompute, ElementB, kElementsPerAccess> source_converter_B;
|
||||
|
||||
return binary_op(
|
||||
source_converter_A(result_A),
|
||||
source_converter_B(result_B)
|
||||
);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
visitor_a_op.end_row(row_idx);
|
||||
visitor_b_op.end_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
visitor_a_op.end_step(step_idx);
|
||||
visitor_b_op.end_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
visitor_a_op.end_epilogue();
|
||||
visitor_b_op.end_epilogue();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,250 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with broadcasting vector to all columns
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementVector T[i][j] <- device-memory Td[i]
|
||||
///
|
||||
/// It can only be a leaf node in the epilogue tree
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename ElementFragment_, ///< Data type used to cache vector in register
|
||||
typename InputTileIterator_ ///< Tile iterator type to read the broadcasted tensor
|
||||
>
|
||||
class VisitorOpColumnBroadcast {
|
||||
public:
|
||||
using InputTileIterator = InputTileIterator_;
|
||||
|
||||
static int const kElementsPerAccess = InputTileIterator::kElementsPerAccess;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementVector = typename InputTileIterator::Element;
|
||||
using ElementFragment = ElementFragment_;
|
||||
|
||||
using VisitAccessType = Array<ElementFragment, kElementsPerAccess>;
|
||||
|
||||
/// Thread map used by input tile iterators
|
||||
using ThreadMap = typename InputTileIterator::ThreadMap;
|
||||
|
||||
/// Fragment object used to store the broadcast values
|
||||
using BroadcastFragment = Array<
|
||||
ElementFragment, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Used for the broadcast
|
||||
struct BroadcastDetail {
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = 32;
|
||||
|
||||
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar column indices handled by each thread
|
||||
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar row indices handled by each thread
|
||||
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
|
||||
|
||||
/// Number of threads per threadblock
|
||||
static int const kThreadCount = ThreadMap::kThreads;
|
||||
|
||||
/// Number of distinct threads per row of output tile
|
||||
static int const kThreadsPerRow = (InputTileIterator::Shape::kN / kColumnsPerThread);
|
||||
|
||||
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock.
|
||||
static int const kThreadRows = kThreadCount / kThreadsPerRow;
|
||||
|
||||
// /// Number of iterations (accesses) the threadblock takes to reduce a row
|
||||
// static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount);
|
||||
};
|
||||
|
||||
// using ComputeFragmentType = Array<ElementVector, BroadcastDetail::kElementsPerAccess>;
|
||||
|
||||
struct SharedStorage {
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() { }
|
||||
};
|
||||
|
||||
/// Host-constructable Argument structure
|
||||
struct Arguments {
|
||||
ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand
|
||||
int64_t batch_stride;
|
||||
|
||||
/// Methods
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():
|
||||
broadcast_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ElementVector *broadcast_ptr,
|
||||
int64_t batch_stride
|
||||
):
|
||||
broadcast_ptr(broadcast_ptr),
|
||||
batch_stride(batch_stride) { }
|
||||
};
|
||||
|
||||
/// Param structure
|
||||
struct Params {
|
||||
ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand
|
||||
int64_t batch_stride;
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
broadcast_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
broadcast_ptr(args.broadcast_ptr),
|
||||
batch_stride(args.batch_stride) { }
|
||||
};
|
||||
|
||||
private:
|
||||
ElementVector *broadcast_ptr;
|
||||
BroadcastFragment broadcast_fragment; ///< Array holds the loaded broadcast fragment
|
||||
MatrixCoord threadblock_offset_;
|
||||
int thread_idx_;
|
||||
MatrixCoord problem_size;
|
||||
|
||||
int thread_start_row_;
|
||||
int state_[3];
|
||||
int thread_offset_row_;
|
||||
|
||||
int64_t batch_stride_;
|
||||
|
||||
public:
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpColumnBroadcast(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
broadcast_ptr(params.broadcast_ptr),
|
||||
threadblock_offset_(threadblock_offset),
|
||||
thread_idx_(thread_idx),
|
||||
problem_size(problem_size),
|
||||
thread_start_row_(ThreadMap::initial_offset(thread_idx).row() + threadblock_offset.row()),
|
||||
batch_stride_(params.batch_stride)
|
||||
{
|
||||
state_[0] = state_[1] = state_[2] = 0;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
broadcast_ptr += batch_idx * batch_stride_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
// get pointer
|
||||
thread_offset_row_ = thread_start_row_ + ThreadMap::iteration_offset(frag_idx).row();
|
||||
|
||||
ElementFragment broadcast_data = ElementFragment(*(broadcast_ptr + thread_offset_row_));
|
||||
|
||||
broadcast_fragment.fill(broadcast_data);
|
||||
|
||||
return broadcast_fragment;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
// run operator ++
|
||||
++state_[0];
|
||||
|
||||
thread_start_row_ += ThreadMap::Shape::kRow;
|
||||
if (state_[0] == ThreadMap::Count::kRow) {
|
||||
state_[0] = 0;
|
||||
++state_[1];
|
||||
thread_start_row_ += (ThreadMap::Shape::kGroup - 1) *
|
||||
ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
|
||||
|
||||
if (state_[1] == ThreadMap::Count::kGroup) {
|
||||
state_[1] = 0;
|
||||
++state_[2];
|
||||
thread_start_row_ += ThreadMap::Count::kGroup *
|
||||
ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;
|
||||
|
||||
if (state_[2] == ThreadMap::Count::kCluster) {
|
||||
state_[2] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() { }
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,342 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with reduction over columns in CTA
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementReductionAccumulator R[j] = \sum_i ElementReductionAccumulator(T[i][j])
|
||||
/// device memory <- ElementReduction(R[j])
|
||||
///
|
||||
template <
|
||||
typename ThreadblockShape_, /// Threadblock shape
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename ElementReduction_, ///< Data type of the output reduction in device memory
|
||||
typename ElementReductionAccumulator_ , ///< Data type to accumulate reduction in smem and register
|
||||
typename OutputTileIterator_, ///< Tile Iterator type
|
||||
typename Visitor_ ///< preceeding visitor op
|
||||
>
|
||||
class VisitorOpColumnReduction {
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementReductionAccumulator = ElementReductionAccumulator_;
|
||||
using ElementReduction = ElementReduction_;
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
using Visitor = Visitor_;
|
||||
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
// TODO: generalize the reduction op
|
||||
using ReductionOp = cutlass::plus<Array<ElementReductionAccumulator, kElementsPerAccess>>;
|
||||
using ReductionOpScalar = cutlass::plus<ElementReductionAccumulator>;
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
|
||||
|
||||
|
||||
/// Fragment type returned from Visitor
|
||||
using VisitAccessTypeVisitor = typename Visitor::VisitAccessType;
|
||||
using ElementVisitor = typename VisitAccessTypeVisitor::Element;
|
||||
|
||||
using VisitAccessType = VisitAccessTypeVisitor;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of redcution
|
||||
using ReductionAccumulatorAccessType = Array<ElementReductionAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Thread map used by output tile iterators
|
||||
using ThreadMap = typename OutputTileIterator::ThreadMap;
|
||||
/// Used for the reduction
|
||||
struct ReductionDetail {
|
||||
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = 32;
|
||||
|
||||
/// Number of distinct scalar column indices handled by each thread
|
||||
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar row indices handled by each thread
|
||||
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
|
||||
|
||||
/// Number of threads per threadblock
|
||||
static int const kThreadCount = ThreadMap::kThreads;
|
||||
|
||||
/// Number of distinct threads per row of output tile
|
||||
static int const kThreadsPerRow = ThreadblockShape::kN / kColumnsPerThread;
|
||||
|
||||
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock
|
||||
static int const kThreadRows = kThreadCount / kThreadsPerRow;
|
||||
|
||||
/// Number of iterations (accesses) the threadblock takes to reduce a row
|
||||
static int const kThreadAccessesPerRow = const_max(1, (ThreadblockShape::kN + kThreadCount - 1) / kThreadCount);
|
||||
|
||||
using StorageShape = MatrixShape<
|
||||
kThreadRows,
|
||||
ThreadblockShape::kN
|
||||
>;
|
||||
};
|
||||
|
||||
using ReductionFragment = Array<ElementReductionAccumulator, ReductionDetail::kColumnsPerThread>;
|
||||
|
||||
/// Shared storage
|
||||
struct SharedStorage {
|
||||
typename Visitor::SharedStorage storage_visitor;
|
||||
AlignedArray<ElementReductionAccumulator, ReductionDetail::StorageShape::kCount, 16> reduction;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() {}
|
||||
};
|
||||
|
||||
/// Host-constructable Argument structure
|
||||
struct Arguments {
|
||||
ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory
|
||||
int64_t batch_stride;
|
||||
typename Visitor::Arguments visitor_arg; ///< Argument type of visitor
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): reduction_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ElementReduction *reduction_ptr,
|
||||
int64_t batch_stride,
|
||||
typename Visitor::Arguments visitor_arg
|
||||
):
|
||||
reduction_ptr(reduction_ptr),
|
||||
batch_stride(batch_stride),
|
||||
visitor_arg(visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
/// Param structure
|
||||
struct Params {
|
||||
ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory
|
||||
int64_t batch_stride;
|
||||
typename Visitor::Params visitor_param; ///< Argument type of visitor
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): reduction_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
reduction_ptr(args.reduction_ptr),
|
||||
batch_stride(args.batch_stride),
|
||||
visitor_param(args.visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
private:
|
||||
ElementReduction *reduction_output_ptr_; ///< Pointer to the reduction tensor in device memory
|
||||
ElementReductionAccumulator *reduction_smem_ptr_; ///< Pointer to the partial reductions in shared memory
|
||||
ReductionFragment reduction_fragment; ///< register fragments that hold the partial reduction
|
||||
Visitor visitor_; ///< visitor
|
||||
int thread_idx_;
|
||||
MatrixCoord threadblock_offset;
|
||||
MatrixCoord problem_size_;
|
||||
int64_t batch_stride_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpColumnReduction(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
visitor_(params.visitor_param, shared_storage.storage_visitor,
|
||||
thread_idx, threadblock_offset, problem_size),
|
||||
reduction_smem_ptr_(shared_storage.reduction.data()),
|
||||
reduction_output_ptr_(params.reduction_ptr),
|
||||
thread_idx_(thread_idx),
|
||||
threadblock_offset(threadblock_offset),
|
||||
problem_size_(problem_size),
|
||||
batch_stride_(params.batch_stride)
|
||||
{ }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
reduction_output_ptr_ += batch_idx * batch_stride_;
|
||||
visitor_.set_batch_index(batch_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
visitor_.begin_epilogue();
|
||||
|
||||
// clear the reduction fragment
|
||||
reduction_fragment.clear();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
visitor_.begin_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
visitor_.begin_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
/// Get result from visitor
|
||||
VisitAccessTypeVisitor result = visitor_.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
|
||||
NumericArrayConverter<ElementReductionAccumulator, ElementVisitor, kElementsPerAccess> reduction_converter;
|
||||
ReductionOp reduction_op;
|
||||
ReductionAccumulatorAccessType* reduction_fragment_ = reinterpret_cast<ReductionAccumulatorAccessType*>(&reduction_fragment);
|
||||
reduction_fragment_[column_idx] = reduction_op(reduction_fragment_[column_idx], reduction_converter(result));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
visitor_.end_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
visitor_.end_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
visitor_.end_epilogue();
|
||||
//
|
||||
// Store the partially reduced value to SMEM
|
||||
//
|
||||
|
||||
// Guard against uses of the existing SMEM tile
|
||||
__syncthreads();
|
||||
|
||||
using AccessType = AlignedArray<ElementReductionAccumulator, ThreadMap::kElementsPerAccess>;
|
||||
|
||||
//
|
||||
// Determine a compact thread arrangement to store to SMEM
|
||||
//
|
||||
|
||||
MatrixCoord thread_offset(
|
||||
thread_idx_ / ReductionDetail::kThreadsPerRow,
|
||||
(thread_idx_ % ReductionDetail::kThreadsPerRow) * ThreadMap::kElementsPerAccess
|
||||
);
|
||||
|
||||
//
|
||||
// Each thread store its fragment to a SMEM
|
||||
//
|
||||
AccessType *aligned_reduction_ptr = reinterpret_cast<AccessType *>(
|
||||
&reduction_smem_ptr_[thread_offset.row() * ThreadblockShape::kN + thread_offset.column()]
|
||||
);
|
||||
|
||||
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(
|
||||
&reduction_fragment
|
||||
);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
|
||||
int col_idx = column * ThreadMap::Delta::kColumn / ThreadMap::kElementsPerAccess;
|
||||
|
||||
aligned_reduction_ptr[col_idx] = frag_ptr[column];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
// Now, threads are assigned several columns of the output. The fetch over all rows from
|
||||
// the compacted SMEM tile and perform a reduction.
|
||||
//
|
||||
|
||||
NumericConverter<ElementReduction, ElementReductionAccumulator> output_converter;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < ReductionDetail::kThreadAccessesPerRow; ++j) {
|
||||
int column_idx = thread_idx_ + j * ReductionDetail::kThreadCount;
|
||||
|
||||
ReductionOpScalar reduction_op;
|
||||
ElementReductionAccumulator reduction_element = ElementReductionAccumulator();
|
||||
|
||||
int output_column_idx = threadblock_offset.column() + column_idx;
|
||||
|
||||
if (column_idx < ThreadblockShape::kN && output_column_idx < problem_size_.column()) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int row = 0; row < ReductionDetail::kThreadRows; ++row) {
|
||||
if (row) {
|
||||
auto frag = reduction_smem_ptr_[row * ThreadblockShape::kN + column_idx];
|
||||
reduction_element = reduction_op(reduction_element, frag);
|
||||
}
|
||||
else {
|
||||
|
||||
reduction_element = reduction_smem_ptr_[column_idx];
|
||||
}
|
||||
}
|
||||
|
||||
// Store
|
||||
reduction_output_ptr_[column_idx + threadblock_offset.column() + threadblock_offset.row() / ThreadblockShape::kM * problem_size_.column()] = output_converter(reduction_element);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,266 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with Linear Combination
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementCompute alpha;
|
||||
/// ElementCompute beta;
|
||||
/// ElementCompute C = BinaryOp(alpha * ElementCompute(Visitor_A), beta * ElementCompute(Visitor_B)
|
||||
/// Return C;
|
||||
///
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename ElementCompute_, ///< Data type used to compute linear combination
|
||||
int kElementsPerAccess_, ///< Number of elements computed per operation
|
||||
typename VisitorA_, ///< Child node A
|
||||
typename VisitorB_ ///< Child node B
|
||||
>
|
||||
class VisitorOpLinearCombination{
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
static int const kElementsPerAccess = kElementsPerAccess_;
|
||||
|
||||
using VisitorA = VisitorA_;
|
||||
using VisitorB = VisitorB_;
|
||||
|
||||
/// Fragment type returned from VisitorA.visit
|
||||
using VisitAccessTypeA = typename VisitorA::VisitAccessType;
|
||||
using ElementA = typename VisitAccessTypeA::Element;
|
||||
|
||||
/// Fragment type returned from VisitorB.visit
|
||||
using VisitAccessTypeB = typename VisitorB::VisitAccessType;
|
||||
using ElementB = typename VisitAccessTypeB::Element;
|
||||
|
||||
/// Fragment type returned by this visitor
|
||||
using VisitAccessType = Array<ElementCompute, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Combination Op TODO: generalize this
|
||||
using CombinationOp = cutlass::plus<VisitAccessType>;
|
||||
|
||||
static_assert(kElementsPerAccess==VisitAccessTypeA::kElements, "kElementsPerAccess mismatches with Visitor A");
|
||||
static_assert(kElementsPerAccess==VisitAccessTypeB::kElements, "kElementsPerAccess misnatches with Visitor B");
|
||||
|
||||
/// SMEM buffer class required in the epilogue visitor
|
||||
struct SharedStorage {
|
||||
typename VisitorA::SharedStorage storage_a;
|
||||
typename VisitorB::SharedStorage storage_b;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() {}
|
||||
};
|
||||
|
||||
|
||||
/// Host-constructable Arguments structure
|
||||
struct Arguments {
|
||||
ElementCompute alpha; ///< scales accumulators
|
||||
ElementCompute beta; ///< scales source tensor
|
||||
typename VisitorA::Arguments visitor_a_arg; ///< Argument type for visitor_a
|
||||
typename VisitorB::Arguments visitor_b_arg; ///< Argument type for visitor_b
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():
|
||||
alpha(ElementCompute(1)),
|
||||
beta(ElementCompute(0))
|
||||
{ }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta,
|
||||
typename VisitorA::Arguments visitor_a_arg,
|
||||
typename VisitorB::Arguments visitor_b_arg
|
||||
):
|
||||
alpha(alpha),
|
||||
beta(beta),
|
||||
visitor_a_arg(visitor_a_arg),
|
||||
visitor_b_arg(visitor_b_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
/// Parameter structure
|
||||
struct Params {
|
||||
ElementCompute alpha; ///< scales accumulators
|
||||
ElementCompute beta; ///< scales source tensor
|
||||
typename VisitorA::Params visitor_a_param; ///< Argument type for visitor_a
|
||||
typename VisitorB::Params visitor_b_param; ///< Argument type for visitor_b
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
alpha(args.alpha),
|
||||
beta(args.beta),
|
||||
visitor_a_param(args.visitor_a_arg),
|
||||
visitor_b_param(args.visitor_b_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
ElementCompute alpha_;
|
||||
ElementCompute beta_;
|
||||
|
||||
VisitorA visitor_a_op;
|
||||
VisitorB visitor_b_op;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpLinearCombination(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
alpha_(params.alpha),
|
||||
beta_(params.beta),
|
||||
visitor_a_op(params.visitor_a_param, shared_storage.storage_a, thread_idx, threadblock_offset, problem_size),
|
||||
visitor_b_op(params.visitor_b_param, shared_storage.storage_b, thread_idx, threadblock_offset, problem_size)
|
||||
{ }
|
||||
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
if (alpha_ != ElementCompute(0)) visitor_a_op.begin_epilogue();
|
||||
if (beta_ != ElementCompute(0)) visitor_b_op.begin_epilogue();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
if (alpha_ != ElementCompute(0)) visitor_a_op.begin_step(step_idx);
|
||||
if (beta_ != ElementCompute(0)) visitor_b_op.begin_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
if (alpha_ != ElementCompute(0)) visitor_a_op.begin_row(row_idx);
|
||||
if (beta_ != ElementCompute(0)) visitor_b_op.begin_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
/// Get result from visitor A and visitor B
|
||||
VisitAccessTypeA result_A;
|
||||
VisitAccessTypeB result_B;
|
||||
|
||||
if (alpha_ != ElementCompute(0)) {
|
||||
result_A = visitor_a_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
} else {
|
||||
// Fill the result A with zeros
|
||||
result_A.clear();
|
||||
}
|
||||
|
||||
if (beta_ != ElementCompute(0)) {
|
||||
result_B = visitor_b_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
} else {
|
||||
// Fill the result B with zeros
|
||||
result_B.clear();
|
||||
}
|
||||
|
||||
/// Type conversion
|
||||
NumericArrayConverter<ElementCompute, ElementA, kElementsPerAccess> source_converter_A;
|
||||
NumericArrayConverter<ElementCompute, ElementB, kElementsPerAccess> source_converter_B;
|
||||
|
||||
CombinationOp combination_op;
|
||||
|
||||
cutlass::multiplies<VisitAccessType> multiply_op;
|
||||
|
||||
return combination_op(
|
||||
multiply_op(alpha_, source_converter_A(result_A)),
|
||||
multiply_op(beta_, source_converter_B(result_B))
|
||||
);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
if (alpha_ != ElementCompute(0)) visitor_a_op.end_row(row_idx);
|
||||
if (beta_ != ElementCompute(0)) visitor_b_op.end_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
if (alpha_ != ElementCompute(0)) visitor_a_op.end_step(step_idx);
|
||||
if (beta_ != ElementCompute(0)) visitor_b_op.end_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
if (alpha_ != ElementCompute(0)) visitor_a_op.end_epilogue();
|
||||
if (beta_ != ElementCompute(0)) visitor_b_op.end_epilogue();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,258 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with broadcasting vector to all rows
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementVector T[i][j] <- device-memory Td[j]
|
||||
///
|
||||
/// It can only be a leaf node in the epilogue tree
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename ElementFragment_, ///< Data type used to cache vector in register
|
||||
typename InputTileIterator_ ///< Tile iterator type to read the broadcasted tensor
|
||||
>
|
||||
class VisitorOpRowBroadcast {
|
||||
public:
|
||||
using InputTileIterator = InputTileIterator_;
|
||||
|
||||
static int const kElementsPerAccess = InputTileIterator::kElementsPerAccess;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementVector = typename InputTileIterator::Element;
|
||||
using ElementFragment = ElementFragment_;
|
||||
|
||||
using VisitAccessType = Array<ElementFragment, kElementsPerAccess>;
|
||||
|
||||
/// Thread map used by input tile iterators
|
||||
using ThreadMap = typename InputTileIterator::ThreadMap;
|
||||
|
||||
/// Fragment object used to store the broadcast values
|
||||
using BroadcastFragment = Array<
|
||||
ElementFragment,
|
||||
ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Used for the broadcast
|
||||
struct BroadcastDetail {
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = 32;
|
||||
|
||||
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar column indices handled by each thread
|
||||
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar row indices handled by each thread
|
||||
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
|
||||
|
||||
/// Number of threads per threadblock
|
||||
static int const kThreadCount = ThreadMap::kThreads;
|
||||
|
||||
/// Number of distinct threads per row of output tile
|
||||
static int const kThreadsPerRow = (InputTileIterator::Shape::kN / kColumnsPerThread);
|
||||
|
||||
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock.
|
||||
static int const kThreadRows = kThreadCount / kThreadsPerRow;
|
||||
|
||||
// /// Number of iterations (accesses) the threadblock takes to reduce a row
|
||||
// static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount);
|
||||
};
|
||||
|
||||
// using ComputeFragmentType = Array<ElementVector, BroadcastDetail::kElementsPerAccess>;
|
||||
|
||||
struct SharedStorage {
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() { }
|
||||
};
|
||||
|
||||
/// Host-constructable Argument structure
|
||||
struct Arguments {
|
||||
ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand
|
||||
int64_t batch_stride;
|
||||
|
||||
/// Methods
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():
|
||||
broadcast_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ElementVector *broadcast_ptr,
|
||||
int64_t batch_stride
|
||||
):
|
||||
broadcast_ptr(broadcast_ptr),
|
||||
batch_stride(batch_stride) { }
|
||||
};
|
||||
|
||||
/// Param structure
|
||||
struct Params {
|
||||
ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand
|
||||
int64_t batch_stride;
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
broadcast_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
broadcast_ptr(args.broadcast_ptr),
|
||||
batch_stride(args.batch_stride) { }
|
||||
};
|
||||
|
||||
private:
|
||||
ElementVector *broadcast_ptr;
|
||||
BroadcastFragment broadcast_fragment; ///< Array holds the loaded broadcast fragment
|
||||
MatrixCoord threadblock_offset_;
|
||||
int thread_idx_;
|
||||
MatrixCoord problem_size;
|
||||
int64_t batch_stride_;
|
||||
|
||||
public:
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpRowBroadcast(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
broadcast_ptr(params.broadcast_ptr + threadblock_offset.column()),
|
||||
threadblock_offset_(threadblock_offset),
|
||||
thread_idx_(thread_idx),
|
||||
problem_size(problem_size),
|
||||
batch_stride_(params.batch_stride) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
broadcast_ptr += batch_idx * batch_stride_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
// load broadcast fragment
|
||||
load_broadcast_fragment_();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
VisitAccessType* broadcast_fragment_ = reinterpret_cast<VisitAccessType*>(&broadcast_fragment);
|
||||
return broadcast_fragment_[column_idx];
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() { }
|
||||
|
||||
private:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void load_broadcast_fragment_() {
|
||||
|
||||
broadcast_fragment.clear();
|
||||
|
||||
// If no pointer is supplied, set with all zeros and avoid memory accesses
|
||||
if (!broadcast_ptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column();
|
||||
|
||||
int thread_column_idx = threadblock_offset_.column() + thread_initial_column;
|
||||
broadcast_ptr += thread_initial_column;
|
||||
|
||||
NumericArrayConverter<ElementFragment, ElementVector, BroadcastDetail::kElementsPerAccess> converter;
|
||||
using AccessType = AlignedArray<ElementVector, BroadcastDetail::kElementsPerAccess>;
|
||||
using AccessFragmentType = Array<ElementFragment, BroadcastDetail::kElementsPerAccess>;
|
||||
|
||||
AccessFragmentType *frag_ptr = reinterpret_cast<AccessFragmentType *>(&broadcast_fragment);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) {
|
||||
|
||||
AccessType loaded;
|
||||
|
||||
loaded.clear();
|
||||
|
||||
if (thread_column_idx < problem_size.column()) {
|
||||
loaded = *reinterpret_cast<AccessType const *>(broadcast_ptr);
|
||||
}
|
||||
|
||||
AccessFragmentType cvt = converter(loaded);
|
||||
frag_ptr[j] = cvt;
|
||||
|
||||
thread_column_idx += ThreadMap::Delta::kColumn;
|
||||
broadcast_ptr += ThreadMap::Delta::kColumn;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,320 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with reduction over rows in CTA
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "stdio.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementReductionAccumulator R[i] = \sum_i ElementReductionAccumulator(T[i][j])
|
||||
/// device memory <- ElementReduction(R[i])
|
||||
///
|
||||
template <
|
||||
typename ThreadblockShape_, /// Threadblock shape
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename ElementReduction_, ///< Data type of the output reduction in device memory
|
||||
typename ElementReductionAccumulator_ , ///< Data type to accumulate reduction in smem and register
|
||||
typename OutputTileIterator_, ///< Tile Iterator type
|
||||
typename Visitor_ ///< preceeding visitor op
|
||||
>
|
||||
class VisitorOpRowReduction {
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementReductionAccumulator = ElementReductionAccumulator_;
|
||||
using ElementReduction = ElementReduction_;
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
using Visitor = Visitor_;
|
||||
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
// TODO: generalize the reduction op
|
||||
using ReductionOp = cutlass::plus<Array<ElementReductionAccumulator, kElementsPerAccess>>;
|
||||
using ReductionOpScalar = cutlass::plus<ElementReductionAccumulator>;
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
|
||||
/// Fragment type returned from Visitor
|
||||
using VisitAccessTypeVisitor = typename Visitor::VisitAccessType;
|
||||
using ElementVisitor = typename VisitAccessTypeVisitor::Element;
|
||||
|
||||
using VisitAccessType = VisitAccessTypeVisitor;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of redcution
|
||||
using ReductionAccumulatorAccessType = Array<ElementReductionAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Thread map used by output tile iterators
|
||||
using ThreadMap = typename OutputTileIterator::ThreadMap;
|
||||
/// Used for the reduction
|
||||
struct ReductionDetail {
|
||||
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = 32;
|
||||
|
||||
/// Number of distinct scalar column indices handled by each thread
|
||||
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar row indices handled by each thread
|
||||
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
|
||||
|
||||
/// Number of threads per threadblock
|
||||
static int const kThreadCount = ThreadMap::kThreads;
|
||||
|
||||
/// Number of distinct threads per row of output tile
|
||||
static int const kThreadsPerRow = ThreadblockShape::kN / kColumnsPerThread;
|
||||
|
||||
/// Half number of threads per row used for cross-thread reduction
|
||||
static int const kHalfThreadsPerRow = (kThreadsPerRow >> 1);
|
||||
|
||||
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock
|
||||
static int const kThreadRows = kThreadCount / kThreadsPerRow;
|
||||
};
|
||||
|
||||
/// Shared storage
|
||||
struct SharedStorage {
|
||||
typename Visitor::SharedStorage storage_visitor;
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() { }
|
||||
};
|
||||
|
||||
/// Host-constructable Argument structure
|
||||
struct Arguments {
|
||||
ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory
|
||||
int64_t batch_stride;
|
||||
typename Visitor::Arguments visitor_arg; ///< Argument type of visitor
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): reduction_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ElementReduction *reduction_ptr,
|
||||
int64_t batch_stride,
|
||||
typename Visitor::Arguments visitor_arg
|
||||
):
|
||||
reduction_ptr(reduction_ptr),
|
||||
batch_stride(batch_stride),
|
||||
visitor_arg(visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
/// Param structure
|
||||
struct Params {
|
||||
ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory
|
||||
int64_t batch_stride;
|
||||
typename Visitor::Params visitor_param; ///< Argument type of visitor
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): reduction_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
reduction_ptr(args.reduction_ptr),
|
||||
batch_stride(args.batch_stride),
|
||||
visitor_param(args.visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
private:
|
||||
ElementReduction *reduction_output_ptr_; ///< Pointer to the reduction tensor in device memory
|
||||
ElementReductionAccumulator reduction_accum;
|
||||
Visitor visitor_; ///< visitor
|
||||
int thread_idx_;
|
||||
MatrixCoord threadblock_offset;
|
||||
MatrixCoord problem_size_;
|
||||
|
||||
int thread_start_row_; /// used to identify
|
||||
int state_[3]; /// used to track row iterator
|
||||
int thread_offset_row_;
|
||||
int64_t batch_stride_;
|
||||
public:
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpRowReduction(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
visitor_(params.visitor_param, shared_storage.storage_visitor,
|
||||
thread_idx, threadblock_offset, problem_size),
|
||||
reduction_output_ptr_(params.reduction_ptr),
|
||||
thread_idx_(thread_idx),
|
||||
threadblock_offset(threadblock_offset),
|
||||
problem_size_(problem_size),
|
||||
thread_start_row_(ThreadMap::initial_offset(thread_idx).row() + threadblock_offset.row()),
|
||||
batch_stride_(params.batch_stride)
|
||||
{
|
||||
state_[0] = state_[1] = state_[2] = 0;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
reduction_output_ptr_ += batch_idx * batch_stride_;
|
||||
visitor_.set_batch_index(batch_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
visitor_.begin_epilogue();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
visitor_.begin_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
visitor_.begin_row(row_idx);
|
||||
|
||||
reduction_accum = ElementReductionAccumulator(0);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
/// Get result from visitor
|
||||
VisitAccessTypeVisitor result = visitor_.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
|
||||
thread_offset_row_ = thread_start_row_ + ThreadMap::iteration_offset(frag_idx).row();
|
||||
|
||||
ReductionOpScalar reduction_op;
|
||||
|
||||
ElementReductionAccumulator reduction_accum_ = reduction(result);
|
||||
|
||||
// After performing the in-thread reduction, we then perform cross-thread / in-warp reduction
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = ReductionDetail::kHalfThreadsPerRow; i > 0; i >>= 1) {
|
||||
reduction_accum_ = reduction_op(reduction_accum_, __shfl_xor_sync(0xFFFFFFFF, reduction_accum_, i));
|
||||
}
|
||||
reduction_accum = reduction_op(reduction_accum, reduction_accum_);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
visitor_.end_row(row_idx);
|
||||
NumericConverter<ElementReduction, ElementReductionAccumulator> output_converter;
|
||||
|
||||
bool is_write_thread = (thread_offset_row_ < problem_size_.row() && (thread_idx_ % ReductionDetail::kThreadsPerRow) == 0);
|
||||
int row_offset = thread_offset_row_ + threadblock_offset.column() / ThreadblockShape::kN * problem_size_.row();
|
||||
|
||||
ElementReduction *curr_ptr_reduction = reduction_output_ptr_ + row_offset;
|
||||
|
||||
arch::global_store<ElementReduction, sizeof(ElementReduction)>(
|
||||
output_converter(reduction_accum),
|
||||
(void *)curr_ptr_reduction,
|
||||
is_write_thread);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
visitor_.end_step(step_idx);
|
||||
|
||||
// run operator ++
|
||||
++state_[0];
|
||||
|
||||
thread_start_row_ += ThreadMap::Shape::kRow;
|
||||
if (state_[0] == ThreadMap::Count::kRow) {
|
||||
state_[0] = 0;
|
||||
++state_[1];
|
||||
thread_start_row_ += (ThreadMap::Shape::kGroup - 1) *
|
||||
ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
|
||||
|
||||
if (state_[1] == ThreadMap::Count::kGroup) {
|
||||
state_[1] = 0;
|
||||
++state_[2];
|
||||
thread_start_row_ += ThreadMap::Count::kGroup *
|
||||
ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;
|
||||
|
||||
if (state_[2] == ThreadMap::Count::kCluster) {
|
||||
state_[2] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
visitor_.end_epilogue();
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementReductionAccumulator reduction(VisitAccessTypeVisitor const& result) {
|
||||
ElementReductionAccumulator sum_ = ElementReductionAccumulator(0);
|
||||
|
||||
ReductionOpScalar reduction_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < VisitAccessTypeVisitor::kElements; ++i) {
|
||||
sum_ = reduction_op(sum_, result[i]);
|
||||
}
|
||||
|
||||
return sum_;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,188 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with Tensor Output
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementInput C <- device memory
|
||||
///
|
||||
/// It can only be a leaf node in the epilogue tree
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename InputTileIterator_ ///< Tile iterator type to read the tensor
|
||||
>
|
||||
class VisitorOpTensorInput {
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using InputTileIterator = InputTileIterator_;
|
||||
|
||||
static int const kElementsPerAccess = InputTileIterator::kElementsPerAccess;
|
||||
using ElementInput = typename InputTileIterator::Element;
|
||||
|
||||
using VisitAccessType = Array<ElementInput, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
struct SharedStorage {
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() { }
|
||||
};
|
||||
|
||||
/// Host-constructable Argument structure
|
||||
struct Arguments {
|
||||
ElementInput *input_ptr; ///< Pointer to the input tensor in device memory
|
||||
int ldt; ///< Leading dimension of the input tensor operand
|
||||
int64_t batch_stride; ///< batch stride for batched GEMM
|
||||
|
||||
/// Methods
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): input_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ElementInput *input_ptr,
|
||||
int ldt, int64_t batch_stride
|
||||
):
|
||||
input_ptr(input_ptr),
|
||||
ldt(ldt),
|
||||
batch_stride(batch_stride)
|
||||
{ }
|
||||
};
|
||||
|
||||
/// Param structure
|
||||
struct Params {
|
||||
typename InputTileIterator::Params params_input;
|
||||
ElementInput *input_ptr;
|
||||
int64_t batch_stride;
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
input_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
params_input(args.ldt),
|
||||
input_ptr(args.input_ptr),
|
||||
batch_stride(args.batch_stride)
|
||||
{ }
|
||||
};
|
||||
|
||||
private:
|
||||
InputTileIterator iterator_T_;
|
||||
typename InputTileIterator::Fragment fragment_T_;
|
||||
MatrixCoord problem_size;
|
||||
int64_t batch_stride_;
|
||||
|
||||
public:
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpTensorInput(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
iterator_T_(
|
||||
InputTileIterator(
|
||||
params.params_input,
|
||||
params.input_ptr,
|
||||
problem_size,
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
)
|
||||
),
|
||||
problem_size(problem_size),
|
||||
batch_stride_(params.batch_stride) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
iterator_T_.add_pointer_offset(batch_idx * batch_stride_);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
fragment_T_.clear();
|
||||
iterator_T_.load(fragment_T_);
|
||||
++iterator_T_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
VisitAccessType source = reinterpret_cast<VisitAccessType *>(&fragment_T_)[frag_idx];
|
||||
return source;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() { }
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,240 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with Tensor Output
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "stdio.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementOutput T = ElementOutput(Visitor)
|
||||
/// T-> device memory
|
||||
///
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename OutputTileIterator_, ///< Tile iterator type to write the tensor
|
||||
typename Visitor_ ///< Child visitor that produces the output tensor
|
||||
>
|
||||
class VisitorOpTensorOutput {
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
|
||||
using Visitor = Visitor_;
|
||||
|
||||
/// Fragment type returned from Visitor
|
||||
using VisitAccessTypeVisitor = typename Visitor::VisitAccessType;
|
||||
using ElementVisitor = typename VisitAccessTypeVisitor::Element;
|
||||
|
||||
using VisitAccessType = VisitAccessTypeVisitor;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of output
|
||||
using OutputAccessType = Array<ElementOutput, kElementsPerAccess>;
|
||||
|
||||
static_assert(kElementsPerAccess==VisitAccessTypeVisitor::kElements, "kElementsPerAccess mismatches with Visitor");
|
||||
|
||||
struct SharedStorage {
|
||||
typename Visitor::SharedStorage storage_visitor;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() { }
|
||||
};
|
||||
|
||||
/// Host-constructable Argument structure
|
||||
struct Arguments {
|
||||
ElementOutput *output_ptr; ///< Pointer to the output tensor in device memory
|
||||
int ldt; ///< Leading dimension of the output tensor operand
|
||||
int64_t batch_stride; ///< batch stride
|
||||
typename Visitor::Arguments visitor_arg; ///< Argument type of visitor
|
||||
|
||||
/// Methods
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): output_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ElementOutput *output_ptr,
|
||||
int ldt,
|
||||
int64_t batch_stride,
|
||||
typename Visitor::Arguments visitor_arg
|
||||
):
|
||||
output_ptr(output_ptr),
|
||||
ldt(ldt),
|
||||
batch_stride(batch_stride),
|
||||
visitor_arg(visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
/// Param structure
|
||||
struct Params {
|
||||
typename OutputTileIterator::Params params_output;
|
||||
ElementOutput *output_ptr;
|
||||
int64_t batch_stride;
|
||||
typename Visitor::Params visitor_param;
|
||||
|
||||
/// Method
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
output_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
params_output(args.ldt),
|
||||
output_ptr(args.output_ptr),
|
||||
batch_stride(args.batch_stride),
|
||||
visitor_param(args.visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
private:
|
||||
OutputTileIterator iterator_T_;
|
||||
typename OutputTileIterator::Fragment fragment_T_;
|
||||
MatrixCoord problem_size;
|
||||
Visitor visitor_;
|
||||
int64_t batch_stride_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpTensorOutput(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
visitor_(params.visitor_param, shared_storage.storage_visitor, thread_idx, threadblock_offset, problem_size),
|
||||
iterator_T_(
|
||||
OutputTileIterator(
|
||||
params.params_output,
|
||||
params.output_ptr,
|
||||
problem_size,
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
)
|
||||
),
|
||||
problem_size(problem_size),
|
||||
batch_stride_(params.batch_stride) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
iterator_T_.add_pointer_offset(batch_idx * batch_stride_);
|
||||
visitor_.set_batch_index(batch_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
visitor_.begin_epilogue();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
fragment_T_.clear();
|
||||
visitor_.begin_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
visitor_.begin_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
/// Get result from visitor
|
||||
VisitAccessTypeVisitor result = visitor_.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
|
||||
// Column guard
|
||||
MatrixCoord thread_offset_ = iterator_T_.thread_start() + OutputTileIterator::ThreadMap::iteration_offset(frag_idx);
|
||||
bool column_guard = (thread_offset_.column() < problem_size.column());
|
||||
|
||||
if (column_guard) {
|
||||
NumericArrayConverter<ElementOutput, ElementVisitor, kElementsPerAccess> output_converter;
|
||||
OutputAccessType &output = reinterpret_cast<OutputAccessType *>(&fragment_T_)[frag_idx];
|
||||
output = output_converter(result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
visitor_.end_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
visitor_.end_step(step_idx);
|
||||
iterator_T_.store(fragment_T_);
|
||||
++iterator_T_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
visitor_.end_epilogue();
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,226 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
|
||||
\brief A file contains the epilogue visitor Op with Unary operation
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "unary_ops.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
/// Epilogue Visitor operator for the following computation:
|
||||
///
|
||||
/// ElementCompute alpha;
|
||||
/// ElementCompute beta;
|
||||
/// ElementCompute C = UnaryOp(ElementCompute(Visitor))
|
||||
/// Return C;
|
||||
///
|
||||
template <
|
||||
typename ElementAccumulator_, ///< Data type of the Accumulator
|
||||
typename ElementCompute_, ///< Data type used to compute linear combination
|
||||
int kElementsPerAccess_, ///< Number of elements computed per operation
|
||||
typename Visitor_, ///< Child node
|
||||
template<typename T, int N> typename UnaryOp_
|
||||
>
|
||||
class VisitorOpUnary{
|
||||
public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
static int const kElementsPerAccess = kElementsPerAccess_;
|
||||
|
||||
using Visitor = Visitor_;
|
||||
|
||||
/// Fragment type returned from Visitor.visit
|
||||
using VisitAccessTypeVisitor = typename Visitor::VisitAccessType;
|
||||
using ElementVisit = typename VisitAccessTypeVisitor::Element;
|
||||
|
||||
/// Fragment type returned by this visitor
|
||||
using VisitAccessType = Array<ElementCompute, kElementsPerAccess>;
|
||||
|
||||
/// Fragment type of accumulator
|
||||
using AccumulatorAccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
/// Combination Op TODO: generalize this
|
||||
using UnaryOp = UnaryOp_<ElementCompute, kElementsPerAccess>;
|
||||
|
||||
static_assert(kElementsPerAccess==VisitAccessTypeVisitor::kElements, "kElementsPerAccess mismatches with Visitor");
|
||||
|
||||
/// SMEM buffer class required in the epilogue visitor
|
||||
struct SharedStorage {
|
||||
typename Visitor::SharedStorage storage_visitor;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() {}
|
||||
};
|
||||
|
||||
|
||||
/// Host-constructable Arguments structure
|
||||
struct Arguments {
|
||||
typename UnaryOp::Arguments unary_arg;
|
||||
typename Visitor::Arguments visitor_arg; ///< Argument type for visitor
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():unary_arg() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
typename UnaryOp::Arguments unary_arg,
|
||||
typename Visitor::Arguments visitor_arg
|
||||
):
|
||||
unary_arg(unary_arg),
|
||||
visitor_arg(visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
/// Parameter structure
|
||||
struct Params {
|
||||
typename UnaryOp::Params unary_param;
|
||||
typename Visitor::Params visitor_param; ///< Argument type for visitor
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():unary_param() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
unary_param(args.unary_arg),
|
||||
visitor_param(args.visitor_arg)
|
||||
{ }
|
||||
};
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
UnaryOp unary_op;
|
||||
|
||||
Visitor visitor_op;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the function object
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorOpUnary(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset,
|
||||
MatrixCoord problem_size
|
||||
):
|
||||
unary_op(params.unary_param),
|
||||
visitor_op(params.visitor_param, shared_storage.storage_visitor, thread_idx, threadblock_offset, problem_size)
|
||||
{ }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
visitor_op.set_batch_index(batch_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
if (unary_op.guard()) visitor_op.begin_epilogue();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
if (unary_op.guard()) visitor_op.begin_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
if (unary_op.guard()) visitor_op.begin_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
VisitAccessType visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorAccessType const &accum
|
||||
) {
|
||||
/// Get result from visitor A and visitor B
|
||||
VisitAccessTypeVisitor result;
|
||||
|
||||
if (unary_op.guard()){
|
||||
result = visitor_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum);
|
||||
} else {
|
||||
result.clear();
|
||||
}
|
||||
|
||||
/// Type conversion
|
||||
NumericArrayConverter<ElementCompute, ElementVisit, kElementsPerAccess> source_converter;
|
||||
|
||||
cutlass::multiplies<VisitAccessType> multiply_op;
|
||||
|
||||
return unary_op(source_converter(result));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
if (unary_op.guard()) visitor_op.end_row(row_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
if (unary_op.guard()) visitor_op.end_step(step_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
if (unary_op.guard()) visitor_op.end_epilogue();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,481 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this layernormware without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief A file contains all functioning classes needed by GemmLayernorm.
|
||||
|
||||
GemmLayernorm example = GEMM0 with partial reduction fused in epilogue (EpilogueVisitorLayerNorm)
|
||||
+ lightweight full reduction kernel (ApplyFinalReduction)
|
||||
+ GEMM1 with elemenwise operations fused in mainloop (GemmLayernormMainloopFusion)
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_complex.h"
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename ThreadblockShape_,
|
||||
int ThreadCount,
|
||||
typename OutputTileIterator_,
|
||||
typename AccumulatorTile_,
|
||||
typename ElementAccumulator_,
|
||||
typename ElementVariance_,
|
||||
typename ElementMean_,
|
||||
typename ElementLayernormCompute_,
|
||||
typename ElementwiseFunctor_,
|
||||
bool IsShiftedVariance_ = false
|
||||
>
|
||||
class EpilogueVisitorLayerNorm {
|
||||
public:
|
||||
|
||||
using ElementVariance = ElementVariance_;
|
||||
using ElementMean = ElementMean_;
|
||||
using ElementLayernormCompute = ElementLayernormCompute_;
|
||||
|
||||
using AccumulatorTile = AccumulatorTile_;
|
||||
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
static int const kThreadCount = ThreadCount;
|
||||
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using ElementwiseFunctor = ElementwiseFunctor_;
|
||||
|
||||
static int const kIterations = OutputTileIterator::kIterations;
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
static int const kRowIterations = OutputTileIterator::ThreadMap::Iterations::kRow;
|
||||
|
||||
static int const kThreads = OutputTileIterator::ThreadMap::kThreads;
|
||||
|
||||
static bool const kIsShiftedVariance = IsShiftedVariance_;
|
||||
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
|
||||
static int const kDeltaRow = OutputTileIterator::ThreadMap::Delta::kRow;
|
||||
|
||||
/// Array type used in Shift-K Layernorm
|
||||
static int const kRowAccessCount = kIterations * kRowIterations;
|
||||
|
||||
using ConvertedShiftFragment = Array<ElementLayernormCompute, kRowAccessCount>;
|
||||
|
||||
// Conducts manual transpose externally (already supported) for column major
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
|
||||
using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
using LayernormFragment = Array<ElementLayernormCompute, kElementsPerAccess>;
|
||||
using OutputVector = Array<ElementOutput, kElementsPerAccess>;
|
||||
using TensorRefD = TensorRef<ElementOutput, LayoutOutput>;
|
||||
|
||||
static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth;
|
||||
static int const kThreadsInColumn = kThreads / kThreadsPerRow;
|
||||
static int const kHalfThreadsPerRow = (kThreadsPerRow >> 1);
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
typename ElementwiseFunctor::Params elementwise;
|
||||
ElementVariance *ptr_Variance;
|
||||
ElementMean *ptr_Mean;
|
||||
ElementOutput *ptr_Shifted_K;
|
||||
MatrixCoord extent;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
Arguments():
|
||||
ptr_Variance(nullptr),
|
||||
ptr_Mean(nullptr),
|
||||
ptr_Shifted_K(nullptr)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
Arguments(
|
||||
typename ElementwiseFunctor::Params elementwise_,
|
||||
ElementVariance *ptr_Variance,
|
||||
ElementMean *ptr_Mean_,
|
||||
ElementOutput *ptr_Shifted_K_ = nullptr,
|
||||
MatrixCoord extent = MatrixCoord(0, 0)
|
||||
):
|
||||
elementwise(elementwise_),
|
||||
ptr_Variance(ptr_Variance),
|
||||
ptr_Mean(ptr_Mean_),
|
||||
ptr_Shifted_K(ptr_Shifted_K_),
|
||||
extent(extent)
|
||||
{
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
struct Params {
|
||||
|
||||
typename ElementwiseFunctor::Params elementwise;
|
||||
ElementVariance *ptr_Variance;
|
||||
ElementMean *ptr_Mean;
|
||||
ElementOutput *ptr_Shifted_K;
|
||||
MatrixCoord extent;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
ptr_Variance(nullptr),
|
||||
ptr_Mean(nullptr)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
elementwise(args.elementwise),
|
||||
ptr_Variance(args.ptr_Variance),
|
||||
ptr_Mean(args.ptr_Mean),
|
||||
ptr_Shifted_K(args.ptr_Shifted_K),
|
||||
extent(args.extent)
|
||||
{
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared storage
|
||||
struct SharedStorage {
|
||||
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
Params const & params_;
|
||||
SharedStorage & shared_storage_;
|
||||
MatrixCoord extent_;
|
||||
ElementwiseFunctor elementwise_;
|
||||
|
||||
OutputTileIterator iterator_C_;
|
||||
OutputTileIterator iterator_D_;
|
||||
typename OutputTileIterator::Fragment fragment_C_;
|
||||
typename OutputTileIterator::Fragment fragment_D_;
|
||||
|
||||
ElementAccumulator alpha_;
|
||||
ElementAccumulator beta_;
|
||||
ConvertedShiftFragment shift_k_frag_;
|
||||
|
||||
ElementLayernormCompute accum_sum_square_;
|
||||
ElementLayernormCompute accum_sum_element_;
|
||||
int thread_idx_;
|
||||
|
||||
MatrixCoord thread_offset_;
|
||||
|
||||
gemm::GemmCoord threadblock_tile_offset_;
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
EpilogueVisitorLayerNorm(
|
||||
Params const ¶ms, ///< Parameters routed to the epilogue
|
||||
SharedStorage &shared_storage, ///< Shared storage needed by the functors here
|
||||
MatrixCoord threadblock_offset,
|
||||
gemm::GemmCoord threadblock_tile_offset,
|
||||
int thread_idx,
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
OutputTileIterator source_iterator ///< Threadblock tile coordinate in GEMMM
|
||||
):
|
||||
params_(params),
|
||||
shared_storage_(shared_storage),
|
||||
elementwise_(params.elementwise),
|
||||
extent_(params.extent),
|
||||
iterator_C_(source_iterator),
|
||||
iterator_D_(destination_iterator),
|
||||
threadblock_tile_offset_(threadblock_tile_offset),
|
||||
thread_idx_(thread_idx)
|
||||
{
|
||||
alpha_ = (params.elementwise.alpha_ptr ? *params.elementwise.alpha_ptr : params.elementwise.alpha);
|
||||
beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta);
|
||||
|
||||
if (beta_ == ElementAccumulator()) {
|
||||
iterator_C_.clear_mask();
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to indicate split-K behavior
|
||||
CUTLASS_DEVICE
|
||||
void set_k_partition(
|
||||
int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
|
||||
int split_k_slices) { ///< Total number of split-K slices
|
||||
|
||||
}
|
||||
|
||||
/// Called to set the batch index
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
|
||||
}
|
||||
|
||||
/// Called at the start of the epilogue just before iterating over accumulator slices
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
|
||||
// If shift-K feature is enabled, we load shift-k fragment
|
||||
// at the very beginning of an epilogue
|
||||
if (kIsShiftedVariance && params_.ptr_Shifted_K != nullptr) {
|
||||
shift_k_frag_.clear();
|
||||
int thread_offset_row_base = iterator_D_.thread_start_row();
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int iter_idx = 0; iter_idx < kIterations; ++iter_idx) {
|
||||
int step_offset = iter_idx * OutputTileIterator::Shape::kRow;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int rid = 0; rid < kRowIterations; ++rid) {
|
||||
int row_step_offset = rid * kDeltaRow;
|
||||
int row_offset = thread_offset_row_base + step_offset + row_step_offset;
|
||||
bool is_load = (row_offset < extent_.row());
|
||||
shift_k_frag_[iter_idx * kRowIterations + rid] = load_shift_k_(row_offset, is_load);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// Called at the start of one step before starting accumulator exchange
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
fragment_D_.clear();
|
||||
|
||||
if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
|
||||
fragment_C_.clear();
|
||||
iterator_C_.load(fragment_C_);
|
||||
++iterator_C_;
|
||||
}
|
||||
}
|
||||
|
||||
/// Called at the start of a row
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
/// set the accumulator to 0
|
||||
accum_sum_element_ = ElementLayernormCompute(0);
|
||||
accum_sum_square_ = ElementLayernormCompute(0);
|
||||
}
|
||||
|
||||
/// Called after accumulators have been exchanged for each accumulator vector
|
||||
CUTLASS_DEVICE
|
||||
void visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorFragment const &accum) {
|
||||
|
||||
using Mul = cutlass::multiplies<ElementLayernormCompute>;
|
||||
using Minus = cutlass::minus<ElementLayernormCompute>;
|
||||
using Exp = cutlass::fast_exp_op<ElementLayernormCompute>;
|
||||
|
||||
Minus minus;
|
||||
Mul mul;
|
||||
Exp exponential;
|
||||
|
||||
LayernormFragment result;
|
||||
|
||||
thread_offset_ =
|
||||
iterator_D_.thread_start() +
|
||||
OutputTileIterator::ThreadMap::iteration_offset(frag_idx);
|
||||
|
||||
NumericArrayConverter<ElementLayernormCompute, ElementOutput, kElementsPerAccess> source_converter;
|
||||
OutputVector &source_vector = reinterpret_cast<OutputVector *>(&fragment_C_)[frag_idx];
|
||||
|
||||
bool column_guard = (thread_offset_.column() < extent_.column());
|
||||
|
||||
if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
|
||||
result = source_converter(elementwise_(accum));
|
||||
}else{
|
||||
result = source_converter(elementwise_(accum, source_vector));
|
||||
}
|
||||
|
||||
|
||||
ElementLayernormCompute inv_scalar = cutlass::constants::one<ElementLayernormCompute>() / ElementLayernormCompute(extent_.column());
|
||||
|
||||
// Fragment is cleared for non-reachable columns so no need to check against column guard
|
||||
ElementLayernormCompute accum_sum_element_tmp = element_sum_accumulator_(result);
|
||||
|
||||
// Square sum is different. Non-reachable columns should've been computed for shift-k
|
||||
// Otherwise we will incorrectly have some extra k^2 added into square sum.
|
||||
ElementLayernormCompute accum_sum_square_tmp = ElementLayernormCompute(0);
|
||||
|
||||
if (column_guard) {
|
||||
accum_sum_square_tmp = (kIsShiftedVariance) ? \
|
||||
square_sum_accumulator_(result, shift_k_frag_[iter_idx * kRowIterations + row_idx]) : \
|
||||
square_sum_accumulator_(result);
|
||||
}
|
||||
|
||||
accum_sum_element_tmp *= inv_scalar;
|
||||
accum_sum_square_tmp *= inv_scalar;
|
||||
|
||||
// After performing the in-thread reduction, we then perform cross-thread / in-warp reduction
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = kHalfThreadsPerRow; i > 0; i >>= 1) {
|
||||
accum_sum_element_tmp += __shfl_xor_sync(0xFFFFFFFF, accum_sum_element_tmp, i);
|
||||
accum_sum_square_tmp += __shfl_xor_sync(0xFFFFFFFF, accum_sum_square_tmp, i);
|
||||
}
|
||||
accum_sum_element_ += accum_sum_element_tmp;
|
||||
accum_sum_square_ += accum_sum_square_tmp;
|
||||
|
||||
// Convert to the output
|
||||
NumericArrayConverter<ElementOutput, ElementLayernormCompute, kElementsPerAccess> output_converter;
|
||||
OutputVector &output = reinterpret_cast<OutputVector *>(&fragment_D_)[frag_idx];
|
||||
output = output_converter(result);
|
||||
}
|
||||
|
||||
/// Called at the start of a row
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
|
||||
using ConvertVarianceOutput = cutlass::NumericConverter<ElementVariance, ElementLayernormCompute>;
|
||||
using ConvertMeanOutput = cutlass::NumericConverter<ElementMean, ElementLayernormCompute>;
|
||||
|
||||
ConvertVarianceOutput convert_variance_output;
|
||||
ConvertMeanOutput convert_mean_output;
|
||||
|
||||
bool is_write_thread = (thread_offset_.row() < extent_.row() && (threadIdx.x % kThreadsPerRow) == 0);
|
||||
int row_offset = thread_offset_.row() + threadblock_tile_offset_.n() * extent_.row();
|
||||
|
||||
ElementVariance *curr_ptr_sum_square = params_.ptr_Variance + row_offset;
|
||||
ElementMean *curr_ptr_element_sum = params_.ptr_Mean + row_offset;
|
||||
|
||||
arch::global_store<ElementVariance, sizeof(ElementVariance)>(
|
||||
convert_variance_output(accum_sum_square_),
|
||||
(void *)curr_ptr_sum_square,
|
||||
is_write_thread);
|
||||
|
||||
arch::global_store<ElementMean, sizeof(ElementMean)>(
|
||||
convert_mean_output(accum_sum_element_),
|
||||
(void *)curr_ptr_element_sum,
|
||||
is_write_thread);
|
||||
}
|
||||
|
||||
/// Called after all accumulator elements have been visited
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
|
||||
iterator_D_.store(fragment_D_);
|
||||
++iterator_D_;
|
||||
}
|
||||
|
||||
/// Called after all steps have been completed
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementLayernormCompute load_shift_k_(int row_offset, bool is_load) {
|
||||
using ConvertShiftK = cutlass::NumericConverter<ElementLayernormCompute, ElementOutput>;
|
||||
ConvertShiftK convert_shift_k;
|
||||
ElementOutput shift_k_val;
|
||||
|
||||
// Computes the address to load shift_k element
|
||||
ElementOutput *curr_ptr_shift_k = params_.ptr_Shifted_K + row_offset;
|
||||
// Conditionally loads from global memory
|
||||
arch::global_load<ElementOutput, sizeof(ElementOutput)>(shift_k_val, (void *)curr_ptr_shift_k, is_load);
|
||||
// Converts data type to return
|
||||
ElementLayernormCompute converted_shift_k_val = convert_shift_k(shift_k_val);
|
||||
|
||||
return converted_shift_k_val;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum) {
|
||||
ElementLayernormCompute sum_ = ElementLayernormCompute(0);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < LayernormFragment::kElements; ++i) {
|
||||
auto accum_ = accum[i];
|
||||
sum_ += accum_ * accum_;
|
||||
}
|
||||
|
||||
return sum_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum, ElementLayernormCompute shift_k_val) {
|
||||
ElementLayernormCompute sum_ = ElementLayernormCompute(0);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < LayernormFragment::kElements; ++i) {
|
||||
auto accum_ = accum[i] - shift_k_val;
|
||||
sum_ += accum_ * accum_;
|
||||
}
|
||||
|
||||
return sum_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementLayernormCompute element_sum_accumulator_(LayernormFragment const &accum) {
|
||||
ElementLayernormCompute sum_ = ElementLayernormCompute(0);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < LayernormFragment::kElements; ++i) {
|
||||
sum_ += accum[i];
|
||||
}
|
||||
|
||||
return sum_;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,692 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
|
||||
>
|
||||
struct GemmUniversalwithEpilogueVisitor {
|
||||
public:
|
||||
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueVisitor = typename Epilogue::Visitor;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename EpilogueVisitor::ElementOutput;
|
||||
using LayoutC = typename EpilogueVisitor::OutputTileIterator::Layout;
|
||||
|
||||
static ComplexTransform const kTransformA = Mma::kTransformA;
|
||||
static ComplexTransform const kTransformB = Mma::kTransformB;
|
||||
using Operator = typename Mma::Operator;
|
||||
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
/// Split-K preserves splits that are 128b aligned
|
||||
static int const kSplitKAlignment = const_max(
|
||||
128 / sizeof_bits<ElementA>::value,
|
||||
128 / sizeof_bits<ElementB>::value
|
||||
);
|
||||
|
||||
//
|
||||
// Structures
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmUniversalMode mode;
|
||||
GemmCoord problem_size;
|
||||
int batch_count;
|
||||
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor;
|
||||
|
||||
void const * ptr_A;
|
||||
void const * ptr_B;
|
||||
void const * ptr_C;
|
||||
void * ptr_D;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
typename LayoutA::Stride stride_a;
|
||||
typename LayoutB::Stride stride_b;
|
||||
typename LayoutC::Stride stride_c;
|
||||
typename LayoutC::Stride stride_d;
|
||||
|
||||
typename LayoutA::Stride::LongIndex lda;
|
||||
typename LayoutB::Stride::LongIndex ldb;
|
||||
typename LayoutC::Stride::LongIndex ldc;
|
||||
typename LayoutC::Stride::LongIndex ldd;
|
||||
|
||||
int const * ptr_gather_A_indices;
|
||||
int const * ptr_gather_B_indices;
|
||||
int const * ptr_scatter_D_indices;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Arguments():
|
||||
mode(GemmUniversalMode::kGemm),
|
||||
batch_count(1),
|
||||
ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr),
|
||||
ptr_gather_A_indices(nullptr),
|
||||
ptr_gather_B_indices(nullptr),
|
||||
ptr_scatter_D_indices(nullptr) {}
|
||||
|
||||
/// constructs an arguments structure
|
||||
Arguments(
|
||||
GemmUniversalMode mode,
|
||||
GemmCoord problem_size,
|
||||
int batch_count,
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor,
|
||||
void const * ptr_A,
|
||||
void const * ptr_B,
|
||||
void const * ptr_C,
|
||||
void * ptr_D,
|
||||
int64_t batch_stride_A,
|
||||
int64_t batch_stride_B,
|
||||
int64_t batch_stride_C,
|
||||
int64_t batch_stride_D,
|
||||
typename LayoutA::Stride stride_a,
|
||||
typename LayoutB::Stride stride_b,
|
||||
typename LayoutC::Stride stride_c,
|
||||
typename LayoutC::Stride stride_d,
|
||||
int const *ptr_gather_A_indices = nullptr,
|
||||
int const *ptr_gather_B_indices = nullptr,
|
||||
int const *ptr_scatter_D_indices = nullptr
|
||||
):
|
||||
mode(mode),
|
||||
problem_size(problem_size),
|
||||
batch_count(batch_count),
|
||||
epilogue_visitor(epilogue_visitor),
|
||||
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
|
||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
|
||||
stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d),
|
||||
ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),
|
||||
ptr_scatter_D_indices(ptr_scatter_D_indices) {
|
||||
lda = 0;
|
||||
ldb = 0;
|
||||
ldc = 0;
|
||||
ldd = 0;
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
|
||||
}
|
||||
|
||||
/// constructs an arguments structure
|
||||
Arguments(
|
||||
GemmUniversalMode mode,
|
||||
GemmCoord problem_size,
|
||||
int batch_count,
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor,
|
||||
void const * ptr_A,
|
||||
void const * ptr_B,
|
||||
void const * ptr_C,
|
||||
void * ptr_D,
|
||||
int64_t batch_stride_A,
|
||||
int64_t batch_stride_B,
|
||||
int64_t batch_stride_C,
|
||||
int64_t batch_stride_D,
|
||||
typename LayoutA::Stride::LongIndex lda,
|
||||
typename LayoutB::Stride::LongIndex ldb,
|
||||
typename LayoutC::Stride::LongIndex ldc,
|
||||
typename LayoutC::Stride::LongIndex ldd,
|
||||
int const *ptr_gather_A_indices = nullptr,
|
||||
int const *ptr_gather_B_indices = nullptr,
|
||||
int const *ptr_scatter_D_indices = nullptr
|
||||
):
|
||||
mode(mode),
|
||||
problem_size(problem_size),
|
||||
batch_count(batch_count),
|
||||
epilogue_visitor(epilogue_visitor),
|
||||
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
|
||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
|
||||
lda(lda), ldb(ldb), ldc(ldc), ldd(ldd),
|
||||
ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),
|
||||
ptr_scatter_D_indices(ptr_scatter_D_indices) {
|
||||
stride_a = make_Coord(lda);
|
||||
stride_b = make_Coord(ldb);
|
||||
stride_c = make_Coord(ldc);
|
||||
stride_d = make_Coord(ldd);
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
|
||||
}
|
||||
|
||||
/// Returns arguments for the transposed problem
|
||||
Arguments transposed_problem() const {
|
||||
Arguments args(*this);
|
||||
|
||||
std::swap(args.problem_size.m(), args.problem_size.n());
|
||||
std::swap(args.ptr_A, args.ptr_B);
|
||||
std::swap(args.lda, args.ldb);
|
||||
std::swap(args.stride_a, args.stride_b);
|
||||
std::swap(args.batch_stride_A, args.batch_stride_B);
|
||||
std::swap(args.ptr_gather_A_indices, args.ptr_gather_B_indices);
|
||||
|
||||
return args;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
|
||||
typename Mma::IteratorA::Params params_A;
|
||||
typename Mma::IteratorB::Params params_B;
|
||||
typename EpilogueVisitor::OutputTileIterator::Params params_C;
|
||||
typename EpilogueVisitor::OutputTileIterator::Params params_D;
|
||||
|
||||
typename EpilogueVisitor::Params epilogue_visitor;
|
||||
|
||||
GemmUniversalMode mode;
|
||||
int batch_count;
|
||||
int gemm_k_size;
|
||||
|
||||
void * ptr_A;
|
||||
void * ptr_B;
|
||||
void * ptr_C;
|
||||
void * ptr_D;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
int * ptr_gather_A_indices;
|
||||
int * ptr_gather_B_indices;
|
||||
int * ptr_scatter_D_indices;
|
||||
|
||||
int *semaphore;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
swizzle_log_tile(0),
|
||||
params_A(0),
|
||||
params_B(0),
|
||||
params_C(0),
|
||||
params_D(0),
|
||||
batch_count(0),
|
||||
gemm_k_size(0),
|
||||
mode(cutlass::gemm::GemmUniversalMode::kGemm),
|
||||
ptr_A(nullptr),
|
||||
ptr_B(nullptr),
|
||||
ptr_C(nullptr),
|
||||
ptr_D(nullptr),
|
||||
batch_stride_A(0),
|
||||
batch_stride_B(0),
|
||||
batch_stride_C(0),
|
||||
batch_stride_D(0),
|
||||
ptr_gather_A_indices(nullptr),
|
||||
ptr_gather_B_indices(nullptr),
|
||||
ptr_scatter_D_indices(nullptr),
|
||||
semaphore(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Arguments const &args,
|
||||
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
||||
int gemm_k_size,
|
||||
void *workspace = nullptr
|
||||
):
|
||||
problem_size(args.problem_size),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
|
||||
params_A(args.lda ? make_Coord_with_padding<LayoutA::kStrideRank>(args.lda) : args.stride_a),
|
||||
params_B(args.ldb ? make_Coord_with_padding<LayoutB::kStrideRank>(args.ldb) : args.stride_b),
|
||||
params_C(args.ldc ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldc) : args.stride_c),
|
||||
params_D(args.ldd ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldd) : args.stride_d),
|
||||
epilogue_visitor(args.epilogue_visitor),
|
||||
mode(args.mode),
|
||||
batch_count(args.batch_count),
|
||||
gemm_k_size(gemm_k_size),
|
||||
ptr_A(const_cast<void *>(args.ptr_A)),
|
||||
ptr_B(const_cast<void *>(args.ptr_B)),
|
||||
ptr_C(const_cast<void *>(args.ptr_C)),
|
||||
ptr_D(args.ptr_D),
|
||||
batch_stride_A(args.batch_stride_A),
|
||||
batch_stride_B(args.batch_stride_B),
|
||||
batch_stride_C(args.batch_stride_C),
|
||||
batch_stride_D(args.batch_stride_D),
|
||||
ptr_gather_A_indices(const_cast<int *>(args.ptr_gather_A_indices)),
|
||||
ptr_gather_B_indices(const_cast<int *>(args.ptr_gather_B_indices)),
|
||||
ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices)),
|
||||
semaphore(static_cast<int *>(workspace)) {
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void update(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr) {
|
||||
|
||||
ptr_A = const_cast<void *>(args.ptr_A);
|
||||
ptr_B = const_cast<void *>(args.ptr_B);
|
||||
ptr_C = const_cast<void *>(args.ptr_C);
|
||||
ptr_D = args.ptr_D;
|
||||
|
||||
ptr_gather_A_indices = const_cast<int *>(args.ptr_gather_A_indices);
|
||||
ptr_gather_B_indices = const_cast<int *>(args.ptr_gather_B_indices);
|
||||
ptr_scatter_D_indices = const_cast<int *>(args.ptr_scatter_D_indices);
|
||||
|
||||
batch_stride_A = args.batch_stride_A;
|
||||
batch_stride_B = args.batch_stride_B;
|
||||
batch_stride_C = args.batch_stride_C;
|
||||
batch_stride_D = args.batch_stride_D;
|
||||
|
||||
epilogue_visitor = args.epilogue_visitor;
|
||||
|
||||
semaphore = static_cast<int *>(workspace);
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::Params::update()");
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage {
|
||||
typename Mma::SharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
typename EpilogueVisitor::SharedStorage visitor;
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GemmUniversalwithEpilogueVisitor() { }
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(
|
||||
cutlass::gemm::GemmCoord const & problem_size) {
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalwithEpilogueVisitor::can_implement()");
|
||||
|
||||
static int const kAlignmentA = (platform::is_same<LayoutA,
|
||||
layout::ColumnMajorInterleaved<32>>::value)
|
||||
? 32
|
||||
: (platform::is_same<LayoutA,
|
||||
layout::ColumnMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = (platform::is_same<LayoutB,
|
||||
layout::RowMajorInterleaved<32>>::value)
|
||||
? 32
|
||||
: (platform::is_same<LayoutB,
|
||||
layout::RowMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = (platform::is_same<LayoutC,
|
||||
layout::ColumnMajorInterleaved<32>>::value)
|
||||
? 32
|
||||
: (platform::is_same<LayoutC,
|
||||
layout::ColumnMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
bool isAMisaligned = false;
|
||||
bool isBMisaligned = false;
|
||||
bool isCMisaligned = false;
|
||||
|
||||
if (platform::is_same<LayoutA, layout::RowMajor>::value) {
|
||||
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||
} else if (platform::is_same<LayoutA, layout::ColumnMajor>::value) {
|
||||
isAMisaligned = problem_size.m() % kAlignmentA;
|
||||
} else if (platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {
|
||||
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||
}
|
||||
|
||||
if (platform::is_same<LayoutB, layout::RowMajor>::value) {
|
||||
isBMisaligned = problem_size.n() % kAlignmentB;
|
||||
} else if (platform::is_same<LayoutB, layout::ColumnMajor>::value) {
|
||||
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||
} else if (platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {
|
||||
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||
}
|
||||
|
||||
if (platform::is_same<LayoutC, layout::RowMajor>::value) {
|
||||
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||
} else if (platform::is_same<LayoutC, layout::ColumnMajor>::value) {
|
||||
isCMisaligned = problem_size.m() % kAlignmentC;
|
||||
} else if (platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {
|
||||
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||
}
|
||||
|
||||
if (isAMisaligned) {
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (isBMisaligned) {
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (isCMisaligned) {
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" returning kSuccess");
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const &args) {
|
||||
return can_implement(args.problem_size);
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(Arguments const &args,
|
||||
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
||||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
int offset_k = 0;
|
||||
int problem_size_k = params.problem_size.k();
|
||||
|
||||
ElementA *ptr_A = static_cast<ElementA *>(params.ptr_A);
|
||||
ElementB *ptr_B = static_cast<ElementB *>(params.ptr_B);
|
||||
|
||||
//
|
||||
// Fetch pointers based on mode.
|
||||
//
|
||||
if (params.mode == GemmUniversalMode::kGemm ||
|
||||
params.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
|
||||
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) {
|
||||
|
||||
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
|
||||
}
|
||||
|
||||
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kBatched) {
|
||||
ptr_A += threadblock_tile_offset.k() * params.batch_stride_A;
|
||||
ptr_B += threadblock_tile_offset.k() * params.batch_stride_B;
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kArray) {
|
||||
ptr_A = static_cast<ElementA * const *>(params.ptr_A)[threadblock_tile_offset.k()];
|
||||
ptr_B = static_cast<ElementB * const *>(params.ptr_B)[threadblock_tile_offset.k()];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
offset_k,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{
|
||||
offset_k,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN
|
||||
};
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
params.params_A,
|
||||
ptr_A,
|
||||
{params.problem_size.m(), problem_size_k},
|
||||
thread_idx,
|
||||
tb_offset_A,
|
||||
params.ptr_gather_A_indices);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
params.params_B,
|
||||
ptr_B,
|
||||
{problem_size_k, params.problem_size.n()},
|
||||
thread_idx,
|
||||
tb_offset_B,
|
||||
params.ptr_gather_B_indices);
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(
|
||||
gemm_k_iterations,
|
||||
accumulators,
|
||||
iterator_A,
|
||||
iterator_B,
|
||||
accumulators);
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
// EpilogueOutputOp output_op(params.output_op);
|
||||
|
||||
//
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
//assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN
|
||||
);
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
ElementC *ptr_C = static_cast<ElementC *>(params.ptr_C);
|
||||
ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);
|
||||
|
||||
//
|
||||
// Fetch pointers based on mode.
|
||||
//
|
||||
|
||||
// Construct the semaphore.
|
||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||
|
||||
// if (params.mode == GemmUniversalMode::kGemm) {
|
||||
|
||||
// // TODO: fix this order
|
||||
// // If performing a reduction via split-K, fetch the initial synchronization
|
||||
// if (params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// // Fetch the synchronization lock initially but do not block.
|
||||
// semaphore.fetch();
|
||||
|
||||
// // Indicate which position in a serial reduction the output operator is currently updating
|
||||
// output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
// }
|
||||
// }
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
|
||||
EpilogueVisitor epilogue_visitor(
|
||||
params.epilogue_visitor,
|
||||
shared_storage.visitor,
|
||||
threadblock_offset,
|
||||
threadblock_tile_offset,
|
||||
thread_idx,
|
||||
params.problem_size.mn()
|
||||
);
|
||||
|
||||
// if (params.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
// ptr_D += threadblock_tile_offset.k() * params.batch_stride_D;
|
||||
// }
|
||||
if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) {
|
||||
epilogue_visitor.set_batch_index(threadblock_tile_offset.k());
|
||||
}
|
||||
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
// Wait on the semaphore - this latency may have been covered by iterator construction
|
||||
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
||||
// TODO: ???
|
||||
// if (threadblock_tile_offset.k()) {
|
||||
// iterator_C = iterator_D;
|
||||
// }
|
||||
|
||||
semaphore.wait(threadblock_tile_offset.k());
|
||||
}
|
||||
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(epilogue_visitor, accumulators);
|
||||
|
||||
//
|
||||
// Release the semaphore
|
||||
//
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
int lock = 0;
|
||||
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
|
||||
|
||||
// The final threadblock resets the semaphore for subsequent grids.
|
||||
lock = 0;
|
||||
}
|
||||
else {
|
||||
// Otherwise, the semaphore is incremented
|
||||
lock = threadblock_tile_offset.k() + 1;
|
||||
}
|
||||
|
||||
semaphore.release(lock);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -50,7 +50,13 @@ void bind_tensor_coord(py::module &m) {
|
||||
R"pbdoc(Defines a canonical 4D coordinate used by tensor operations)pbdoc")
|
||||
.def(py::init<int, int, int, int>(),
|
||||
py::arg("n"), py::arg("h"), py::arg("w"), py::arg("c"),
|
||||
R"pbdoc(Helper to construct from N, H, W, and C)pbdoc");
|
||||
R"pbdoc(Helper to construct from N, H, W, and C)pbdoc")
|
||||
.def("at", py::overload_cast<int>(&cutlass::Tensor4DCoord::at),
|
||||
py::arg("dim"),
|
||||
R"pbdoc(Gets the index of a given Coord element)pbdoc")
|
||||
.def("size", [](const cutlass::Tensor4DCoord & coord) {
|
||||
return coord.at(0) * coord.at(1) * coord.at(2) * coord.at(3);},
|
||||
R"pbdoc(The size of the tensor coord)pbdoc");
|
||||
|
||||
py::class_<cutlass::Coord<3>>(m, "Tensor3DCoord",
|
||||
R"pbdoc(Defines a canonical 3D coordinate used by tensor operations)pbdoc")
|
||||
|
||||
@ -1,7 +1,24 @@
|
||||
from pycutlass.type import *
|
||||
import re
|
||||
|
||||
|
||||
def SubstituteTemplate(template, values):
|
||||
text = template
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
for key, value in values.items():
|
||||
regex = "\\$\\{%s\\}" % key
|
||||
newtext = re.sub(regex, value, text)
|
||||
if newtext != text:
|
||||
changed = True
|
||||
text = newtext
|
||||
return text
|
||||
|
||||
from pycutlass.type_hint import *
|
||||
from pycutlass.tensor_ref import *
|
||||
from pycutlass.operation import *
|
||||
from pycutlass.epilogue import *
|
||||
from pycutlass.parser import *
|
||||
from pycutlass.compiler import ArtifactManager
|
||||
from pycutlass.memory_manager import *
|
||||
from pycutlass.arguments import *
|
||||
|
||||
@ -60,6 +60,13 @@ class ArgumentBase:
|
||||
C: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]',
|
||||
D: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]',
|
||||
**kwargs) -> None:
|
||||
|
||||
# tensor_C can be interpreted as the bias with bias=True in keyword args
|
||||
if "bias" in kwargs.keys():
|
||||
self.bias = kwargs["bias"]
|
||||
else:
|
||||
# by default, tensor_C is not bias
|
||||
self.bias = False
|
||||
|
||||
# preprocessing input tensors
|
||||
if isinstance(A, np.ndarray):
|
||||
@ -72,21 +79,28 @@ class ArgumentBase:
|
||||
self.ptr_B = self.buffer_B.ptr
|
||||
self.ptr_C = self.buffer_C.ptr
|
||||
self.ptr_D = self.buffer_D.ptr
|
||||
# number of elements in C
|
||||
self.tensor_c_numel = C.size
|
||||
elif torch_available and isinstance(A, torch.Tensor):
|
||||
self.ptr_A = TorchFrontend.argument(A)
|
||||
self.ptr_B = TorchFrontend.argument(B)
|
||||
self.ptr_C = TorchFrontend.argument(C)
|
||||
self.ptr_D = TorchFrontend.argument(D)
|
||||
# number of elements in C
|
||||
self.tensor_c_numel = C.numel()
|
||||
elif isinstance(A, cuda.CUdeviceptr):
|
||||
self.ptr_A = A
|
||||
self.ptr_B = B
|
||||
self.ptr_C = C
|
||||
self.ptr_D = D
|
||||
|
||||
elif cupy_available and isinstance(A, cp.ndarray):
|
||||
self.ptr_A = CupyFrontend.argument(A)
|
||||
self.ptr_B = CupyFrontend.argument(B)
|
||||
self.ptr_C = CupyFrontend.argument(C)
|
||||
self.ptr_D = CupyFrontend.argument(D)
|
||||
# number of elements in C
|
||||
self.tensor_c_numel = C.size
|
||||
else:
|
||||
raise TypeError(
|
||||
"Unsupported Frontend. Only support numpy and torch")
|
||||
|
||||
@ -63,22 +63,9 @@ dtype2ctype = {
|
||||
}
|
||||
|
||||
|
||||
def get_epilogue_output_op(element_compute_):
|
||||
element_compute = dtype2ctype[element_compute_]
|
||||
def get_gemm_arguments(epilogue_functor):
|
||||
|
||||
class _EpilogueOutputOpParams(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("alpha", element_compute),
|
||||
("beta", element_compute),
|
||||
("alpha_ptr", ctypes.c_void_p),
|
||||
("beta_ptr", ctypes.c_void_p)
|
||||
]
|
||||
return _EpilogueOutputOpParams
|
||||
|
||||
|
||||
def get_gemm_arguments(element_compute_):
|
||||
|
||||
_EpilogueOutputOpParams = get_epilogue_output_op(element_compute_)
|
||||
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
|
||||
|
||||
class _GemmArguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
@ -116,8 +103,8 @@ def get_gemm_arguments(element_compute_):
|
||||
|
||||
# include/cutlass/gemm/kernel/gemm_grouped.h
|
||||
|
||||
def get_gemm_grouped_arguments(element_compute_):
|
||||
_EpilogueOutputOpParams = get_epilogue_output_op(element_compute_)
|
||||
def get_gemm_grouped_arguments(epilogue_functor):
|
||||
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
|
||||
|
||||
class _GEMMGroupedArguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
@ -214,8 +201,8 @@ class TensorRef2D_(ctypes.Structure):
|
||||
# include/cutlass/conv/kernel/implicit_gemm_convolution.h
|
||||
# split_k_mode: kNone: 0, kSerial: 1, kParallel: 2, kParallelSerial: 3, kInvalid: 4
|
||||
|
||||
def get_conv2d_arguments(element_compute_):
|
||||
_EpilogueOutputOpParams = get_epilogue_output_op(element_compute_)
|
||||
def get_conv2d_arguments(epilogue_functor):
|
||||
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
|
||||
|
||||
class _Conv2dArguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
@ -236,8 +223,8 @@ def get_conv2d_arguments(element_compute_):
|
||||
############################################################################################
|
||||
|
||||
|
||||
def get_reduction_params(element_compute_):
|
||||
_EpilogueOutputParams = get_epilogue_output_op(element_compute_)
|
||||
def get_reduction_params(epilogue_functor):
|
||||
_EpilogueOutputParams = epilogue_functor.epilogue_type
|
||||
|
||||
class _ReductionParams(ctypes.Structure):
|
||||
_fields_ = [
|
||||
|
||||
@ -1,366 +0,0 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
from pycutlass import *
|
||||
from pycutlass.library import SubstituteTemplate
|
||||
import cutlass
|
||||
from cuda import cuda
|
||||
from cuda import nvrtc
|
||||
import tempfile
|
||||
import os
|
||||
import ctypes
|
||||
|
||||
#
|
||||
import json
|
||||
import sqlite3
|
||||
|
||||
|
||||
IncludeTemplate = r'''#include "${include}"
|
||||
'''
|
||||
|
||||
#
|
||||
class CompilationOptions:
|
||||
'''
|
||||
Compilation options.
|
||||
'''
|
||||
|
||||
#
|
||||
def __init__(self, architectures = [80], include_paths = []):
|
||||
self.includes = []
|
||||
self.include_paths = include_paths
|
||||
self.flags = ['-std=c++11', '-default-device']
|
||||
self.architectures = architectures
|
||||
|
||||
#
|
||||
def get(self):
|
||||
options = []
|
||||
|
||||
for flag in self.flags:
|
||||
options.append(bytes(str.encode(flag)))
|
||||
|
||||
for incl in self.include_paths:
|
||||
options.append(bytes(str.encode('--include-path=%s' % incl)))
|
||||
|
||||
arch_list = "-arch="
|
||||
for idx, arch in enumerate(self.architectures):
|
||||
if idx:
|
||||
arch_list += ","
|
||||
arch_list += "sm_%d" % arch
|
||||
|
||||
options.append(bytes(str.encode(arch_list)))
|
||||
|
||||
return options
|
||||
|
||||
def convertToBinaryData(filename):
|
||||
with open(filename, 'rb') as file:
|
||||
blobData = file.read()
|
||||
return blobData
|
||||
|
||||
def CDLLBin(host_binary):
|
||||
tempfile.tempdir = "./"
|
||||
temp_so = tempfile.NamedTemporaryFile(prefix='host_func', suffix='.so', delete=True)
|
||||
with open(temp_so.name, 'wb') as file:
|
||||
file.write(host_binary)
|
||||
host_lib = ctypes.CDLL(temp_so.name)
|
||||
return host_lib
|
||||
|
||||
|
||||
class ArtifactManager:
|
||||
"""
|
||||
Artifact manager
|
||||
"""
|
||||
def __init__(self) -> None:
|
||||
try:
|
||||
connection = sqlite3.connect("./compiled_cache.db")
|
||||
cursor = connection.cursor()
|
||||
sqlite_create_table_query = """CREATE TABLE compiled_operations(op_key TEXT NOT NULL UNIQUE, cubin BLOB NOT NULL, hostbin BLOB NOT NULL, op_name TEXT NOT NULL, op_attrs TEXT NOT NULL)"""
|
||||
cursor.execute(sqlite_create_table_query)
|
||||
connection.commit()
|
||||
cursor.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
self.compiled_cache_device = cutlass.CompileCache()
|
||||
self.compiled_cache_host = cutlass.CompileCache()
|
||||
|
||||
def insert_operation(self, op_key, cubin, hostfile, op_name, op_attrs):
|
||||
connection = sqlite3.connect("./compiled_cache.db")
|
||||
cursor = connection.cursor()
|
||||
sqlite_insert_blob_query = """ INSERT OR IGNORE INTO compiled_operations (op_key, cubin, hostbin, op_name, op_attrs) VALUES (?, ?, ?, ?, ?)"""
|
||||
|
||||
hostbin = convertToBinaryData(hostfile)
|
||||
|
||||
data_tuple = (op_key, cubin, hostbin, op_name, json.dumps(op_attrs))
|
||||
|
||||
cursor.execute(sqlite_insert_blob_query, data_tuple)
|
||||
connection.commit()
|
||||
cursor.close()
|
||||
|
||||
def load_operation(self, op_key):
|
||||
connection = sqlite3.connect("./compiled_cache.db")
|
||||
cursor = connection.cursor()
|
||||
sqlite_fetch_blob_query = """SELECT * from compiled_operations where op_key = ?"""
|
||||
# try:
|
||||
cursor.execute(sqlite_fetch_blob_query, (op_key, ))
|
||||
record = cursor.fetchall()
|
||||
if len(record) == 0:
|
||||
return False
|
||||
for row in record:
|
||||
key, cubin_image, host_binary, operation_name, op_attr = row
|
||||
op_attr = json.loads(op_attr)
|
||||
err, module = cuda.cuModuleLoadData(cubin_image)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError('Cuda Error: {}'.format(err))
|
||||
|
||||
err, kernel = cuda.cuModuleGetFunction(module, bytes(str.encode(operation_name)))
|
||||
self.compiled_cache_device.insert(key, kernel)
|
||||
|
||||
compiled_host_fns = {}
|
||||
host_lib = CDLLBin(host_binary)
|
||||
|
||||
func_name = operation_name + '_get_params'
|
||||
func = getattr(host_lib, func_name)
|
||||
func.restype = ctypes.POINTER(ctypes.c_char * op_attr[0])
|
||||
compiled_host_fns['get_args'] = func
|
||||
|
||||
func_name = operation_name + '_shared_memory_size'
|
||||
func = getattr(host_lib, func_name)
|
||||
compiled_host_fns['shared_memory_capacity'] = func()
|
||||
|
||||
for attr in op_attr:
|
||||
if isinstance(attr, str):
|
||||
func_name = operation_name + '_' + attr
|
||||
func = getattr(host_lib, func_name)
|
||||
compiled_host_fns[attr] = func
|
||||
|
||||
self.compiled_cache_host.insert(key, compiled_host_fns)
|
||||
return True
|
||||
|
||||
|
||||
def emit_compile_(self, operation_list, compilation_options):
|
||||
"""
|
||||
Compile a list of kernels and store them into database
|
||||
"""
|
||||
source_buffer_device = ""
|
||||
source_buffer_host = ""
|
||||
# 1. include
|
||||
includes = []
|
||||
for operation in operation_list:
|
||||
for incl in operation.emitter.includes:
|
||||
if incl not in includes:
|
||||
includes.append(incl)
|
||||
|
||||
includes_host = [
|
||||
"builtin_types.h", "device_launch_parameters.h", "stddef.h"] + includes
|
||||
for incl in includes:
|
||||
source_buffer_device += SubstituteTemplate(IncludeTemplate, {'include': incl})
|
||||
|
||||
for incl in includes_host:
|
||||
if "/device/" not in incl:
|
||||
source_buffer_host += SubstituteTemplate(IncludeTemplate, { 'include': incl} )
|
||||
|
||||
|
||||
# 2. Operations
|
||||
for operation in operation_list:
|
||||
source_buffer_device += operation.emit()
|
||||
source_buffer_host += operation.emit()
|
||||
values = {
|
||||
'operation_name': operation.name(),
|
||||
'operation_suffix': operation.emitter.operation_suffix
|
||||
}
|
||||
source_buffer_device += SubstituteTemplate(operation.KernelTemplate, values)
|
||||
source_buffer_host += SubstituteTemplate(operation.HostTemplate, values)
|
||||
|
||||
# 3. compile
|
||||
err, program = nvrtc.nvrtcCreateProgram(
|
||||
str.encode(source_buffer_device),
|
||||
bytes(str.encode("module.cu")),
|
||||
0, [], [])
|
||||
|
||||
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise RuntimeError('NVRTC Error: {}'.format(err))
|
||||
|
||||
# Compile program
|
||||
options = compilation_options.get()
|
||||
|
||||
err, = nvrtc.nvrtcCompileProgram(program, len(options), options)
|
||||
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
|
||||
error_string = 'NVRTC Error: {}\n'.format(err)
|
||||
|
||||
# Get log from compilation
|
||||
err, logSize = nvrtc.nvrtcGetProgramLogSize(program)
|
||||
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise RuntimeError('NVRTC Error: {}'.format(err))
|
||||
|
||||
log = b' ' * logSize
|
||||
err, = nvrtc.nvrtcGetProgramLog(program, log)
|
||||
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise RuntimeError('NVRTC Error: {}'.format(err))
|
||||
|
||||
raise RuntimeError(error_string + log.decode() + source_buffer_device)
|
||||
|
||||
# Get data from compilation
|
||||
err, dataSize = nvrtc.nvrtcGetCUBINSize(program)
|
||||
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise RuntimeError('NVRTC Error: {}'.format(err))
|
||||
|
||||
cubin_image = b' ' * dataSize
|
||||
err, = nvrtc.nvrtcGetCUBIN(program, cubin_image)
|
||||
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise RuntimeError('NVRTC Error: {}'.format(err))
|
||||
|
||||
# compile the host code
|
||||
options = compilation_options.get()
|
||||
cmd = "echo '%s'|g++ -x c++ -fpermissive -w -fPIC" % source_buffer_host
|
||||
for opt in options:
|
||||
opt = opt.decode("utf-8")
|
||||
if opt not in ['-default-device', '-std=c++11', '-arch=sm_80']:
|
||||
if '--include-path=' in opt:
|
||||
cmd += " " + opt.replace('--include-path=', '-I')
|
||||
else:
|
||||
cmd += " "+ opt
|
||||
|
||||
tempfile.tempdir = "./"
|
||||
temp = tempfile.NamedTemporaryFile(prefix='host_func', suffix='.so', delete=True)
|
||||
|
||||
cmd += ' - -shared -o %s' % temp.name
|
||||
os.system(cmd)
|
||||
host_lib = ctypes.CDLL(temp.name)
|
||||
|
||||
return cubin_image, host_lib, temp
|
||||
|
||||
|
||||
def add_module(self, operations, compile_options=None):
|
||||
"""
|
||||
Insert a new compiled device module
|
||||
"""
|
||||
if compile_options is None:
|
||||
cutlass_path = os.getenv('CUTLASS_PATH')
|
||||
assert cutlass_path is not None, "Environment variable 'CUTLASS_PATH' is not defined."
|
||||
cuda_install_path = os.getenv('CUDA_INSTALL_PATH')
|
||||
assert cuda_install_path is not None, "Environment variable 'CUDA_INSTALL_PATH' is not defined."
|
||||
architectures = []
|
||||
for operation in operations:
|
||||
if hasattr(operation, "tile_description"):
|
||||
cc = operation.tile_description.minimum_compute_capability
|
||||
if cc not in architectures:
|
||||
architectures.append(cc)
|
||||
include_paths = [
|
||||
cuda_install_path + '/include',
|
||||
cutlass_path + '/include',
|
||||
cutlass_path + '/tools/util/include',
|
||||
]
|
||||
compile_options = CompilationOptions(architectures, include_paths)
|
||||
# save the cubin
|
||||
operation_key = []
|
||||
operation_list = []
|
||||
for operation in operations:
|
||||
# step 1: get kernel string as key
|
||||
key = operation.rt_module.emit() + operation.procedural_name()
|
||||
# step 1: check if the operation is in cache
|
||||
compiled_kernel = self.compiled_cache_device.at(key)
|
||||
|
||||
if compiled_kernel is None:
|
||||
hit = self.load_operation(key)
|
||||
if hit:
|
||||
compiled_kernel = self.compiled_cache_device.at(key)
|
||||
assert compiled_kernel is not None
|
||||
if compiled_kernel is not None:
|
||||
operation.rt_module.kernel = compiled_kernel
|
||||
compiled_host_fns = self.compiled_cache_host.at(key)
|
||||
assert compiled_host_fns is not None
|
||||
for key in compiled_host_fns.keys():
|
||||
setattr(operation.rt_module, key, compiled_host_fns[key])
|
||||
operation.rt_module.initialize()
|
||||
else:
|
||||
operation_list.append(operation.rt_module)
|
||||
operation_key.append(key)
|
||||
if len(operation_list) > 0:
|
||||
cubin_image, host_lib, host_file = self.emit_compile_(operation_list, compile_options)
|
||||
|
||||
err, module = cuda.cuModuleLoadData(cubin_image)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError('Cuda Error: {}'.format(err))
|
||||
|
||||
operation_name = []
|
||||
operation_attr = []
|
||||
for operation, key in zip(operation_list, operation_key):
|
||||
# get device kernels
|
||||
err, operation.kernel = cuda.cuModuleGetFunction(
|
||||
module,
|
||||
bytes(str.encode(operation.name()))
|
||||
)
|
||||
operation_name.append(operation.name())
|
||||
self.compiled_cache_device.insert(key, operation.kernel)
|
||||
# get host functions
|
||||
compiled_host_fns = {}
|
||||
op_attr = []
|
||||
|
||||
# get param size
|
||||
func_name = operation.name() + '_get_param_size'
|
||||
func = getattr(host_lib, func_name)
|
||||
param_size = func()
|
||||
|
||||
func_name = operation.name() + '_get_params'
|
||||
func = getattr(host_lib, func_name)
|
||||
func.argtype = operation.argtype
|
||||
func.restype = ctypes.POINTER(ctypes.c_char * param_size)
|
||||
setattr(operation, 'get_args', func)
|
||||
compiled_host_fns['get_args'] = func
|
||||
|
||||
# set shared memory size
|
||||
func_name = operation.name() + '_shared_memory_size'
|
||||
func = getattr(host_lib, func_name)
|
||||
setattr(operation, 'shared_memory_capacity', func())
|
||||
compiled_host_fns['shared_memory_capacity'] = func()
|
||||
# set the maximum dynamic shared size
|
||||
operation.initialize()
|
||||
|
||||
# get extra functions
|
||||
op_attr.append(param_size)
|
||||
|
||||
if hasattr(operation, "extra_funcs"):
|
||||
for suffix in operation.extra_funcs:
|
||||
func_name = operation.name() + '_' + suffix
|
||||
func = getattr(host_lib, func_name)
|
||||
setattr(operation, suffix, func)
|
||||
compiled_host_fns[suffix] = func
|
||||
op_attr.append(suffix)
|
||||
|
||||
operation_attr.append(op_attr)
|
||||
self.compiled_cache_host.insert(key, compiled_host_fns)
|
||||
|
||||
for key, operation_name, operation_attr in zip(operation_key, operation_name, operation_attr):
|
||||
self.insert_operation(key, cubin_image, host_file.name, operation_name, operation_attr)
|
||||
|
||||
|
||||
artifact_manager = ArtifactManager()
|
||||
@ -30,7 +30,6 @@
|
||||
#
|
||||
#################################################################################################
|
||||
from pycutlass import *
|
||||
from pycutlass.library import SubstituteTemplate
|
||||
import cutlass
|
||||
from cuda import cuda
|
||||
from cuda import nvrtc
|
||||
@ -132,13 +131,15 @@ class ArtifactManager:
|
||||
except:
|
||||
pass
|
||||
|
||||
self.nvcc()
|
||||
self.compiled_cache_device = cutlass.CompileCache()
|
||||
self.compiled_cache_host = cutlass.CompileCache()
|
||||
|
||||
def nvrtc(self):
|
||||
self.backend = "nvrtc"
|
||||
self.default_compile_options = [
|
||||
'-std=c++11', '-default-device',
|
||||
]
|
||||
self.compiled_cache_device = cutlass.CompileCache()
|
||||
self.compiled_cache_host = cutlass.CompileCache()
|
||||
|
||||
def nvcc(self):
|
||||
self.backend = "nvcc"
|
||||
self.default_compile_options = [
|
||||
@ -335,13 +336,14 @@ class ArtifactManager:
|
||||
architectures = []
|
||||
for operation in operations:
|
||||
if hasattr(operation, "tile_description"):
|
||||
cc = operation.tile_description.minimum_compute_capability
|
||||
cc = operation.arch
|
||||
if cc not in architectures:
|
||||
architectures.append(cc)
|
||||
include_paths = [
|
||||
cuda_install_path + '/include',
|
||||
cutlass_path + '/include',
|
||||
cutlass_path + '/tools/util/include',
|
||||
cutlass_path + '/tools/library/scripts/pycutlass/src/cpp/include'
|
||||
]
|
||||
compile_options = CompilationOptions(
|
||||
self.default_compile_options, architectures, include_paths)
|
||||
|
||||
@ -48,6 +48,9 @@ class Conv2dArguments(ArgumentBase):
|
||||
:param operation: the Conv2d operation to take the argument
|
||||
:type operation: :class:`pycutlass.Conv2dOperation`
|
||||
|
||||
:param problem_size: the Conv2d problem size
|
||||
:type problem_size: :class:`cutlass.conv.Conv2dProblemSize`
|
||||
|
||||
:param A: tensor A
|
||||
:type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
||||
|
||||
@ -78,6 +81,7 @@ class Conv2dArguments(ArgumentBase):
|
||||
split_k_mode: 'cutlass.conv.SplitKMode'
|
||||
= cutlass.conv.SplitKMode.Serial, **kwargs) -> None:
|
||||
|
||||
self.operation = operation
|
||||
#: convolution kind
|
||||
self.conv_kind: cutlass.conv.Operator = operation.conv_kind
|
||||
self.layout_A: cutlass.layout = operation.A.layout
|
||||
@ -93,15 +97,12 @@ class Conv2dArguments(ArgumentBase):
|
||||
|
||||
super().__init__(A, B, C, D, **kwargs)
|
||||
# preprocessing output ops
|
||||
if "output_op" in kwargs.keys() and \
|
||||
|
||||
if 'output_op' in kwargs.keys() and \
|
||||
split_k_mode != cutlass.conv.SplitKMode.Parallel:
|
||||
self.alpha = kwargs["output_op"].alpha
|
||||
self.beta = kwargs["output_op"].beta
|
||||
self.output_op = kwargs['output_op']
|
||||
else:
|
||||
self.alpha = 1.0
|
||||
self.beta = 0.0
|
||||
|
||||
self.element_compute = operation.element_epilogue
|
||||
self.output_op = self.operation.epilogue_type(1.0, 0.0)
|
||||
|
||||
if "split_k_slices" in kwargs.keys():
|
||||
self.split_k_mode = split_k_mode
|
||||
@ -114,7 +115,12 @@ class Conv2dArguments(ArgumentBase):
|
||||
self.problem_size: cutlass.conv.Conv2dProblemSize = problem_size
|
||||
self.problem_size.split_k_slices = self.split_k_slices
|
||||
|
||||
self.operation = operation
|
||||
if hasattr(self, "tensor_c_numel"):
|
||||
c_coord = cutlass.conv.implicit_gemm_tensor_c_extent(
|
||||
self.conv_kind, problem_size)
|
||||
if (self.tensor_c_numel == c_coord.at(3) and
|
||||
self.tensor_c_numel < c_coord.size()):
|
||||
self.bias = True
|
||||
|
||||
#
|
||||
# initialize the argument
|
||||
@ -159,6 +165,9 @@ class Conv2dArguments(ArgumentBase):
|
||||
self.conv_kind, problem_size)
|
||||
else:
|
||||
raise ValueError("unknown operand: " + operand)
|
||||
# Zero stride trick
|
||||
if operand == "c" and self.bias:
|
||||
tensor_coord = cutlass.Tensor4DCoord(0, 0, 0, 0)
|
||||
|
||||
layout = tensor_layout.packed(tensor_coord)
|
||||
|
||||
@ -174,24 +183,9 @@ class Conv2dArguments(ArgumentBase):
|
||||
ref_D = TensorRef_(self.get_tensor_ref(
|
||||
self.ptr_D, self.element_C, self.layout_C, self.problem_size, "d"))
|
||||
|
||||
if self.element_compute == cutlass.float16:
|
||||
alpha = cutlass.float16(self.alpha).storage
|
||||
beta = cutlass.float16(self.beta).storage
|
||||
elif self.element_compute == cutlass.int32:
|
||||
alpha = int(self.alpha)
|
||||
beta = int(self.beta)
|
||||
else:
|
||||
alpha = self.alpha
|
||||
beta = self.beta
|
||||
|
||||
argument_type, epilogue_type = get_conv2d_arguments(
|
||||
self.operation.element_epilogue)
|
||||
|
||||
output_op = epilogue_type(alpha, beta, 0, 0)
|
||||
|
||||
self.c_arguments = argument_type(
|
||||
self.c_arguments = self.operation.argument_type(
|
||||
Conv2DProblemSize(self.problem_size),
|
||||
ref_A, ref_B, ref_C, ref_D, output_op, self.split_k_mode
|
||||
ref_A, ref_B, ref_C, ref_D, self.output_op, self.split_k_mode
|
||||
)
|
||||
|
||||
self.semaphore = semaphore
|
||||
@ -296,9 +290,8 @@ extern "C" {
|
||||
|
||||
def __init__(self, operation: 'Conv2dOperation'):
|
||||
super().__init__(operation)
|
||||
|
||||
self.argtype = [ctypes.POINTER(get_conv2d_arguments(
|
||||
operation.element_epilogue)[0]), ctypes.c_void_p]
|
||||
self.argument_type, self.epilogue_type = get_conv2d_arguments(operation.epilogue_functor)
|
||||
self.argtype = [ctypes.POINTER(self.argument_type), ctypes.c_void_p]
|
||||
self.conv_kind = operation.conv_kind
|
||||
|
||||
self.operation: Conv2dOperation = operation
|
||||
@ -410,9 +403,7 @@ class Conv2dOperation:
|
||||
iterator_algorithm: cutlass.conv.IteratorAlgorithm,
|
||||
arch: int, tile_description: TileDescription,
|
||||
A: TensorDescription, B: TensorDescription, C: TensorDescription,
|
||||
element_epilogue: Union[cutlass.int8, cutlass.int32, cutlass.float16,
|
||||
cutlass.bfloat16, cutlass.float32, cutlass.float64],
|
||||
stride_support, epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support, epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1):
|
||||
|
||||
self.operation_kind: OperationKind = OperationKind.Conv2d
|
||||
@ -422,13 +413,14 @@ class Conv2dOperation:
|
||||
self.A: TensorDescription = A
|
||||
self.B: TensorDescription = B
|
||||
self.C: TensorDescription = C
|
||||
self.element_epilogue = element_epilogue
|
||||
self.epilogue_functor = epilogue_functor
|
||||
self.iterator_algorithm = iterator_algorithm
|
||||
self.stride_support = stride_support
|
||||
self.swizzling_functor = swizzling_functor()
|
||||
|
||||
self.rt_module: Conv2dRT = Conv2dRT(self)
|
||||
self.argument_type = self.rt_module.argument_type
|
||||
self.epilogue_type = self.rt_module.epilogue_type
|
||||
|
||||
def run(self, arguments: Conv2dArguments) -> cuda.CUresult:
|
||||
"""
|
||||
@ -577,12 +569,7 @@ typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>,
|
||||
${epilogue_functor},
|
||||
${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
|
||||
${stages},
|
||||
${math_operator},
|
||||
@ -629,8 +616,7 @@ struct ${operation_name}${operation_suffix}:
|
||||
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
||||
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
||||
'epilogue_vector_length': str(epilogue_vector_length),
|
||||
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
||||
'epilogue_functor': operation.epilogue_functor.emit(),
|
||||
'swizzling_functor': operation.swizzling_functor.tag(),
|
||||
'stages': str(operation.tile_description.stages),
|
||||
'iterator_algorithm': IteratorAlgorithmTag[operation.iterator_algorithm],
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -116,6 +116,12 @@ class GemmArguments(ArgumentBase):
|
||||
else:
|
||||
self.problem_size = cutlass.gemm.GemmCoord(
|
||||
problem_size.m(), problem_size.n(), problem_size.k())
|
||||
|
||||
# if the number of elements in C = problem_size.n
|
||||
# C is treated as the bias
|
||||
if hasattr(self, "tensor_c_numel"):
|
||||
if (self.tensor_c_numel == self.problem_size.n() and
|
||||
self.problem_size.m() != 1): self.bias = True
|
||||
|
||||
# get the leading dimension
|
||||
self.lda = operation.A.layout.packed(self.problem_size.mk()).stride()
|
||||
@ -123,27 +129,69 @@ class GemmArguments(ArgumentBase):
|
||||
self.ldc = operation.C.layout.packed(self.problem_size.mn()).stride()
|
||||
self.ldd = self.ldc
|
||||
|
||||
# stride 0 trick
|
||||
if self.bias:
|
||||
self.ldc = 0
|
||||
|
||||
if 'output_op' in kwargs.keys() and \
|
||||
gemm_mode != cutlass.gemm.Mode.GemmSplitKParallel:
|
||||
self.alpha = kwargs['output_op'].alpha
|
||||
self.beta = kwargs['output_op'].beta
|
||||
self.output_op = kwargs['output_op']
|
||||
else:
|
||||
self.alpha = 1.0
|
||||
self.beta = 0.0
|
||||
self.output_op = self.operation.epilogue_type(1.0, 0.0)
|
||||
|
||||
# get number of slices on k dimension
|
||||
self.gemm_mode = gemm_mode
|
||||
if 'split_k_slices' in kwargs.keys():
|
||||
self.split_k_slices = kwargs['split_k_slices']
|
||||
else:
|
||||
self.split_k_slices = 1
|
||||
if gemm_mode in [cutlass.gemm.Mode.Gemm, cutlass.gemm.Mode.GemmSplitKParallel]:
|
||||
if 'split_k_slices' in kwargs.keys():
|
||||
self.batch_count = kwargs['split_k_slices']
|
||||
else:
|
||||
self.batch_count = 1
|
||||
self.split_k_slices = self.batch_count
|
||||
|
||||
self.batch_count = self.split_k_slices
|
||||
if gemm_mode in [cutlass.gemm.Mode.Batched, cutlass.gemm.Mode.Array]:
|
||||
if 'batch' in kwargs.keys():
|
||||
self.batch_count = kwargs['batch']
|
||||
else:
|
||||
self.batch_count = 1
|
||||
|
||||
self.batched_stride_A = self.problem_size.m() * self.problem_size.k()
|
||||
self.batched_stride_B = self.problem_size.n() * self.problem_size.k()
|
||||
self.batched_stride_C = self.problem_size.m() * self.problem_size.n()
|
||||
self.batched_stride_D = self.problem_size.m() * self.problem_size.n()
|
||||
if self.bias:
|
||||
self.batched_stride_C = self.problem_size.n()
|
||||
|
||||
# support GEMM Mode Array
|
||||
if gemm_mode == cutlass.gemm.Mode.Array:
|
||||
self.ptr_A_array = []
|
||||
self.ptr_B_array = []
|
||||
self.ptr_C_array = []
|
||||
self.ptr_D_array = []
|
||||
|
||||
ptr_A_addr = int(self.ptr_A)
|
||||
ptr_B_addr = int(self.ptr_B)
|
||||
ptr_C_addr = int(self.ptr_C)
|
||||
ptr_D_addr = int(self.ptr_D)
|
||||
|
||||
stride_A = self.batched_stride_A * DataTypeSize[self.element_A] // 8
|
||||
stride_B = self.batched_stride_B * DataTypeSize[self.element_B] // 8
|
||||
stride_C = self.batched_stride_C * DataTypeSize[self.element_C] // 8
|
||||
stride_D = self.batched_stride_D * DataTypeSize[self.element_C] // 8
|
||||
for _ in range(self.batch_count):
|
||||
self.ptr_A_array.append(ptr_A_addr)
|
||||
self.ptr_B_array.append(ptr_B_addr)
|
||||
self.ptr_C_array.append(ptr_C_addr)
|
||||
self.ptr_D_array.append(ptr_D_addr)
|
||||
|
||||
ptr_A_addr += stride_A
|
||||
ptr_B_addr += stride_B
|
||||
ptr_C_addr += stride_C
|
||||
ptr_D_addr += stride_D
|
||||
|
||||
self.ptr_A_array_buffer = todevice(self.ptr_A_array, dtype=np.int64)
|
||||
self.ptr_B_array_buffer = todevice(self.ptr_B_array, dtype=np.int64)
|
||||
self.ptr_C_array_buffer = todevice(self.ptr_C_array, dtype=np.int64)
|
||||
self.ptr_D_array_buffer = todevice(self.ptr_D_array, dtype=np.int64)
|
||||
|
||||
if isinstance(self.operation, GemmOperationUniversal):
|
||||
self.initialize()
|
||||
@ -195,28 +243,28 @@ class GemmArguments(ArgumentBase):
|
||||
self.grid_tiled_shape.z
|
||||
)
|
||||
)
|
||||
|
||||
argument_type, epilogue_type = get_gemm_arguments(
|
||||
self.operation.element_epilogue)
|
||||
|
||||
if self.operation.element_epilogue == cutlass.float16:
|
||||
self.alpha = cutlass.float16(self.alpha).storage
|
||||
self.beta = cutlass.float16(self.beta).storage
|
||||
elif self.operation.element_epilogue == cutlass.int32:
|
||||
self.alpha = int(self.alpha)
|
||||
self.beta = int(self.beta)
|
||||
|
||||
output_op = epilogue_type(self.alpha, self.beta, 0, 0)
|
||||
|
||||
arguments = argument_type(
|
||||
self.gemm_mode, problem_size_, self.batch_count, output_op,
|
||||
int(self.ptr_A), int(self.ptr_B), int(self.ptr_C), int(self.ptr_D),
|
||||
self.batched_stride_A, self.batched_stride_B, self.batched_stride_C,
|
||||
self.batched_stride_D,
|
||||
self.lda, self.ldb, self.ldc, self.ldd,
|
||||
self.lda, self.ldb, self.ldc, self.ldd,
|
||||
0, 0, 0
|
||||
)
|
||||
if self.gemm_mode == cutlass.gemm.Mode.Array:
|
||||
arguments = self.operation.argument_type(
|
||||
self.gemm_mode, problem_size_, self.batch_count, self.output_op,
|
||||
int(self.ptr_A_array_buffer.ptr),
|
||||
int(self.ptr_B_array_buffer.ptr),
|
||||
int(self.ptr_C_array_buffer.ptr),
|
||||
int(self.ptr_D_array_buffer.ptr),
|
||||
0, 0, 0, 0,
|
||||
self.lda, self.ldb, self.ldc, self.ldd,
|
||||
self.lda, self.ldb, self.ldc, self.ldd,
|
||||
0, 0, 0
|
||||
)
|
||||
else:
|
||||
arguments = self.operation.argument_type(
|
||||
self.gemm_mode, problem_size_, self.batch_count, self.output_op,
|
||||
int(self.ptr_A), int(self.ptr_B), int(self.ptr_C), int(self.ptr_D),
|
||||
self.batched_stride_A, self.batched_stride_B, self.batched_stride_C,
|
||||
self.batched_stride_D,
|
||||
self.lda, self.ldb, self.ldc, self.ldd,
|
||||
self.lda, self.ldb, self.ldc, self.ldd,
|
||||
0, 0, 0
|
||||
)
|
||||
|
||||
self.arguments = arguments, grid_tiled_shape_, self.gemm_k_size
|
||||
|
||||
@ -381,13 +429,12 @@ class GemmGroupedArguments:
|
||||
else:
|
||||
self.alpha = 1.0
|
||||
self.beta = 0.0
|
||||
|
||||
if 'output_op' in kwargs.keys():
|
||||
self.output_op = kwargs['output_op']
|
||||
else:
|
||||
self.output_op = self.operation.epilogue_type(1.0, 0.0)
|
||||
|
||||
if self.operation.element_epilogue == cutlass.float16:
|
||||
self.alpha = cutlass.float16(self.alpha).storage
|
||||
self.beta = cutlass.float16(self.beta).storage
|
||||
elif self.operation.element_epilogue == cutlass.int32:
|
||||
self.alpha = int(self.alpha)
|
||||
self.beta = int(self.beta)
|
||||
|
||||
# get host problem size
|
||||
self.host_problem_size_ptr = np.array(
|
||||
@ -398,12 +445,7 @@ class GemmGroupedArguments:
|
||||
self.initialize()
|
||||
|
||||
def get_arguments(self):
|
||||
|
||||
argument_type, epilogue_type = get_gemm_grouped_arguments(
|
||||
self.operation.element_epilogue)
|
||||
self.output_op = epilogue_type(self.alpha, self.beta, 0, 0)
|
||||
|
||||
return argument_type(
|
||||
return self.operation.argument_type(
|
||||
self.problem_size_buffer.ptr, self.problem_count, self.total_tiles,
|
||||
self.output_op, self.ptr_A_buffer.ptr, self.ptr_B_buffer.ptr,
|
||||
self.ptr_C_buffer.ptr, self.ptr_D_buffer.ptr, self.lda_buffer.ptr,
|
||||
@ -492,16 +534,6 @@ ${operation_name}(${operation_name}${operation_suffix}::Params params) {
|
||||
#: number of threads per threadblock
|
||||
self.threads: int = operation.tile_description.num_threads
|
||||
|
||||
if (operation.epilogue_functor in
|
||||
[
|
||||
EpilogueFunctor.LinearCombination,
|
||||
EpilogueFunctor.FastLinearCombinationClamp,
|
||||
EpilogueFunctor.LinearCombinationClamp
|
||||
]):
|
||||
self.output_op = LinearCombinationFunctor()
|
||||
else:
|
||||
raise ValueError("unknown epilogue functor")
|
||||
|
||||
#
|
||||
def emit(self):
|
||||
return self.emitter.emit(self.operation)
|
||||
@ -568,9 +600,11 @@ extern "C" {
|
||||
def __init__(self, operation: 'GemmOperation'):
|
||||
super(GemmRTUniversal, self).__init__(operation)
|
||||
self.emitter = EmitGemmUniversalInstance(
|
||||
'_type', operation.direct_store)
|
||||
'_type', operation.direct_store, operation.visitor)
|
||||
|
||||
self.argument_type, self.epilogue_type = get_gemm_arguments(operation.epilogue_functor)
|
||||
self.argtype = [
|
||||
ctypes.POINTER(get_gemm_arguments(operation.element_epilogue)[0]),
|
||||
ctypes.POINTER(self.argument_type),
|
||||
ctypes.POINTER(GemmCoord_), ctypes.c_int, ctypes.c_void_p
|
||||
]
|
||||
|
||||
@ -673,8 +707,8 @@ class GemmRTGrouped(GemmRTbase):
|
||||
self.extra_funcs = ['precompute']
|
||||
|
||||
self.emitter = EmitGemmGroupedInstance('_type')
|
||||
self.argtype = [ctypes.POINTER(get_gemm_grouped_arguments(
|
||||
operation.element_epilogue)[0]), ctypes.c_int, ctypes.c_void_p]
|
||||
self.argument_type, self.epilogue_type = get_gemm_grouped_arguments(operation.epilogue_functor)
|
||||
self.argtype = [ctypes.POINTER(self.argument_type), ctypes.c_int, ctypes.c_void_p]
|
||||
|
||||
def host_precompute(self, arguments, workspace_bytes):
|
||||
self.precompute.argtype = [
|
||||
@ -717,7 +751,7 @@ class GemmOperationBase:
|
||||
def __init__(
|
||||
self, gemm_kind, arch, tile_description: TileDescription,
|
||||
A: TensorDescription, B: TensorDescription, C: TensorDescription,
|
||||
element_epilogue, epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1, **kwargs):
|
||||
|
||||
#: operation kind
|
||||
@ -749,7 +783,7 @@ class GemmOperationBase:
|
||||
#: Operand C
|
||||
self.C: TensorDescription = copy.deepcopy(C)
|
||||
self.switched = False
|
||||
self.element_epilogue = element_epilogue
|
||||
|
||||
self.epilogue_functor = epilogue_functor
|
||||
self.swizzling_functor = swizzling_functor()
|
||||
|
||||
@ -757,6 +791,11 @@ class GemmOperationBase:
|
||||
self.direct_store = kwargs["direct_store"]
|
||||
else:
|
||||
self.direct_store = False
|
||||
|
||||
if "visitor" in kwargs:
|
||||
self.visitor = kwargs["visitor"]
|
||||
else:
|
||||
self.visitor = False
|
||||
|
||||
def run(self, arguments: GemmArguments) -> cuda.CUresult:
|
||||
"""
|
||||
@ -895,22 +934,26 @@ class GemmOperationBase:
|
||||
|
||||
|
||||
class GemmOperationUniversal(GemmOperationBase):
|
||||
def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C, element_epilogue,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs):
|
||||
def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C,
|
||||
epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs):
|
||||
super(GemmOperationUniversal, self).__init__(GemmKind.Universal, arch, tile_description,
|
||||
A, B, C, element_epilogue, epilogue_functor, swizzling_functor, **kwargs)
|
||||
A, B, C, epilogue_functor, swizzling_functor, **kwargs)
|
||||
self.rt_module = GemmRTUniversal(self)
|
||||
self.argument_type = self.rt_module.argument_type
|
||||
self.epilogue_type = self.rt_module.epilogue_type
|
||||
|
||||
|
||||
class GemmOperationGrouped(GemmOperationBase):
|
||||
def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C, element_epilogue,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs):
|
||||
def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C,
|
||||
epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs):
|
||||
super(GemmOperationGrouped, self).__init__(GemmKind.Grouped, arch, tile_description,
|
||||
A, B, C, element_epilogue, epilogue_functor, swizzling_functor, **kwargs)
|
||||
A, B, C, epilogue_functor, swizzling_functor, **kwargs)
|
||||
assert "precompute_mode" in kwargs.keys(
|
||||
), "missing keyword arguement 'precompute_mode'."
|
||||
self.precompute_mode = kwargs["precompute_mode"]
|
||||
self.rt_module = GemmRTGrouped(self)
|
||||
self.argument_type = self.rt_module.argument_type
|
||||
self.epilogue_type = self.rt_module.epilogue_type
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
@ -918,228 +961,14 @@ class GemmOperationGrouped(GemmOperationBase):
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
|
||||
|
||||
class EmitGemmInstance:
|
||||
''' Responsible for emitting a CUTLASS template definition'''
|
||||
|
||||
def __init__(self, operation_suffix=''):
|
||||
self.operation_suffix = operation_suffix
|
||||
self.includes = []
|
||||
self.gemm_template = """
|
||||
// Gemm operator ${operation_name}
|
||||
using Operation_${operation_name} = cutlass::gemm::device::Gemm<
|
||||
${element_a}, ${layout_a},
|
||||
${element_b}, ${layout_b},
|
||||
${element_c}, ${layout_c},
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>,
|
||||
${swizzling_functor},
|
||||
${stages},
|
||||
${align_a},
|
||||
${align_b},
|
||||
false,
|
||||
${math_operation}
|
||||
${residual}
|
||||
>;
|
||||
"""
|
||||
self.gemm_complex_template = """
|
||||
// Gemm operator ${operation_name}
|
||||
using Operation_${operation_name} = cutlass::gemm::device::GemmComplex<
|
||||
${element_a}, ${layout_a},
|
||||
${element_b}, ${layout_b},
|
||||
${element_c}, ${layout_c},
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>,
|
||||
${swizzling_functor},
|
||||
${stages},
|
||||
${transform_a},
|
||||
${transform_b},
|
||||
${math_operation}
|
||||
${residual}
|
||||
>;
|
||||
"""
|
||||
|
||||
#
|
||||
def instance_template(self):
|
||||
return """
|
||||
${compile_guard_start}
|
||||
manifest.append(new ${gemm_kind}<Operation_${operation_name}>("${operation_name}"));
|
||||
${compile_guard_end}
|
||||
"""
|
||||
|
||||
#
|
||||
def emit(self, operation):
|
||||
|
||||
warp_shape = [operation.tile_description.threadblock_shape[idx] //
|
||||
operation.tile_description.warp_count[idx] for idx in range(3)]
|
||||
|
||||
epilogue_vector_length = int(min(
|
||||
operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
|
||||
|
||||
residual = ''
|
||||
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
'element_a': DataTypeTag[operation.A.element],
|
||||
'layout_a': LayoutTag[operation.A.layout],
|
||||
'element_b': DataTypeTag[operation.B.element],
|
||||
'layout_b': LayoutTag[operation.B.layout],
|
||||
'element_c': DataTypeTag[operation.C.element],
|
||||
'layout_c': LayoutTag[operation.C.layout],
|
||||
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
||||
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
||||
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
||||
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
||||
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
||||
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
||||
'warp_shape_m': str(warp_shape[0]),
|
||||
'warp_shape_n': str(warp_shape[1]),
|
||||
'warp_shape_k': str(warp_shape[2]),
|
||||
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
||||
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
||||
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
||||
'epilogue_vector_length': str(epilogue_vector_length),
|
||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
||||
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
||||
'swizzling_functor': operation.swizzling_functor.tag(),
|
||||
'stages': str(operation.tile_description.stages),
|
||||
'align_a': str(operation.A.alignment),
|
||||
'align_b': str(operation.B.alignment),
|
||||
'transform_a': ComplexTransformTag[operation.A.complex_transform],
|
||||
'transform_b': ComplexTransformTag[operation.B.complex_transform],
|
||||
'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
|
||||
'residual': residual
|
||||
}
|
||||
|
||||
template = self.gemm_complex_template if operation.is_complex() else self.gemm_template
|
||||
|
||||
return SubstituteTemplate(template, values)
|
||||
|
||||
###################################################################################################
|
||||
|
||||
|
||||
class EmitSparseGemmInstance:
|
||||
''' Responsible for emitting a CUTLASS template definition'''
|
||||
|
||||
def __init__(self, operation_suffix=''):
|
||||
self.operation_suffix = operation_suffix
|
||||
self.includes = []
|
||||
self.gemm_template = """
|
||||
// Gemm operator ${operation_name}
|
||||
using Operation_${operation_name} = cutlass::gemm::device::SparseGemm<
|
||||
${element_a}, ${layout_a},
|
||||
${element_b}, ${layout_b},
|
||||
${element_c}, ${layout_c},
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>,
|
||||
${swizzling_functor},
|
||||
${stages},
|
||||
${align_a},
|
||||
${align_b},
|
||||
false,
|
||||
${math_operation}
|
||||
${residual}
|
||||
>;
|
||||
"""
|
||||
|
||||
#
|
||||
def instance_template(self):
|
||||
return """
|
||||
${compile_guard_start}
|
||||
manifest.append(new ${gemm_kind}<Operation_${operation_name}>("${operation_name}"));
|
||||
${compile_guard_end}
|
||||
"""
|
||||
|
||||
#
|
||||
def emit(self, operation):
|
||||
|
||||
warp_shape = [operation.tile_description.threadblock_shape[idx] //
|
||||
operation.tile_description.warp_count[idx] for idx in range(3)]
|
||||
|
||||
epilogue_vector_length = int(min(
|
||||
operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
|
||||
|
||||
residual = ''
|
||||
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
'element_a': DataTypeTag[operation.A.element],
|
||||
'layout_a': LayoutTag[operation.A.layout],
|
||||
'element_b': DataTypeTag[operation.B.element],
|
||||
'layout_b': LayoutTag[operation.B.layout],
|
||||
'element_c': DataTypeTag[operation.C.element],
|
||||
'layout_c': LayoutTag[operation.C.layout],
|
||||
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
||||
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
||||
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
||||
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
||||
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
||||
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
||||
'warp_shape_m': str(warp_shape[0]),
|
||||
'warp_shape_n': str(warp_shape[1]),
|
||||
'warp_shape_k': str(warp_shape[2]),
|
||||
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
||||
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
||||
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
||||
'epilogue_vector_length': str(epilogue_vector_length),
|
||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
||||
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
||||
'swizzling_functor': operation.swizzling_functor.tag(),
|
||||
'stages': str(operation.tile_description.stages),
|
||||
'align_a': str(operation.A.alignment),
|
||||
'align_b': str(operation.B.alignment),
|
||||
'transform_a': ComplexTransformTag[operation.A.complex_transform],
|
||||
'transform_b': ComplexTransformTag[operation.B.complex_transform],
|
||||
'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
|
||||
'residual': residual
|
||||
}
|
||||
|
||||
template = self.gemm_template
|
||||
|
||||
return SubstituteTemplate(template, values)
|
||||
|
||||
###################################################################################################
|
||||
|
||||
|
||||
#
|
||||
class EmitGemmUniversalInstance:
|
||||
''' Responsible for emitting a CUTLASS template definition'''
|
||||
|
||||
def __init__(self, operation_suffix='', direct_store=False):
|
||||
def __init__(self, operation_suffix='', direct_store=False, visitor=False):
|
||||
self.operation_suffix = operation_suffix
|
||||
self.direct_store = direct_store
|
||||
self.visitor = visitor
|
||||
self.includes = [
|
||||
"cutlass/cutlass.h",
|
||||
"cutlass/numeric_types.h",
|
||||
@ -1150,46 +979,15 @@ class EmitGemmUniversalInstance:
|
||||
"cutlass/gemm/device/gemm_universal_adapter.h",
|
||||
"cutlass/gemm/kernel/default_gemm_universal.h",
|
||||
]
|
||||
if self.visitor:
|
||||
self.includes += [
|
||||
"gemm/gemm_universal_with_visitor.h",
|
||||
"epilogue/epilogue_visitor_with_layernorm.h",
|
||||
"epilogue/epilogue_visitor_generic.h"
|
||||
]
|
||||
if self.direct_store:
|
||||
self.includes.append(
|
||||
"cutlass/epilogue/threadblock/default_epilogue_direct_store.h")
|
||||
self.builtin_epilogue_functor_template = """
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>
|
||||
"""
|
||||
self.builtin_epilogue_functor_template_clamp = """
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length}
|
||||
>
|
||||
"""
|
||||
self.gemm_template = """
|
||||
// Gemm operator ${operation_name}
|
||||
using ${operation_name}_base =
|
||||
typename cutlass::gemm::kernel::DefaultGemmUniversal<
|
||||
${element_a}, ${layout_a}, ${transform_a}, ${align_a},
|
||||
${element_b}, ${layout_b}, ${transform_b}, ${align_b},
|
||||
${element_c}, ${layout_c},
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor},
|
||||
${swizzling_functor},
|
||||
${stages},
|
||||
${math_operation}
|
||||
>::GemmKernel;
|
||||
|
||||
// Define named type
|
||||
struct ${operation_name}${operation_suffix} :
|
||||
public ${operation_name}_base { };
|
||||
"""
|
||||
self.gemm_template_interleaved = """
|
||||
// Gemm operator ${operation_name}
|
||||
using ${operation_name}_base =
|
||||
@ -1241,6 +1039,42 @@ using ${operation_name}_base =
|
||||
${operation_name}_default::ThreadblockSwizzle
|
||||
>;
|
||||
|
||||
// Define named type
|
||||
struct ${operation_name}${operation_suffix} :
|
||||
public ${operation_name}_base { };
|
||||
"""
|
||||
self.gemm_template_visitor = """
|
||||
// Gemm operator ${operation_name}
|
||||
using ${operation_name}_default =
|
||||
typename cutlass::gemm::kernel::DefaultGemmUniversal<
|
||||
${element_a}, ${layout_a}, ${transform_a}, ${align_a},
|
||||
${element_b}, ${layout_b}, ${transform_b}, ${align_b},
|
||||
${element_c}, ${layout_c},
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${elementwise_epilogue_functor},
|
||||
${swizzling_functor},
|
||||
${stages},
|
||||
${math_operation}
|
||||
>::GemmKernel;
|
||||
|
||||
${epilogue_visitor}
|
||||
|
||||
using ${operation_name}_Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue<
|
||||
${operation_name}_EpilogueVisitor,
|
||||
typename ${operation_name}_default::Epilogue>::Epilogue;
|
||||
|
||||
using ${operation_name}_base =
|
||||
cutlass::gemm::kernel::GemmUniversalwithEpilogueVisitor<
|
||||
${operation_name}_default::Mma,
|
||||
${operation_name}_Epilogue,
|
||||
${operation_name}_default::ThreadblockSwizzle
|
||||
>;
|
||||
|
||||
// Define named type
|
||||
struct ${operation_name}${operation_suffix} :
|
||||
public ${operation_name}_base { };
|
||||
@ -1284,32 +1118,12 @@ ${compile_guard_end}
|
||||
(operation.A.layout, operation.B.layout, operation.C.layout)
|
||||
if self.direct_store:
|
||||
gemm_template = self.gemm_template_direct_store
|
||||
elif self.visitor:
|
||||
gemm_template = self.gemm_template_visitor
|
||||
else:
|
||||
gemm_template = self.gemm_template_interleaved
|
||||
#
|
||||
|
||||
# Support built-in epilogue functors or user-defined functions
|
||||
if isinstance(operation.epilogue_functor, enum.Enum):
|
||||
|
||||
epilogue_vector_length = \
|
||||
min(operation.C.alignment *
|
||||
DataTypeSize[operation.C.element], 128) // DataTypeSize[operation.C.element]
|
||||
|
||||
values = {
|
||||
'epilogue_vector_length': str(epilogue_vector_length),
|
||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
||||
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
||||
}
|
||||
if operation.epilogue_functor == EpilogueFunctor.FastLinearCombinationClamp:
|
||||
epilogue_functor = SubstituteTemplate(
|
||||
self.builtin_epilogue_functor_template_clamp, values)
|
||||
else:
|
||||
epilogue_functor = SubstituteTemplate(
|
||||
self.builtin_epilogue_functor_template, values)
|
||||
else:
|
||||
epilogue_functor = self.epilogue_functor.emit_declaration()
|
||||
#
|
||||
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
'operation_suffix': self.operation_suffix,
|
||||
@ -1331,7 +1145,6 @@ ${compile_guard_end}
|
||||
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
||||
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
||||
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
||||
'epilogue_functor': epilogue_functor,
|
||||
'swizzling_functor': operation.swizzling_functor.tag(),
|
||||
'stages': str(operation.tile_description.stages),
|
||||
'align_a': str(operation.A.alignment),
|
||||
@ -1341,6 +1154,12 @@ ${compile_guard_end}
|
||||
'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation]
|
||||
}
|
||||
|
||||
if self.visitor:
|
||||
values['epilogue_visitor'] = operation.epilogue_functor.emit(operation)
|
||||
values['elementwise_epilogue_functor'] = operation.epilogue_functor.elementwise_functor.emit()
|
||||
else:
|
||||
values['epilogue_functor'] = operation.epilogue_functor.emit()
|
||||
|
||||
return SubstituteTemplate(gemm_template, values)
|
||||
|
||||
###################################################################################################
|
||||
@ -1348,185 +1167,6 @@ ${compile_guard_end}
|
||||
#
|
||||
|
||||
|
||||
class EmitGemmPlanarComplexInstance:
|
||||
''' Responsible for emitting a CUTLASS template definition'''
|
||||
|
||||
def __init__(self, operation_suffix=''):
|
||||
self.operation_suffix = operation_suffix
|
||||
self.includes = []
|
||||
self.template = """
|
||||
// Gemm operator ${operation_name}
|
||||
using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal<
|
||||
${element_a}, ${layout_a}, ${transform_a}, ${alignment_a},
|
||||
${element_b}, ${layout_b}, ${transform_b}, ${alignment_b},
|
||||
${element_c}, cutlass::layout::RowMajor,
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
cutlass::epilogue::thread::LinearCombinationPlanarComplex<
|
||||
${element_c},
|
||||
${alignment_c},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
${stages},
|
||||
${math_operator}
|
||||
>::GemmKernel;
|
||||
|
||||
struct ${operation_name} :
|
||||
public Operation_${operation_name} { };
|
||||
"""
|
||||
|
||||
#
|
||||
def instance_template(self):
|
||||
return """
|
||||
${compile_guard_start}
|
||||
manifest.append(new ${gemm_kind}<
|
||||
cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
|
||||
>("${operation_name}"));
|
||||
${compile_guard_end}
|
||||
"""
|
||||
|
||||
#
|
||||
def emit(self, operation):
|
||||
|
||||
warp_shape = [operation.tile_description.threadblock_shape[idx] //
|
||||
operation.tile_description.warp_count[idx] for idx in range(3)]
|
||||
|
||||
# exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major
|
||||
transposed_layout_A = TransposedLayout[operation.A.layout]
|
||||
transposed_layout_B = TransposedLayout[operation.B.layout]
|
||||
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
'element_a': DataTypeTag[operation.B.element],
|
||||
'layout_a': LayoutTag[transposed_layout_B],
|
||||
'transform_a': ComplexTransformTag[operation.B.complex_transform],
|
||||
'alignment_a': str(operation.B.alignment),
|
||||
'element_b': DataTypeTag[operation.A.element],
|
||||
'layout_b': LayoutTag[transposed_layout_A],
|
||||
'transform_b': ComplexTransformTag[operation.A.complex_transform],
|
||||
'alignment_b': str(operation.A.alignment),
|
||||
'element_c': DataTypeTag[operation.C.element],
|
||||
'layout_c': LayoutTag[operation.C.layout],
|
||||
'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
|
||||
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
||||
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
||||
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
||||
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
||||
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
||||
'warp_shape_m': str(warp_shape[0]),
|
||||
'warp_shape_n': str(warp_shape[1]),
|
||||
'warp_shape_k': str(warp_shape[2]),
|
||||
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
||||
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
||||
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
||||
'alignment_c': str(operation.C.alignment),
|
||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
||||
'stages': str(operation.tile_description.stages),
|
||||
'math_operator': 'cutlass::arch::OpMultiplyAdd'
|
||||
}
|
||||
|
||||
return SubstituteTemplate(self.template, values)
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
|
||||
|
||||
class EmitGemmPlanarComplexArrayInstance:
|
||||
''' Responsible for emitting a CUTLASS template definition'''
|
||||
|
||||
def __init__(self, operation_suffix=''):
|
||||
self.operation_suffix = operation_suffix
|
||||
self.includes = []
|
||||
self.template = """
|
||||
// Gemm operator ${operation_name}
|
||||
using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal<
|
||||
${element_a}, ${layout_a}, ${transform_a}, ${alignment_a},
|
||||
${element_b}, ${layout_b}, ${transform_b}, ${alignment_b},
|
||||
${element_c}, cutlass::layout::RowMajor,
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
cutlass::epilogue::thread::LinearCombinationPlanarComplex<
|
||||
${element_c},
|
||||
${alignment_c},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
${stages},
|
||||
${math_operator}
|
||||
>::GemmArrayKernel;
|
||||
|
||||
struct ${operation_name} : public Operation_${operation_name} { };
|
||||
"""
|
||||
|
||||
#
|
||||
def instance_template(self):
|
||||
return """
|
||||
${compile_guard_start}
|
||||
manifest.append(new ${gemm_kind}<
|
||||
cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
|
||||
>("${operation_name}"));
|
||||
${compile_guard_end}
|
||||
"""
|
||||
|
||||
#
|
||||
def emit(self, operation):
|
||||
|
||||
warp_shape = [operation.tile_description.threadblock_shape[idx] //
|
||||
operation.tile_description.warp_count[idx] for idx in range(3)]
|
||||
|
||||
# exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major
|
||||
transposed_layout_A = TransposedLayout[operation.A.layout]
|
||||
transposed_layout_B = TransposedLayout[operation.B.layout]
|
||||
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
'element_a': DataTypeTag[operation.B.element],
|
||||
'layout_a': LayoutTag[transposed_layout_B],
|
||||
'transform_a': ComplexTransformTag[operation.B.complex_transform],
|
||||
'alignment_a': str(operation.B.alignment),
|
||||
'element_b': DataTypeTag[operation.A.element],
|
||||
'layout_b': LayoutTag[transposed_layout_A],
|
||||
'transform_b': ComplexTransformTag[operation.A.complex_transform],
|
||||
'alignment_b': str(operation.A.alignment),
|
||||
'element_c': DataTypeTag[operation.C.element],
|
||||
'layout_c': LayoutTag[operation.C.layout],
|
||||
'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
|
||||
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
||||
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
||||
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
||||
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
||||
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
||||
'warp_shape_m': str(warp_shape[0]),
|
||||
'warp_shape_n': str(warp_shape[1]),
|
||||
'warp_shape_k': str(warp_shape[2]),
|
||||
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
||||
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
||||
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
||||
'alignment_c': str(operation.C.alignment),
|
||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
||||
'stages': str(operation.tile_description.stages),
|
||||
'math_operator': 'cutlass::arch::OpMultiplyAdd'
|
||||
}
|
||||
|
||||
return SubstituteTemplate(self.template, values)
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
|
||||
|
||||
class EmitGemmGroupedInstance:
|
||||
''' Responsible for emitting a CUTLASS template definition'''
|
||||
|
||||
@ -1541,14 +1181,6 @@ class EmitGemmGroupedInstance:
|
||||
"cutlass/gemm/kernel/gemm_grouped.h",
|
||||
"cutlass/gemm/kernel/default_gemm_grouped.h"
|
||||
]
|
||||
self.builtin_epilogue_functor_template = """
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>
|
||||
"""
|
||||
self.gemm_template = """
|
||||
// Gemm operator ${operation_name}
|
||||
using ${operation_name}_base =
|
||||
@ -1598,23 +1230,8 @@ ${compile_guard_end}
|
||||
#
|
||||
|
||||
# Support built-in epilogue functors or user-defined functions
|
||||
if isinstance(operation.epilogue_functor, enum.Enum):
|
||||
|
||||
epilogue_vector_length = \
|
||||
min(operation.C.alignment *
|
||||
DataTypeSize[operation.C.element], 128) // DataTypeSize[operation.C.element]
|
||||
|
||||
values = {
|
||||
'epilogue_vector_length': str(epilogue_vector_length),
|
||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
||||
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
||||
}
|
||||
epilogue_functor = SubstituteTemplate(
|
||||
self.builtin_epilogue_functor_template, values)
|
||||
else:
|
||||
epilogue_functor = self.epilogue_functor.emit_declaration()
|
||||
#
|
||||
|
||||
epilogue_functor = operation.epilogue_functor.emit()
|
||||
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
'operation_suffix': self.operation_suffix,
|
||||
|
||||
@ -478,27 +478,6 @@ SharedMemPerCC = {
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
|
||||
|
||||
def SubstituteTemplate(template, values):
|
||||
text = template
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
for key, value in values.items():
|
||||
regex = "\\$\\{%s\\}" % key
|
||||
newtext = re.sub(regex, value, text)
|
||||
if newtext != text:
|
||||
changed = True
|
||||
text = newtext
|
||||
return text
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
|
||||
|
||||
class GemmKind(enum.Enum):
|
||||
Gemm = enum_auto()
|
||||
Sparse = enum_auto()
|
||||
@ -557,22 +536,6 @@ SymmKindNames = {
|
||||
#
|
||||
|
||||
|
||||
class EpilogueFunctor(enum.Enum):
|
||||
LinearCombination = enum_auto()
|
||||
LinearCombinationClamp = enum_auto()
|
||||
FastLinearCombinationClamp = enum_auto()
|
||||
|
||||
|
||||
#
|
||||
EpilogueFunctorTag = {
|
||||
EpilogueFunctor.LinearCombination: 'cutlass::epilogue::thread::LinearCombination',
|
||||
EpilogueFunctor.LinearCombinationClamp: 'cutlass::epilogue::thread::LinearCombinationClamp',
|
||||
EpilogueFunctor.FastLinearCombinationClamp: 'cutlass::epilogue::thread::FastLinearCombinationClamp'
|
||||
}
|
||||
|
||||
#
|
||||
|
||||
|
||||
class SwizzlingFunctor(enum.Enum):
|
||||
Identity1 = enum_auto()
|
||||
Identity2 = enum_auto()
|
||||
@ -700,7 +663,7 @@ class MathInstruction:
|
||||
|
||||
class TileDescription:
|
||||
|
||||
def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute):
|
||||
def __init__(self, threadblock_shape, stages, warp_count, math_instruction):
|
||||
self.threadblock_shape = threadblock_shape
|
||||
|
||||
#: number of pipeline stages
|
||||
@ -710,11 +673,6 @@ class TileDescription:
|
||||
self.warp_count: list[int] = warp_count
|
||||
self.math_instruction = math_instruction
|
||||
|
||||
#: minimum compute capability
|
||||
self.minimum_compute_capability: int = min_compute
|
||||
#: maximum compute capability
|
||||
self.maximum_compute_capability: int = max_compute
|
||||
|
||||
#: number threads per threadblock
|
||||
self.num_threads: int = 32
|
||||
for cnt in self.warp_count:
|
||||
|
||||
619
tools/library/scripts/pycutlass/src/pycutlass/parser.py
Normal file
619
tools/library/scripts/pycutlass/src/pycutlass/parser.py
Normal file
@ -0,0 +1,619 @@
|
||||
################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Generic, TypeVar
|
||||
from treelib import Tree
|
||||
import numpy as np
|
||||
|
||||
from pycutlass import *
|
||||
import pycutlass
|
||||
|
||||
import ast
|
||||
import textwrap
|
||||
import inspect
|
||||
|
||||
################################################################################
|
||||
# Type annotation for input arguments
|
||||
################################################################################
|
||||
|
||||
Ttype = TypeVar("Ttype")
|
||||
Dtype = TypeVar("Dtype")
|
||||
|
||||
class NDArray(np.ndarray, Generic[Ttype, Dtype]):
|
||||
pass
|
||||
|
||||
################################################################################
|
||||
# Operations
|
||||
################################################################################
|
||||
|
||||
operators = {
|
||||
ast.Add: "Add",
|
||||
ast.Div: "Div",
|
||||
ast.Eq: "Equal",
|
||||
ast.Mult: "Mult"
|
||||
}
|
||||
|
||||
################################################################################
|
||||
# AST Node abstractions
|
||||
################################################################################
|
||||
class UnaryNode:
|
||||
cnt = 0
|
||||
# Concept: this is created by the BinOp Node in python ast
|
||||
def __init__(self,
|
||||
element_accumulator, element_compute, elements_per_access,
|
||||
node, args) -> None:
|
||||
if isinstance(node, BinOpNode):
|
||||
self.op = node.op
|
||||
elif isinstance(node, ast.Call):
|
||||
if isinstance(node.func, ast.Name):
|
||||
self.op = node.func.id
|
||||
elif isinstance(node.func, ast.Attribute):
|
||||
self.op = node.func.value.id
|
||||
else:
|
||||
raise TypeError
|
||||
else:
|
||||
raise TypeError
|
||||
self.tag = "Unary" + self.op + str(UnaryNode.cnt)
|
||||
self.id = self.op + str(UnaryNode.cnt)
|
||||
self.args = args
|
||||
UnaryNode.cnt += 1
|
||||
|
||||
self.type = "tensor"
|
||||
|
||||
self.epilogue_op = getattr(pycutlass, self.op)(element_compute)
|
||||
|
||||
# data types
|
||||
self.element_accumulator = element_accumulator
|
||||
self.element_compute = element_compute
|
||||
self.elements_per_access = elements_per_access
|
||||
|
||||
def get_epilogue_node(self, visitors):
|
||||
self.epilogue_node = UnaryOp(
|
||||
self.element_accumulator, self.element_compute,
|
||||
self.elements_per_access, *visitors, self.epilogue_op)
|
||||
|
||||
def get_argument(self, visitor_args, kwargs):
|
||||
epilogue_ops = []
|
||||
for arg in self.args:
|
||||
try:
|
||||
epilogue_ops.append(kwargs[arg])
|
||||
except:
|
||||
epilogue_ops.append(arg) # direct arguments like constant
|
||||
self.argument = self.epilogue_node.argument_type(self.epilogue_op.argument_type(*epilogue_ops), *visitor_args)
|
||||
|
||||
|
||||
class BinOpNode:
|
||||
cnt = 0
|
||||
# Concept: this is created by the BinOp Node in python ast
|
||||
def __init__(self,
|
||||
element_accumulator, element_compute, elements_per_access,
|
||||
node) -> None:
|
||||
self.op = operators[type(node.op)]
|
||||
self.tag = "Binary" + self.op + str(BinOpNode.cnt)
|
||||
self.id = self.op + str(BinOpNode.cnt)
|
||||
self.args = None
|
||||
BinOpNode.cnt += 1
|
||||
|
||||
self.type = "tensor"
|
||||
|
||||
self.epilogue_op = getattr(pycutlass, "Vector"+self.op)(element_compute)
|
||||
|
||||
# data types
|
||||
self.element_accumulator = element_accumulator
|
||||
self.element_compute = element_compute
|
||||
self.elements_per_access = elements_per_access
|
||||
|
||||
def get_epilogue_node(self, visitors):
|
||||
self.epilogue_node = BinaryOp(
|
||||
self.element_accumulator, self.element_compute,
|
||||
self.elements_per_access, *visitors, self.epilogue_op)
|
||||
|
||||
def get_argument(self, visitor_args, kwargs):
|
||||
self.argument = self.epilogue_node.argument_type(self.epilogue_op.argument_type(self.args), *visitor_args)
|
||||
|
||||
|
||||
class NameNode:
|
||||
# Concept: this is created by the Name Node in python ast
|
||||
def __init__(self, node) -> None:
|
||||
try:
|
||||
self.id = node.id
|
||||
except:
|
||||
self.id = node.targets[0].id
|
||||
self.tag = self.id
|
||||
|
||||
class ScalarInputNode(NameNode):
|
||||
# Concept: scalar
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.tag = "Scalar:" + self.tag
|
||||
self.type = "scalar"
|
||||
|
||||
class AccumulatorNode(NameNode):
|
||||
# Concept: VisitorOpAccumulator
|
||||
def __init__(self,
|
||||
element_accumulator, elements_per_access, node) -> None:
|
||||
super().__init__(node)
|
||||
self.tag = "Accum:" + self.tag
|
||||
self.type = "tensor"
|
||||
|
||||
self.element_accumulator = element_accumulator
|
||||
self.elements_per_access = elements_per_access
|
||||
|
||||
def get_epilogue_node(self, visitors):
|
||||
self.epilogue_node = AccumulatorOp(
|
||||
self.element_accumulator, self.elements_per_access)
|
||||
|
||||
def get_argument(self, visitor_args, kwargs):
|
||||
self.argument = self.epilogue_node.argument_type()
|
||||
|
||||
class TensorInputNode(NameNode):
|
||||
# Concept: VisitorOpTensorInput
|
||||
def __init__(self, element_accumulator, node) -> None:
|
||||
super().__init__(node)
|
||||
self.tag = "TensorInput:" + self.tag
|
||||
self.type = "tensor"
|
||||
self.element_accumulator = element_accumulator
|
||||
|
||||
def get_epilogue_node(self, *args):
|
||||
self.epilogue_node = TensorInputOp(self.element_accumulator)
|
||||
|
||||
def get_argument(self, visitor_args, kwargs):
|
||||
self.argument = self.epilogue_node.argument_type(
|
||||
kwargs[self.id + "_ptr"], kwargs["problem_size"][1],
|
||||
kwargs["problem_size"][0] * kwargs["problem_size"][1])
|
||||
|
||||
class RowBroadcastNode(NameNode):
|
||||
# Concept: VisitorOpRowBroadcast
|
||||
def __init__(self, element_accumulator, element_fragment, node) -> None:
|
||||
super().__init__(node)
|
||||
#
|
||||
self.tag = "RowBroadcast:" + self.tag
|
||||
self.type = "tensor"
|
||||
self.element_accumulator = element_accumulator
|
||||
self.element_fragment = element_fragment
|
||||
|
||||
def get_epilogue_node(self, *args):
|
||||
self.epilogue_node = RowBroadcastOp(
|
||||
self.element_accumulator, self.element_fragment)
|
||||
|
||||
def get_argument(self, visitor_args, kwargs):
|
||||
self.argument = self.epilogue_node.argument_type(kwargs[self.id + "_ptr"], kwargs["problem_size"][1])
|
||||
|
||||
class ColumnBroadcastNode(NameNode):
|
||||
# Concept: VisitorOpColumnBroadcast
|
||||
def __init__(self, element_accumulator, element_fragment, node) -> None:
|
||||
super().__init__(node)
|
||||
self.tag = "ColumnBroadcast:" + self.tag
|
||||
self.type = "tensor"
|
||||
self.element_accumulator = element_accumulator
|
||||
self.element_fragment = element_fragment
|
||||
|
||||
def get_epilogue_node(self, *args):
|
||||
self.epilogue_node = ColumnBroadcastOp(
|
||||
self.element_accumulator, self.element_fragment)
|
||||
|
||||
def get_argument(self, visitor_args, kwargs):
|
||||
self.argument = self.epilogue_node.argument_type(kwargs[self.id + "_ptr"], kwargs["problem_size"][0])
|
||||
|
||||
class TensorOutputNode(NameNode):
|
||||
# Concept: VisitorOpTensorOutput
|
||||
def __init__(self, element_accumulator, node) -> None:
|
||||
super().__init__(node)
|
||||
self.tag = "TensorOutput:" + self.tag
|
||||
self.type = "tensor"
|
||||
self.element_accumulator = element_accumulator
|
||||
|
||||
def get_epilogue_node(self, visitors):
|
||||
self.epilogue_node = TensorOutputOp(self.element_accumulator, *visitors)
|
||||
|
||||
def get_argument(self, visitor_args, kwargs):
|
||||
self.argument = self.epilogue_node.argument_type(kwargs[self.id + "_ptr"], kwargs["problem_size"][1], *visitor_args, kwargs["problem_size"][0] * kwargs["problem_size"][1])
|
||||
|
||||
class RowReductionNode:
|
||||
# Concept: RowReductionOp
|
||||
def __init__(self, element_accumulator, element_reduction,
|
||||
element_reduction_accumulator, id, factor) -> None:
|
||||
#
|
||||
self.id = id
|
||||
self.tag = "RowReduction:" + self.id
|
||||
self.type = "tensor"
|
||||
self.element_accumulator = element_accumulator
|
||||
self.element_reduction = element_reduction
|
||||
self.element_reduction_accumulator = element_reduction_accumulator
|
||||
self.factor = factor
|
||||
|
||||
def get_epilogue_node(self, visitors):
|
||||
self.epilogue_node = RowReductionOp(
|
||||
self.element_accumulator, self.element_reduction,
|
||||
self.element_reduction_accumulator, *visitors)
|
||||
|
||||
def get_batch_stride(self, problem_size):
|
||||
return problem_size[0] * ((problem_size[1] + self.factor - 1) // self.factor)
|
||||
|
||||
def get_argument(self, visitor_args, kwargs):
|
||||
self.argument = self.epilogue_node.argument_type(kwargs[self.id + "_ptr"], *visitor_args, self.get_batch_stride(kwargs["problem_size"]))
|
||||
|
||||
class ColumnReductionNode:
|
||||
# Concept: ColumnReductionOp
|
||||
def __init__(self, element_accumulator, element_reduction,
|
||||
element_reduction_accumulator, id, factor) -> None:
|
||||
#
|
||||
self.id = id
|
||||
self.tag = "ColumnReduction:" + self.id
|
||||
self.type = "tensor"
|
||||
self.element_accumulator = element_accumulator
|
||||
self.element_reduction = element_reduction
|
||||
self.element_reduction_accumulator = element_reduction_accumulator
|
||||
self.factor = factor
|
||||
|
||||
def get_epilogue_node(self, visitors):
|
||||
self.epilogue_node = ColumnReductionOp(
|
||||
self.element_accumulator, self.element_reduction,
|
||||
self.element_reduction_accumulator, *visitors)
|
||||
|
||||
def get_batch_stride(self, problem_size):
|
||||
return problem_size[1] * ((problem_size[0] + self.factor - 1) // self.factor)
|
||||
|
||||
def get_argument(self, visitor_args, kwargs):
|
||||
self.argument = self.epilogue_node.argument_type(kwargs[self.id + '_ptr'], *visitor_args, self.get_batch_stride(kwargs["problem_size"]))
|
||||
|
||||
################################################################################
|
||||
# Epilogue parser function
|
||||
################################################################################
|
||||
class EpilogueAST(ast.NodeVisitor):
|
||||
def __init__(self, epilogue,
|
||||
tile_description,
|
||||
element_accumulator, elements_per_access,
|
||||
element_compute, element_output) -> None:
|
||||
#
|
||||
|
||||
self.tile_description = tile_description
|
||||
self.element_accumulator = element_accumulator
|
||||
self.elements_per_access = elements_per_access
|
||||
self.element_compute = element_compute
|
||||
self.element_output = element_output
|
||||
self.epilogue = epilogue
|
||||
|
||||
self.source = textwrap.dedent(inspect.getsource(epilogue.__call__))
|
||||
self.ast_tree = ast.parse(self.source)
|
||||
self.epilogue_tree = Tree()
|
||||
|
||||
|
||||
# print(ast.dump(self.ast_tree, indent=4)) # For Debug purpose
|
||||
|
||||
# input arguments
|
||||
self.input_args = {}
|
||||
# return nodes
|
||||
self.returns = []
|
||||
# reduction source nodes
|
||||
self.reduction_source = {}
|
||||
|
||||
# stack used to keep the parent node id
|
||||
self.stack = []
|
||||
|
||||
# visit the AST
|
||||
self.visit(self.ast_tree)
|
||||
|
||||
# visit the name node
|
||||
def visit_Name(self, node):
|
||||
# append the return ids into self.returns
|
||||
if self.stack[-1] == "return":
|
||||
self.returns.append(node.id)
|
||||
else:
|
||||
# accum is produced from accumulator node
|
||||
if node.id == "accum":
|
||||
name_node = AccumulatorNode(
|
||||
self.element_accumulator, self.elements_per_access, node)
|
||||
else:
|
||||
# for input nodes
|
||||
if node.id in self.input_args.keys():
|
||||
type = self.input_args[node.id][0]
|
||||
if type == "tensor":
|
||||
name_node = TensorInputNode(self.element_accumulator, node)
|
||||
elif type == "row":
|
||||
name_node = RowBroadcastNode(self.element_accumulator, self.element_compute, node)
|
||||
elif type == "column":
|
||||
name_node = ColumnBroadcastNode(self.element_accumulator, self.element_compute, node)
|
||||
elif type == "scalar":
|
||||
name_node = ScalarInputNode(node)
|
||||
else:
|
||||
raise ValueError(type)
|
||||
# for output nodes
|
||||
else:
|
||||
name_node = TensorOutputNode(self.element_accumulator, node)
|
||||
self.epilogue_tree.create_node(name_node.tag, name_node.id, data=name_node, parent=self.stack[-1])
|
||||
|
||||
def visit_Assign(self, node):
|
||||
pre_assign_node = self.epilogue_tree.get_node(node.targets[0].id)
|
||||
if pre_assign_node is None:
|
||||
# The assign is to a root node
|
||||
# skip the reduction nodes
|
||||
if isinstance(node.value, ast.Call):
|
||||
if isinstance(node.value.func, ast.Name):
|
||||
func_type = node.value.func.id
|
||||
elif isinstance(node.value.func, ast.Attribute):
|
||||
func_type = node.value.func.value.id
|
||||
else:
|
||||
raise TypeError
|
||||
if func_type == 'reduction_op':
|
||||
self.reduction_source[node.value.args[0].id] = [node.value.args[1].value, node.value.args[2].value, node.targets[0].id]
|
||||
return
|
||||
name_node = TensorOutputNode(self.element_accumulator, node)
|
||||
self.epilogue_tree.create_node(name_node.tag, name_node.id, data=name_node)
|
||||
self.stack.append(name_node.id)
|
||||
else:
|
||||
if node.targets[0].id in self.returns or node.targets[0].id in self.reduction_source.keys():
|
||||
self.stack.append(node.targets[0].id)
|
||||
else:
|
||||
self.stack.append(pre_assign_node.predecessor(self.epilogue_tree.identifier))
|
||||
self.epilogue_tree.remove_node(node.targets[0].id)
|
||||
|
||||
# get child tag
|
||||
self.visit(node.value)
|
||||
self.stack.pop()
|
||||
|
||||
def visit_Call(self, node):
|
||||
if isinstance(node.func, ast.Name):
|
||||
func_type = node.func.id
|
||||
elif isinstance(node.func, ast.Attribute):
|
||||
func_type = node.func.value.id
|
||||
else:
|
||||
raise TypeError
|
||||
if func_type == "reduction_op":
|
||||
self.visit(node.args[0])
|
||||
else:
|
||||
arg_list = []
|
||||
for idx, arg in enumerate(node.args):
|
||||
if idx == 0: continue
|
||||
if isinstance(arg, ast.Constant):
|
||||
arg_list.append(arg.value)
|
||||
elif isinstance(arg, ast.Name):
|
||||
arg_list.append(arg.id)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
unary_node = UnaryNode(self.element_accumulator, self.element_compute, self.elements_per_access, node, arg_list)
|
||||
self.epilogue_tree.create_node(unary_node.tag, unary_node.id, parent=self.stack[-1], data=unary_node)
|
||||
self.stack.append(unary_node.id)
|
||||
self.visit(node.args[0])
|
||||
self.stack.pop()
|
||||
|
||||
def visit_BinOp(self, node):
|
||||
binop = BinOpNode(self.element_accumulator, self.element_compute,
|
||||
self.elements_per_access, node)
|
||||
self.epilogue_tree.create_node(binop.tag, binop.id, data=binop, parent=self.stack[-1])
|
||||
self.stack.append(binop.id)
|
||||
self.visit(node.left)
|
||||
self.visit(node.right)
|
||||
self.stack.pop()
|
||||
|
||||
def visit_Return(self, node):
|
||||
self.stack.append("return")
|
||||
self.visit(node.value)
|
||||
self.stack.pop()
|
||||
|
||||
# # A function definition
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef):
|
||||
# visit args
|
||||
for arg in node.args.args:
|
||||
if arg.arg == "self": continue
|
||||
if isinstance(arg.annotation, ast.Constant):
|
||||
self.input_args[arg.arg] = [arg.annotation.value, ]
|
||||
# visit the assign in the reverse order
|
||||
for idx in range(len(node.body)):
|
||||
self.visit(node.body[-1-idx])
|
||||
|
||||
#
|
||||
# Tree optimization pass
|
||||
#
|
||||
|
||||
# pass 1: lower Binary to Unary
|
||||
def pass_binary_2_unary(self, tree, nid):
|
||||
node = tree.get_node(nid)
|
||||
if isinstance(node.data, BinOpNode):
|
||||
lhs_node = tree.get_node(node.successors(tree.identifier)[0])
|
||||
left_type = lhs_node.data.type
|
||||
rhs_node = tree.get_node(node.successors(tree.identifier)[1])
|
||||
right_type = rhs_node.data.type
|
||||
|
||||
if left_type == "scalar" and right_type == "tensor":
|
||||
node.data = UnaryNode(
|
||||
self.element_accumulator, self.element_compute,
|
||||
self.elements_per_access,
|
||||
node.data, [lhs_node.data.id,])
|
||||
node.tag = node.data.tag
|
||||
tree.remove_node(lhs_node.data.id)
|
||||
self.pass_binary_2_unary(tree, rhs_node.data.id)
|
||||
|
||||
elif left_type == "tensor" and right_type == "scalar":
|
||||
node.data = UnaryNode(
|
||||
self.element_accumulator, self.element_compute,
|
||||
self.elements_per_access,
|
||||
node.data, [rhs_node.id,])
|
||||
node.tag = node.data.tag
|
||||
tree.remove_node(rhs_node.data.id)
|
||||
self.pass_binary_2_unary(tree, lhs_node.data.id)
|
||||
|
||||
else:
|
||||
self.pass_binary_2_unary(tree, lhs_node.data.id)
|
||||
self.pass_binary_2_unary(tree, rhs_node.data.id)
|
||||
else:
|
||||
for child in node.successors(tree.identifier):
|
||||
self.pass_binary_2_unary(tree, child)
|
||||
|
||||
# pass 2: inject reduction nodes
|
||||
def pass_inject_reduction(self, tree, nid):
|
||||
node = tree.get_node(nid)
|
||||
if isinstance(node.data, TensorOutputNode):
|
||||
if node.data.id in self.reduction_source.keys():
|
||||
direction = self.reduction_source[node.data.id][0]
|
||||
target = self.reduction_source[node.data.id][-1]
|
||||
if direction == 'row':
|
||||
reduction_node = RowReductionNode(
|
||||
self.element_accumulator, self.element_output,
|
||||
self.element_accumulator, target, self.tile_description.threadblock_shape[1])
|
||||
elif direction == "column":
|
||||
reduction_node = ColumnReductionNode(
|
||||
self.element_accumulator, self.element_output,
|
||||
self.element_accumulator, target, self.tile_description.threadblock_shape[0])
|
||||
else:
|
||||
raise ValueError(direction)
|
||||
child_nid = node.successors(tree.identifier)[0]
|
||||
# if this output node is injected only for reduction
|
||||
if node.data.id not in self.returns:
|
||||
# get reduction config from disc
|
||||
node.data = reduction_node
|
||||
node.tag = reduction_node.tag
|
||||
self.pass_inject_reduction(tree, child_nid)
|
||||
# if this output node is also a tensor output, inject reduction as its children
|
||||
else:
|
||||
# get child node
|
||||
tree.create_node(reduction_node.tag, reduction_node.id, data=reduction_node, parent=node.data.id)
|
||||
tree.move_node(child_nid, reduction_node.id)
|
||||
child = tree.get_node(child_nid)
|
||||
for grand_child in child.successors(tree.identifier):
|
||||
self.pass_inject_reduction(tree, grand_child)
|
||||
else:
|
||||
for child in node.successors(tree.identifier):
|
||||
self.pass_inject_reduction(tree, child)
|
||||
else:
|
||||
for child in node.successors(tree.identifier):
|
||||
self.pass_inject_reduction(tree, child)
|
||||
|
||||
def pass_inject_epilogue_op(self, tree, nid):
|
||||
node = tree.get_node(nid)
|
||||
visitors = []
|
||||
for child in node.successors(tree.identifier):
|
||||
visitors.append(self.pass_inject_epilogue_op(tree, child))
|
||||
|
||||
node.data.get_epilogue_node(visitors)
|
||||
return node.data.epilogue_node
|
||||
|
||||
def get_arguments(self, tree, nid, kwargs):
|
||||
node = tree.get_node(nid)
|
||||
visitor_args = []
|
||||
for child in node.successors(tree.identifier):
|
||||
visitor_args.append(self.get_arguments(tree, child, kwargs))
|
||||
|
||||
node.data.get_argument(visitor_args, kwargs)
|
||||
return node.data.argument
|
||||
|
||||
class EpilogueVisitTree:
|
||||
KernelTemplate = """
|
||||
${visitor}
|
||||
|
||||
using ${operation_name}_EpilogueVisitor = cutlass::epilogue::threadblock::EpilogueVisitorGeneric<${visitor_name}>;
|
||||
"""
|
||||
def __init__(self, elementwise_functor, tile_description,
|
||||
element_accumulator, elements_per_access,
|
||||
element_compute, element_output) -> None:
|
||||
#
|
||||
# data types
|
||||
self.tile_description = tile_description
|
||||
self.element_accumulator = element_accumulator
|
||||
self.elements_per_access = elements_per_access
|
||||
self.element_compute = element_compute
|
||||
self.element_output = element_output
|
||||
# TODO: deprecate this
|
||||
self.elementwise_functor = elementwise_functor
|
||||
pass
|
||||
|
||||
def initialize(self):
|
||||
function = EpilogueAST(self, self.tile_description,
|
||||
self.element_accumulator, self.elements_per_access,
|
||||
self.element_compute, self.element_output)
|
||||
#
|
||||
tree = function.epilogue_tree
|
||||
self.tree = tree
|
||||
# self.tree.show() # for debug
|
||||
function.pass_binary_2_unary(self.tree, self.tree.root)
|
||||
# self.tree.show() # for debug
|
||||
function.pass_inject_reduction(self.tree, self.tree.root)
|
||||
# self.tree.show() # for debug
|
||||
function.pass_inject_epilogue_op(self.tree,self.tree.root)
|
||||
|
||||
visitor = self.tree.get_node(self.tree.root).data.epilogue_node
|
||||
self.visitor = visitor
|
||||
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("visitor_arg", visitor.argument_type)
|
||||
]
|
||||
def __init__(self, **kwargs) -> None:
|
||||
# process input args
|
||||
_kwargs = {}
|
||||
for input_key in function.input_args.keys():
|
||||
if input_key == "accum":
|
||||
continue
|
||||
if function.input_args[input_key][0] == "scalar":
|
||||
# _kwargs[input_key] = kwargs[input_key]
|
||||
continue
|
||||
# tensor input
|
||||
else:
|
||||
setattr(self, "buffer_tensor_" + input_key, NumpyFrontend.argument(kwargs[input_key], False))
|
||||
setattr(self, input_key + "_ptr", int(getattr(self, "buffer_tensor_" + input_key).ptr))
|
||||
_kwargs[input_key+"_ptr"] = getattr(self, input_key + "_ptr")
|
||||
# process the return args
|
||||
for ret in function.returns:
|
||||
setattr(self, "buffer_tensor_" + ret, NumpyFrontend.argument(kwargs[ret], True))
|
||||
setattr(self, ret + "_ptr", int(getattr(self, "buffer_tensor_" + ret).ptr))
|
||||
_kwargs[ret+"_ptr"] = getattr(self, ret + "_ptr")
|
||||
setattr(self, "host_tensor_" + ret, kwargs[ret])
|
||||
|
||||
_kwargs.update(kwargs)
|
||||
function.get_arguments(tree, tree.root, _kwargs)
|
||||
self.visitor_arg = tree.get_node(tree.root).data.argument
|
||||
|
||||
def sync(self, stream_sync=True):
|
||||
if stream_sync:
|
||||
err, = cudart.cudaDeviceSynchronize()
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
|
||||
for ret in function.returns:
|
||||
err, = cuda.cuMemcpyDtoH(
|
||||
getattr(self, "host_tensor_" + ret), cuda.CUdeviceptr(getattr(self, ret + "_ptr")),
|
||||
getattr(self, "host_tensor_" + ret).size * getattr(self, "host_tensor_" + ret).itemsize
|
||||
)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
pass
|
||||
|
||||
self.epilogue_type = _Argument
|
||||
|
||||
def emit(self, operation):
|
||||
values = {
|
||||
'visitor': self.visitor.emit(operation),
|
||||
'operation_name': operation.procedural_name(),
|
||||
'visitor_name': self.visitor.instance_name
|
||||
}
|
||||
return SubstituteTemplate(self.KernelTemplate, values)
|
||||
@ -58,6 +58,13 @@ class ReductionArguments:
|
||||
destination: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]',
|
||||
source: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]', **kwargs) -> None:
|
||||
|
||||
# tensor_C can be interpreted as the bias with bias=True in keyword args
|
||||
if "bias" in kwargs.keys():
|
||||
self.bias = kwargs["bias"]
|
||||
else:
|
||||
# by default, tensor_C is not bias
|
||||
self.bias = False
|
||||
|
||||
self.operation = operation
|
||||
#: pointer to the workspace
|
||||
self.ptr_workspace = workspace
|
||||
@ -89,11 +96,9 @@ class ReductionArguments:
|
||||
problem_size[1] * DataTypeSize[operation.C.element] // 8
|
||||
|
||||
if "output_op" in kwargs.keys():
|
||||
self.alpha = kwargs["output_op"].alpha
|
||||
self.beta = kwargs["output_op"].beta
|
||||
self.output_op = kwargs['output_op']
|
||||
else:
|
||||
self.alpha = 1.0
|
||||
self.beta = 0.0
|
||||
self.output_op = self.operation.epilogue_type(1.0, 0.0)
|
||||
|
||||
# get arguments
|
||||
self.get_arguments()
|
||||
@ -109,31 +114,25 @@ class ReductionArguments:
|
||||
ref_workspace = ReductionArguments.get_tensor_ref(
|
||||
extent=[self.problem_size.row, self.problem_size.column],
|
||||
device_ptr=self.ptr_workspace, layout=cutlass.RowMajor)
|
||||
|
||||
ref_source = ReductionArguments.get_tensor_ref(
|
||||
extent=[self.problem_size.row, self.problem_size.column],
|
||||
device_ptr=self.ptr_source, layout=cutlass.RowMajor)
|
||||
if self.bias:
|
||||
ref_source = ReductionArguments.get_tensor_ref(
|
||||
extent=[0, 0],
|
||||
device_ptr=self.ptr_source, layout=cutlass.RowMajor)
|
||||
else:
|
||||
ref_source = ReductionArguments.get_tensor_ref(
|
||||
extent=[self.problem_size.row, self.problem_size.column],
|
||||
device_ptr=self.ptr_source, layout=cutlass.RowMajor)
|
||||
|
||||
ref_destination = ReductionArguments.get_tensor_ref(
|
||||
extent=[self.problem_size.row, self.problem_size.column],
|
||||
device_ptr=self.ptr_destination, layout=cutlass.RowMajor)
|
||||
|
||||
argument_type, epilogue_type = get_reduction_params(
|
||||
self.operation.element_compute)
|
||||
|
||||
if self.operation.element_compute == cutlass.float16:
|
||||
self.alpha = cutlass.float16(self.alpha).storage
|
||||
self.beta = cutlass.float16(self.beta).storage
|
||||
elif self.operation.element_compute == cutlass.int32:
|
||||
self.alpha = int(self.alpha)
|
||||
self.beta = int(self.beta)
|
||||
|
||||
output_op = epilogue_type(self.alpha, self.beta, 0, 0)
|
||||
self.c_arguments = argument_type(
|
||||
self.c_arguments = self.operation.argument_type(
|
||||
self.problem_size, self.partitions,
|
||||
self.partition_stride, ref_workspace,
|
||||
ref_destination, ref_source,
|
||||
output_op
|
||||
self.output_op
|
||||
)
|
||||
|
||||
params_ = self.operation.rt_module.get_args(
|
||||
@ -210,8 +209,8 @@ extern "C" {
|
||||
self.emitter = EmitReductionInstance('_type')
|
||||
|
||||
self.elements_per_access = self.operation.count
|
||||
self.argtype = [ctypes.POINTER(
|
||||
get_reduction_params(operation.element_compute)[0])]
|
||||
self.argument_type, self.epilogue_type = get_reduction_params(operation.epilogue_functor)
|
||||
self.argtype = [ctypes.POINTER(self.argument_type)]
|
||||
|
||||
def emit(self):
|
||||
return self.emitter.emit(self.operation)
|
||||
@ -247,14 +246,14 @@ class ReductionOperation:
|
||||
|
||||
def __init__(self, shape: cutlass.MatrixCoord, C: TensorDescription,
|
||||
element_accumulator, element_workspace=None,
|
||||
element_compute=None, epilogue_functor: EpilogueFunctor = EpilogueFunctor.LinearCombination,
|
||||
element_compute=None, epilogue_functor=None,
|
||||
count: int = 1, partitions_per_stage: int = 4) -> None:
|
||||
""" Constructor
|
||||
"""
|
||||
|
||||
self.shape = shape
|
||||
#: epilogue functor (default: LinearCombination)
|
||||
self.epilogue_functor: EpilogueFunctor = epilogue_functor
|
||||
self.epilogue_functor = epilogue_functor
|
||||
#: datatype of accumulator
|
||||
self.element_accumulator = element_accumulator
|
||||
|
||||
@ -285,6 +284,8 @@ class ReductionOperation:
|
||||
self.partitions_per_stage: int = partitions_per_stage
|
||||
|
||||
self.rt_module: ReductionRT = ReductionRT(self)
|
||||
self.argument_type = self.rt_module.argument_type
|
||||
self.epilogue_type = self.rt_module.epilogue_type
|
||||
|
||||
#
|
||||
def extended_name(self):
|
||||
@ -363,12 +364,7 @@ class EmitReductionInstance:
|
||||
using ${operation_name}_base =
|
||||
typename cutlass::reduction::kernel::ReduceSplitK<
|
||||
cutlass::MatrixShape<${shape_row}, ${shape_column}>,
|
||||
${epilogue_functor}<
|
||||
${element_output},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_compute}
|
||||
>,
|
||||
${epilogue_functor},
|
||||
cutlass::reduction::thread::ReduceAdd<
|
||||
${element_accumulator},
|
||||
${element_output},
|
||||
@ -389,7 +385,7 @@ struct ${operation_name}${operation_suffix}:
|
||||
'operation_suffix': self.operation_suffix,
|
||||
'shape_row': str(operation.shape.row()),
|
||||
'shape_column': str(operation.shape.column()),
|
||||
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
||||
'epilogue_functor': operation.epilogue_functor.emit(),
|
||||
'element_output': DataTypeTag[operation.element_output],
|
||||
'epilogue_vector_length': str(epilogue_vector_length),
|
||||
'element_accumulator': DataTypeTag[operation.element_accumulator],
|
||||
|
||||
@ -68,4 +68,3 @@ class TensorRef:
|
||||
# the dtype(0) is used to overload between different data types
|
||||
# with the same layout
|
||||
self.tensor_ref = cutlass.get_tensor_ref(int(ptr), dtype(0), layout)
|
||||
|
||||
|
||||
@ -124,7 +124,7 @@ class Conv2dLauncher:
|
||||
self.reduction_operation = ReductionOperation(
|
||||
shape=cutlass.MatrixCoord(4, 32 * operation.C.alignment),
|
||||
C=operation.C, element_accumulator=operation.tile_description.math_instruction.element_accumulator,
|
||||
element_compute=operation.element_epilogue,
|
||||
element_compute=operation.epilogue_functor.element_epilogue, epilogue_functor=operation.epilogue_functor,
|
||||
count=operation.C.alignment
|
||||
)
|
||||
|
||||
@ -183,7 +183,7 @@ class Conv2dLauncher:
|
||||
# Get the host reference function
|
||||
#
|
||||
|
||||
self.element_compute = operation.element_epilogue
|
||||
self.element_compute = operation.epilogue_functor.element_epilogue
|
||||
|
||||
self.host_conv2d = cutlass.test.conv.host.conv2d
|
||||
|
||||
@ -441,7 +441,7 @@ class Conv2dLauncher:
|
||||
arguments = Conv2dArguments(
|
||||
operation=self.operation, problem_size=problem_size, A=tensor_A,
|
||||
B=tensor_B, C=tensor_C, D=tensor_D,
|
||||
output_op = LinearCombinationFunctorArguments(alpha, beta),
|
||||
output_op = self.operation.epilogue_type(alpha, beta),
|
||||
split_k_slices=problem_size.split_k_slices,
|
||||
split_k_mode=split_k_mode
|
||||
)
|
||||
@ -454,7 +454,7 @@ class Conv2dLauncher:
|
||||
workspace=arguments.ptr_D,
|
||||
destination=tensor_D,
|
||||
source=tensor_C,
|
||||
output_op = LinearCombinationFunctorArguments(alpha, beta)
|
||||
output_op = self.reduction_operation.epilogue_type(alpha, beta)
|
||||
)
|
||||
|
||||
self.operation.run(arguments)
|
||||
|
||||
@ -68,7 +68,7 @@ class TestbedGrouped:
|
||||
self.scope_min = -8
|
||||
|
||||
#: compute type
|
||||
self.compute_type = operation.element_epilogue
|
||||
self.compute_type = operation.epilogue_functor.element_epilogue
|
||||
|
||||
self.accumulator_type = operation.tile_description.math_instruction.element_accumulator
|
||||
|
||||
@ -176,7 +176,7 @@ class TestbedGrouped:
|
||||
arguments = GemmGroupedArguments(
|
||||
operation=self.operation, problem_sizes=problem_sizes,
|
||||
A=tensor_As, B=tensor_Bs, C=tensor_Cs, D=tensor_Ds,
|
||||
output_op=LinearCombinationFunctorArguments(alpha, beta)
|
||||
output_op=self.operation.epilogue_type(alpha, beta)
|
||||
)
|
||||
|
||||
self.operation.run(arguments)
|
||||
|
||||
@ -143,7 +143,7 @@ class GemmUniversalLauncher:
|
||||
self.reduction_operation: ReductionOperation = ReductionOperation(
|
||||
shape=cutlass.MatrixCoord(4, 32 * operation.C.alignment),
|
||||
C=operation.C, element_accumulator=operation.tile_description.math_instruction.element_accumulator,
|
||||
element_compute=operation.element_epilogue,
|
||||
element_compute=operation.epilogue_functor.element_epilogue, epilogue_functor=operation.epilogue_functor,
|
||||
count=operation.C.alignment
|
||||
)
|
||||
|
||||
@ -200,7 +200,7 @@ class GemmUniversalLauncher:
|
||||
self.interleaved = interleaved
|
||||
|
||||
#: compute type
|
||||
self.compute_type = operation.element_epilogue
|
||||
self.compute_type = operation.epilogue_functor.element_epilogue
|
||||
self.accumulator_type = operation.tile_description.math_instruction.element_accumulator
|
||||
|
||||
def print_problem_size(self, p, mode, batch_count):
|
||||
@ -391,7 +391,7 @@ class GemmUniversalLauncher:
|
||||
arguments = GemmArguments(
|
||||
operation=self.operation, problem_size=problem_size,
|
||||
A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D,
|
||||
output_op=LinearCombinationFunctorArguments(alpha, beta),
|
||||
output_op=self.operation.epilogue_type(alpha, beta),
|
||||
gemm_mode=mode, split_k_slices=batch_count
|
||||
)
|
||||
|
||||
@ -403,7 +403,7 @@ class GemmUniversalLauncher:
|
||||
workspace=arguments.ptr_D,
|
||||
destination=tensor_D,
|
||||
source=tensor_C,
|
||||
output_op=LinearCombinationFunctorArguments(alpha, beta)
|
||||
output_op=self.reduction_operation.epilogue_type(alpha, beta)
|
||||
)
|
||||
|
||||
self.operation.run(arguments)
|
||||
|
||||
@ -34,6 +34,7 @@ import numpy as np
|
||||
import cutlass
|
||||
from pycutlass.library import TensorDescription
|
||||
from typing import Union
|
||||
from bfloat16 import bfloat16
|
||||
try:
|
||||
import torch
|
||||
torch_available = True
|
||||
@ -46,7 +47,7 @@ class ReferenceModule:
|
||||
self.layout_B = B.layout
|
||||
self.layout_C = C.layout
|
||||
|
||||
def run(self, A: np.ndarray, B: np.ndarray, C: np.ndarray, problem_size: cutlass.gemm.GemmCoord, alpha: float=1.0, beta: float=0.0):
|
||||
def run(self, A: np.ndarray, B: np.ndarray, C: np.ndarray, problem_size: cutlass.gemm.GemmCoord, alpha: float=1.0, beta: float=0.0, bias=False, batch=1):
|
||||
"""
|
||||
Compute the reference result on CPU
|
||||
Args:
|
||||
@ -57,27 +58,38 @@ class ReferenceModule:
|
||||
M, N, K = problem_size.m(), problem_size.n(), problem_size.k()
|
||||
if isinstance(A, np.ndarray):
|
||||
if self.layout_A == cutlass.RowMajor:
|
||||
A_row = np.reshape(A, newshape=(M, K))
|
||||
A_row = np.reshape(A, newshape=(batch, M, K))
|
||||
else:
|
||||
A_col = np.reshape(A, newshape=(K, M))
|
||||
A_row = np.transpose(A_col, axes=(1, 0))
|
||||
A_col = np.reshape(A, newshape=(batch, K, M))
|
||||
A_row = np.transpose(A_col, axes=(0, 2, 1))
|
||||
|
||||
if self.layout_B == cutlass.RowMajor:
|
||||
B_row = np.reshape(B, newshape=(K, N))
|
||||
B_row = np.reshape(B, newshape=(batch, K, N))
|
||||
else:
|
||||
B_col = np.reshape(B, newshape=(N, K))
|
||||
B_row = np.transpose(B_col, axes=(1, 0))
|
||||
B_col = np.reshape(B, newshape=(batch, N, K))
|
||||
B_row = np.transpose(B_col, axes=(0, 2, 1))
|
||||
|
||||
if self.layout_C == cutlass.RowMajor:
|
||||
C_row = np.reshape(C, newshape=(M, N))
|
||||
if bias:
|
||||
C_row = np.reshape(C, newshape=(batch, 1, N))
|
||||
else:
|
||||
C_row = np.reshape(C, newshape=(batch, M, N))
|
||||
else:
|
||||
C_col = np.reshape(C, newshape=(N, M))
|
||||
C_row = np.transpose(C_col, axes=(1, 0))
|
||||
if bias:
|
||||
C_row = np.reshape(C, newshape=(batch, M, 1))
|
||||
else:
|
||||
C_col = np.reshape(C, newshape=(batch, N, M))
|
||||
C_row = np.transpose(C_col, axes=(0, 2, 1))
|
||||
|
||||
out_row = np.matmul(A_row, B_row) * alpha + C_row * beta
|
||||
if A_row.dtype == bfloat16:
|
||||
# numpy's einsum doesn't support bfloat16
|
||||
out_row = np.einsum("bik,bkj->bij", A_row.astype(np.float32), B_row.astype(np.float32)) * alpha + C_row * beta
|
||||
out_row = out_row.astype(C_row.dtype)
|
||||
else:
|
||||
out_row = np.einsum("bik,bkj->bij", A_row, B_row) * alpha + C_row * beta
|
||||
|
||||
if self.layout_C == cutlass.ColumnMajor:
|
||||
out = np.transpose(out_row, axes=(1, 0))
|
||||
out = np.transpose(out_row, axes=(0, 2, 1))
|
||||
else:
|
||||
out = out_row
|
||||
|
||||
@ -128,7 +140,7 @@ if torch_available:
|
||||
def run(self,
|
||||
A: Union[np.ndarray, torch.Tensor],
|
||||
B: Union[np.ndarray, torch.Tensor],
|
||||
C: Union[np.ndarray, torch.Tensor], problem_size, alpha=1.0, beta=0.0) -> np.ndarray:
|
||||
C: Union[np.ndarray, torch.Tensor], problem_size, alpha=1.0, beta=0.0, bias=False) -> np.ndarray:
|
||||
"""
|
||||
Compute the reference result on CPU
|
||||
"""
|
||||
@ -184,7 +196,10 @@ if torch_available:
|
||||
B_torch_nchw = torch.permute(B_nhwc, (0, 3, 1, 2))
|
||||
|
||||
if self.layout_C == cutlass.TensorNHWC:
|
||||
C_nhwc = C.view((k, r, s, c))
|
||||
if bias:
|
||||
C_nhwc = C.view((1, 1, 1, c))
|
||||
else:
|
||||
C_nhwc = C.view((k, r, s, c))
|
||||
C_torch_nchw = torch.permute(C_nhwc, (0, 3, 1, 2))
|
||||
elif self.kind == cutlass.conv.Operator.dgrad:
|
||||
if self.layout_A == cutlass.TensorNHWC:
|
||||
@ -196,7 +211,10 @@ if torch_available:
|
||||
B_torch_nchw = torch.permute(B_nhwc, (0, 3, 1, 2))
|
||||
|
||||
if self.layout_C == cutlass.TensorNHWC:
|
||||
C_nhwc = C.view((n, h, w, c))
|
||||
if bias:
|
||||
C_nhwc = C.view((1, 1, 1, c))
|
||||
else:
|
||||
C_nhwc = C.view((n, h, w, c))
|
||||
C_torch_nchw = torch.permute(C_nhwc, (0, 3, 1, 2))
|
||||
else:
|
||||
if self.layout_A == cutlass.TensorNHWC:
|
||||
@ -208,7 +226,10 @@ if torch_available:
|
||||
B_torch_nchw = torch.permute(B_nhwc, (0, 3, 1, 2))
|
||||
|
||||
if self.layout_C == cutlass.TensorNHWC:
|
||||
C_nhwc = C.view((n, p, q, k))
|
||||
if bias:
|
||||
C_nhwc = C.view((1, 1, 1, k))
|
||||
else:
|
||||
C_nhwc = C.view((n, p, q, k))
|
||||
C_torch_nchw = torch.permute(C_nhwc, (0, 3, 1, 2))
|
||||
|
||||
if self.kind == cutlass.conv.Operator.fprop:
|
||||
|
||||
@ -30,15 +30,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float16)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float16, stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -68,15 +71,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float16)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float16, stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -106,15 +112,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float16)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float16, stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -156,15 +165,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float16)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float16, stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -29,15 +29,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 32], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -67,15 +70,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 32], stages=4,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -105,15 +111,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -143,15 +152,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=4,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -30,15 +30,18 @@ class Conv2dDgradImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase)
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 8], stages=4,
|
||||
warp_count=[4, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -68,15 +71,18 @@ class Conv2dDgradImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase)
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 8], stages=4,
|
||||
warp_count=[2, 4, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -29,15 +29,18 @@ class Conv2dDgradImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.Te
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 16], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -67,15 +70,18 @@ class Conv2dDgradImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.Te
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 16], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -97,15 +97,18 @@ class Conv2dFpropFewChannelsF16NHWCF16NHWCF16HNWCTensorOpF32SM80(unittest.TestCa
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.few_channels,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -135,15 +138,18 @@ class Conv2dFpropFewChannelsF16NHWCF16NHWCF16HNWCTensorOpF32SM80(unittest.TestCa
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 32], stages=2,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.few_channels,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -79,15 +79,18 @@ class Conv2dFpropFixedChannelsF16NHWCF16NHWCF16HNWCTensorOpF32SM80(unittest.Test
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.fixed_channels,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -117,15 +120,18 @@ class Conv2dFpropFixedChannelsF16NHWCF16NHWCF16HNWCTensorOpF32SM80(unittest.Test
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.fixed_channels,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -155,15 +161,18 @@ class Conv2dFpropFixedChannelsF16NHWCF16NHWCF16HNWCTensorOpF32SM80(unittest.Test
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.fixed_channels,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -29,15 +29,18 @@ class Conv2dFpropImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float16)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -67,15 +70,18 @@ class Conv2dFpropImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float16)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -105,15 +111,18 @@ class Conv2dFpropImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float16)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -173,15 +182,18 @@ class Conv2dFpropImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float16)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -241,15 +253,18 @@ class Conv2dFpropImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float16)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -29,15 +29,18 @@ class Conv2dFpropImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -30,15 +30,18 @@ class Conv2dFpropImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase)
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 8], stages=4,
|
||||
warp_count=[4, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle2
|
||||
)
|
||||
|
||||
@ -68,15 +71,18 @@ class Conv2dFpropImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase)
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 8], stages=4,
|
||||
warp_count=[2, 4, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -29,15 +29,18 @@ class Conv2dFpropImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.Te
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 16], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -67,15 +70,18 @@ class Conv2dFpropImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.Te
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 16], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -29,15 +29,18 @@ class Conv2dStridedDgradImplicitGemmF16NHWCF16NHWCF32NHWCTensorOpF32SM80(unittes
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 32], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.StridedDgradIdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -67,15 +70,18 @@ class Conv2dStridedDgradImplicitGemmF16NHWCF16NHWCF32NHWCTensorOpF32SM80(unittes
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 256, 64], stages=3,
|
||||
warp_count=[2, 4, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.StridedDgradIdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -105,15 +111,18 @@ class Conv2dStridedDgradImplicitGemmF16NHWCF16NHWCF32NHWCTensorOpF32SM80(unittes
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 32], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.StridedDgradIdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -155,15 +164,18 @@ class Conv2dStridedDgradImplicitGemmF16NHWCF16NHWCF32NHWCTensorOpF32SM80(unittes
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 32], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.StridedDgradIdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -193,15 +205,18 @@ class Conv2dStridedDgradImplicitGemmF16NHWCF16NHWCF32NHWCTensorOpF32SM80(unittes
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 32], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.StridedDgradIdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -29,15 +29,19 @@ class Conv2dWgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment, math_inst.element_accumulator,
|
||||
cutlass.float16
|
||||
)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -67,15 +71,19 @@ class Conv2dWgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment, math_inst.element_accumulator,
|
||||
cutlass.float16
|
||||
)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -29,15 +29,18 @@ class Conv2dWgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 16], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -67,15 +70,18 @@ class Conv2dWgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 16], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -105,15 +111,18 @@ class Conv2dWgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[64, 256, 32], stages=3,
|
||||
warp_count=[1, 4, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -143,15 +152,18 @@ class Conv2dWgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 16], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -193,15 +205,18 @@ class Conv2dWgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 16], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -30,15 +30,18 @@ class Conv2dWgradImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase)
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 8], stages=4,
|
||||
warp_count=[2, 4, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -68,15 +71,18 @@ class Conv2dWgradImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase)
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 8], stages=4,
|
||||
warp_count=[2, 4, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -29,15 +29,18 @@ class Conv2dWgradImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.Te
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 16], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -67,15 +70,18 @@ class Conv2dWgradImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.Te
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 32], stages=3,
|
||||
warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst,
|
||||
min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=80, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
|
||||
@ -0,0 +1,80 @@
|
||||
pushd $CUTLASS_PATH/examples/40_cutlass_py/
|
||||
|
||||
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
|
||||
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2
|
||||
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2
|
||||
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_f32 -op TensorOp -b 64 64 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 4
|
||||
|
||||
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
|
||||
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2
|
||||
|
||||
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 256 128 64 -s 3 -w 4 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 3
|
||||
|
||||
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 5
|
||||
|
||||
python gemm.py -i 16 8 32 -ta int8 -tb int8 -tc int8 -tacc int32 -m multiply_add -op TensorOp -b 128 128 128 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 16 -lb ColumnMajor -ab 16 -lc RowMajor -ac 16 -te float32 -ep FastLinearCombinationClamp -sw IdentitySwizzle2 -p 512 512 512 -alpha 1.0 -beta 0.0 -gm Gemm -k 1
|
||||
|
||||
python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 0.0 -pm Device
|
||||
|
||||
python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 1.0 -pm Host
|
||||
|
||||
python gemm_grouped.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 64 8 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device
|
||||
|
||||
python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device
|
||||
|
||||
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 13 17 8 -krsc 24 3 3 8 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
|
||||
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 1.0 -beta 1.0
|
||||
|
||||
python conv2d.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 128 8 -s 4 -w 4 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -co fprop -st Strided -ia analytic -sm Parallel -k 3 -nhwc 1 71 80 32 -krsc 64 5 5 32 -pad 2 2 2 2 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 1.0
|
||||
|
||||
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 1 -lb TensorNHWC -ab 1 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co wgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 8 8 1 -krsc 1 3 3 1 -pad 1 1 1 1 -stride 1 1 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
|
||||
python conv2d.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 128 8 -s 4 -w 2 4 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co wgrad -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
|
||||
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
|
||||
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 8 -lb TensorNHWC -ab 8 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
|
||||
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
|
||||
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 8 -lb TensorNHWC -ab 8 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia fixed_channels -sm Serial -k 1 -nhwc 1 8 8 8 -krsc 16 3 3 8 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
|
||||
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 56 56 12 -krsc 8 1 1 12 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
|
||||
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 -bias -activ relu
|
||||
|
||||
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2 -bias -activ leaky_relu -activ_arg 0.2
|
||||
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2 -bias -activ tanh
|
||||
|
||||
python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 0.0 -beta 0.5 -pm Host -bias -activ sigmoid -bias -activ sigmoid
|
||||
|
||||
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ silu
|
||||
|
||||
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ hardswish
|
||||
|
||||
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 0.0 -beta 0.5 -gm GemmSplitKParallel -k 5 -bias -activ gelu
|
||||
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3
|
||||
|
||||
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 2
|
||||
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
|
||||
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -epv ColumnBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
|
||||
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
|
||||
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnReduction -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3
|
||||
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnBroadcast -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 3
|
||||
popd
|
||||
@ -49,7 +49,7 @@ class Test_Frontend(unittest.TestCase):
|
||||
|
||||
tile_description = TileDescription(
|
||||
[128, 128, 8], 4, [2, 4, 1],
|
||||
math_inst, 80, 80
|
||||
math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -64,10 +64,14 @@ class Test_Frontend(unittest.TestCase):
|
||||
cutlass.float32, cutlass.RowMajor, 1
|
||||
)
|
||||
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
self.operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=cutlass.float32,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
)
|
||||
|
||||
@ -89,7 +93,7 @@ class Test_Frontend(unittest.TestCase):
|
||||
arguments = GemmArguments(
|
||||
operation=self.operation, problem_size=problem_size,
|
||||
A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D,
|
||||
output_op=LinearCombinationFunctorArguments(alpha, beta),
|
||||
output_op=self.operation.epilogue_type(alpha, beta),
|
||||
gemm_mode=cutlass.gemm.Mode.Gemm, split_k_splices=1
|
||||
)
|
||||
|
||||
@ -119,7 +123,7 @@ class Test_Frontend(unittest.TestCase):
|
||||
arguments = GemmArguments(
|
||||
operation=self.operation, problem_size=problem_size,
|
||||
A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D,
|
||||
output_op=LinearCombinationFunctorArguments(alpha, beta),
|
||||
output_op=self.operation.epilogue_type(alpha, beta),
|
||||
gemm_mode=cutlass.gemm.Mode.Gemm, split_k_splices=1
|
||||
)
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ class GemmBF16TensorOpSm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[64, 128, 64],
|
||||
stages=4, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -33,15 +33,15 @@ class GemmBF16TensorOpSm80(unittest.TestCase):
|
||||
alignment=4
|
||||
)
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -58,7 +58,7 @@ class GemmBF16TensorOpSm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[64, 128, 32],
|
||||
stages=6, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -74,15 +74,15 @@ class GemmBF16TensorOpSm80(unittest.TestCase):
|
||||
alignment=8
|
||||
)
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 32],
|
||||
stages=3, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -36,13 +36,15 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = cutlass.BatchedIdentitySwizzle
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor,
|
||||
direct_store=True
|
||||
)
|
||||
@ -60,7 +62,7 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 64],
|
||||
stages=3, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -78,13 +80,15 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -101,7 +105,7 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 256, 64],
|
||||
stages=3, warp_count=[2, 4, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -119,13 +123,15 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -142,7 +148,7 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[256, 128, 64],
|
||||
stages=3, warp_count=[4, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -160,13 +166,15 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -183,7 +191,7 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 64, 64],
|
||||
stages=3, warp_count=[2, 1, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -201,13 +209,15 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.float16
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -224,7 +234,7 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[64, 64, 32],
|
||||
stages=10, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -242,13 +252,15 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.float16
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -265,7 +277,7 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[256, 128, 64],
|
||||
stages=3, warp_count=[4, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -283,13 +295,15 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -306,7 +320,7 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 64, 64],
|
||||
stages=3, warp_count=[2, 1, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -324,13 +338,15 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -347,7 +363,7 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 256, 64],
|
||||
stages=3, warp_count=[2, 4, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -365,13 +381,15 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -388,7 +406,7 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 256, 64],
|
||||
stages=3, warp_count=[2, 4, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -406,13 +424,15 @@ class GemmF16Sm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ class GemmF32nF32nF32nTensorOpF32Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 32],
|
||||
stages=3, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -37,13 +37,15 @@ class GemmF32nF32nF32nTensorOpF32Sm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -61,7 +63,7 @@ class GemmF32nF32nF32nTensorOpF32Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 32],
|
||||
stages=3, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -79,13 +81,15 @@ class GemmF32nF32nF32nTensorOpF32Sm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -102,7 +106,7 @@ class GemmF32nF32nF32nTensorOpF32Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[64, 64, 32],
|
||||
stages=3, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -120,13 +124,15 @@ class GemmF32nF32nF32nTensorOpF32Sm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ class GemmF64TensorOpSm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[32, 32, 16],
|
||||
stages=4, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
# alignment 1 restricted for double
|
||||
@ -36,13 +36,15 @@ class GemmF64TensorOpSm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.float64
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -59,7 +61,7 @@ class GemmF64TensorOpSm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[64, 64, 16],
|
||||
stages=4, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
# alignment 1 restricted for double
|
||||
@ -78,13 +80,15 @@ class GemmF64TensorOpSm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.float64
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ class GemmGroupedSm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 32],
|
||||
stages=3, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -37,14 +37,15 @@ class GemmGroupedSm80(unittest.TestCase):
|
||||
)
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
swizzling_functor = cutlass.BatchedIdentitySwizzle
|
||||
|
||||
for precompute_mode in [SchedulerMode.Device, SchedulerMode.Host]:
|
||||
operation = GemmOperationGrouped(
|
||||
tile_description.minimum_compute_capability,
|
||||
80,
|
||||
tile_description, A, B, C,
|
||||
element_epilogue,
|
||||
epilogue_functor, swizzling_functor,
|
||||
precompute_mode=precompute_mode
|
||||
)
|
||||
@ -64,7 +65,7 @@ class GemmGroupedSm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[64, 64, 16],
|
||||
stages=4, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -83,14 +84,15 @@ class GemmGroupedSm80(unittest.TestCase):
|
||||
)
|
||||
|
||||
element_epilogue = cutlass.float64
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
swizzling_functor = cutlass.BatchedIdentitySwizzle
|
||||
|
||||
for precompute_mode in [SchedulerMode.Device, SchedulerMode.Host]:
|
||||
operation = GemmOperationGrouped(
|
||||
tile_description.minimum_compute_capability,
|
||||
80,
|
||||
tile_description, A, B, C,
|
||||
element_epilogue,
|
||||
epilogue_functor, swizzling_functor,
|
||||
precompute_mode=precompute_mode
|
||||
)
|
||||
@ -110,7 +112,7 @@ class GemmGroupedSm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 64, 8],
|
||||
stages=4, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -129,14 +131,15 @@ class GemmGroupedSm80(unittest.TestCase):
|
||||
)
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
swizzling_functor = cutlass.BatchedIdentitySwizzle
|
||||
|
||||
for precompute_mode in [SchedulerMode.Device, SchedulerMode.Host]:
|
||||
operation = GemmOperationGrouped(
|
||||
tile_description.minimum_compute_capability,
|
||||
80,
|
||||
tile_description, A, B, C,
|
||||
element_epilogue,
|
||||
epilogue_functor, swizzling_functor,
|
||||
precompute_mode=precompute_mode
|
||||
)
|
||||
@ -156,7 +159,7 @@ class GemmGroupedSm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 32],
|
||||
stages=3, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -175,14 +178,15 @@ class GemmGroupedSm80(unittest.TestCase):
|
||||
)
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
swizzling_functor = cutlass.BatchedIdentitySwizzle
|
||||
|
||||
for precompute_mode in [SchedulerMode.Device, SchedulerMode.Host]:
|
||||
operation = GemmOperationGrouped(
|
||||
tile_description.minimum_compute_capability,
|
||||
80,
|
||||
tile_description, A, B, C,
|
||||
element_epilogue,
|
||||
epilogue_functor, swizzling_functor,
|
||||
precompute_mode=precompute_mode
|
||||
)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
from pycutlass.epilogue import LinearCombinationClamp
|
||||
from pycutlass.test import *
|
||||
import unittest
|
||||
|
||||
@ -17,7 +18,7 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[64, 64, 64],
|
||||
stages=6, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -33,15 +34,15 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase):
|
||||
alignment=8
|
||||
)
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.FastLinearCombinationClamp
|
||||
epilogue_functor = FastLinearCombinationClamp(
|
||||
C.element, C.alignment
|
||||
)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -58,7 +59,7 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 128],
|
||||
stages=3, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -74,15 +75,15 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase):
|
||||
alignment=16
|
||||
)
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.FastLinearCombinationClamp
|
||||
epilogue_functor = FastLinearCombinationClamp(
|
||||
C.element, C.alignment
|
||||
)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -99,7 +100,7 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 128],
|
||||
stages=3, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -115,15 +116,15 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase):
|
||||
alignment=16
|
||||
)
|
||||
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.FastLinearCombinationClamp
|
||||
epilogue_functor = FastLinearCombinationClamp(
|
||||
C.element, C.alignment
|
||||
)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -140,7 +141,7 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 128],
|
||||
stages=3, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -158,13 +159,16 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.int32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombinationClamp
|
||||
epilogue_functor = LinearCombinationClamp(
|
||||
C.element, C.alignment, math_inst.element_accumulator,
|
||||
element_epilogue
|
||||
)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -181,7 +185,7 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase):
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=[128, 128, 128],
|
||||
stages=3, warp_count=[2, 2, 1],
|
||||
math_instruction=math_inst, min_compute=80, max_compute=80
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -199,13 +203,16 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase):
|
||||
|
||||
element_epilogue = cutlass.int32
|
||||
|
||||
epilogue_functor = EpilogueFunctor.LinearCombinationClamp
|
||||
epilogue_functor = LinearCombinationClamp(
|
||||
C.element, C.alignment, math_inst.element_accumulator,
|
||||
element_epilogue
|
||||
)
|
||||
|
||||
swizzling_functor = cutlass.IdentitySwizzle1
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
|
||||
@ -348,3 +348,16 @@ conv2d fprop_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1
|
||||
conv2d fprop_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 164942943 4259285988 984016853 888753301
|
||||
conv2d fprop_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 2823094147 1681845497 4242738907 3244428635
|
||||
conv2d fprop_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 s8nhwc_s8nhwc_inhwc_i_i 4060010502 2881035321 3927119619 3311661122
|
||||
conv2d dgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4110991321 3464637181 1030377090 3211227145
|
||||
conv2d dgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4110991321 1479940693 2379046159 2482639965
|
||||
conv2d dgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 832653836 1871463331 2718290800 1797658305
|
||||
conv2d dgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3484040069 664160900 3954982568 985899371
|
||||
conv2d dgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1513864544 1924855848 1728786974 3821277575
|
||||
conv2d dgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 868180534 1764715518 3998637379 2782670608
|
||||
conv2d dgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3437976747 666906244 2107859856 831363691
|
||||
conv2d dgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4195072693 1575210381 2486552517 3268706408
|
||||
conv2d dgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3457330201 2316839359 1729888024 2308314800
|
||||
conv2d dgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 754609939 2469024119 464378888 544154978
|
||||
conv2d dgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 754609939 2469024119 464378888 3191247524
|
||||
conv2d dgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1690216859 554790212 956712535 1281779197
|
||||
conv2d dgrad_1x56x56x8_28x28_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3184127693 835105643 4011933753 3207244654
|
||||
|
||||
@ -42,8 +42,7 @@ import unittest
|
||||
#
|
||||
|
||||
def TestGemmOperator(gemm_kind, math_inst, layout, alignment, tiling, arch, mixed=False,
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1, **kwargs):
|
||||
epilogue_functor=None, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs):
|
||||
"""
|
||||
Test GEMM Operation based on configuration
|
||||
"""
|
||||
@ -68,7 +67,7 @@ def TestGemmOperator(gemm_kind, math_inst, layout, alignment, tiling, arch, mixe
|
||||
|
||||
tile_description = TileDescription(
|
||||
tiling[0], tiling[1], tiling[2],
|
||||
math_inst, arch, arch
|
||||
math_inst
|
||||
)
|
||||
|
||||
A = TensorDescription(
|
||||
@ -84,11 +83,15 @@ def TestGemmOperator(gemm_kind, math_inst, layout, alignment, tiling, arch, mixe
|
||||
)
|
||||
|
||||
element_epilogue = data_type[3]
|
||||
if epilogue_functor is None:
|
||||
epilogue_functor = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
if gemm_kind == GemmKind.Universal:
|
||||
operation = GemmOperationUniversal(
|
||||
arch=arch, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
if A.layout in [cutlass.ColumnMajorInterleaved32, cutlass.RowMajorInterleaved32]:
|
||||
@ -99,7 +102,7 @@ def TestGemmOperator(gemm_kind, math_inst, layout, alignment, tiling, arch, mixe
|
||||
elif gemm_kind == GemmKind.Grouped:
|
||||
operation = GemmOperationGrouped(
|
||||
arch, tile_description, A, B, C,
|
||||
element_epilogue, epilogue_functor, swizzling_functor,
|
||||
epilogue_functor, swizzling_functor,
|
||||
precompute_mode=kwargs["precompute_mode"]
|
||||
)
|
||||
testbed = TestbedGrouped(operation=operation)
|
||||
@ -110,7 +113,7 @@ def TestGemmOperator(gemm_kind, math_inst, layout, alignment, tiling, arch, mixe
|
||||
|
||||
def TestConv2dOperator(math_inst, alignment, tiling, arch,
|
||||
stride_supports=[StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided],
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
epilogue_functor=None,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1, interleaved=False, **kwargs):
|
||||
"""
|
||||
Test Conv2d Operation based on configurations
|
||||
@ -167,20 +170,24 @@ def TestConv2dOperator(math_inst, alignment, tiling, arch,
|
||||
tile_description = TileDescription(
|
||||
threadblock_shape=tiling[0], stages=tiling[1],
|
||||
warp_count=tiling[2],
|
||||
math_instruction=math_inst,
|
||||
min_compute=arch, max_compute=arch
|
||||
math_instruction=math_inst
|
||||
)
|
||||
|
||||
if conv_kind == cutlass.conv.Operator.dgrad and stride_support == StrideSupport.Strided:
|
||||
swizzling_functor = cutlass.StridedDgradIdentitySwizzle1
|
||||
else:
|
||||
swizzling_functor = default_swizzling_functor
|
||||
|
||||
if epilogue_functor is None:
|
||||
epilogue_functor_ = LinearCombination(
|
||||
C.element, C.alignment,
|
||||
math_inst.element_accumulator, data_type[3])
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=conv_kind, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=arch, tile_description=tile_description, A=A, B=B, C=C,
|
||||
element_epilogue=data_type[3], stride_support=stride_support,
|
||||
epilogue_functor=epilogue_functor,
|
||||
stride_support=stride_support,
|
||||
epilogue_functor=epilogue_functor_,
|
||||
swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
@ -369,7 +376,11 @@ class Test_SM80(unittest.TestCase):
|
||||
tiling = ([256, 64, 64], 4, [4, 1, 1])
|
||||
data_type_mixed = [cutlass.int8, cutlass.int8, cutlass.int8, cutlass.float32]
|
||||
|
||||
self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment_mixed, tiling, 80, False, data_type=data_type_mixed, epilogue_functor=EpilogueFunctor.FastLinearCombinationClamp))
|
||||
epilogue_functor = FastLinearCombinationClamp(
|
||||
data_type_mixed[2], alignment_mixed[2]
|
||||
)
|
||||
|
||||
self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment_mixed, tiling, 80, False, data_type=data_type_mixed, epilogue_functor=epilogue_functor))
|
||||
stride_supports = [StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided]
|
||||
layout = [cutlass.TensorNC32HW32, cutlass.TensorC32RSK32, cutlass.TensorNC32HW32]
|
||||
results = TestConv2dOperator(math_inst, alignment_mixed, tiling, 80, stride_supports=stride_supports, data_type=data_type_mixed, layout=layout, interleaved=True)
|
||||
@ -378,59 +389,59 @@ class Test_SM80(unittest.TestCase):
|
||||
|
||||
def SM80_SparseTensorOp_16832(self):
|
||||
pass
|
||||
def test_SM80_PlanarComplexTensorOp_16816(self):
|
||||
def SM80_PlanarComplexTensorOp_16816(self):
|
||||
pass
|
||||
def test_SM80_SparseTensorOp_16816_fast_math(self):
|
||||
def SM80_SparseTensorOp_16816_fast_math(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_1688_complex(self):
|
||||
def SM80_TensorOp_1688_complex(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_1688_fast_fp32_math_complex(self):
|
||||
def SM80_TensorOp_1688_fast_fp32_math_complex(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_1688_rank_k(self):
|
||||
def SM80_TensorOp_1688_rank_k(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_1688_rank_k_complex(self):
|
||||
def SM80_TensorOp_1688_rank_k_complex(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_1688_trmm(self):
|
||||
def SM80_TensorOp_1688_trmm(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_1688_trmm_complex(self):
|
||||
def SM80_TensorOp_1688_trmm_complex(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_1688_symm(self):
|
||||
def SM80_TensorOp_1688_symm(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_1688_symm_complex(self):
|
||||
def SM80_TensorOp_1688_symm_complex(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_884_complex(self):
|
||||
def SM80_TensorOp_884_complex(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_884_complex_gaussian(self):
|
||||
def SM80_TensorOp_884_complex_gaussian(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_884_rank_k(self):
|
||||
def SM80_TensorOp_884_rank_k(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_884_rank_k_complex(self):
|
||||
def SM80_TensorOp_884_rank_k_complex(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_884_rank_k_complex_gaussian(self):
|
||||
def SM80_TensorOp_884_rank_k_complex_gaussian(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_884_trmm(self):
|
||||
def SM80_TensorOp_884_trmm(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_884_trmm_complex(self):
|
||||
def SM80_TensorOp_884_trmm_complex(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_884_trmm_complex_gaussian(self):
|
||||
def SM80_TensorOp_884_trmm_complex_gaussian(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_884_symm(self):
|
||||
def SM80_TensorOp_884_symm(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_884_symm_complex(self):
|
||||
def SM80_TensorOp_884_symm_complex(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_884_symm_complex_gaussian(self):
|
||||
def SM80_TensorOp_884_symm_complex_gaussian(self):
|
||||
pass
|
||||
def test_SM80_SparseTensorOp_16864_TN(self):
|
||||
def SM80_SparseTensorOp_16864_TN(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_16864_TN(self):
|
||||
def SM80_TensorOp_16864_TN(self):
|
||||
pass
|
||||
def test_SM80_SparseTensorOp_168128_TN(self):
|
||||
def SM80_SparseTensorOp_168128_TN(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_16864_Interleaved(self):
|
||||
def SM80_TensorOp_16864_Interleaved(self):
|
||||
pass
|
||||
def test_SM80_TensorOp_168256(self):
|
||||
def SM80_TensorOp_168256(self):
|
||||
pass
|
||||
def test_SM80_Simt_complex(self):
|
||||
def SM80_Simt_complex(self):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@ -1195,13 +1195,13 @@ std::string lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type) {
|
||||
break;
|
||||
case NumericTypeID::kBF16:
|
||||
{
|
||||
float tmp = *reinterpret_cast<bfloat16_t *>(bytes.data());;
|
||||
float tmp = *reinterpret_cast<bfloat16_t *>(bytes.data());
|
||||
ss << tmp;
|
||||
}
|
||||
break;
|
||||
case NumericTypeID::kTF32:
|
||||
{
|
||||
float tmp = *reinterpret_cast<tfloat32_t *>(bytes.data());;
|
||||
float tmp = *reinterpret_cast<tfloat32_t *>(bytes.data());
|
||||
ss << tmp;
|
||||
}
|
||||
break;
|
||||
|
||||
@ -183,7 +183,7 @@ __global__ void GemmPlanarComplex(
|
||||
ComplexC d_ij;
|
||||
|
||||
d_ij.real() = convert_op(result.real());
|
||||
d_ij.imag() = convert_op(result.imag());;
|
||||
d_ij.imag() = convert_op(result.imag());
|
||||
|
||||
tensor_d.at(coord) = d_ij;
|
||||
}
|
||||
|
||||
@ -172,7 +172,7 @@ void GemmPlanarComplex(
|
||||
complex<ScalarType> result = alpha * acc + beta * src;
|
||||
|
||||
d_ij.real() = convert_op(result.real());
|
||||
d_ij.imag() = convert_op(result.imag());;
|
||||
d_ij.imag() = convert_op(result.imag());
|
||||
|
||||
tensor_d.at(coord) = d_ij;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user