diff --git a/CHANGELOG.md b/CHANGELOG.md index cecffb5e..04af9524 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,11 +1,19 @@ # NVIDIA CUTLASS Changelog ## [2.10.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.10.0) (2022-08-23) -* [Grouped convolution targeting implicit GEMM](test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu) -* [Depthwise separable convolution](test/unit/conv/device/depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) -* Optimizations for CUTLASS's [Grouped GEMM](examples/24_gemm_grouped/gemm_grouped.cu) kernel -* [Grouped GEMM for Multihead Attention](examples/41_multi_head_attention) -* [GEMM + Layer norm fusion for Ampere](examples/37_gemm_layernorm_gemm_fusion/) +* [CUTLASS Python](/examples/40_cutlass_py) now supports GEMM, CONV, Group GEMM for different data types as well as different epilogue flavours. +* Optimizations for CUTLASS's [Grouped GEMM](examples/24_gemm_grouped/gemm_grouped.cu) kernel. Threadblock scheduling part is improved. Some computation can be moved to the host side if applicable. [Grouped Syr2k](examples/38_syr2k_grouped/syr2k_grouped.cu) kernels are added, too. +* Optimizations for [GEMM+Softmax](examples/35_gemm_softmax). All the reduction computation is fused into the previous GEMM. More template arguments are provided to fine tune the performance. +* [Grouped GEMM for Multihead Attention](examples/41_multi_head_attention). This general group gemm based MHA does not require the sequence length of all GEMMs to be the same which makes it most useful for natural language processing. +* [GEMM + Layer norm fusion for Ampere](examples/37_gemm_layernorm_gemm_fusion/) splits the layernorm into two parts and both of them can be fused into the GEMMs before and after separately. In addition to use square sum to compute variance of layernorm, [Shift-K](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data) is provided if square sum raise numerical issues. +* [GEMM Epilogue Permutation Fusion](examples/39_gemm_permute) can apply user provided permutation layout mapping in the GEMM epilogue. +* [Grouped convolution targeting implicit GEMM](test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) introduces the first group convolution implementation to CUTLASS. It is an Analytical implementation, not an Optimized. The restrictions are: 1) input and output channel number should be multiple of group number. 2) split-K is not supported. The implementation has 2 modes: + * kSingleGroup: output channel per group is multiple of Threadblock tile N. + * kMultipleGroup: Threadblock tile N is multiple of output channel per group. +* [Depthwise separable convolution](test/unit/conv/device/depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) introduces the first depthwise convolution which is also Analytical for now. The restrictions are: 1) SIMT only 2) No split-K 3) input channel equals to output channel equals to group number. +* Standalone [Layernorm](/tools/util/include/cutlass/util/device_layernorm.h) and [Pooling](/tools/util/include/cutlass/util/device_nhwc_pooling.h) kernels. +* [Back-to-back GEMM/CONV](examples/13_two_tensor_op_fusion) relaxes the requirement that the first GEMM K dimension needs to be the multiple of Threadblock Tile K dimension. +* Optimal performance using [**CUDA 11.6u2**](https://developer.nvidia.com/cuda-downloads) * Updates and bugfixes from the community (thanks!) * **Deprecation announcement:** CUTLASS plans to deprecate the following: @@ -47,7 +55,7 @@ * New elementwise fusion pattern for [residual block](/include/cutlass/epilogue/thread/linear_combination_residual_block.h). * [Group GEMM](/examples/24_gemm_grouped) thread block number calculation fix which helps to launch the intended number of threadblocks to fully occupy the GPUs. * [Parallel GEMM splitk](https://github.com/NVIDIA/cutlass/pull/277) support in the CUTLASS profiler. -* Optimal performance using [**CUDA 11.7**](https://developer.nvidia.com/cuda-downloads) +* Optimal performance using [**CUDA 11.6u2**](https://developer.nvidia.com/cuda-downloads) * Updates and bugfixes from the community (thanks!) diff --git a/README.md b/README.md index e884c735..5bf9f4b3 100644 --- a/README.md +++ b/README.md @@ -39,11 +39,16 @@ supported at each level of the execution model hierarchy. # What's New in CUTLASS 2.10 CUTLASS 2.10 is an update to CUTLASS adding: -- [Grouped convolution targeting implicit GEMM](test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu) -- [Depthwise separable convolution](test/unit/conv/device/depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) -- Optimizations for CUTLASS's [Grouped GEMM](examples/24_gemm_grouped/gemm_grouped.cu) kernel -- [Grouped GEMM for Multihead Attention](examples/41_multi_head_attention) -- [GEMM + Layer norm fusion for Ampere](examples/37_gemm_layernorm_gemm_fusion/) +- [CUTLASS Python](/examples/40_cutlass_py) now supports GEMM, CONV, Group GEMM for different data types as well as different epilogue flavours. +- Optimizations for CUTLASS's [Grouped GEMM](examples/24_gemm_grouped/gemm_grouped.cu) kernel. It can move some scheduling into the host side if applicable. +- Optimizations for [GEMM+Softmax](examples/35_gemm_softmax). +- [Grouped GEMM for Multihead Attention](examples/41_multi_head_attention) is a general MHA that does not require equal sequence length in every GEMM. +- [GEMM + Layer norm fusion for Ampere](examples/37_gemm_layernorm_gemm_fusion/) can fuse the layernorm into GEMMs before and after. +- [GEMM Epilogue Permutation Fusion](examples/39_gemm_permute) can permute the GEMM output before storing. +- [Grouped convolution targeting implicit GEMM](test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) introduces the first group convolution implementation to CUTLASS. It is an Analytical implementation, not an Optimized. +- [Depthwise separable convolution](test/unit/conv/device/depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) introduces the first depthwise convolution which is also Analytical for now. +- Standalone [Layernorm](/tools/util/include/cutlass/util/device_layernorm.h) and [Pooling](/tools/util/include/cutlass/util/device_nhwc_pooling.h) kernels. +- [Back-to-back GEMM](examples/13_two_tensor_op_fusion) enhancements. - Updates and bugfixes from the community (thanks!) - **Deprecation announcement:** CUTLASS plans to deprecate the following: - Maxwell and Pascal GPU architectures diff --git a/examples/24_gemm_grouped/gemm_grouped.cu b/examples/24_gemm_grouped/gemm_grouped.cu index 1000f359..f8e0ede3 100644 --- a/examples/24_gemm_grouped/gemm_grouped.cu +++ b/examples/24_gemm_grouped/gemm_grouped.cu @@ -1528,6 +1528,9 @@ int main(int argc, char const **args) { cutlass::epilogue::thread::LinearCombination< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementAccumulator>, + // NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. + // This parameter is passed in at present to match the APIs of other kernels. The parameter + // is unused within the kernel. cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, 4>::GemmKernel; diff --git a/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu b/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu index afab8344..ada9dda7 100644 --- a/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu +++ b/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu @@ -187,8 +187,8 @@ struct Options { // elements in input matrices. using ElementAccumulator = float; // <- data type of accumulator using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations -using ElementInputA = cutlass::half_t;; // <- data type of elements in input matrix A -using ElementInputB = cutlass::half_t;; // <- data type of elements in input matrix B +using ElementInputA = cutlass::half_t; // <- data type of elements in input matrix A +using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B using ElementOutput = float; // <- data type of elements in output matrix D // The code section below describes matrix layout of input and output matrices. @@ -252,7 +252,7 @@ using Gemm = cutlass::gemm::device::GemmUniversal; + + // NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. + // This parameter is passed in at present to match the APIs of other kernels. The parameter + // is unused within the kernel. using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + const int kStages = 4; const bool kSplitKSerial = false; using Operator = cutlass::arch::OpMultiplyAdd; diff --git a/examples/40_cutlass_py/README.md b/examples/40_cutlass_py/README.md index 9f9b8cac..dcbcee2a 100644 --- a/examples/40_cutlass_py/README.md +++ b/examples/40_cutlass_py/README.md @@ -92,25 +92,35 @@ Example 1: SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32_256x128x128_64x64x128 ```python python gemm.py -i 16 8 32 -ta int8 -tb int8 -tc int8 -tacc int32 -m multiply_add -op TensorOp -b 128 128 128 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 16 -lb ColumnMajor -ab 16 -lc RowMajor -ac 16 -te float32 -ep FastLinearCombinationClamp -sw IdentitySwizzle2 -p 512 512 512 -alpha 1.0 -beta 0.0 -gm Gemm -k 1 ``` + +### Batched & Array GEMM +Example 1: Batched GEMM +```python +python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3 +``` +Example 2: Array GEMM +```python +python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 2 +``` *** ## GEMM Grouped Examples The GEMM Grouped examples use numpy to create input tensors and verify the results. Example 1: SM80_Device_GemmGrouped_f16t_f16t_f32t_tensor_op_f32_128x128x32_64x64x32, device schedule ```python -python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 0.0 -pm Device +python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 0.0 -pm Device ``` Example 2: SM80_Device_GemmGrouped_f64n_f64n_f64t_tensor_op_f64_64x64x16_32x32x16, host schedule ```python -python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle2 -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 1.0 -pm Host +python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 1.0 -pm Host ``` Example 3: SM80_Device_GemmGrouped_f32n_f32n_f32n_simt_f32_128x64x8_64x32x1, device schedule ```python -python gemm_grouped.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 64 8 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device +python gemm_grouped.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 64 8 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device ``` Example 4: SM80_Device_GemmGrouped_f16t_f16t_f32t_tensor_op_f32_128x128x32_64x64x32, device schedule ```python -python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle8 -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device +python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device ``` *** ## Conv2d Example @@ -160,3 +170,61 @@ Example 4: SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nh ```python python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 56 56 12 -krsc 8 1 1 12 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0 ``` + +## Epilogue +### Bias +To replace C with a bias vector, add `-bias` flag. +### Activation function +Example 1: ReLU +```python +python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 -bias -activ relu +``` +Example 2: leaky ReLU +```python +python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2 -bias -activ leaky_relu -activ_arg 0.2 +``` +Example 3: tanh (alpha=0 to avoid saturation) +```python +python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2 -bias -activ tanh +``` +Example 4: sigmoid +```python +python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 0.0 -beta 0.5 -pm Host -bias -activ sigmoid -bias -activ sigmoid +``` +Example 5: SiLU +```python +python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ silu +``` +Example 6: HardSwish +```python +python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ hardswish +``` +Example 7: GELU +```python +python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 0.0 -beta 0.5 -gm GemmSplitKParallel -k 5 -bias -activ gelu +``` +### Epilogue Visitor Tree +Example 1: +```python +python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 +``` +Example 2: +```python +python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -epv ColumnBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 +``` +Example 3: +```python +python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 +``` +Example 4: +```python +python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnReduction -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 +``` +Example 5: +```python +python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3 +``` +Example 6: +```python +python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnBroadcast -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 3 +``` diff --git a/examples/40_cutlass_py/conv2d.py b/examples/40_cutlass_py/conv2d.py index 687cfdc4..8be1f4d9 100644 --- a/examples/40_cutlass_py/conv2d.py +++ b/examples/40_cutlass_py/conv2d.py @@ -33,6 +33,7 @@ import pycutlass from pycutlass import * from pycutlass.conv2d_operation import * from pycutlass.utils import reference_model +import torch.nn.functional as F import argparse @@ -127,6 +128,13 @@ parser.add_argument("-stride", "--stride", nargs=2, type=int, help="stride (stri parser.add_argument("-dilation", "--dilation", nargs=2, type=int, help="dilation (dilation_h, dilation_w)") parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha") parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta") +parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector") +# Activation function +parser.add_argument("-activ", "--activation_function", default="identity", + choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function") +parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float, + help="addition arguments for activation") + parser.add_argument('--print_cuda', action="store_true", help="print the underlying CUDA kernel") @@ -138,6 +146,8 @@ except: pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) +np.random.seed(0) + element_a = getattr(cutlass, args.element_a) element_b = getattr(cutlass, args.element_b) element_c = getattr(cutlass, args.element_c) @@ -152,7 +162,7 @@ math_inst = MathInstruction( tile_description = TileDescription( args.threadblock_shape, args.stages, args.warp_count, - math_inst, args.compute_capability, args.compute_capability + math_inst ) layout_a = getattr(cutlass, args.layout_a) @@ -172,7 +182,16 @@ C = TensorDescription( ) element_epilogue = getattr(cutlass, args.element_epilogue) -epilogue_functor = getattr(EpilogueFunctor, args.epilogue_functor) +if (args.activation_function == "identity" + or (args.split_k_mode == "Parallel" and args.split_k_slices > 1)): + # + epilogue_functor = getattr(pycutlass, args.epilogue_functor)( + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) +else: + epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")( + getattr(pycutlass, args.activation_function)(element_epilogue), + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + iterator_algorithm = getattr(cutlass.conv.IteratorAlgorithm, args.iterator_algorithm) swizzling_functor = getattr(cutlass, args.swizzling_functor) stride_support = getattr(StrideSupport, args.stride_support) @@ -181,7 +200,7 @@ conv_kind = getattr(cutlass.conv.Operator, args.conv_kind) operation = Conv2dOperation( conv_kind=conv_kind, iterator_algorithm=iterator_algorithm, arch=args.compute_capability, tile_description=tile_description, - A=A, B=B, C=C, element_epilogue=element_epilogue, stride_support=stride_support, + A=A, B=B, C=C, stride_support=stride_support, epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor ) @@ -191,10 +210,18 @@ if args.print_cuda: operations = [operation,] if args.split_k_mode == "Parallel" and args.split_k_slices > 1: + if (args.activation_function == "identity"): + epilogue_functor_reduction = getattr(pycutlass, args.epilogue_functor)( + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + else: + epilogue_functor_reduction = getattr(pycutlass, "LinearCombinationGeneric")( + getattr(pycutlass, args.activation_function)(element_epilogue), + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) reduction_operation = ReductionOperation( shape=cutlass.MatrixCoord(4, 32 * C.alignment), C=C, element_accumulator=element_acc, element_compute=element_epilogue, + epilogue_functor=epilogue_functor_reduction, count=C.alignment ) operations.append(reduction_operation) @@ -219,9 +246,18 @@ tensor_A_size = cutlass.conv.implicit_gemm_tensor_a_size( tensor_B_size = cutlass.conv.implicit_gemm_tensor_b_size( conv_kind, problem_size ) -tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_size( - conv_kind, problem_size -) +if args.bias: + tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_extent( + conv_kind, problem_size + ).at(3) +else: + tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_size( + conv_kind, problem_size + ) + +tensor_D_size = cutlass.conv.implicit_gemm_tensor_c_size( + conv_kind, problem_size + ) if args.element_a != "int8": tensor_A = torch.ceil(torch.empty(size=(tensor_A_size,), dtype=getattr(torch, args.element_a), device="cuda").uniform_(-8.5, 7.5)) @@ -238,12 +274,12 @@ if args.element_c != "int8": else: tensor_C = torch.empty(size=(tensor_C_size,), dtype=getattr(torch, args.element_c), device="cuda").uniform_(-2, 2) -tensor_D = torch.ones_like(tensor_C) +tensor_D = torch.ones(size=(tensor_D_size,), dtype=getattr(torch, args.element_c), device="cuda") arguments = Conv2dArguments( operation=operation, problem_size=problem_size, A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, - output_op = LinearCombinationFunctorArguments(args.alpha, args.beta), + output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)), split_k_mode=getattr(cutlass.conv.SplitKMode, args.split_k_mode), split_k_slices=problem_size.split_k_slices ) @@ -257,7 +293,8 @@ if args.split_k_mode == "Parallel" and args.split_k_slices > 1: workspace=arguments.ptr_D, destination=tensor_D, source=tensor_C, - output_op = LinearCombinationFunctorArguments(args.alpha, args.beta) + output_op = reduction_operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)), + bias = arguments.bias ) operation.run(arguments) @@ -270,8 +307,12 @@ else: reference_model = Conv2dReferenceModule(A, B, C, conv_kind) -tensor_D_ref = reference_model.run(tensor_A, tensor_B, tensor_C, arguments.problem_size, args.alpha, args.beta) - -assert torch.equal(tensor_D, tensor_D_ref) +tensor_D_ref = reference_model.run(tensor_A, tensor_B, tensor_C, arguments.problem_size, args.alpha, args.beta, args.bias) +if (args.activation_function != "identity"): + tensor_D_ref = getattr(F, args.activation_function)(*([tensor_D_ref,] + args.activation_args)) +try: + assert torch.equal(tensor_D, tensor_D_ref) +except: + assert torch.allclose(tensor_D, tensor_D_ref, rtol=1e-2) print("Passed.") diff --git a/examples/40_cutlass_py/gemm.py b/examples/40_cutlass_py/gemm.py index 8341d10d..b9b7fabc 100644 --- a/examples/40_cutlass_py/gemm.py +++ b/examples/40_cutlass_py/gemm.py @@ -99,9 +99,11 @@ parser.add_argument("-te", "--element_epilogue", default="float32", type=str, parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination", type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'], help="This option describes the epilogue part of the kernel") +parser.add_argument("-epv", "--epilogue_visitor", default=None, + type=str, choices=['RowReduction', 'ColumnReduction', 'RowBroadcast', 'ColumnBroadcast'], help="epilogue visitor for more complex epilogues") # swizzling parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[ - "IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle"], + "IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle", "BatchedIdentitySwizzle"], help="This option describes how thread blocks are scheduled on GPU") # Argument @@ -113,17 +115,22 @@ parser.add_argument("-alpha", "--alpha", default=1.0, type=float, parser.add_argument("-beta", "--beta", default=0.0, type=float, help="Scaling factor of C") parser.add_argument("-gm", "--gemm_mode", default="Gemm", type=str, - choices=["Gemm", "GemmSplitKParallel"], + choices=["Gemm", "GemmSplitKParallel", "Batched", "Array"], help="GEMM mode. Gemm is used for non-splitK or serial-splitK. \ GemmSplitKParallel is used for parallel splitK") parser.add_argument('-k', '--split_k_slices', default=1, type=int, help="Number of split-k partitions. (default 1)") +parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector") +parser.add_argument('-batch', '--batch', default=1, type=int, help="batch size for batched GEMM") +# Activation function +parser.add_argument("-activ", "--activation_function", default="identity", + choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function") +parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float, + help="addition arguments for activation") parser.add_argument('--print_cuda', action="store_true", help="print the underlying CUDA kernel") -# parser.add_argument('-h', '--help', action="store_true", -# help="print help information") try: args = parser.parse_args() @@ -131,6 +138,9 @@ except: sys.exit(0) pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) +pycutlass.compiler.nvcc() + +np.random.seed(0) element_a = getattr(cutlass, args.element_a) element_b = getattr(cutlass, args.element_b) @@ -146,7 +156,7 @@ math_inst = MathInstruction( tile_description = TileDescription( args.threadblock_shape, args.stages, args.warp_count, - math_inst, args.compute_capability, args.compute_capability + math_inst ) layout_a = getattr(cutlass, args.layout_a) @@ -166,13 +176,83 @@ C = TensorDescription( ) element_epilogue = getattr(cutlass, args.element_epilogue) -epilogue_functor = getattr(EpilogueFunctor, args.epilogue_functor) +if (args.activation_function == "identity" + or (args.gemm_mode == "GemmSplitKParallel" and args.split_k_slices > 1)): + # + epilogue_functor = getattr(pycutlass, args.epilogue_functor)( + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) +else: + epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")( + getattr(pycutlass, args.activation_function)(element_epilogue), + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + swizzling_functor = getattr(cutlass, args.swizzling_functor) +visitor = args.epilogue_visitor is not None + +if args.epilogue_visitor == "ColumnReduction": + class ColumnReduction_(EpilogueVisitTree): + def __call__( + self, accum: 'tensor', c: 'tensor', + alpha: 'scalar', beta: 'scalar'): + # + D = alpha * accum + beta * c + reduction = reduction_op(D, "column", "Add", args.threadblock_shape[0]) + return D, reduction + epilogue_functor = ColumnReduction_( + epilogue_functor, tile_description, math_inst.element_accumulator, + C.alignment, element_epilogue, C.element) + epilogue_functor.initialize() +elif args.epilogue_visitor == "RowReduction": + class RowReduction_(EpilogueVisitTree): + def __call__( + self, accum: 'tensor', c: 'tensor', + alpha: 'scalar', beta: 'scalar'): + # + D = alpha * accum + tanh.numpy(beta * c) + reduction = reduction_op(D, "row", "Add", args.threadblock_shape[1]) + return D, reduction + epilogue_functor = RowReduction_( + epilogue_functor, tile_description, math_inst.element_accumulator, + C.alignment, element_epilogue, C.element) + epilogue_functor.initialize() + +elif args.epilogue_visitor == "RowBroadcast": + class RowBroadcast_(EpilogueVisitTree): + def __call__( + self, accum: 'tensor', c: 'tensor', + vector: 'row', alpha: 'scalar', beta: 'scalar'): + # + T = accum + vector + scale_T = alpha * T + Z = relu.numpy(scale_T + beta * c) + return Z, T + epilogue_functor = RowBroadcast_( + epilogue_functor, tile_description, math_inst.element_accumulator, + C.alignment, element_epilogue, C.element) + epilogue_functor.initialize() +elif args.epilogue_visitor == "ColumnBroadcast": + class ColumnBroadcast_(EpilogueVisitTree): + def __call__( + self, accum: 'tensor', c: 'tensor', + vector: 'column', alpha: 'scalar', beta: 'scalar'): + # + T = accum + vector + scale_T = leaky_relu.numpy(alpha * T, 0.2) + Z = scale_T + beta * c + return Z, T + epilogue_functor = ColumnBroadcast_( + epilogue_functor, tile_description, math_inst.element_accumulator, + C.alignment, element_epilogue, C.element) + epilogue_functor.initialize() +else: + epilogue_functor = epilogue_functor + operation = GemmOperationUniversal( arch=args.compute_capability, tile_description=tile_description, - A=A, B=B, C=C, element_epilogue=element_epilogue, - epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor + A=A, B=B, C=C, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor, + visitor=visitor ) if args.print_cuda: @@ -181,10 +261,19 @@ if args.print_cuda: operations = [operation, ] if args.gemm_mode == "GemmSplitKParallel": + if (args.activation_function == "identity"): + epilogue_functor_reduction = getattr(pycutlass, args.epilogue_functor)( + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + else: + epilogue_functor_reduction = getattr(pycutlass, "LinearCombinationGeneric")( + getattr(pycutlass, args.activation_function)(element_epilogue), + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + reduction_operation = ReductionOperation( shape=cutlass.MatrixCoord(4, 32 * C.alignment), C=C, element_accumulator=element_acc, - element_compute=element_epilogue, + element_compute=element_epilogue, + epilogue_functor=epilogue_functor_reduction, count=C.alignment ) operations.append(reduction_operation) @@ -196,47 +285,102 @@ pycutlass.compiler.add_module(operations) problem_size = cutlass.gemm.GemmCoord( args.problem_size[0], args.problem_size[1], args.problem_size[2]) +tensor_a_size = args.batch * problem_size.m() * problem_size.k() if args.element_a != "int8": if args.element_a == "bfloat16": - tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m() - * problem_size.k(),))).astype(bfloat16) + tensor_A = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(tensor_a_size,)) + ).astype(bfloat16) else: - tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m() - * problem_size.k(),))).astype(getattr(np, args.element_a)) + tensor_A = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(tensor_a_size,)) + ).astype(getattr(np, args.element_a)) else: - tensor_A = np.random.uniform(low=-2, high=2, size=(problem_size.m() - * problem_size.k(),)).astype(getattr(np, args.element_a)) + tensor_A = np.random.uniform( + low=-2, high=2,size=(tensor_a_size,) + ).astype(getattr(np, args.element_a)) +tensor_b_size = args.batch * problem_size.k() * problem_size.n() if args.element_b != "int8": if args.element_b == "bfloat16": - tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k() - * problem_size.n(),))).astype(bfloat16) + tensor_B = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(tensor_b_size,)) + ).astype(bfloat16) else: - tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k() - * problem_size.n(),))).astype(getattr(np, args.element_b)) + tensor_B = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(tensor_b_size,)) + ).astype(getattr(np, args.element_b)) else: - tensor_B = np.random.uniform(low=-2, high=2, size=(problem_size.k() - * problem_size.n(),)).astype(getattr(np, args.element_b)) + tensor_B = np.random.uniform( + low=-2, high=2, size=(tensor_b_size,) + ).astype(getattr(np, args.element_b)) if args.element_c != "int8": - if args.element_c == "bfloat16": - tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m() - * problem_size.n(),))).astype(bfloat16) + if args.bias: + if args.layout_c == "RowMajor": + tensor_c_size = args.batch * problem_size.n() + elif args.layout_c == "ColumnMajor": + tensor_c_size = args.batch * problem_size.m() + else: + raise ValueError(args.layout_c) else: - tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m() - * problem_size.n(),))).astype(getattr(np, args.element_c)) + tensor_c_size = args.batch * problem_size.m() * problem_size.n() + if args.element_c == "bfloat16": + tensor_C = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(tensor_c_size,)) + ).astype(bfloat16) + else: + tensor_C = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(tensor_c_size,)) + ).astype(getattr(np, args.element_c)) else: - tensor_C = np.random.uniform(low=-2, high=2, size=(problem_size.m() - * problem_size.n(),)).astype(getattr(np, args.element_c)) + tensor_C = np.random.uniform( + low=-2, high=2, size=(args.batch * problem_size.m() * problem_size.n(),) + ).astype(getattr(np, args.element_c)) -tensor_D = np.ones_like(tensor_C) +tensor_D = np.zeros( + shape=(args.batch * problem_size.m() * problem_size.n(),) +).astype(getattr(np, args.element_c)) + +if args.epilogue_visitor == "RowReduction": + cta_n = args.threadblock_shape[1] + num_cta_n = (problem_size.n() + cta_n - 1) // cta_n + reduction = np.zeros(shape=(args.batch * problem_size.m() * num_cta_n,), dtype=getattr(np, args.element_c)) + output_op = operation.epilogue_type( + D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()] + ) +elif args.epilogue_visitor == "ColumnReduction": + cta_m = args.threadblock_shape[0] + num_cta_m = (problem_size.m() + cta_m - 1) // cta_m + reduction = np.zeros(shape=(args.batch * problem_size.n() * num_cta_m,), dtype=getattr(np, args.element_c)) + output_op = operation.epilogue_type( + D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()] + ) +elif args.epilogue_visitor == "RowBroadcast": + vector = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(args.batch, 1, problem_size.n())) + ).astype(getattr(np, args.element_c)) + tensor_t = np.empty_like(tensor_D) + output_op = operation.epilogue_type( + c=tensor_C, vector=vector, alpha=args.alpha, beta=args.beta, Z=tensor_D, T=tensor_t, problem_size=[problem_size.m(), problem_size.n()] + ) +elif args.epilogue_visitor == "ColumnBroadcast": + vector = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(args.batch, problem_size.m(), 1)) + ).astype(getattr(np, args.element_c)) + tensor_t = np.empty_like(tensor_D) + output_op = operation.epilogue_type( + c=tensor_C, vector=vector, alpha=args.alpha, beta=args.beta, Z=tensor_D, T=tensor_t, problem_size=[problem_size.m(), problem_size.n()] + ) +else: + output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)) arguments = GemmArguments( operation=operation, problem_size=problem_size, A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, - output_op=LinearCombinationFunctorArguments(args.alpha, args.beta), + output_op=output_op, gemm_mode=getattr(cutlass.gemm.Mode, args.gemm_mode), - split_k_slices=args.split_k_slices + split_k_slices=args.split_k_slices, batch=args.batch ) if args.gemm_mode == "GemmSplitKParallel": @@ -245,7 +389,8 @@ if args.gemm_mode == "GemmSplitKParallel": problem_size=[problem_size.m(), problem_size.n()], partitions=args.split_k_slices, workspace=arguments.ptr_D, destination=tensor_D, source=tensor_C, - output_op=LinearCombinationFunctorArguments(args.alpha, args.beta) + output_op=reduction_operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)), + bias = arguments.bias ) operation.run(arguments) @@ -259,8 +404,42 @@ else: # run the host reference module reference = ReferenceModule(A, B, C) tensor_D_ref = reference.run( - tensor_A, tensor_B, tensor_C, problem_size, args.alpha, args.beta) + tensor_A, tensor_B, tensor_C, problem_size, args.alpha, args.beta, args.bias, args.batch) -assert np.array_equal(tensor_D, tensor_D_ref) +if args.epilogue_visitor in ["RowBroadcast", "ColumnBroadcast"]: + tensor_D_ref = (tensor_D_ref.reshape((args.batch, problem_size.m(), problem_size.n())) + vector).flatten() +tensor_D_ref = getattr(pycutlass, args.activation_function).numpy(*([tensor_D_ref,] + args.activation_args)) +if args.epilogue_visitor in ["RowReduction", "ColumnReduction"]: + output_op.sync() + accum_ref = reference.run( + tensor_A, tensor_B, tensor_C, problem_size, 1.0, 0.0, args.bias, args.batch) + tensor_D_ref, reduction_ref = epilogue_functor( + accum_ref.reshape((args.batch, problem_size.m(), problem_size.n())), + tensor_C.reshape((args.batch, problem_size.m(), problem_size.n())), + args.alpha, args.beta + ) + tensor_D_ref = tensor_D_ref.flatten() + reduction_ref = reduction_ref.flatten() + assert np.allclose(reduction_ref, reduction, atol=1e-2) + +elif args.epilogue_visitor in ["RowBroadcast", "ColumnBroadcast"]: + output_op.sync() + accum_ref = reference.run( + tensor_A, tensor_B, tensor_C, problem_size, 1.0, 0.0, args.bias, args.batch) + + tensor_D_ref, tensor_T_ref = epilogue_functor( + accum_ref.reshape((args.batch, problem_size.m(), problem_size.n())), + tensor_C.reshape((args.batch, problem_size.m(), problem_size.n())), + vector, args.alpha, args.beta) + + tensor_D_ref = tensor_D_ref.flatten() + tensor_T_ref = tensor_T_ref.flatten() + + assert np.array_equal(tensor_t, tensor_T_ref) + +try: + assert np.array_equal(tensor_D, tensor_D_ref) +except: + assert np.allclose(tensor_D, tensor_D_ref, atol=1e-5) print("Passed.") diff --git a/examples/40_cutlass_py/gemm_grouped.py b/examples/40_cutlass_py/gemm_grouped.py index e26ecc97..46ea9fed 100644 --- a/examples/40_cutlass_py/gemm_grouped.py +++ b/examples/40_cutlass_py/gemm_grouped.py @@ -99,7 +99,10 @@ parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination", # swizzling parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[ "IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle"], - help="This option describes how thread blocks are scheduled on GPU") + help="This option describes how thread blocks are scheduled on GPU. \ + NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. \ + This parameter is passed in at present to match the APIs of other kernels. The parameter \ + is unused within the kernel") # precompute mode parser.add_argument("-pm", "--precompute_mode", default="Device", type=str, choices=["Host", "Device"], @@ -109,7 +112,13 @@ parser.add_argument("-p", "--problem_size_dir", type=str, help="path to the csv file contains the problem sizes") parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha") parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta") +parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector") +# Activation function +parser.add_argument("-activ", "--activation_function", default="identity", + choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function") +parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float, + help="addition arguments for activation") parser.add_argument('--print_cuda', action="store_true", help="print the underlying CUDA kernel") @@ -120,6 +129,8 @@ except: pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) +np.random.seed(0) + element_a = getattr(cutlass, args.element_a) element_b = getattr(cutlass, args.element_b) element_c = getattr(cutlass, args.element_c) @@ -134,7 +145,7 @@ math_inst = MathInstruction( tile_description = TileDescription( args.threadblock_shape, args.stages, args.warp_count, - math_inst, args.compute_capability, args.compute_capability + math_inst ) layout_a = getattr(cutlass, args.layout_a) @@ -154,13 +165,19 @@ C = TensorDescription( ) element_epilogue = getattr(cutlass, args.element_epilogue) -epilogue_functor = getattr(EpilogueFunctor, args.epilogue_functor) +if args.activation_function == "identity": + epilogue_functor = getattr(pycutlass, args.epilogue_functor)( + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) +else: + epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")( + getattr(pycutlass, args.activation_function)(element_epilogue), + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) swizzling_functor = getattr(cutlass, args.swizzling_functor) precompute_mode = getattr(SchedulerMode, args.precompute_mode) operation = GemmOperationGrouped( arch=args.compute_capability, tile_description=tile_description, - A=A, B=B, C=C, element_epilogue=element_epilogue, + A=A, B=B, C=C, epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor, precompute_mode=precompute_mode ) @@ -214,28 +231,45 @@ for problem_size in problem_sizes: * problem_size.n(),)).astype(getattr(np, args.element_b)) if args.element_c != "int8": - if args.element_c == "bfloat16": - tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m() - * problem_size.n(),))).astype(bfloat16) + if args.bias: + if args.layout_c == "RowMajor": + c_size = problem_size.n() + elif args.layout_c == "ColumnMajor": + c_size = problem_size.m() + else: + raise ValueError(args.layout_c) else: - tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m() - * problem_size.n(),))).astype(getattr(np, args.element_c)) + c_size = problem_size.m() * problem_size.n() + if args.element_c == "bfloat16": + tensor_C = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(c_size,)) + ).astype(bfloat16) + else: + tensor_C = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(c_size,)) + ).astype(getattr(np, args.element_c)) else: - tensor_C = np.random.uniform(low=-2, high=2, size=(problem_size.m() - * problem_size.n(),)).astype(getattr(np, args.element_c)) - tensor_D = np.zeros_like(tensor_C) + tensor_C = np.random.uniform( + low=-2, high=2, size=(problem_size.m() * problem_size.n(),) + ).astype(getattr(np, args.element_c)) + tensor_D = np.zeros( + shape=(problem_size.m() * problem_size.n(),) + ).astype(getattr(np, args.element_c)) tensor_As.append(tensor_A) tensor_Bs.append(tensor_B) tensor_Cs.append(tensor_C) tensor_Ds.append(tensor_D) - tensor_D_refs.append(reference_module.run( - tensor_A, tensor_B, tensor_C, problem_size, args.alpha, args.beta)) + tensor_D_ref = reference_module.run( + tensor_A, tensor_B, tensor_C, problem_size, + args.alpha, args.beta, args.bias) + tensor_D_ref = getattr(pycutlass, args.activation_function).numpy(*([tensor_D_ref,] + args.activation_args)) + tensor_D_refs.append(tensor_D_ref) problem_sizes_coord.append(problem_size) arguments = GemmGroupedArguments( operation, problem_sizes_coord, tensor_As, tensor_Bs, tensor_Cs, tensor_Ds, - output_op=LinearCombinationFunctorArguments(args.alpha, args.beta) + output_op=operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)) ) operation.run(arguments) @@ -243,6 +277,9 @@ operation.run(arguments) arguments.sync() for tensor_d, tensor_d_ref in zip(tensor_Ds, tensor_D_refs): - assert np.array_equal(tensor_d, tensor_d_ref) + try: + assert np.array_equal(tensor_d, tensor_d_ref) + except: + assert np.allclose(tensor_d, tensor_d_ref, rtol=1e-5) print("Passed.") diff --git a/include/cutlass/arch/memory.h b/include/cutlass/arch/memory.h index a41110cb..a98114a1 100644 --- a/include/cutlass/arch/memory.h +++ b/include/cutlass/arch/memory.h @@ -69,7 +69,7 @@ struct global_load; ///////////////////////////////////////////////////////////////////////////////////////////////// // The redundant mov PTX instruction is used to enforce the compiler to -// initialize data to zero before ld.global +// keep the initializing code before ld.global template struct global_load +struct LinearCombinationGenericParams { + T alpha; ///< scales accumulators + T beta; ///< scales source tensor + T const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory + T const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory + + // + // Methods + // + + CUTLASS_HOST_DEVICE + LinearCombinationGenericParams(): + alpha(T(1)), + beta(T(0)), + alpha_ptr(nullptr), + beta_ptr(nullptr) { } + + CUTLASS_HOST_DEVICE + LinearCombinationGenericParams( + T alpha, + T beta = T(0) + ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { } + + CUTLASS_HOST_DEVICE + LinearCombinationGenericParams( + T const *alpha_ptr, + T const *beta_ptr = nullptr + ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// /// ReLu operator - propagates NaNs @@ -79,6 +111,14 @@ struct ReLu { return mx(value, T(0)); } + + /// Host-constructable parameters structure + using Params = LinearCombinationGenericParams; + + CUTLASS_HOST_DEVICE + T operator()(T value, Params const ¶ms_) const { + return this->operator()(value); + } }; template @@ -96,20 +136,74 @@ struct ReLu> { maximum > mx; return mx(frag, T(0)); } + + /// Host-constructable parameters structure + using Params = LinearCombinationGenericParams; + + CUTLASS_HOST_DEVICE + Array operator()(Array const &frag, Params const ¶ms_) const { + return this->operator()(frag); + } }; // Leaky Relu operator template struct LeakyReLU { + + struct Params: LinearCombinationGenericParams { + T leaky_alpha; ///< leaky_alpha + + // Methods + using LinearCombinationGenericParams::LinearCombinationGenericParams; + + CUTLASS_HOST_DEVICE + Params(): + LinearCombinationGenericParams(), + leaky_alpha(T(1)) {} + + CUTLASS_HOST_DEVICE + Params( + T alpha, + T beta, + T leaky_alpha = T(1) + ): LinearCombinationGenericParams(alpha, beta), leaky_alpha(leaky_alpha) {} + }; + CUTLASS_HOST_DEVICE T operator()(T const &value, T const & alpha_recip) const { T res = value > T(0) ? value : value * alpha_recip; return res; } + + CUTLASS_HOST_DEVICE + T operator()(T const &value, Params const ¶ms_) const { + this->operator()(value, params_.leaky_alpha); + } }; template struct LeakyReLU > { + + struct Params: LinearCombinationGenericParams { + T leaky_alpha; ///< leaky_alpha + using LinearCombinationGenericParams::LinearCombinationGenericParams; + + // Methods + + CUTLASS_HOST_DEVICE + Params(): + LinearCombinationGenericParams(), + leaky_alpha(T(1)) {} + + CUTLASS_HOST_DEVICE + Params( + T alpha, + T beta, + T leaky_alpha = T(1) + ): LinearCombinationGenericParams(alpha, beta), leaky_alpha(leaky_alpha) {} + }; + + CUTLASS_HOST_DEVICE Array operator()(Array const &rhs, T const & alpha_recip) const { Array y; @@ -122,6 +216,11 @@ struct LeakyReLU > { return y; } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &rhs, Params const ¶ms_) const { + return this->operator()(rhs, params_.leaky_alpha); + } }; // Tanh operator @@ -131,6 +230,13 @@ struct Tanh { T operator()(T const &scalar) const { return fast_tanh(scalar); } + + using Params = LinearCombinationGenericParams; + + CUTLASS_HOST_DEVICE + T operator()(T const &scalar, Params const ¶ms_) const { + return this->operator()(scalar); + } }; template @@ -147,6 +253,13 @@ struct Tanh > { return y; } + + using Params = LinearCombinationGenericParams; + + CUTLASS_HOST_DEVICE + Array operator()(Array const &rhs, Params const ¶ms_) const { + return this->operator()(rhs); + } }; template @@ -159,6 +272,13 @@ struct Tanh> { return tanh(z); } + + using Params = LinearCombinationGenericParams; + + CUTLASS_HOST_DEVICE + Array operator()(Array const &rhs, Params const ¶ms_) const { + return this->operator()(rhs); + } }; // Sigmoid operator @@ -168,6 +288,13 @@ struct Sigmoid { T operator()(T const &scalar) const { return T(1) / (T(1) + fast_exp(-scalar)); } + + using Params = LinearCombinationGenericParams; + + CUTLASS_HOST_DEVICE + T operator()(T const &scalar, Params const ¶ms_) const { + return this->operator()(scalar); + } }; template @@ -184,6 +311,13 @@ struct Sigmoid > { return y; } + + using Params = LinearCombinationGenericParams; + + CUTLASS_HOST_DEVICE + Array operator()(Array const &rhs, Params const ¶ms_) const { + return this->operator()(rhs); + } }; template @@ -208,6 +342,12 @@ struct Sigmoid> { fast_exp(neg(z)))); #endif } + + using Params = LinearCombinationGenericParams; + + Array operator()(Array const &z, Params const ¶ms_) const { + return this->operator()(z); + } }; // SiLu (swish) operator introduced by Elfwing et al. in the following paper @@ -222,6 +362,13 @@ struct SiLu { Sigmoid sigmoid; return scalar * sigmoid(scalar); } + + using Params = LinearCombinationGenericParams; + + CUTLASS_HOST_DEVICE + T operator()(T const &scalar, Params const ¶ms_) const { + return this->operator()(scalar); + } }; template @@ -232,6 +379,13 @@ struct SiLu> { multiplies> mul; return mul(rhs, sigmoid_op(rhs)); } + + using Params = LinearCombinationGenericParams; + + CUTLASS_HOST_DEVICE + Array operator()(Array const &rhs, Params const ¶ms_) const { + return this->operator()(rhs); + } }; // Hardswish operator introduced by Howard et al. in the following paper @@ -248,6 +402,13 @@ struct HardSwish { T relu6 = mn(mx(x + T(3), T(0)), T(6)); return x * relu6 / T(6); } + + using Params = LinearCombinationGenericParams; + + CUTLASS_HOST_DEVICE + T operator()(T const &x, Params const ¶ms_) const { + return this->operator()(x); + } }; template <> @@ -261,6 +422,13 @@ struct HardSwish { T relu6 = mn(mx(x + T(3), T(0)), T(6)); return x * relu6 * 0.16666667f; } + + using Params = LinearCombinationGenericParams; + + CUTLASS_HOST_DEVICE + T operator()(T const &x, Params const ¶ms_) const { + return this->operator()(x); + } }; template @@ -277,6 +445,13 @@ struct HardSwish > { return y; } + + using Params = LinearCombinationGenericParams; + + CUTLASS_HOST_DEVICE + Array operator()(Array const &x, Params const ¶ms_) const { + return this->operator()(x); + } }; template @@ -292,6 +467,13 @@ struct HardSwish > { return mul(mul(mn(mx(add(rhs, T(3)), T(0)), T(6)), rhs), T(0.16666667f)); } + + using Params = LinearCombinationGenericParams; + + CUTLASS_HOST_DEVICE + Array operator()(Array const &x, Params const ¶ms_) const { + return this->operator()(x); + } }; // @@ -311,6 +493,13 @@ struct GELU { return T(cutlass::constants::half() * scalar * (cutlass::constants::one() + (T)erff((float)(scalar / cutlass::constants::root_two())))); } + + using Params = LinearCombinationGenericParams; + + CUTLASS_HOST_DEVICE + T operator()(T const &scalar, Params const ¶ms_) const { + return this->operator()(scalar); + } }; template <> @@ -320,6 +509,13 @@ struct GELU { return cutlass::constants::half() * scalar * (cutlass::constants::one() + erff( scalar / cutlass::constants::root_two() )); } + + using Params = LinearCombinationGenericParams; + + CUTLASS_HOST_DEVICE + float operator()(float const &scalar, Params const ¶ms_) const { + return this->operator()(scalar); + } }; template <> @@ -329,6 +525,13 @@ struct GELU { return cutlass::constants::half() * scalar * (cutlass::constants::one() + erf( scalar / cutlass::constants::root_two() )); } + + using Params = LinearCombinationGenericParams; + + CUTLASS_HOST_DEVICE + double operator()(double const &scalar, Params const ¶ms_) const { + return this->operator()(scalar); + } }; template @@ -345,6 +548,13 @@ struct GELU > { return y; } + + using Params = LinearCombinationGenericParams; + + CUTLASS_HOST_DEVICE + Array operator()(Array const &rhs, Params const ¶ms_) const { + return this->operator()(rhs); + } }; // GELU operator implemented using the Taylor series approximation @@ -360,6 +570,9 @@ struct GELU_taylor { return T(cutlass::constants::half() * z * (cutlass::constants::one() + fast_tanh(k0 * z * (cutlass::constants::one() + k1 * z * z)))); } + + using Params = LinearCombinationGenericParams; + }; template @@ -386,6 +599,8 @@ struct GELU_taylor > { return y; } + + using Params = LinearCombinationGenericParams; }; template @@ -403,6 +618,8 @@ struct GELU_taylor > { return y; } + + using Params = LinearCombinationGenericParams; }; /// Computes backwards pass for GELU operator assuming d_t is the layer gradient and diff --git a/include/cutlass/epilogue/thread/linear_combination.h b/include/cutlass/epilogue/thread/linear_combination.h index 2b083d71..48ef6cf9 100644 --- a/include/cutlass/epilogue/thread/linear_combination.h +++ b/include/cutlass/epilogue/thread/linear_combination.h @@ -40,6 +40,7 @@ #include "cutlass/functional.h" #include "cutlass/numeric_conversion.h" #include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/epilogue/thread/linear_combination_params.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -76,22 +77,23 @@ public: using FragmentAccumulator = Array; using ComputeFragment = Array; + using ParamsBase = LinearCombinationParams; + static FloatRoundStyle const kRound = Round; /// Host-constructable parameters structure - struct Params { - + struct Params : ParamsBase{ ElementCompute alpha; ///< scales accumulators ElementCompute beta; ///< scales source tensor ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory - // - // Methods - // - CUTLASS_HOST_DEVICE Params(): + ParamsBase( + ElementCompute(1), + ElementCompute(0) + ), alpha(ElementCompute(1)), beta(ElementCompute(0)), alpha_ptr(nullptr), @@ -101,30 +103,43 @@ public: Params( ElementCompute alpha, ElementCompute beta - ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { - - } + ): + ParamsBase(alpha, beta), + alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { } CUTLASS_HOST_DEVICE Params( ElementCompute alpha - ): alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) { - - } + ): + ParamsBase(alpha, ElementCompute(0)), + alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) { } CUTLASS_HOST_DEVICE Params( ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr - ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { - - } + ): + ParamsBase(*alpha_ptr, *beta_ptr), + alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { } CUTLASS_HOST_DEVICE Params( ElementCompute const *alpha_ptr - ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) { + ): + ParamsBase(*alpha_ptr, ElementCompute(0)), + alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) { } + CUTLASS_HOST_DEVICE + Params( + ParamsBase const& base + ): ParamsBase(base), alpha_ptr(nullptr), beta_ptr(nullptr) { + #if defined(__CUDA_ARCH__) + alpha = reinterpret_cast(base.alpha_data); + beta = reinterpret_cast(base.beta_data); + #else + memcpy( alpha, base.alpha_data, sizeof(ElementCompute) ); + memcpy( beta, base.alpha_data, sizeof(ElementCompute) ); + #endif } }; @@ -142,7 +157,6 @@ public: /// Constructs the function object, possibly loading from pointers in host memory CUTLASS_HOST_DEVICE LinearCombination(Params const ¶ms) { - alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); } diff --git a/include/cutlass/epilogue/thread/linear_combination_generic.h b/include/cutlass/epilogue/thread/linear_combination_generic.h index 9f184f85..41dc2b9f 100644 --- a/include/cutlass/epilogue/thread/linear_combination_generic.h +++ b/include/cutlass/epilogue/thread/linear_combination_generic.h @@ -83,40 +83,7 @@ public: static FloatRoundStyle const kRound = Round; /// Host-constructable parameters structure - struct Params { - - ElementCompute alpha; ///< scales accumulators - ElementCompute beta; ///< scales source tensor - ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory - ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params(): - alpha(ElementCompute(1)), - beta(ElementCompute(0)), - alpha_ptr(nullptr), - beta_ptr(nullptr) { } - - CUTLASS_HOST_DEVICE - Params( - ElementCompute alpha, - ElementCompute beta = ElementCompute(0) - ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { - - } - - CUTLASS_HOST_DEVICE - Params( - ElementCompute const *alpha_ptr, - ElementCompute const *beta_ptr = nullptr - ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { - - } - }; + using Params = typename ActivationFunctor::Params; private: @@ -124,8 +91,7 @@ private: // Data members // - ElementCompute alpha_; - ElementCompute beta_; + Params params_; bool skip_elementwise_; public: @@ -133,9 +99,9 @@ public: /// Constructs the function object, possibly loading from pointers in host memory CUTLASS_HOST_DEVICE LinearCombinationGeneric(Params const ¶ms) { - - alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); - beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); + params_ = params; + params_.alpha = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); + params_.beta = (params.beta_ptr ? *params.beta_ptr : params.beta); skip_elementwise_ = false; } @@ -148,14 +114,14 @@ public: if (Scale == ScaleType::Nothing) return false; - return beta_ != ElementCompute(0); + return params_.beta != ElementCompute(0); } /// Functionally required for serial reduction in the epilogue CUTLASS_HOST_DEVICE void set_k_partition(int k_partition, int k_partition_count) { if (k_partition) { - beta_ = ElementCompute(1); + params_.beta = ElementCompute(1); } if (k_partition != k_partition_count - 1) { @@ -186,15 +152,15 @@ public: if (Scale == ScaleType::NoBetaScaling) { intermediate = converted_source; - intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + intermediate = mul_add_accumulator(params_.alpha, converted_accumulator, intermediate); // D = alpha * Accum + X } else if (Scale == ScaleType::Nothing) { intermediate = converted_accumulator; } else { - intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform - intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + intermediate = mul_add_source(params_.beta, converted_source); // X = beta * C + uniform + intermediate = mul_add_accumulator(params_.alpha, converted_accumulator, intermediate); // D = alpha * Accum + X } - intermediate = skip_elementwise_ ? intermediate : activation(intermediate); + intermediate = skip_elementwise_ ? intermediate : activation(intermediate, params_); // Convert to destination numeric type NumericArrayConverter destination_converter; @@ -222,10 +188,10 @@ public: if (Scale == ScaleType::Nothing) { intermediate = converted_accumulator; } else { - intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + intermediate = mul_add_accumulator(params_.alpha, converted_accumulator); // D = alpha * Accum } - intermediate = skip_elementwise_ ? intermediate : activation(intermediate); + intermediate = skip_elementwise_ ? intermediate : activation(intermediate, params_); // Convert to destination numeric type NumericArrayConverter destination_converter; diff --git a/include/cutlass/epilogue/thread/linear_combination_params.h b/include/cutlass/epilogue/thread/linear_combination_params.h new file mode 100644 index 00000000..8097af54 --- /dev/null +++ b/include/cutlass/epilogue/thread/linear_combination_params.h @@ -0,0 +1,75 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct LinearCombinationParams { + uint64_t alpha_data[2]; + uint64_t beta_data[2]; + + CUTLASS_HOST_DEVICE + LinearCombinationParams() + : alpha_data {0lu, 0lu}, beta_data {0lu, 0lu} + { } + + template + CUTLASS_HOST_DEVICE + LinearCombinationParams(ElementCompute alpha, ElementCompute beta) + : alpha_data {0lu, 0lu}, beta_data {0lu, 0lu} + { + #if defined(__CUDA_ARCH__) + reinterpret_cast(alpha_data) = alpha; + reinterpret_cast(beta_data) = beta; + #else + memcpy( alpha_data, &alpha, sizeof(ElementCompute) ); + memcpy( beta_data, &beta, sizeof(ElementCompute) ); + #endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine_layout_params.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine_layout_params.h new file mode 100644 index 00000000..f0cdc3a0 --- /dev/null +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine_layout_params.h @@ -0,0 +1,156 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/fast_math.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Rank +> +struct PredicatedTileIteratorAffineLayoutRankNParams { + using Layout = layout::AffineRankN; + using TensorCoord = typename Layout::TensorCoord; + + static bool const kBigEndian = false; + + // + // Data members + // + + Layout layout; + + /// Stride in units of bytes along M modes + Coord stride_m; + + /// Stride in units of bytes along N modes + Coord stride_n; + + /// Fast divmod objects divided by tensor extents + FastDivmod divmod_m[(Layout::kRank == 2) ? 1 : (Layout::kRank/2 - 1)]; + + /// Fast divmod objects divided by tensor extents + FastDivmod divmod_n[(Layout::kRank == 2) ? 1 : (Layout::kRank/2 - 1)]; + + int64_t rank2_inc_col; + int64_t rank2_inc_row; + + // + // Methods + // + CUTLASS_HOST_DEVICE + PredicatedTileIteratorAffineLayoutRankNParams() { } + + CUTLASS_HOST_DEVICE + PredicatedTileIteratorAffineLayoutRankNParams(TensorCoord const &extent, + Layout const &layout_, + int64_t element_sizeof_bits) + : layout(layout_) + { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank / 2; ++i) { + stride_m[i] = OffsetBytes(layout_.stride()[i], element_sizeof_bits); + stride_n[i] = OffsetBytes(layout_.stride()[i + Layout::kRank / 2], element_sizeof_bits); + } + + if (kBigEndian) { + // "Big Endian" scheme + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank / 2 - 1; ++i) { + divmod_m[i] = FastDivmod(extent[i + 1]); + divmod_n[i] = FastDivmod(extent[i + Layout::kRank / 2 + 1]); + } + } + else { + // "Little Endian" scheme + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank / 2 - 1; ++i) { + divmod_m[i] = FastDivmod(extent[i]); + divmod_n[i] = FastDivmod(extent[i + Layout::kRank / 2]); + } + } + + #if 0 + // + // Debug print statements to verify extents and strides are passed correctly. + // + printf("PredicatedTileIteratorAffine::Params() entered\n"); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank; ++i) { + printf(" extent[%d]: %d\n", i, extent[i]); + } + for (int i = 0; i < Layout::kRank; ++i) { + printf(" stride[%d]: %ld\n", i, layout_.stride()[i]); + } + printf("PredicatedTileIteratorAffine::Params() returning\n"); + #endif + } + + CUTLASS_HOST_DEVICE + PredicatedTileIteratorAffineLayoutRankNParams(Layout const &layout_, + int32_t threadmap_delta_kColumn, + int32_t threadmap_delta_kRow, + int64_t element_sizeof_bits) + : layout(layout_) + { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank / 2; ++i) { + stride_m[i] = OffsetBytes(layout_.stride()[i], element_sizeof_bits); + stride_n[i] = OffsetBytes(layout_.stride()[i + Layout::kRank / 2], element_sizeof_bits); + } + + rank2_inc_col = threadmap_delta_kColumn * stride_n[0]; + rank2_inc_row = threadmap_delta_kRow * stride_m[0]; + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/fast_math.h b/include/cutlass/fast_math.h index 99f64c67..aabfbdc3 100644 --- a/include/cutlass/fast_math.h +++ b/include/cutlass/fast_math.h @@ -561,6 +561,17 @@ CUTLASS_HOST_DEVICE int64_t OffsetBytes(int64_t index) { } } +CUTLASS_HOST_DEVICE int64_t OffsetBytes(int64_t index, int64_t element_sizeof_bits) { + if (element_sizeof_bits >= 8) { + return index * (element_sizeof_bits / 8); + } + else { + int64_t const kElementsPerByte = ((8 / element_sizeof_bits) + ((element_sizeof_bits >= 8) ? 1 : 0)); + return index / kElementsPerByte; + } +} + + ///////////////////////////////////////////////////////////////////////////////////////////////// // Min/Max ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/device/gemm_universal_adapter.h b/include/cutlass/gemm/device/gemm_universal_adapter.h index 95196ae2..69f5ba5e 100644 --- a/include/cutlass/gemm/device/gemm_universal_adapter.h +++ b/include/cutlass/gemm/device/gemm_universal_adapter.h @@ -88,12 +88,12 @@ public: using ElementA = typename MapArguments::ElementA; using LayoutA = typename MapArguments::LayoutA; static ComplexTransform const kTransformA = MapArguments::kTransformA; - static int const kAlignmentA = GemmKernel::kAlignmentA; + static int const kAlignmentA = MapArguments::kAlignmentA; using ElementB = typename MapArguments::ElementB; using LayoutB = typename MapArguments::LayoutB; static ComplexTransform const kTransformB = MapArguments::kTransformB; - static int const kAlignmentB = GemmKernel::kAlignmentB; + static int const kAlignmentB = MapArguments::kAlignmentB; using ElementC = typename GemmKernel::ElementC; using LayoutC = typename MapArguments::LayoutC; diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index ad628e8a..f708f2f6 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -40,7 +40,7 @@ #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" -#include "cutlass/transform/thread/unaryOp.h" +#include "cutlass/transform/thread/unary_op.h" #include "cutlass/array.h" #include "cutlass/half.h" @@ -1273,6 +1273,40 @@ struct NumericArrayConverter { ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Partial specialization for Array <= Array +/// Conversion is performed with saturation regardless of setting of +/// the `Round` template parameter. +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter { + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + // Convert float to int + Array temporary; + + NumericArrayConverter compute_converter; + temporary = compute_converter(source); + + // Convert to int to int8_t + NumericArrayConverter destination_converter; + return destination_converter(temporary); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && \ ((__CUDACC_VER_MAJOR__ > 10) || \ ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2))) diff --git a/include/cutlass/transform/thread/unaryOp.h b/include/cutlass/transform/thread/unary_op.h similarity index 100% rename from include/cutlass/transform/thread/unaryOp.h rename to include/cutlass/transform/thread/unary_op.h diff --git a/test/unit/core/numeric_conversion.cu b/test/unit/core/numeric_conversion.cu index a3003e47..6eb7fc74 100644 --- a/test/unit/core/numeric_conversion.cu +++ b/test/unit/core/numeric_conversion.cu @@ -189,3 +189,35 @@ TEST(NumericConversion, f16x8_to_f32x8_rn) { } ///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(NumericConversion, f32x8_to_s8x8_rn) { + + int const kN = 8; + using Source = float; + using Destination = int8_t; + + dim3 grid(1, 1); + dim3 block(1, 1); + + cutlass::HostTensor destination({1, kN}); + cutlass::HostTensor source({1, kN}); + + for (int i = 0; i < kN; ++i) { + source.host_data()[i] = float(i); + } + + source.sync_device(); + + test::core::kernel::convert<<< grid, block >>>( + reinterpret_cast *>(destination.device_data()), + reinterpret_cast const *>(source.device_data()) + ); + + destination.sync_host(); + + for (int i = 0; i < kN; ++i) { + EXPECT_TRUE(float(destination.host_data()[i]) == source.host_data()[i]); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/scripts/pycutlass/README.md b/tools/library/scripts/pycutlass/README.md index 8d4f9279..7b079d36 100644 --- a/tools/library/scripts/pycutlass/README.md +++ b/tools/library/scripts/pycutlass/README.md @@ -16,7 +16,7 @@ math_inst = MathInstruction( tile_description = TileDescription( [128, 128, 8], 4, [2, 4, 1], - math_inst, 80, 80 + math_inst ) A = TensorDescription( @@ -31,10 +31,12 @@ C = TensorDescription( cutlass.float32, cutlass.RowMajor, 1 ) +epilogue_functor = LinearCombination(cutlass.float32, 1, cutlass.float32, cutlass.float32) + operation = GemmOperationUniversal( arch=80, tile_description=tile_description, - A=A, B=B, C=C, element_epilogue=cutlass.float32, - epilogue_functor=EpilogueFunctor.LinearCombination, + A=A, B=B, C=C, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) @@ -54,7 +56,7 @@ beta = 0.0 arguments = GemmArguments( operation=operation, problem_size=problem_size, A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, - output_op=LinearCombinationFunctorArguments(alpha, beta), + output_op=operation.epilogue_type(alpha, beta), gemm_mode=cutlass.gemm.Mode.Gemm, split_k_splices=1 ) @@ -68,6 +70,14 @@ assert torch.equal(tensor_D, tensor_D_ref) ``` PyCUTLASS also provides infrastructures for profiling, compiled artifact management, and pool memory manager +## Supported Features +PyCUTLASS currently supports following operations: +* GEMM with mode {Serial, Parallel Split K, Batched GEMM, Array GEMM}, op class {SIMT, TensorCore}, data type {int8, f16, bf16, f32, f64}, layout {RowMajor, ColumnMajor, Row/ColumnMajorInterleaved<32> for int8}, math operation {MultiplyAdd, MultiplyAddFastF16, MultiplyAddFastBF16, MultiplyAddFastF32}, swizzling functions {IdentitySwizzle<1,2,4,8>, HorizontalSwizzle, BatchedIdentitySwizzle}, and epilogue {LinearCombination, LinearCombinationClamp} +* GEMM grouped with op class {SIMT, TensorCore}, data type {int8, f16, bf16, f32, f64}, layout {RowMajor, ColumnMajor}, math operation {MultiplyAdd, MultiplyAddFastF16, MultiplyAddFastBF16, MultiplyAddFastF32}, scheduling mode {Host, Device}, and epilogue {LinearCombination, LinearCombinationClamp}. +* Conv2d with {Fprop, Dgrad, Wgrad}, op class {SIMT, TensorCore}, data type {int8, f16, bf16, f32, f64}, layout {Tensor NHWC, TensorNC32HW32 and TensorC32RSK for int8}, math operation {MultiplyAdd, MultiplyAddFastF16, MultiplyAddFastBF16, MultiplyAddFastF32}, split-k mode {Parallel, Serial}, and epilogue {LinearCombination, LinearCombinationClamp} + +The tiling size of above operations can also be customized. + ## Installation ### Using Docker @@ -94,12 +104,19 @@ cd $CUTLASS_PATH/tools/library/scripts/pycutlass && bash build.sh ``` ## Examples -Examples can be found in `$CUTLASS_PATH/examples/40_cutlass_py` +Examples can be found in [$CUTLASS_PATH/examples/40_cutlass_py](examples/40_cutlass_py) ## Test The test cases are listed in `$CUTLASS_PATH//tools/library/scripts/pycutlass/test`. The unit test can be run with ```shell cd $CUTLASS_PATH/tools/library/scripts/pycutlass/test/unit && python test_sm80.py +cd $CUTLASS_PATH/tools/library/scripts/pycutlass/test/example && run_all_example.sh +``` + +## build documentation +Run +```shell +bash build_doc.sh ``` diff --git a/tools/library/scripts/pycutlass/build_doc.sh b/tools/library/scripts/pycutlass/build_doc.sh index def7c773..aa7ef7c7 100644 --- a/tools/library/scripts/pycutlass/build_doc.sh +++ b/tools/library/scripts/pycutlass/build_doc.sh @@ -1,2 +1,4 @@ -python setup.py develop +pip install enum-tools +pip install sphinx-toolbox +pip install m2r2 sphinx-build -b html docs/source/ docs/build/html diff --git a/tools/library/scripts/pycutlass/docs/source/conf.py b/tools/library/scripts/pycutlass/docs/source/conf.py index 73ec0687..19a71eae 100644 --- a/tools/library/scripts/pycutlass/docs/source/conf.py +++ b/tools/library/scripts/pycutlass/docs/source/conf.py @@ -50,7 +50,7 @@ # -- Project information ----------------------------------------------------- project = 'PyCutlass' -copyright = '2022, Andrew Kerr; Zhaodong Chen; Haicheng Wu; Szymon Migacz; Graham Markall' +copyright = '2022, Zhaodong Chen; Andrew Kerr; Haicheng Wu; Szymon Migacz; Graham Markall' author = 'Zhaodong Chen; Andrew Kerr; Haicheng Wu; Szymon Migacz; Graham Markall' @@ -65,9 +65,12 @@ extensions = [ 'sphinx.ext.autodoc', 'sphinx.ext.intersphinx', 'enum_tools.autoenum', - 'sphinx.ext.autosummary' + 'sphinx.ext.autosummary', + 'm2r2' ] +source_suffix = [".rst", ".md"] + autosummary_generate = True autosummary_imported_members = True @@ -85,7 +88,7 @@ exclude_patterns = [] # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'classic' +html_theme = 'bizstyle' # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, diff --git a/tools/library/scripts/pycutlass/docs/source/cutlass.rst b/tools/library/scripts/pycutlass/docs/source/cutlass.rst index 6ec68253..43c13e5e 100644 --- a/tools/library/scripts/pycutlass/docs/source/cutlass.rst +++ b/tools/library/scripts/pycutlass/docs/source/cutlass.rst @@ -1,2 +1,100 @@ cutlass ======= + +.. rubric:: Operator Classification + +.. autoclass:: cutlass.OpClass + :members: + +.. rubric:: GEMM Layout + +.. autoclass:: cutlass.RowMajor + :members: + +.. autoclass:: cutlass.ColumnMajor + :members: + +.. autoclass:: cutlass.RowMajorInterleaved32 + :members: + +.. autoclass:: cutlass.ColumnMajorInterleaved32 + :members: + +.. rubric:: Conv Layout + +.. autoclass:: cutlass.TensorNHWC + :members: + +.. autoclass:: cutlass.TensorNC32HW32 + :members: + +.. autoclass:: cutlass.TensorC32RSK32 + :members: + +.. rubric:: Threadblock Swizzle + +.. autoclass:: cutlass.dim3 + :special-members: + :members: + +.. autoclass:: cutlass.IdentitySwizzle1 + :special-members: + :members: + +.. autoclass:: cutlass.IdentitySwizzle2 + :special-members: + :members: + +.. autoclass:: cutlass.IdentitySwizzle4 + :special-members: + :members: + +.. autoclass:: cutlass.IdentitySwizzle8 + :special-members: + :members: + +.. autoclass:: cutlass.HorizontalSwizzle + :special-members: + :members: + +.. autoclass:: cutlass.BatchedIdentitySwizzle + :special-members: + :members: + +.. autoclass:: cutlass.StridedDgradIdentitySwizzle1 + :special-members: + :members: + +.. autoclass:: cutlass.StridedDgradIdentitySwizzle4 + :special-members: + :members: + +.. autoclass:: cutlass.StridedDgradHorizontalSwizzle + :special-members: + :members: + +.. rubric:: Coordinates + +.. autoclass:: cutlass.Tensor4DCoord + :special-members: + :members: + +.. autoclass:: cutlass.Tensor3DCoord + :special-members: + :members: + +.. autoclass:: cutlass.MatrixCoord + :special-members: + :members: + + +.. rubric:: Convolution + +.. autoclass:: cutlass.conv.Operator + :members: + +.. autoclass:: cutlass.conv.IteratorAlgorithm + :members: + +.. autoclass:: cutlass.conv.StrideSupport + :members: diff --git a/tools/library/scripts/pycutlass/docs/source/descriptor.rst b/tools/library/scripts/pycutlass/docs/source/descriptor.rst deleted file mode 100644 index cd0a3b98..00000000 --- a/tools/library/scripts/pycutlass/docs/source/descriptor.rst +++ /dev/null @@ -1,6 +0,0 @@ -Descriptions -============== - -.. autoclass:: pycutlass.TileDescription - :special-members: - :members: diff --git a/tools/library/scripts/pycutlass/docs/source/frontend.rst b/tools/library/scripts/pycutlass/docs/source/frontend.rst deleted file mode 100644 index 1da97eeb..00000000 --- a/tools/library/scripts/pycutlass/docs/source/frontend.rst +++ /dev/null @@ -1,5 +0,0 @@ -Frontend -============== - -.. autoclass:: pycutlass.NumpyFrontend - :members: diff --git a/tools/library/scripts/pycutlass/docs/source/index.rst b/tools/library/scripts/pycutlass/docs/source/index.rst index 5e2fa7ad..b8a16e16 100644 --- a/tools/library/scripts/pycutlass/docs/source/index.rst +++ b/tools/library/scripts/pycutlass/docs/source/index.rst @@ -3,27 +3,29 @@ You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. -Welcome to PyCutlass's documentation! +CUTLASS Python Project Documentation ===================================== - +.. mdinclude:: ../../README.md + .. toctree:: :maxdepth: 2 :caption: Contents: -Indices and tables +.. Indices and tables +.. ================== + +.. * :ref:`genindex` +.. * :ref:`modindex` +.. * :ref:`search` + + +Indices ================== - -* :ref:`genindex` -* :ref:`modindex` -* :ref:`search` - - .. toctree:: - types - cutlass - descriptor - frontend + user_guide + visitor_tree gemm_op conv2d_op + cutlass diff --git a/tools/library/scripts/pycutlass/docs/source/md/EpilogueVisitorTree.md b/tools/library/scripts/pycutlass/docs/source/md/EpilogueVisitorTree.md new file mode 100644 index 00000000..7cda6873 --- /dev/null +++ b/tools/library/scripts/pycutlass/docs/source/md/EpilogueVisitorTree.md @@ -0,0 +1,225 @@ +# Epilogue Visitor Tree +The Epilogue Visitor Tree is an experimental feature that directly generates epilogues from user-provide python functions. + +## Usage + +The Epilogue Visitor tree support many different operations. + +### Unary functions +Epilogue Visitor Tree supports unary functions like activation functions. For example, +```python +class UnaryEpilogue_(EpilogueVisitTree): + def __call__( + self, accum: 'tensor', c: 'tensor', + alpha: 'scalar', beta: 'scalar'): + # + T = leaky_relu.numpy(accum, 0.2) + Z = alpha * T + beta * c + return Z +epilogue_functor = UnaryEpilogue_( + epilogue_functor, tile_description, math_inst.element_accumulator, + C.alignment, element_epilogue, C.element) +``` + +### Broadcast Operation +Epilogue Visitor Tree supports broadcasting row and column vectors to the whole output matrix. To use broadcast, you just need to specify whether the source vector is a `row` vector or a `column` vector. Here is an example. +```python +class ColumnBroadcast_(EpilogueVisitTree): + def __call__( + self, accum: 'tensor', c: 'tensor', + vector: 'column', alpha: 'scalar', beta: 'scalar'): + # + T = accum + vector + scale_T = leaky_relu.numpy(alpha * T, 0.2) + Z = scale_T + beta * c + return Z, T +epilogue_functor = ColumnBroadcast_( + epilogue_functor, tile_description, math_inst.element_accumulator, + C.alignment, element_epilogue, C.element) +``` + +### Reduction Operation + +Epilogue Visitor Tree also supports row and column-wise reduction in each threadblock tile. The syntax for reduction is +```python +{reduction_output} = reduction_op({input_tensor}, {row|column}, {Add}, {threadblock_shape.n|threadblock_shape.m}) +``` +The `{row|column}` indicates whether the `row` vectors are reduced or the `column` vectors are reduction. The `{Add}` specifies the reduction operation. The `{threadblock_shape.n|threadblock_shape.m}` are the reduction lengths. + +**Constraint** +* The `{input_tensor}` can only be the name of source or intermediate result. `reduction_op(A + B, ...)` will not work, please use `C = A + B`, `reduction_op(C, ...)` instead. +* The `{reduction_output}` cannot be used in the epilogue. It will be directly written to global memory after the reduction is done. +```python +class RowReduction_(EpilogueVisitTree): + def __call__( + self, accum: 'tensor', c: 'tensor', + alpha: 'scalar', beta: 'scalar'): + # + D = alpha * accum + tanh.numpy(beta * c) + reduction = reduction_op(D, "row", "Add", args.threadblock_shape[1]) + return D, reduction +epilogue_functor = RowReduction_( + epilogue_functor, tile_description, math_inst.element_accumulator, + C.alignment, element_epilogue, C.element) +epilogue_functor.initialize() +``` + +## Get output_op + +As shown in the user guide, an `output_op` is required by the argument wrapper. We will take the `RowReduction_` as an example to show how to get `output_op`. +```python +class RowReduction_(EpilogueVisitTree): + def __call__( + self, accum: 'tensor', c: 'tensor', + alpha: 'scalar', beta: 'scalar'): + # + D = alpha * accum + tanh.numpy(beta * c) + reduction = reduction_op(D, "row", "Add", args.threadblock_shape[1]) + return D, reduction +epilogue_functor = RowReduction_( + epilogue_functor, tile_description, math_inst.element_accumulator, + C.alignment, element_epilogue, C.element) +epilogue_functor.initialize() + +cta_n = args.threadblock_shape[1] +num_cta_n = (problem_size.n() + cta_n - 1) // cta_n +reduction = np.zeros(shape=(args.batch * problem_size.m() * num_cta_n,), dtype=getattr(np, element_c)) +# get output op +output_op = operation.epilogue_type( + D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()] +) +``` +Like other epilogue functors such as `LinearCombination`, the output op for EpilogueVisitorTree is also created with `operation.epilogue_type(*)`. However, there are two differences: +* The arguments need to be passed as keyword-arguments. The keywords are the argument names in `def __call__`. +* An additional `problem_size=[problem_size.m(), problem_size.n()]` is required. + + +## Add new Unary Operation (e.g. Activation Function) +To add additional unary operation into epilogue visitor tree, a new unary op +should be created for `VisitorOpUnary`. We will take `tanh` as an example. + +### Step 1: define TanhVisitor + +The visitor defines the parameters and computation required by the unary option. +The unary operations are registered in [pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/unary_ops.h](tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/unary_ops.h). But you can define it in any header file and include the header file in [pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_unary.h](tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_unary.h). + + +* Two template arguments are required: + * `T`: data type used to compute the unary operation + * `N`: compute fragment length +* We also need to provide the `Arguments` and `Params` structures. The `Arguments` will be assembled by [ctypes](https://docs.python.org/3/library/ctypes.html), the `Params` will be generated from `Arguments` automatically. If the unary function takes no argument, an integer like `int tmp` can be provide to ensure the correctness of ctypes. +* The constructor can only take the `params` as the single argument. +* The operation is defined in `Array operator()(Array const &frag) const `. On common way to do that is first define a scalar computation, and them use it for the fragment computation with an unrolled for-loop. +* A guard function is required. If it returns `true`, it will disable all the children nodes of the unary node and return zeros to parent node. This is very helpful for multiplication with scalar while the scalar is `0`. For general cases, you can just return `true`. +```c++ +// T: data type used to compute the unary operation +// N: compute fragment length +template +struct TanhVisitor { + /// Argument + struct Arguments { + // a placeholder argument to ensure correctness of ctypes + int tmp; + + CUTLASS_HOST_DEVICE + Arguments(): tmp(0) { }; + + CUTLASS_HOST_DEVICE + Arguments(int tmp): tmp(tmp) { }; + }; + + /// Param + struct Params { + CUTLASS_HOST_DEVICE + Params(){ }; + Params(Arguments const &args) { } + }; + + /// Constructor + CUTLASS_HOST_DEVICE + TanhVisitor(Params const ¶ms) { } + + // scalar operator + CUTLASS_HOST_DEVICE + T tanh_op(T const &scalar) const { + return fast_tanh(scalar); + } + + /// vector operator + CUTLASS_HOST_DEVICE + Array operator()(Array const &frag) const { + Array y; + + CUTLASS_PRAGMA_UNROLL + for (int i=0; i < N; ++i) { + y[i] = tanh_op(frag[i]); + } + + return y; + } + + // Guard + CUTLASS_HOST_DEVICE + bool guard() { + return true; + } +}; +``` + +### Step 2: register Tanh function +After defining the function in C++, we need to register it in python. The class below gives an example. +* The init function takes the data type `element_compute`, which will be the `T` in the C++ template. +In the init function, we also generate the `_Arguments` class as a `ctypes.Structure`. It includes all the data members in the `TanhVisitor::Arguments`. +* The `_Arguments` need to be registered as `self.argument_type` of `tanh` class. +* A `emit` function is required to emit the namespace and typename of `TanhVisitor`. +* A staticmethod as numpy reference is required to implement the python code to parse. + +The built-in functions are defined in [pycutlass/src/pycutlass/epilogue.py](tools/library/scripts/pycutlass/src/pycutlass/epilogue.py). You can defined yours in any file as long as it can be found by [/pycutlass/src/pycutlass/parser.py](tools/library/scripts/pycutlass/src/pycutlass/parser.py). +```python +class tanh(ActivationFunctor): + def __init__(self, element_compute) -> None: + super().__init__() + class _Arguments(ctypes.Structure): + _fields_ = [ + ("tmp", ctypes.c_int) + ] + def __init__(self, *args) -> None: + self.tmp = 0 + self.argument_type = _Arguments + + def emit(self): + return "cutlass::TanhVisitor" + + @staticmethod + def numpy(x: np.ndarray): + return np.tanh(x) +``` + +### Step 3: Run the function +Now the new unary op is ready to use. An epilogue visitor tree can be built with +```python +class RowReduction_(EpilogueVisitTree): + def __call__( + self, accum: NDArray['tensor', 'float32'], c: NDArray['tensor', 'float32'], + alpha: 'float32', beta: 'float32'): + # + D = alpha * accum + tanh.numpy(beta * c) + reduction = reduction_op(D, "row", "Add", args.threadblock_shape[1]) + return D, reduction +epilogue_functor = RowReduction_( + epilogue_functor, tile_description, math_inst.element_accumulator, + C.alignment, element_epilogue, C.element) +epilogue_functor.initialize() +``` + +## Limitations and Future work + +Although the Epilogue Visitor Tree brings great flexibility to epilogue construction, as the epilogue is formulated as a single tree, there are several limitations. +* [Future Work] Serial and Parallel Split-K GEMM are not supported yet. + * To support serial split-k, additional tree transformation pass is required to inject a `binaryOpNode(Add)` + `TensorInputNode` before each `TensorOutputNode` to fetch the partial sum back. The `semaphore` also needs to be passed into epilogue. + * To support parallel split-k, an Reduction with visitor kernel is required. +* [Future Work] Convolution and GEMM Grouped are not supported yet. + * To support Conv2d and GEMM Grouped, corresponding *_with_visitor kernels are required. + +* [Limitation] If the same node is used by two operations (except that one of them is reduction), the node and all its offsprings will be executed twice. +* [Limitation] The result of reduction can only be used as the return value. diff --git a/tools/library/scripts/pycutlass/docs/source/md/basic_idea.md b/tools/library/scripts/pycutlass/docs/source/md/basic_idea.md new file mode 100644 index 00000000..655caa39 --- /dev/null +++ b/tools/library/scripts/pycutlass/docs/source/md/basic_idea.md @@ -0,0 +1,283 @@ +# Basics of PyCUTLASS + +PyCUTLASS handles the following things when launch the CUTLASS kernels +* Memory management +* Operation Description +* Code emission and compilation +* Arguments preprocessing +* Kernel launching +* Result Synchronization + +## Memory management + +PyCUTLASS uses [RMM](https://github.com/rapidsai/rmm) to manage device memory. At the begining of the program, call +```python +pycutlass.get_memory_pool({init_pool_size_in_bytes}, {max_pool_size_in_bytes}) +``` +We also provide functions to query the allocated size. +```python +bytes = get_allocated_size() +``` + + +## Operation Description +PyCUTLASS provides operation description for GEMM, GEMM Grouped and Conv2d operations. These operation descriptions are assembled from four foundamental concepts +* Math Instruction: math instruction executed in GPU cores +* Tile Description: tiling sizes and pipeline stages +* Operand Description: data type, layout, memory alignment +* Epilogue Functor: epilogue function + +### Math Instruction + +The math instruction is defined as follows: +```python +math_inst = MathInstruction( + {instruction_shape}, {element_a}, {element_b}, + {element_acc}, {opclass}, {math_operation} +) +``` +The `{instruction_shape}` and `{opclass}` defines the instruction size and type. The table below lists valid combinations. `{element_a}`, `{element_b}` define the source operand data type for each instructions, and `{element_acc}` defines the accumulator type. The `{math_operation}` defines the math operation applied. + +|Opclass | element_a/element_b | element_acc | instruction_shape | math_operation | +| -- | -- | -- | -- | -- | +| cutlass.OpClass.TensorOp | cutlass.float64 | cutlass.float64 | [8, 8, 4] | MathOperation.multiply_add| +| | cutass.float32 cutlass.tfloat32, cutlass.float16 cutlass.bfloat16 | cutlass.float32 | [16, 8, 8] | MathOperation.multiply_add MathOperation.multiply_add_fast_f32 MathOperation.multiply_add_fast_f16 MathOperation.multiply_add_fast_bf16 | +| | cutlass.float16 | cutlass.float16/cutlass.float32|[16, 8, 16]| MathOperation.multiply_add | +| | cutlass.bfloat_16 | cutlass.float32 | [16, 8, 16]|MathOperation.multiply_add | +| | cutlass.int8 | cutlass.int32 | [16, 8, 32] | MathOperation.multiply_add_saturate| +|cutlass.OpClass.Simt| cutlass.float64 | cutlass.float64 | [1, 1, 1] | MathOperation.multiply_add | +| | cutlass.float32 | cutlass.float32 | [1, 1, 1] | MathOperation.multiply_add | + +The `cutlass.OpClass.TensorOp` indicates that the tensor core is used, while `cutlass.OpClass.Simt` uses the SIMT Core. + +The `multiply_add_fast_f32` emulates fast accurate SGEMM kernel which is accelerated +using Ampere Tensor Cores. More details can be found in [examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm](examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm). + +### Tile Description +The tile description describes the threadblock and warp tiling sizes, as well as the pipeline stages. +```python +tile_description = TileDescription( + {threadblock_shape}, {stages}, {warp_count}, + math_inst +) +``` +The `{threadblock_shape}` is a list of 3 integers `[Tile_M, Tile_N, Tile_K]` that defines the threadblock tiling size. `{stages}` defines the number of software pipeline stages ([detail](https://developer.nvidia.com/blog/controlling-data-movement-to-boost-performance-on-ampere-architecture/)). `{warp_count}` defines the number of warps along `M`, `N`, and `K` dimension. I.e., with `{threadblock_shape}=[Tile_M, Tile_N, Tile_K]` and `{warp_count}=[W_M, W_N, W_K]`, the warp tile size would be `[Tile_M / W_M, Tile_N / W_N, Tile_K / W_K]`. + +### Operand Description +The Operand Description defines the data type, layout, and memory alignment of input tensor A, B, and C. The output D shares the same attributes with C. The description is as follows: +```python +A = TensorDescription( + {element_a}, {layout_a}, {alignment_a} +) + +B = TensorDescription( + {element_b}, {layout_b}, {alignment_b} +) + +C = TensorDescription( + {element_c}, {layout_c}, {alignment_c} +) +``` +The table below lists the supported layout and data types for each operation +| Operation | data type | layout | +| -- | -- | -- | +| GEMM, GEMM Grouped | cutlass.float64, cutlass.float32, cutlass.float16, cutlass.bfloat16 | cutlass.RowMajor, cutlass.ColumnMajor | +| | cutlass.int8 | cutlass.RowMajor, cutlass.ColumnMajor, cutlass.RowMajorInterleaved32, cutlass.ColumnMajorInterleaved32| +| Conv2d Fprop, Dgrad, Wgrad | cutlass.float64, cutlass.float32, cutlass.float16, cutlass.bfloat16 | cutlass.TensorNHWC | +| Conv2d Fprop | cutlass.int8 | cutlass.TensorNHWC, cutlass.TensorNC32HW32, cutlass.TensorC32RSK32| + +### Epilogue Functor +The epilogue functor defines the epilogue executed after mainloop. +We expose the following epilogue functors. +| Epilogue Functor | Remark | +| -- | -- | +| LinearCombination | $D=\alpha \times Accm + \beta \times C$ | +| LinearCombinationClamp | $D=\alpha \times Accm + \beta \times C$, Output is clamped to the maximum value of the data type output | +| FastLinearCombinationClamp | $D=\alpha \times Accm + \beta \times C$, only used for problem size $K\le 256$ for cutlass.int8, with accumulator data type `cutlass.int32` and epilogue compute data type `cutlass.float32` | +| LinearCombinationGeneric | $D = activation(\alpha \times Accm + \beta \times C)$, available activations include `relu`, `leaky_relu`, `tanh`, `sigmoid`, `silu`, `hardswish`, and `gelu` | + +The epilogue functors can be created as follows +```python +# LinearCombination +epilogue_functor = LinearCombination( + element_C, alignment_c, element_acc, element_epilogue_compute +) + +# LinearCombinationClamp +epilogue_functor = LinearCombinationClamp( + element_C, alignment_c, element_acc, element_epilogue_compute +) + +# FastLinearCombinationClamp +epilogue_functor = FastLinearCombinationClamp( + element_C, alignment_c +) + +# LinearCombinationGeneric +epilogue_functor = LinearCombinationGeneric( + relu(element_epilogue_compute), element_C, alignment_c, + element_acc, element_epilogue_compute +) +``` + +We also provides an experimental feature "Epilogue Visitor Tree" for GEMM operation. The details can be found in [EpilogueVisitorTree](tools/library/scripts/pycutlass/docs/source/md/EpilogueVisitorTree.md). + + +### GEMM Operation + +The GEMM Operation description can be created with +```python +operation = GemmOperationUniversal( + {compute_capability}, tile_description, + A, B, C, epilogue_functor, + {swizzling_functor}, {visitor} +) +``` +* `{compute_capability}` is an integer indicates the compute capability of the GPU. For A100, it is 80. +* `{swizzling_functor}` describes how threadblocks are scheduled on GPU. This is used to improve the L2 Locality ([detail](https://developer.nvidia.com/blog/optimizing-compute-shaders-for-l2-locality-using-thread-group-id-swizzling/)). Currently we support `cutlass.{IdentitySwizzle1|IdentitySwizzle2|IdentitySwizzle4|IdentitySwizzle8|BatchedIdentitySwizzle}`. The last one is used for batched or array GEMM. +* `{visitor}`: a bool variable indicates whether the epilogue visitor tree is used. + +### GEMM Grouped Operation +The GEMM Grouped Operation description can be created with +```python +operation = GemmOperationGrouped( + compute_capability, tile_description, + A, B, C, epilogue_functor, + swizzling_functor, {precompute_mode} +) +``` +* `{precompute_mode}`: It could be `SchedulerMode.Host` or `SchedulerMode.Device`. See [examples/24_gemm_grouped](examples/24_gemm_grouped) for more details. + + +### Conv2d Operation +The Conv2d Operation description can be created with +```python +operation = Conv2dOperation( + {conv_kind}, {iterator_algorithm}, + compute_capability, tile_description, + A, B, C, {stride_support}, + epilogue_functor, swizzling_functor +) +``` +* `{conv_kind}` defines which convolution is executed. Available options include `fprop`, `dgrad`, and `wgrad`. +* `{iterator_algorithm}` specifies the iterator algorithm used by the implicit GEMM in convolution. The options are as follows: + * `analytic`: functionally correct in all cases but lower performance + * `optimized`: optimized for R <= 32, S <= 32 and unity-stride dgrad + * `fixed_channels`: analytic algorithm optimized for fixed channel count (C == AccessSize) + * `few_channels`: Analytic algorithm optimized for few channels (C divisible by AccessSize) +* `{stride_support}`: distinguishes among partial specializations that accelerate certain problems where convolution +stride is unit. + * `strided`: arbitrary convolution stride + * `unity`: unit convolution stride + +*** +## Code Emission and Compilation +After implementing the operation description, the related host and device code can be compiled with +```python +import pycutlass + +pycutlass.compiler.add_module([operation,]) +``` +Several operations can be compiled togather. The `nvcc` at `$CUDA_INSTALL_PATH/bin` is used by default as the compiler backend. But you can also switch to [CUDA Python](https://nvidia.github.io/cuda-python/overview.html)'s `nvrtc` with +```python +pycutlass.compiler.nvrtc() +``` +We also have an internal compiled artifact manager that caches the compiled kernel in both memory and disk. The `compiled_cache.db` at your workspace is the database that contains the binary files. You can delete the file if you want to recompile the kernels. +*** +## Argument Processing +We provide argument wrapper to convert python tensors to the kernel parameters. Currently it supports [torch.Tensor](https://pytorch.org/), [numpy.ndarray](https://numpy.org/), and [cupy.ndarray](https://cupy.dev/). +### GEMM Arguments +The Gemm arguments can be created with +```python +arguments = GemmArguments( + operation=operation, problem_size={problem_size}, + A={tensor_A}, B={tensor_B}, C={tensor_C}, D={tensor_D}, + output_op={output_op}, + gemm_mode={gemm_mode}, + split_k_slices={split_k_slices}, batch={batch} +) +``` +* `problem_size` is a `cutlass.gemm.GemmCoord(M, N, K)` object that defines $M\times N\times K$ matrix multiplication. +* `tensor_X`: user-provide tensors. +* `output_op`: the params for the epilogue functor. +* `gemm_mode`, `split_k_slices`, and `batch`: + +|gemm_mode| split_k_slices | batch | remark| +|--|--|--|--| +|cutlass.gemm.Mode.Gemm | number of split-K slices | - | the ordinary GEMM or GEMM with serial split-K| +|cutlass.gemm.Mode.GemmSplitKParallel | number of split-K slices | - | GEMM Split-K Parallel| +|cutlass.gemm.Mode.Batched | - | batch size | Batched GEMM | +|cutlass.gemm.Mode.Array | - | batch size | Array GEMM | + +### GEMM Grouped Arguments +The GEMM grouped arguments can be created with +```python +arguments = GemmGroupedArguments( + operation, {problem_sizes_coord}, {tensor_As}, {tensor_Bs}, {tensor_Cs}, {tensor_Ds}, + output_op=output_op) +) +``` +* `problem_size_coord` is a list of `cutlass.gemm.GemmCoord(M, N, K)` for each problem size. +* `tensor_Xs` is a list of user-provide tensors. +* `output_op`: the params of the epilogue functor + +### Conv2d Arguments +The Conv2d arguments can be created with +```python +arguments = Conv2dArguments( + operation, {problem_size}, {tensor_A}, + {tensor_B}, {tensor_C}, {tensor_D}, + {output_op}, + {split_k_mode}, + {split_k_slices} +) +``` +* `problem_size`: it can be constructed with + ```python + problem_size = cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(N, H, W, C), + cutlass.Tensor4DCoord(K, R, S, C), + cutlass.Tensor4DCoord(pad[0], pad[1], pad[2], pad[3]), + cutlass.MatrixCoord(stride[0], stride[1]), + cutlass.MatrixCoord(dilation[0], dilation[1]), + cutlass.conv.Mode.cross_correlation, + split_k_slices, 1 + ) + ``` +* `tensor_X` are user-provide tensors +* `output_op`: the params of the epilogue functor +* `split_k_mode`: currently we support `cutlass.conv.SplitKMode.Serial` and `cutlass.conv.SplitKMode.Parallel`. +* `split_k_slice`: number of split-k slices + +For ordianry conv2d, just use `cutlass.conv.SplitKMode.Serial` with `split_k_slice=1`. + +### Getting output_op +The way to create output_op is listed below +```python +output_op = operation.epilogue_type(*([alpha, beta] + args.activation_args)), +``` +It is a list of arguments start with the scaling factor `alpha` and `beta`. +The `output_op` of EpilogueVisitorTree is slightly different. Please check [EpilogueVisitorTree](tools/library/scripts/pycutlass/docs/source/md/EpilogueVisitorTree.md) for details. + + +## Kernel Launching + +With the arguments and operations, the kernel can be launched simply with +```python +operation.run(arguments) +``` + +## Sync results + +We also provide function to synchronize the kernel execution. If you use `numpy`, it will also copy the result back to host. To do that, run +```python +arguments.sync() +``` +If you use EpilogueVisitorTree, please call +```python +output_op.sync() +``` + +## Reduction Kernel behind Parallel Split-K + +If you use parallel-split-K in GEMM or Conv2d, an additional reduction kernel is required. Please check [examples/40_cutlass_py](examples/40_cutlass_py) for detail. diff --git a/tools/library/scripts/pycutlass/docs/source/types.rst b/tools/library/scripts/pycutlass/docs/source/types.rst deleted file mode 100644 index 15893511..00000000 --- a/tools/library/scripts/pycutlass/docs/source/types.rst +++ /dev/null @@ -1,6 +0,0 @@ -Types -======== - - -.. autoenum:: pycutlass.OperationKind - :members: diff --git a/tools/library/scripts/pycutlass/docs/source/user_guide.rst b/tools/library/scripts/pycutlass/docs/source/user_guide.rst new file mode 100644 index 00000000..3db70dbb --- /dev/null +++ b/tools/library/scripts/pycutlass/docs/source/user_guide.rst @@ -0,0 +1,4 @@ +User Guide +===================================== + +.. mdinclude:: ./md/basic_idea.md diff --git a/tools/library/scripts/pycutlass/docs/source/visitor_tree.rst b/tools/library/scripts/pycutlass/docs/source/visitor_tree.rst new file mode 100644 index 00000000..c48cdba3 --- /dev/null +++ b/tools/library/scripts/pycutlass/docs/source/visitor_tree.rst @@ -0,0 +1,4 @@ +User Guide +===================================== + +.. mdinclude:: ./md/EpilogueVisitorTree.md diff --git a/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py b/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py index 16e04cca..9093db83 100644 --- a/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +++ b/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py @@ -32,6 +32,7 @@ from pycutlass import * import pycutlass +from pycutlass.epilogue import LinearCombination from pycutlass.test.conv2d_testbed import Conv2dLauncher @@ -62,15 +63,16 @@ if __name__ == "__main__": tile_description = TileDescription( threadblock_shape=[128, 128, 64], stages=4, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination(cutlass.float32, 4, cutlass.float32, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) diff --git a/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py b/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py index 31f52546..caff35c4 100644 --- a/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +++ b/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py @@ -49,7 +49,7 @@ if __name__ == '__main__': tile_description = TileDescription( threadblock_shape=[256, 128, 32], stages=3, warp_count=[4, 2, 1], - math_instruction=math_inst, min_compute=80, max_compute=80 + math_instruction=math_inst ) A = TensorDescription( @@ -67,7 +67,7 @@ if __name__ == '__main__': element_epilogue = cutlass.float32 - epilogue_functor = EpilogueFunctor.LinearCombination + epilogue_functor = LinearCombination(cutlass.float32, 4, cutlass.float32, cutlass.float32) swizzling_functor = cutlass.IdentitySwizzle1 diff --git a/tools/library/scripts/pycutlass/setup.py b/tools/library/scripts/pycutlass/setup.py index c3933455..219face0 100644 --- a/tools/library/scripts/pycutlass/setup.py +++ b/tools/library/scripts/pycutlass/setup.py @@ -43,7 +43,7 @@ try: Pybind11Extension("cutlass", ["src/cpp/cutlass.cpp"], include_dirs=include_dirs, - extra_compile_args=["-fpermissive"]) + extra_compile_args=["-fpermissive", "-w"]) ] except ImportError: pass @@ -69,7 +69,8 @@ setup( 'typeguard', 'bfloat16', 'typing', - 'scikit-build' + 'scikit-build', + 'treelib' ], cmdclass={ 'rmm': BuildRMM diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_generic.h b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_generic.h new file mode 100644 index 00000000..cac334b3 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_generic.h @@ -0,0 +1,225 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this layernormware without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + + \brief A file contains the epilogue visitor with CTA row-wise broadcast + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/kernel/default_gemm_complex.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h" + +#include "epilogue_visitor_op/visitor_op_linear_combination.h" +#include "epilogue_visitor_op/visitor_op_tensor_input.h" +#include "epilogue_visitor_op/visitor_op_accumulator.h" +#include "epilogue_visitor_op/visitor_op_row_broadcast.h" +#include "epilogue_visitor_op/visitor_op_tensor_output.h" +#include "epilogue_visitor_op/visitor_op_column_reduction.h" +#include "epilogue_visitor_op/visitor_op_row_reduction.h" +#include "epilogue_visitor_op/visitor_op_column_broadcast.h" +#include "epilogue_visitor_op/visitor_op_unary.h" +#include "epilogue_visitor_op/visitor_op_binary.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Generic Epilogue Visitor. +template < + typename OutputOp_ +> +class EpilogueVisitorGeneric { +public: + + using OutputOp = OutputOp_; + using AccumulatorAccessType = typename OutputOp::AccumulatorAccessType; + static int const kElementsPerAccess = OutputOp::kElementsPerAccess; + using ElementOutput = typename OutputOp::ElementOutput; + using OutputTileIterator = typename OutputOp::OutputTileIterator; + + static int const kIterations = OutputTileIterator::kIterations; + + /// + /// End Epilogue Tree + /// + + /// Additional SMEM bufer is not required in the broadcast epilogue visitor + struct SharedStorage { + + typename OutputOp::SharedStorage output_smem; + CUTLASS_HOST_DEVICE + SharedStorage() { } + }; + +public: + + /// Argument structure + struct Arguments { + typename OutputOp::Arguments output_op_args; + // + // Methods + // + Arguments() { } + + Arguments( + typename OutputOp::Arguments output_op_args + ): + output_op_args(output_op_args) + { + + } + }; + + struct Params { + typename OutputOp::Params output_op_params; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(Arguments const &args): + output_op_params(args.output_op_args) + { + + } + }; + + + +private: + + OutputOp output_op; + +public: + + /// Constructor + CUTLASS_DEVICE + EpilogueVisitorGeneric( + Params const ¶ms, ///< Parameters routed to the epilogue + SharedStorage &shared_storage, ///< Shared storage needed by the functors here + MatrixCoord threadblock_offset, + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + MatrixCoord problem_size + ): + output_op(params.output_op_params, shared_storage.output_smem, thread_idx, threadblock_offset, problem_size) + { } + + /// Helper to indicate split-K behavior + CUTLASS_DEVICE + void set_k_partition( + int split_k_index, ///< Index of this threadblock within split-K partitioned scheme + int split_k_slices) { ///< Total number of split-K slices + + } + + /// Called to set the batch index + CUTLASS_DEVICE + void set_batch_index(int batch_idx) { + output_op.set_batch_index(batch_idx); + } + + /// Called at the start of the epilogue just before iterating over accumulator slices + CUTLASS_DEVICE + void begin_epilogue() { + output_op.begin_epilogue(); + } + + /// Called at the start of one step before starting accumulator exchange + CUTLASS_DEVICE + void begin_step(int step_idx) { + output_op.begin_step(step_idx); + } + + /// Called at the start of a row + CUTLASS_DEVICE + void begin_row(int row_idx) { + output_op.begin_row(row_idx); + } + + /// Called after accumulators have been exchanged for each accumulator vector + CUTLASS_DEVICE + void visit( + int iter_idx, + int row_idx, + int column_idx, + int frag_idx, + AccumulatorAccessType const &accum) { + output_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum); + } + + /// Called at the start of a row + CUTLASS_DEVICE + void end_row(int row_idx) { + output_op.end_row(row_idx); + + } + + /// Called after all accumulator elements have been visited + CUTLASS_DEVICE + void end_step(int step_idx) { + output_op.end_step(step_idx); + } + + /// Called after all steps have been completed + CUTLASS_DEVICE + void end_epilogue() { + output_op.end_epilogue(); + } + +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/binary_ops.h b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/binary_ops.h new file mode 100644 index 00000000..e13624a4 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/binary_ops.h @@ -0,0 +1,84 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this layernormware without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + + \brief A file contains the binary ops +*/ + +#pragma once +#include "cutlass/cutlass.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +/// Scalar multiplication +template +struct VectorAdd { + + struct Arguments { + int tmp; + + CUTLASS_HOST_DEVICE + Arguments():tmp(0){ } + + CUTLASS_HOST_DEVICE + Arguments(int tmp): tmp(tmp) { } + }; + + struct Params { + + CUTLASS_HOST_DEVICE + Params(Arguments const &args) { } + }; + + CUTLASS_HOST_DEVICE + VectorAdd( + Params const ¶ms + ) { } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + cutlass::plus> add_op; + return add_op(lhs, rhs); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/unary_ops.h b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/unary_ops.h new file mode 100644 index 00000000..09679db6 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/unary_ops.h @@ -0,0 +1,233 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this layernormware without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + + \brief A file contains the unary ops +*/ + +#pragma once +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/activation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +/// Scalar multiplication +template +struct Mult { + + struct Arguments { + T alpha; + + CUTLASS_HOST_DEVICE + Arguments():alpha(T(1.0)){ } + + CUTLASS_HOST_DEVICE + Arguments(T alpha): alpha(alpha) { } + }; + + struct Params { + T alpha; ///< scales accumulators + + CUTLASS_HOST_DEVICE + Params():alpha(T(1.0)){ } + + CUTLASS_HOST_DEVICE + Params(Arguments const &args): alpha(args.alpha) { } + }; + + T alpha_; + + CUTLASS_HOST_DEVICE + Mult( + Params const ¶ms + ): + alpha_(params.alpha) + { } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &source) const { + cutlass::multiplies> multiply_op; + return multiply_op(source, alpha_); + } + + CUTLASS_HOST_DEVICE + bool guard() { + return alpha_ != T(0); + } + +}; + + +/// ReLU +template +struct ReLUVisitor { + struct Arguments { + T threshold; + + CUTLASS_HOST_DEVICE + Arguments():threshold(T(0.0)) { } + + CUTLASS_HOST_DEVICE + Arguments(T threshold): threshold(threshold) { } + }; + + struct Params { + T threshold; + + CUTLASS_HOST_DEVICE + Params():threshold(T(0.0)) { } + + CUTLASS_HOST_DEVICE + Params(Arguments const &args): threshold(args.threshold) { } + }; + + T threshold_; + + CUTLASS_HOST_DEVICE + ReLUVisitor(Params const ¶ms): + threshold_(params.threshold) { } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &frag) const { + maximum> mx; + return mx(frag, threshold_); + } + + CUTLASS_HOST_DEVICE + bool guard() { + return true; + } +}; + +/// leakyReLU +template +struct LeakyReLUVisitor { + struct Arguments { + T leaky_alpha; + + CUTLASS_HOST_DEVICE + Arguments():leaky_alpha(T(0.0)) { } + + CUTLASS_HOST_DEVICE + Arguments(T leaky_alpha): leaky_alpha(leaky_alpha) { } + }; + + struct Params { + T leaky_alpha; + + CUTLASS_HOST_DEVICE + Params():leaky_alpha(T(0.0)) { } + + CUTLASS_HOST_DEVICE + Params(Arguments const &args): leaky_alpha(args.leaky_alpha) { } + }; + + T leaky_alpha_; + + CUTLASS_HOST_DEVICE + LeakyReLUVisitor(Params const ¶ms): + leaky_alpha_(params.leaky_alpha) { } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &frag) const { + cutlass::epilogue::thread::LeakyReLU> leaky_op; + return leaky_op(frag, leaky_alpha_); + } + + CUTLASS_HOST_DEVICE + bool guard() { + return true; + } + +}; + +/// Tanh +template +struct TanhVisitor { + /// Argument + struct Arguments { + // a placeholder argument to ensure correctness of ctypes + int tmp; + + CUTLASS_HOST_DEVICE + Arguments(): tmp(0) { }; + + CUTLASS_HOST_DEVICE + Arguments(int tmp): tmp(tmp) { }; + }; + + /// Param + struct Params { + CUTLASS_HOST_DEVICE + Params(){ }; + Params(Arguments const &args) { } + }; + + /// Constructor + CUTLASS_HOST_DEVICE + TanhVisitor(Params const ¶ms) { } + + // scalar operator + CUTLASS_HOST_DEVICE + T tanh_op(T const &scalar) const { + return fast_tanh(scalar); + } + + /// vector operator + CUTLASS_HOST_DEVICE + Array operator()(Array const &frag) const { + Array y; + + CUTLASS_PRAGMA_UNROLL + for (int i=0; i < N; ++i) { + y[i] = tanh_op(frag[i]); + } + + return y; + } + + CUTLASS_HOST_DEVICE + bool guard() { + return true; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_accumulator.h b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_accumulator.h new file mode 100644 index 00000000..75f83c4e --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_accumulator.h @@ -0,0 +1,148 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this layernormware without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + + \brief A file contains the epilogue visitor Op with accumulator +*/ + +#pragma once +#include "cutlass/cutlass.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Epilogue Visitor operator for the following Computation +/// +/// ElementAccumulator accum; +/// return accum; +/// +/// It can only be the leaf node of the epilogue tree + +template < + typename ElementAccumulator_, ///< Data type of the Accumulator + int kElementsPerAccess_ ///< Number of elements computed per operation +> +class VisitorOpAccumulator{ +public: + using ElementAccumulator = ElementAccumulator_; + static int const kElementsPerAccess = kElementsPerAccess_; + + /// Fragment type for Accumulator + using AccumulatorAccessType = Array; + + /// Fragment type returned by this visitor + using VisitAccessType = AccumulatorAccessType; + + /// SMEM buffer class required in the epilogue visitor + struct SharedStorage { + CUTLASS_HOST_DEVICE + SharedStorage() {} + }; + + /// Host-constructable Arguments structure + struct Arguments { + // Note: it is strange that ctypes will return issue with empty arguments + int tmp; + + CUTLASS_HOST_DEVICE + Arguments() { } + + CUTLASS_HOST_DEVICE + Arguments(int tmp): tmp(tmp) { } + }; + + /// Parameter structure + struct Params { + + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(Arguments const &args) { } + }; + +public: + /// Constructs the function object + CUTLASS_HOST_DEVICE + VisitorOpAccumulator( + Params const ¶ms, + SharedStorage &shared_storage, + int thread_idx, + MatrixCoord threadblock_offset, + MatrixCoord problem_size + ) { } + + CUTLASS_DEVICE + void set_batch_index(int batch_idx) { } + + CUTLASS_DEVICE + void begin_epilogue() { } + + CUTLASS_DEVICE + void begin_step(int step_idx) { } + + CUTLASS_DEVICE + void begin_row(int row_idx) { } + + CUTLASS_DEVICE + VisitAccessType visit( + int iter_idx, + int row_idx, + int column_idx, + int frag_idx, + AccumulatorAccessType const &accum + ) { + return accum; + } + + CUTLASS_DEVICE + void end_row(int row_idx) { } + + CUTLASS_DEVICE + void end_step(int step_idx) { } + + CUTLASS_DEVICE + void end_epilogue() { } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_binary.h b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_binary.h new file mode 100644 index 00000000..124a9fd8 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_binary.h @@ -0,0 +1,246 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this layernormware without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + + \brief A file contains the epilogue visitor Op with Binary op +*/ + +#pragma once +#include "cutlass/cutlass.h" +#include "binary_ops.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Epilogue Visitor operator for the following computation: +/// +/// ElementCompute alpha; +/// ElementCompute beta; +/// ElementCompute C = BinaryOp(alpha * ElementCompute(Visitor_A), beta * ElementCompute(Visitor_B) +/// Return C; +/// +template < + typename ElementAccumulator_, ///< Data type of the Accumulator + typename ElementCompute_, ///< Data type used to compute linear combination + int kElementsPerAccess_, ///< Number of elements computed per operation + typename VisitorA_, ///< Child node A + typename VisitorB_, ///< Child node B + template typename BinaryOp_ +> +class VisitorOpBinary{ +public: + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + static int const kElementsPerAccess = kElementsPerAccess_; + + using VisitorA = VisitorA_; + using VisitorB = VisitorB_; + + /// Fragment type returned from VisitorA.visit + using VisitAccessTypeA = typename VisitorA::VisitAccessType; + using ElementA = typename VisitAccessTypeA::Element; + + /// Fragment type returned from VisitorB.visit + using VisitAccessTypeB = typename VisitorB::VisitAccessType; + using ElementB = typename VisitAccessTypeB::Element; + + /// Fragment type returned by this visitor + using VisitAccessType = Array; + + /// Fragment type of accumulator + using AccumulatorAccessType = Array; + + /// Combination Op TODO: generalize this + using BinaryOp = BinaryOp_; + + static_assert(kElementsPerAccess==VisitAccessTypeA::kElements, "kElementsPerAccess mismatches with Visitor A"); + static_assert(kElementsPerAccess==VisitAccessTypeB::kElements, "kElementsPerAccess misnatches with Visitor B"); + + /// SMEM buffer class required in the epilogue visitor + struct SharedStorage { + typename VisitorA::SharedStorage storage_a; + typename VisitorB::SharedStorage storage_b; + + CUTLASS_HOST_DEVICE + SharedStorage() {} + }; + + + /// Host-constructable Arguments structure + struct Arguments { + typename BinaryOp::Arguments binary_arg; + typename VisitorA::Arguments visitor_a_arg; ///< Argument type for visitor_a + typename VisitorB::Arguments visitor_b_arg; ///< Argument type for visitor_b + + // + // Methods + // + CUTLASS_HOST_DEVICE + Arguments():binary_arg() { } + + CUTLASS_HOST_DEVICE + Arguments( + typename BinaryOp::Arguments binary_arg, + typename VisitorA::Arguments visitor_a_arg, + typename VisitorB::Arguments visitor_b_arg + ): + binary_arg(binary_arg), + visitor_a_arg(visitor_a_arg), + visitor_b_arg(visitor_b_arg) + { } + }; + + /// Parameter structure + struct Params { + typename BinaryOp::Params binary_param; + typename VisitorA::Params visitor_a_param; ///< Argument type for visitor_a + typename VisitorB::Params visitor_b_param; ///< Argument type for visitor_b + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(Arguments const &args): + binary_param(args.binary_arg), + visitor_a_param(args.visitor_a_arg), + visitor_b_param(args.visitor_b_arg) + { } + }; + +private: + // + // Data members + // + + BinaryOp binary_op; + + VisitorA visitor_a_op; + VisitorB visitor_b_op; + +public: + + /// Constructs the function object + CUTLASS_HOST_DEVICE + VisitorOpBinary( + Params const ¶ms, + SharedStorage &shared_storage, + int thread_idx, + MatrixCoord threadblock_offset, + MatrixCoord problem_size + ): + binary_op(params.binary_param), + visitor_a_op(params.visitor_a_param, shared_storage.storage_a, thread_idx, threadblock_offset, problem_size), + visitor_b_op(params.visitor_b_param, shared_storage.storage_b, thread_idx, threadblock_offset, problem_size) + { } + + + CUTLASS_DEVICE + void begin_epilogue() { + visitor_a_op.begin_epilogue(); + visitor_b_op.begin_epilogue(); + } + + CUTLASS_DEVICE + void set_batch_index(int batch_idx) { + visitor_a_op.set_batch_index(batch_idx); + visitor_b_op.set_batch_index(batch_idx); + } + + CUTLASS_DEVICE + void begin_step(int step_idx) { + visitor_a_op.begin_step(step_idx); + visitor_b_op.begin_step(step_idx); + } + + CUTLASS_DEVICE + void begin_row(int row_idx) { + visitor_a_op.begin_row(row_idx); + visitor_b_op.begin_row(row_idx); + } + + CUTLASS_DEVICE + VisitAccessType visit( + int iter_idx, + int row_idx, + int column_idx, + int frag_idx, + AccumulatorAccessType const &accum + ) { + /// Get result from visitor A and visitor B + VisitAccessTypeA result_A = visitor_a_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum); + VisitAccessTypeB result_B = visitor_b_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum); + + /// Type conversion + NumericArrayConverter source_converter_A; + NumericArrayConverter source_converter_B; + + return binary_op( + source_converter_A(result_A), + source_converter_B(result_B) + ); + } + + CUTLASS_DEVICE + void end_row(int row_idx) { + visitor_a_op.end_row(row_idx); + visitor_b_op.end_row(row_idx); + } + + CUTLASS_DEVICE + void end_step(int step_idx) { + visitor_a_op.end_step(step_idx); + visitor_b_op.end_step(step_idx); + } + + CUTLASS_DEVICE + void end_epilogue() { + visitor_a_op.end_epilogue(); + visitor_b_op.end_epilogue(); + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_broadcast.h b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_broadcast.h new file mode 100644 index 00000000..d631b27e --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_broadcast.h @@ -0,0 +1,250 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this layernormware without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + + \brief A file contains the epilogue visitor Op with broadcasting vector to all columns +*/ + +#pragma once +#include "cutlass/cutlass.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Epilogue Visitor operator for the following computation: +/// +/// ElementVector T[i][j] <- device-memory Td[i] +/// +/// It can only be a leaf node in the epilogue tree +template < + typename ElementAccumulator_, ///< Data type of the Accumulator + typename ElementFragment_, ///< Data type used to cache vector in register + typename InputTileIterator_ ///< Tile iterator type to read the broadcasted tensor +> +class VisitorOpColumnBroadcast { +public: + using InputTileIterator = InputTileIterator_; + + static int const kElementsPerAccess = InputTileIterator::kElementsPerAccess; + using ElementAccumulator = ElementAccumulator_; + using ElementVector = typename InputTileIterator::Element; + using ElementFragment = ElementFragment_; + + using VisitAccessType = Array; + + /// Thread map used by input tile iterators + using ThreadMap = typename InputTileIterator::ThreadMap; + + /// Fragment object used to store the broadcast values + using BroadcastFragment = Array< + ElementFragment, kElementsPerAccess>; + + /// Fragment type of accumulator + using AccumulatorAccessType = Array; + + /// Used for the broadcast + struct BroadcastDetail { + /// Number of threads per warp + static int const kWarpSize = 32; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + /// Number of distinct scalar column indices handled by each thread + static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess; + + /// Number of distinct scalar row indices handled by each thread + static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn; + + /// Number of threads per threadblock + static int const kThreadCount = ThreadMap::kThreads; + + /// Number of distinct threads per row of output tile + static int const kThreadsPerRow = (InputTileIterator::Shape::kN / kColumnsPerThread); + + /// Number of distinct threads which must be reduced during the final reduction phase within the threadblock. + static int const kThreadRows = kThreadCount / kThreadsPerRow; + + // /// Number of iterations (accesses) the threadblock takes to reduce a row + // static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount); + }; + + // using ComputeFragmentType = Array; + + struct SharedStorage { + CUTLASS_HOST_DEVICE + SharedStorage() { } + }; + + /// Host-constructable Argument structure + struct Arguments { + ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand + int64_t batch_stride; + + /// Methods + CUTLASS_HOST_DEVICE + Arguments(): + broadcast_ptr(nullptr) { } + + CUTLASS_HOST_DEVICE + Arguments( + ElementVector *broadcast_ptr, + int64_t batch_stride + ): + broadcast_ptr(broadcast_ptr), + batch_stride(batch_stride) { } + }; + + /// Param structure + struct Params { + ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand + int64_t batch_stride; + + /// Method + CUTLASS_HOST_DEVICE + Params(): + broadcast_ptr(nullptr) { } + + CUTLASS_HOST_DEVICE + Params(Arguments const &args): + broadcast_ptr(args.broadcast_ptr), + batch_stride(args.batch_stride) { } + }; + +private: + ElementVector *broadcast_ptr; + BroadcastFragment broadcast_fragment; ///< Array holds the loaded broadcast fragment + MatrixCoord threadblock_offset_; + int thread_idx_; + MatrixCoord problem_size; + + int thread_start_row_; + int state_[3]; + int thread_offset_row_; + + int64_t batch_stride_; + +public: + /// Constructs the function object + CUTLASS_HOST_DEVICE + VisitorOpColumnBroadcast( + Params const ¶ms, + SharedStorage &shared_storage, + int thread_idx, + MatrixCoord threadblock_offset, + MatrixCoord problem_size + ): + broadcast_ptr(params.broadcast_ptr), + threadblock_offset_(threadblock_offset), + thread_idx_(thread_idx), + problem_size(problem_size), + thread_start_row_(ThreadMap::initial_offset(thread_idx).row() + threadblock_offset.row()), + batch_stride_(params.batch_stride) + { + state_[0] = state_[1] = state_[2] = 0; + } + + CUTLASS_DEVICE + void set_batch_index(int batch_idx) { + broadcast_ptr += batch_idx * batch_stride_; + } + + CUTLASS_DEVICE + void begin_epilogue() { } + + CUTLASS_DEVICE + void begin_step(int step_idx) {} + + CUTLASS_DEVICE + void begin_row(int row_idx) {} + + CUTLASS_DEVICE + VisitAccessType visit( + int iter_idx, + int row_idx, + int column_idx, + int frag_idx, + AccumulatorAccessType const &accum + ) { + // get pointer + thread_offset_row_ = thread_start_row_ + ThreadMap::iteration_offset(frag_idx).row(); + + ElementFragment broadcast_data = ElementFragment(*(broadcast_ptr + thread_offset_row_)); + + broadcast_fragment.fill(broadcast_data); + + return broadcast_fragment; + } + + CUTLASS_DEVICE + void end_row(int row_idx) { } + + CUTLASS_DEVICE + void end_step(int step_idx) { + // run operator ++ + ++state_[0]; + + thread_start_row_ += ThreadMap::Shape::kRow; + if (state_[0] == ThreadMap::Count::kRow) { + state_[0] = 0; + ++state_[1]; + thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * + ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + state_[1] = 0; + ++state_[2]; + thread_start_row_ += ThreadMap::Count::kGroup * + ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + } + } + } + } + + CUTLASS_DEVICE + void end_epilogue() { } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_reduction.h b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_reduction.h new file mode 100644 index 00000000..5f671675 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_reduction.h @@ -0,0 +1,342 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this layernormware without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + + \brief A file contains the epilogue visitor Op with reduction over columns in CTA +*/ + +#pragma once +#include "cutlass/cutlass.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Epilogue Visitor operator for the following computation: +/// +/// ElementReductionAccumulator R[j] = \sum_i ElementReductionAccumulator(T[i][j]) +/// device memory <- ElementReduction(R[j]) +/// +template < + typename ThreadblockShape_, /// Threadblock shape + typename ElementAccumulator_, ///< Data type of the Accumulator + typename ElementReduction_, ///< Data type of the output reduction in device memory + typename ElementReductionAccumulator_ , ///< Data type to accumulate reduction in smem and register + typename OutputTileIterator_, ///< Tile Iterator type + typename Visitor_ ///< preceeding visitor op +> +class VisitorOpColumnReduction { +public: + using ElementAccumulator = ElementAccumulator_; + using ElementReductionAccumulator = ElementReductionAccumulator_; + using ElementReduction = ElementReduction_; + using OutputTileIterator = OutputTileIterator_; + using ThreadblockShape = ThreadblockShape_; + using Visitor = Visitor_; + + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + // TODO: generalize the reduction op + using ReductionOp = cutlass::plus>; + using ReductionOpScalar = cutlass::plus; + using ElementOutput = typename OutputTileIterator::Element; + + + + /// Fragment type returned from Visitor + using VisitAccessTypeVisitor = typename Visitor::VisitAccessType; + using ElementVisitor = typename VisitAccessTypeVisitor::Element; + + using VisitAccessType = VisitAccessTypeVisitor; + + /// Fragment type of accumulator + using AccumulatorAccessType = Array; + + /// Fragment type of redcution + using ReductionAccumulatorAccessType = Array; + + /// Thread map used by output tile iterators + using ThreadMap = typename OutputTileIterator::ThreadMap; + /// Used for the reduction + struct ReductionDetail { + + /// Number of threads per warp + static int const kWarpSize = 32; + + /// Number of distinct scalar column indices handled by each thread + static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess; + + /// Number of distinct scalar row indices handled by each thread + static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn; + + /// Number of threads per threadblock + static int const kThreadCount = ThreadMap::kThreads; + + /// Number of distinct threads per row of output tile + static int const kThreadsPerRow = ThreadblockShape::kN / kColumnsPerThread; + + /// Number of distinct threads which must be reduced during the final reduction phase within the threadblock + static int const kThreadRows = kThreadCount / kThreadsPerRow; + + /// Number of iterations (accesses) the threadblock takes to reduce a row + static int const kThreadAccessesPerRow = const_max(1, (ThreadblockShape::kN + kThreadCount - 1) / kThreadCount); + + using StorageShape = MatrixShape< + kThreadRows, + ThreadblockShape::kN + >; + }; + + using ReductionFragment = Array; + + /// Shared storage + struct SharedStorage { + typename Visitor::SharedStorage storage_visitor; + AlignedArray reduction; + + CUTLASS_HOST_DEVICE + SharedStorage() {} + }; + + /// Host-constructable Argument structure + struct Arguments { + ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory + int64_t batch_stride; + typename Visitor::Arguments visitor_arg; ///< Argument type of visitor + + /// Method + CUTLASS_HOST_DEVICE + Arguments(): reduction_ptr(nullptr) { } + + CUTLASS_HOST_DEVICE + Arguments( + ElementReduction *reduction_ptr, + int64_t batch_stride, + typename Visitor::Arguments visitor_arg + ): + reduction_ptr(reduction_ptr), + batch_stride(batch_stride), + visitor_arg(visitor_arg) + { } + }; + + /// Param structure + struct Params { + ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory + int64_t batch_stride; + typename Visitor::Params visitor_param; ///< Argument type of visitor + + /// Method + CUTLASS_HOST_DEVICE + Params(): reduction_ptr(nullptr) { } + + CUTLASS_HOST_DEVICE + Params(Arguments const &args): + reduction_ptr(args.reduction_ptr), + batch_stride(args.batch_stride), + visitor_param(args.visitor_arg) + { } + }; + +private: + ElementReduction *reduction_output_ptr_; ///< Pointer to the reduction tensor in device memory + ElementReductionAccumulator *reduction_smem_ptr_; ///< Pointer to the partial reductions in shared memory + ReductionFragment reduction_fragment; ///< register fragments that hold the partial reduction + Visitor visitor_; ///< visitor + int thread_idx_; + MatrixCoord threadblock_offset; + MatrixCoord problem_size_; + int64_t batch_stride_; + +public: + + /// Constructs the function object + CUTLASS_HOST_DEVICE + VisitorOpColumnReduction( + Params const ¶ms, + SharedStorage &shared_storage, + int thread_idx, + MatrixCoord threadblock_offset, + MatrixCoord problem_size + ): + visitor_(params.visitor_param, shared_storage.storage_visitor, + thread_idx, threadblock_offset, problem_size), + reduction_smem_ptr_(shared_storage.reduction.data()), + reduction_output_ptr_(params.reduction_ptr), + thread_idx_(thread_idx), + threadblock_offset(threadblock_offset), + problem_size_(problem_size), + batch_stride_(params.batch_stride) + { } + + CUTLASS_DEVICE + void set_batch_index(int batch_idx) { + reduction_output_ptr_ += batch_idx * batch_stride_; + visitor_.set_batch_index(batch_idx); + } + + CUTLASS_DEVICE + void begin_epilogue() { + visitor_.begin_epilogue(); + + // clear the reduction fragment + reduction_fragment.clear(); + } + + CUTLASS_DEVICE + void begin_step(int step_idx) { + visitor_.begin_step(step_idx); + } + + CUTLASS_DEVICE + void begin_row(int row_idx) { + visitor_.begin_row(row_idx); + } + + CUTLASS_DEVICE + VisitAccessType visit( + int iter_idx, + int row_idx, + int column_idx, + int frag_idx, + AccumulatorAccessType const &accum + ) { + /// Get result from visitor + VisitAccessTypeVisitor result = visitor_.visit(iter_idx, row_idx, column_idx, frag_idx, accum); + + NumericArrayConverter reduction_converter; + ReductionOp reduction_op; + ReductionAccumulatorAccessType* reduction_fragment_ = reinterpret_cast(&reduction_fragment); + reduction_fragment_[column_idx] = reduction_op(reduction_fragment_[column_idx], reduction_converter(result)); + + return result; + } + + CUTLASS_DEVICE + void end_row(int row_idx) { + visitor_.end_row(row_idx); + } + + CUTLASS_DEVICE + void end_step(int step_idx) { + visitor_.end_step(step_idx); + } + + CUTLASS_DEVICE + void end_epilogue() { + visitor_.end_epilogue(); + // + // Store the partially reduced value to SMEM + // + + // Guard against uses of the existing SMEM tile + __syncthreads(); + + using AccessType = AlignedArray; + + // + // Determine a compact thread arrangement to store to SMEM + // + + MatrixCoord thread_offset( + thread_idx_ / ReductionDetail::kThreadsPerRow, + (thread_idx_ % ReductionDetail::kThreadsPerRow) * ThreadMap::kElementsPerAccess + ); + + // + // Each thread store its fragment to a SMEM + // + AccessType *aligned_reduction_ptr = reinterpret_cast( + &reduction_smem_ptr_[thread_offset.row() * ThreadblockShape::kN + thread_offset.column()] + ); + + AccessType const *frag_ptr = reinterpret_cast( + &reduction_fragment + ); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + int col_idx = column * ThreadMap::Delta::kColumn / ThreadMap::kElementsPerAccess; + + aligned_reduction_ptr[col_idx] = frag_ptr[column]; + } + + __syncthreads(); + + // + // Now, threads are assigned several columns of the output. The fetch over all rows from + // the compacted SMEM tile and perform a reduction. + // + + NumericConverter output_converter; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < ReductionDetail::kThreadAccessesPerRow; ++j) { + int column_idx = thread_idx_ + j * ReductionDetail::kThreadCount; + + ReductionOpScalar reduction_op; + ElementReductionAccumulator reduction_element = ElementReductionAccumulator(); + + int output_column_idx = threadblock_offset.column() + column_idx; + + if (column_idx < ThreadblockShape::kN && output_column_idx < problem_size_.column()) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ReductionDetail::kThreadRows; ++row) { + if (row) { + auto frag = reduction_smem_ptr_[row * ThreadblockShape::kN + column_idx]; + reduction_element = reduction_op(reduction_element, frag); + } + else { + + reduction_element = reduction_smem_ptr_[column_idx]; + } + } + + // Store + reduction_output_ptr_[column_idx + threadblock_offset.column() + threadblock_offset.row() / ThreadblockShape::kM * problem_size_.column()] = output_converter(reduction_element); + } + } + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_linear_combination.h b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_linear_combination.h new file mode 100644 index 00000000..da62808e --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_linear_combination.h @@ -0,0 +1,266 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this layernormware without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + + \brief A file contains the epilogue visitor Op with Linear Combination +*/ + +#pragma once +#include "cutlass/cutlass.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Epilogue Visitor operator for the following computation: +/// +/// ElementCompute alpha; +/// ElementCompute beta; +/// ElementCompute C = BinaryOp(alpha * ElementCompute(Visitor_A), beta * ElementCompute(Visitor_B) +/// Return C; +/// +template < + typename ElementAccumulator_, ///< Data type of the Accumulator + typename ElementCompute_, ///< Data type used to compute linear combination + int kElementsPerAccess_, ///< Number of elements computed per operation + typename VisitorA_, ///< Child node A + typename VisitorB_ ///< Child node B +> +class VisitorOpLinearCombination{ +public: + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + static int const kElementsPerAccess = kElementsPerAccess_; + + using VisitorA = VisitorA_; + using VisitorB = VisitorB_; + + /// Fragment type returned from VisitorA.visit + using VisitAccessTypeA = typename VisitorA::VisitAccessType; + using ElementA = typename VisitAccessTypeA::Element; + + /// Fragment type returned from VisitorB.visit + using VisitAccessTypeB = typename VisitorB::VisitAccessType; + using ElementB = typename VisitAccessTypeB::Element; + + /// Fragment type returned by this visitor + using VisitAccessType = Array; + + /// Fragment type of accumulator + using AccumulatorAccessType = Array; + + /// Combination Op TODO: generalize this + using CombinationOp = cutlass::plus; + + static_assert(kElementsPerAccess==VisitAccessTypeA::kElements, "kElementsPerAccess mismatches with Visitor A"); + static_assert(kElementsPerAccess==VisitAccessTypeB::kElements, "kElementsPerAccess misnatches with Visitor B"); + + /// SMEM buffer class required in the epilogue visitor + struct SharedStorage { + typename VisitorA::SharedStorage storage_a; + typename VisitorB::SharedStorage storage_b; + + CUTLASS_HOST_DEVICE + SharedStorage() {} + }; + + + /// Host-constructable Arguments structure + struct Arguments { + ElementCompute alpha; ///< scales accumulators + ElementCompute beta; ///< scales source tensor + typename VisitorA::Arguments visitor_a_arg; ///< Argument type for visitor_a + typename VisitorB::Arguments visitor_b_arg; ///< Argument type for visitor_b + + // + // Methods + // + CUTLASS_HOST_DEVICE + Arguments(): + alpha(ElementCompute(1)), + beta(ElementCompute(0)) + { } + + CUTLASS_HOST_DEVICE + Arguments( + ElementCompute alpha, + ElementCompute beta, + typename VisitorA::Arguments visitor_a_arg, + typename VisitorB::Arguments visitor_b_arg + ): + alpha(alpha), + beta(beta), + visitor_a_arg(visitor_a_arg), + visitor_b_arg(visitor_b_arg) + { } + }; + + /// Parameter structure + struct Params { + ElementCompute alpha; ///< scales accumulators + ElementCompute beta; ///< scales source tensor + typename VisitorA::Params visitor_a_param; ///< Argument type for visitor_a + typename VisitorB::Params visitor_b_param; ///< Argument type for visitor_b + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(Arguments const &args): + alpha(args.alpha), + beta(args.beta), + visitor_a_param(args.visitor_a_arg), + visitor_b_param(args.visitor_b_arg) + { } + }; + +private: + // + // Data members + // + + ElementCompute alpha_; + ElementCompute beta_; + + VisitorA visitor_a_op; + VisitorB visitor_b_op; + +public: + + /// Constructs the function object + CUTLASS_HOST_DEVICE + VisitorOpLinearCombination( + Params const ¶ms, + SharedStorage &shared_storage, + int thread_idx, + MatrixCoord threadblock_offset, + MatrixCoord problem_size + ): + alpha_(params.alpha), + beta_(params.beta), + visitor_a_op(params.visitor_a_param, shared_storage.storage_a, thread_idx, threadblock_offset, problem_size), + visitor_b_op(params.visitor_b_param, shared_storage.storage_b, thread_idx, threadblock_offset, problem_size) + { } + + + CUTLASS_DEVICE + void begin_epilogue() { + if (alpha_ != ElementCompute(0)) visitor_a_op.begin_epilogue(); + if (beta_ != ElementCompute(0)) visitor_b_op.begin_epilogue(); + } + + CUTLASS_DEVICE + void begin_step(int step_idx) { + if (alpha_ != ElementCompute(0)) visitor_a_op.begin_step(step_idx); + if (beta_ != ElementCompute(0)) visitor_b_op.begin_step(step_idx); + } + + CUTLASS_DEVICE + void begin_row(int row_idx) { + if (alpha_ != ElementCompute(0)) visitor_a_op.begin_row(row_idx); + if (beta_ != ElementCompute(0)) visitor_b_op.begin_row(row_idx); + } + + CUTLASS_DEVICE + VisitAccessType visit( + int iter_idx, + int row_idx, + int column_idx, + int frag_idx, + AccumulatorAccessType const &accum + ) { + /// Get result from visitor A and visitor B + VisitAccessTypeA result_A; + VisitAccessTypeB result_B; + + if (alpha_ != ElementCompute(0)) { + result_A = visitor_a_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum); + } else { + // Fill the result A with zeros + result_A.clear(); + } + + if (beta_ != ElementCompute(0)) { + result_B = visitor_b_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum); + } else { + // Fill the result B with zeros + result_B.clear(); + } + + /// Type conversion + NumericArrayConverter source_converter_A; + NumericArrayConverter source_converter_B; + + CombinationOp combination_op; + + cutlass::multiplies multiply_op; + + return combination_op( + multiply_op(alpha_, source_converter_A(result_A)), + multiply_op(beta_, source_converter_B(result_B)) + ); + } + + CUTLASS_DEVICE + void end_row(int row_idx) { + if (alpha_ != ElementCompute(0)) visitor_a_op.end_row(row_idx); + if (beta_ != ElementCompute(0)) visitor_b_op.end_row(row_idx); + } + + CUTLASS_DEVICE + void end_step(int step_idx) { + if (alpha_ != ElementCompute(0)) visitor_a_op.end_step(step_idx); + if (beta_ != ElementCompute(0)) visitor_b_op.end_step(step_idx); + } + + CUTLASS_DEVICE + void end_epilogue() { + if (alpha_ != ElementCompute(0)) visitor_a_op.end_epilogue(); + if (beta_ != ElementCompute(0)) visitor_b_op.end_epilogue(); + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_broadcast.h b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_broadcast.h new file mode 100644 index 00000000..b5a18127 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_broadcast.h @@ -0,0 +1,258 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this layernormware without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + + \brief A file contains the epilogue visitor Op with broadcasting vector to all rows +*/ + +#pragma once +#include "cutlass/cutlass.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Epilogue Visitor operator for the following computation: +/// +/// ElementVector T[i][j] <- device-memory Td[j] +/// +/// It can only be a leaf node in the epilogue tree +template < + typename ElementAccumulator_, ///< Data type of the Accumulator + typename ElementFragment_, ///< Data type used to cache vector in register + typename InputTileIterator_ ///< Tile iterator type to read the broadcasted tensor +> +class VisitorOpRowBroadcast { +public: + using InputTileIterator = InputTileIterator_; + + static int const kElementsPerAccess = InputTileIterator::kElementsPerAccess; + using ElementAccumulator = ElementAccumulator_; + using ElementVector = typename InputTileIterator::Element; + using ElementFragment = ElementFragment_; + + using VisitAccessType = Array; + + /// Thread map used by input tile iterators + using ThreadMap = typename InputTileIterator::ThreadMap; + + /// Fragment object used to store the broadcast values + using BroadcastFragment = Array< + ElementFragment, + ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>; + + /// Fragment type of accumulator + using AccumulatorAccessType = Array; + + /// Used for the broadcast + struct BroadcastDetail { + /// Number of threads per warp + static int const kWarpSize = 32; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + /// Number of distinct scalar column indices handled by each thread + static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess; + + /// Number of distinct scalar row indices handled by each thread + static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn; + + /// Number of threads per threadblock + static int const kThreadCount = ThreadMap::kThreads; + + /// Number of distinct threads per row of output tile + static int const kThreadsPerRow = (InputTileIterator::Shape::kN / kColumnsPerThread); + + /// Number of distinct threads which must be reduced during the final reduction phase within the threadblock. + static int const kThreadRows = kThreadCount / kThreadsPerRow; + + // /// Number of iterations (accesses) the threadblock takes to reduce a row + // static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount); + }; + + // using ComputeFragmentType = Array; + + struct SharedStorage { + CUTLASS_HOST_DEVICE + SharedStorage() { } + }; + + /// Host-constructable Argument structure + struct Arguments { + ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand + int64_t batch_stride; + + /// Methods + CUTLASS_HOST_DEVICE + Arguments(): + broadcast_ptr(nullptr) { } + + CUTLASS_HOST_DEVICE + Arguments( + ElementVector *broadcast_ptr, + int64_t batch_stride + ): + broadcast_ptr(broadcast_ptr), + batch_stride(batch_stride) { } + }; + + /// Param structure + struct Params { + ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand + int64_t batch_stride; + + /// Method + CUTLASS_HOST_DEVICE + Params(): + broadcast_ptr(nullptr) { } + + CUTLASS_HOST_DEVICE + Params(Arguments const &args): + broadcast_ptr(args.broadcast_ptr), + batch_stride(args.batch_stride) { } + }; + +private: + ElementVector *broadcast_ptr; + BroadcastFragment broadcast_fragment; ///< Array holds the loaded broadcast fragment + MatrixCoord threadblock_offset_; + int thread_idx_; + MatrixCoord problem_size; + int64_t batch_stride_; + +public: + /// Constructs the function object + CUTLASS_HOST_DEVICE + VisitorOpRowBroadcast( + Params const ¶ms, + SharedStorage &shared_storage, + int thread_idx, + MatrixCoord threadblock_offset, + MatrixCoord problem_size + ): + broadcast_ptr(params.broadcast_ptr + threadblock_offset.column()), + threadblock_offset_(threadblock_offset), + thread_idx_(thread_idx), + problem_size(problem_size), + batch_stride_(params.batch_stride) { } + + CUTLASS_DEVICE + void set_batch_index(int batch_idx) { + broadcast_ptr += batch_idx * batch_stride_; + } + + CUTLASS_DEVICE + void begin_epilogue() { + // load broadcast fragment + load_broadcast_fragment_(); + } + + CUTLASS_DEVICE + void begin_step(int step_idx) {} + + CUTLASS_DEVICE + void begin_row(int row_idx) {} + + CUTLASS_DEVICE + VisitAccessType visit( + int iter_idx, + int row_idx, + int column_idx, + int frag_idx, + AccumulatorAccessType const &accum + ) { + VisitAccessType* broadcast_fragment_ = reinterpret_cast(&broadcast_fragment); + return broadcast_fragment_[column_idx]; + } + + CUTLASS_DEVICE + void end_row(int row_idx) { } + + CUTLASS_DEVICE + void end_step(int step_idx) { } + + CUTLASS_DEVICE + void end_epilogue() { } + +private: + + CUTLASS_DEVICE + void load_broadcast_fragment_() { + + broadcast_fragment.clear(); + + // If no pointer is supplied, set with all zeros and avoid memory accesses + if (!broadcast_ptr) { + return; + } + + int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column(); + + int thread_column_idx = threadblock_offset_.column() + thread_initial_column; + broadcast_ptr += thread_initial_column; + + NumericArrayConverter converter; + using AccessType = AlignedArray; + using AccessFragmentType = Array; + + AccessFragmentType *frag_ptr = reinterpret_cast(&broadcast_fragment); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) { + + AccessType loaded; + + loaded.clear(); + + if (thread_column_idx < problem_size.column()) { + loaded = *reinterpret_cast(broadcast_ptr); + } + + AccessFragmentType cvt = converter(loaded); + frag_ptr[j] = cvt; + + thread_column_idx += ThreadMap::Delta::kColumn; + broadcast_ptr += ThreadMap::Delta::kColumn; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_reduction.h b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_reduction.h new file mode 100644 index 00000000..f5387dc2 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_reduction.h @@ -0,0 +1,320 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this layernormware without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + + \brief A file contains the epilogue visitor Op with reduction over rows in CTA +*/ + +#pragma once +#include "cutlass/cutlass.h" +#include "stdio.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Epilogue Visitor operator for the following computation: +/// +/// ElementReductionAccumulator R[i] = \sum_i ElementReductionAccumulator(T[i][j]) +/// device memory <- ElementReduction(R[i]) +/// +template < + typename ThreadblockShape_, /// Threadblock shape + typename ElementAccumulator_, ///< Data type of the Accumulator + typename ElementReduction_, ///< Data type of the output reduction in device memory + typename ElementReductionAccumulator_ , ///< Data type to accumulate reduction in smem and register + typename OutputTileIterator_, ///< Tile Iterator type + typename Visitor_ ///< preceeding visitor op +> +class VisitorOpRowReduction { +public: + using ElementAccumulator = ElementAccumulator_; + using ElementReductionAccumulator = ElementReductionAccumulator_; + using ElementReduction = ElementReduction_; + using OutputTileIterator = OutputTileIterator_; + using ThreadblockShape = ThreadblockShape_; + using Visitor = Visitor_; + + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + // TODO: generalize the reduction op + using ReductionOp = cutlass::plus>; + using ReductionOpScalar = cutlass::plus; + using ElementOutput = typename OutputTileIterator::Element; + + /// Fragment type returned from Visitor + using VisitAccessTypeVisitor = typename Visitor::VisitAccessType; + using ElementVisitor = typename VisitAccessTypeVisitor::Element; + + using VisitAccessType = VisitAccessTypeVisitor; + + /// Fragment type of accumulator + using AccumulatorAccessType = Array; + + /// Fragment type of redcution + using ReductionAccumulatorAccessType = Array; + + /// Thread map used by output tile iterators + using ThreadMap = typename OutputTileIterator::ThreadMap; + /// Used for the reduction + struct ReductionDetail { + + /// Number of threads per warp + static int const kWarpSize = 32; + + /// Number of distinct scalar column indices handled by each thread + static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess; + + /// Number of distinct scalar row indices handled by each thread + static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn; + + /// Number of threads per threadblock + static int const kThreadCount = ThreadMap::kThreads; + + /// Number of distinct threads per row of output tile + static int const kThreadsPerRow = ThreadblockShape::kN / kColumnsPerThread; + + /// Half number of threads per row used for cross-thread reduction + static int const kHalfThreadsPerRow = (kThreadsPerRow >> 1); + + /// Number of distinct threads which must be reduced during the final reduction phase within the threadblock + static int const kThreadRows = kThreadCount / kThreadsPerRow; + }; + + /// Shared storage + struct SharedStorage { + typename Visitor::SharedStorage storage_visitor; + CUTLASS_HOST_DEVICE + SharedStorage() { } + }; + + /// Host-constructable Argument structure + struct Arguments { + ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory + int64_t batch_stride; + typename Visitor::Arguments visitor_arg; ///< Argument type of visitor + + /// Method + CUTLASS_HOST_DEVICE + Arguments(): reduction_ptr(nullptr) { } + + CUTLASS_HOST_DEVICE + Arguments( + ElementReduction *reduction_ptr, + int64_t batch_stride, + typename Visitor::Arguments visitor_arg + ): + reduction_ptr(reduction_ptr), + batch_stride(batch_stride), + visitor_arg(visitor_arg) + { } + }; + + /// Param structure + struct Params { + ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory + int64_t batch_stride; + typename Visitor::Params visitor_param; ///< Argument type of visitor + + /// Method + CUTLASS_HOST_DEVICE + Params(): reduction_ptr(nullptr) { } + + CUTLASS_HOST_DEVICE + Params(Arguments const &args): + reduction_ptr(args.reduction_ptr), + batch_stride(args.batch_stride), + visitor_param(args.visitor_arg) + { } + }; + +private: + ElementReduction *reduction_output_ptr_; ///< Pointer to the reduction tensor in device memory + ElementReductionAccumulator reduction_accum; + Visitor visitor_; ///< visitor + int thread_idx_; + MatrixCoord threadblock_offset; + MatrixCoord problem_size_; + + int thread_start_row_; /// used to identify + int state_[3]; /// used to track row iterator + int thread_offset_row_; + int64_t batch_stride_; +public: + /// Constructs the function object + CUTLASS_HOST_DEVICE + VisitorOpRowReduction( + Params const ¶ms, + SharedStorage &shared_storage, + int thread_idx, + MatrixCoord threadblock_offset, + MatrixCoord problem_size + ): + visitor_(params.visitor_param, shared_storage.storage_visitor, + thread_idx, threadblock_offset, problem_size), + reduction_output_ptr_(params.reduction_ptr), + thread_idx_(thread_idx), + threadblock_offset(threadblock_offset), + problem_size_(problem_size), + thread_start_row_(ThreadMap::initial_offset(thread_idx).row() + threadblock_offset.row()), + batch_stride_(params.batch_stride) + { + state_[0] = state_[1] = state_[2] = 0; + } + + CUTLASS_DEVICE + void set_batch_index(int batch_idx) { + reduction_output_ptr_ += batch_idx * batch_stride_; + visitor_.set_batch_index(batch_idx); + } + + CUTLASS_DEVICE + void begin_epilogue() { + visitor_.begin_epilogue(); + } + + CUTLASS_DEVICE + void begin_step(int step_idx) { + visitor_.begin_step(step_idx); + } + + CUTLASS_DEVICE + void begin_row(int row_idx) { + visitor_.begin_row(row_idx); + + reduction_accum = ElementReductionAccumulator(0); + } + + CUTLASS_DEVICE + VisitAccessType visit( + int iter_idx, + int row_idx, + int column_idx, + int frag_idx, + AccumulatorAccessType const &accum + ) { + /// Get result from visitor + VisitAccessTypeVisitor result = visitor_.visit(iter_idx, row_idx, column_idx, frag_idx, accum); + + thread_offset_row_ = thread_start_row_ + ThreadMap::iteration_offset(frag_idx).row(); + + ReductionOpScalar reduction_op; + + ElementReductionAccumulator reduction_accum_ = reduction(result); + + // After performing the in-thread reduction, we then perform cross-thread / in-warp reduction + CUTLASS_PRAGMA_UNROLL + for (int i = ReductionDetail::kHalfThreadsPerRow; i > 0; i >>= 1) { + reduction_accum_ = reduction_op(reduction_accum_, __shfl_xor_sync(0xFFFFFFFF, reduction_accum_, i)); + } + reduction_accum = reduction_op(reduction_accum, reduction_accum_); + + return result; + } + + CUTLASS_DEVICE + void end_row(int row_idx) { + visitor_.end_row(row_idx); + NumericConverter output_converter; + + bool is_write_thread = (thread_offset_row_ < problem_size_.row() && (thread_idx_ % ReductionDetail::kThreadsPerRow) == 0); + int row_offset = thread_offset_row_ + threadblock_offset.column() / ThreadblockShape::kN * problem_size_.row(); + + ElementReduction *curr_ptr_reduction = reduction_output_ptr_ + row_offset; + + arch::global_store( + output_converter(reduction_accum), + (void *)curr_ptr_reduction, + is_write_thread); + } + + CUTLASS_DEVICE + void end_step(int step_idx) { + visitor_.end_step(step_idx); + + // run operator ++ + ++state_[0]; + + thread_start_row_ += ThreadMap::Shape::kRow; + if (state_[0] == ThreadMap::Count::kRow) { + state_[0] = 0; + ++state_[1]; + thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * + ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + state_[1] = 0; + ++state_[2]; + thread_start_row_ += ThreadMap::Count::kGroup * + ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + } + } + } + + } + + CUTLASS_DEVICE + void end_epilogue() { + visitor_.end_epilogue(); + } + +private: + + CUTLASS_DEVICE + ElementReductionAccumulator reduction(VisitAccessTypeVisitor const& result) { + ElementReductionAccumulator sum_ = ElementReductionAccumulator(0); + + ReductionOpScalar reduction_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < VisitAccessTypeVisitor::kElements; ++i) { + sum_ = reduction_op(sum_, result[i]); + } + + return sum_; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_tensor_input.h b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_tensor_input.h new file mode 100644 index 00000000..5434912b --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_tensor_input.h @@ -0,0 +1,188 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this layernormware without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + + \brief A file contains the epilogue visitor Op with Tensor Output +*/ + +#pragma once +#include "cutlass/cutlass.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Epilogue Visitor operator for the following computation: +/// +/// ElementInput C <- device memory +/// +/// It can only be a leaf node in the epilogue tree +template < + typename ElementAccumulator_, ///< Data type of the Accumulator + typename InputTileIterator_ ///< Tile iterator type to read the tensor +> +class VisitorOpTensorInput { +public: + using ElementAccumulator = ElementAccumulator_; + using InputTileIterator = InputTileIterator_; + + static int const kElementsPerAccess = InputTileIterator::kElementsPerAccess; + using ElementInput = typename InputTileIterator::Element; + + using VisitAccessType = Array; + + /// Fragment type of accumulator + using AccumulatorAccessType = Array; + + struct SharedStorage { + CUTLASS_HOST_DEVICE + SharedStorage() { } + }; + + /// Host-constructable Argument structure + struct Arguments { + ElementInput *input_ptr; ///< Pointer to the input tensor in device memory + int ldt; ///< Leading dimension of the input tensor operand + int64_t batch_stride; ///< batch stride for batched GEMM + + /// Methods + CUTLASS_HOST_DEVICE + Arguments(): input_ptr(nullptr) { } + + CUTLASS_HOST_DEVICE + Arguments( + ElementInput *input_ptr, + int ldt, int64_t batch_stride + ): + input_ptr(input_ptr), + ldt(ldt), + batch_stride(batch_stride) + { } + }; + + /// Param structure + struct Params { + typename InputTileIterator::Params params_input; + ElementInput *input_ptr; + int64_t batch_stride; + + /// Method + CUTLASS_HOST_DEVICE + Params(): + input_ptr(nullptr) { } + + CUTLASS_HOST_DEVICE + Params(Arguments const &args): + params_input(args.ldt), + input_ptr(args.input_ptr), + batch_stride(args.batch_stride) + { } + }; + +private: + InputTileIterator iterator_T_; + typename InputTileIterator::Fragment fragment_T_; + MatrixCoord problem_size; + int64_t batch_stride_; + +public: + /// Constructs the function object + CUTLASS_HOST_DEVICE + VisitorOpTensorInput( + Params const ¶ms, + SharedStorage &shared_storage, + int thread_idx, + MatrixCoord threadblock_offset, + MatrixCoord problem_size + ): + iterator_T_( + InputTileIterator( + params.params_input, + params.input_ptr, + problem_size, + thread_idx, + threadblock_offset + ) + ), + problem_size(problem_size), + batch_stride_(params.batch_stride) { } + + CUTLASS_DEVICE + void set_batch_index(int batch_idx) { + iterator_T_.add_pointer_offset(batch_idx * batch_stride_); + } + + CUTLASS_DEVICE + void begin_epilogue() { } + + CUTLASS_DEVICE + void begin_step(int step_idx) { + fragment_T_.clear(); + iterator_T_.load(fragment_T_); + ++iterator_T_; + } + + CUTLASS_DEVICE + void begin_row(int row_idx) { } + + CUTLASS_DEVICE + VisitAccessType visit( + int iter_idx, + int row_idx, + int column_idx, + int frag_idx, + AccumulatorAccessType const &accum + ) { + VisitAccessType source = reinterpret_cast(&fragment_T_)[frag_idx]; + return source; + } + + CUTLASS_DEVICE + void end_row(int row_idx) { } + + CUTLASS_DEVICE + void end_step(int step_idx) { } + + CUTLASS_DEVICE + void end_epilogue() { } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_tensor_output.h b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_tensor_output.h new file mode 100644 index 00000000..d2affd3c --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_tensor_output.h @@ -0,0 +1,240 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this layernormware without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + + \brief A file contains the epilogue visitor Op with Tensor Output +*/ + +#pragma once +#include "cutlass/cutlass.h" +#include "stdio.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Epilogue Visitor operator for the following computation: +/// +/// ElementOutput T = ElementOutput(Visitor) +/// T-> device memory +/// +template < + typename ElementAccumulator_, ///< Data type of the Accumulator + typename OutputTileIterator_, ///< Tile iterator type to write the tensor + typename Visitor_ ///< Child visitor that produces the output tensor +> +class VisitorOpTensorOutput { +public: + using ElementAccumulator = ElementAccumulator_; + using OutputTileIterator = OutputTileIterator_; + + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + using ElementOutput = typename OutputTileIterator::Element; + + using Visitor = Visitor_; + + /// Fragment type returned from Visitor + using VisitAccessTypeVisitor = typename Visitor::VisitAccessType; + using ElementVisitor = typename VisitAccessTypeVisitor::Element; + + using VisitAccessType = VisitAccessTypeVisitor; + + /// Fragment type of accumulator + using AccumulatorAccessType = Array; + + /// Fragment type of output + using OutputAccessType = Array; + + static_assert(kElementsPerAccess==VisitAccessTypeVisitor::kElements, "kElementsPerAccess mismatches with Visitor"); + + struct SharedStorage { + typename Visitor::SharedStorage storage_visitor; + + CUTLASS_HOST_DEVICE + SharedStorage() { } + }; + + /// Host-constructable Argument structure + struct Arguments { + ElementOutput *output_ptr; ///< Pointer to the output tensor in device memory + int ldt; ///< Leading dimension of the output tensor operand + int64_t batch_stride; ///< batch stride + typename Visitor::Arguments visitor_arg; ///< Argument type of visitor + + /// Methods + CUTLASS_HOST_DEVICE + Arguments(): output_ptr(nullptr) { } + + CUTLASS_HOST_DEVICE + Arguments( + ElementOutput *output_ptr, + int ldt, + int64_t batch_stride, + typename Visitor::Arguments visitor_arg + ): + output_ptr(output_ptr), + ldt(ldt), + batch_stride(batch_stride), + visitor_arg(visitor_arg) + { } + }; + + /// Param structure + struct Params { + typename OutputTileIterator::Params params_output; + ElementOutput *output_ptr; + int64_t batch_stride; + typename Visitor::Params visitor_param; + + /// Method + CUTLASS_HOST_DEVICE + Params(): + output_ptr(nullptr) { } + + CUTLASS_HOST_DEVICE + Params(Arguments const &args): + params_output(args.ldt), + output_ptr(args.output_ptr), + batch_stride(args.batch_stride), + visitor_param(args.visitor_arg) + { } + }; + +private: + OutputTileIterator iterator_T_; + typename OutputTileIterator::Fragment fragment_T_; + MatrixCoord problem_size; + Visitor visitor_; + int64_t batch_stride_; + +public: + + /// Constructs the function object + CUTLASS_HOST_DEVICE + VisitorOpTensorOutput( + Params const ¶ms, + SharedStorage &shared_storage, + int thread_idx, + MatrixCoord threadblock_offset, + MatrixCoord problem_size + ): + visitor_(params.visitor_param, shared_storage.storage_visitor, thread_idx, threadblock_offset, problem_size), + iterator_T_( + OutputTileIterator( + params.params_output, + params.output_ptr, + problem_size, + thread_idx, + threadblock_offset + ) + ), + problem_size(problem_size), + batch_stride_(params.batch_stride) { } + + CUTLASS_DEVICE + void set_batch_index(int batch_idx) { + iterator_T_.add_pointer_offset(batch_idx * batch_stride_); + visitor_.set_batch_index(batch_idx); + } + + CUTLASS_DEVICE + void begin_epilogue() { + visitor_.begin_epilogue(); + } + + CUTLASS_DEVICE + void begin_step(int step_idx) { + fragment_T_.clear(); + visitor_.begin_step(step_idx); + } + + CUTLASS_DEVICE + void begin_row(int row_idx) { + visitor_.begin_row(row_idx); + } + + CUTLASS_DEVICE + VisitAccessType visit( + int iter_idx, + int row_idx, + int column_idx, + int frag_idx, + AccumulatorAccessType const &accum + ) { + /// Get result from visitor + VisitAccessTypeVisitor result = visitor_.visit(iter_idx, row_idx, column_idx, frag_idx, accum); + + // Column guard + MatrixCoord thread_offset_ = iterator_T_.thread_start() + OutputTileIterator::ThreadMap::iteration_offset(frag_idx); + bool column_guard = (thread_offset_.column() < problem_size.column()); + + if (column_guard) { + NumericArrayConverter output_converter; + OutputAccessType &output = reinterpret_cast(&fragment_T_)[frag_idx]; + output = output_converter(result); + } + + return result; + } + + CUTLASS_DEVICE + void end_row(int row_idx) { + visitor_.end_row(row_idx); + } + + CUTLASS_DEVICE + void end_step(int step_idx) { + visitor_.end_step(step_idx); + iterator_T_.store(fragment_T_); + ++iterator_T_; + } + + CUTLASS_DEVICE + void end_epilogue() { + visitor_.end_epilogue(); + } + +}; + + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_unary.h b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_unary.h new file mode 100644 index 00000000..aeab725c --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_unary.h @@ -0,0 +1,226 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this layernormware without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + + \brief A file contains the epilogue visitor Op with Unary operation +*/ + +#pragma once +#include "cutlass/cutlass.h" +#include "unary_ops.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +/// Epilogue Visitor operator for the following computation: +/// +/// ElementCompute alpha; +/// ElementCompute beta; +/// ElementCompute C = UnaryOp(ElementCompute(Visitor)) +/// Return C; +/// +template < + typename ElementAccumulator_, ///< Data type of the Accumulator + typename ElementCompute_, ///< Data type used to compute linear combination + int kElementsPerAccess_, ///< Number of elements computed per operation + typename Visitor_, ///< Child node + template typename UnaryOp_ +> +class VisitorOpUnary{ +public: + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + static int const kElementsPerAccess = kElementsPerAccess_; + + using Visitor = Visitor_; + + /// Fragment type returned from Visitor.visit + using VisitAccessTypeVisitor = typename Visitor::VisitAccessType; + using ElementVisit = typename VisitAccessTypeVisitor::Element; + + /// Fragment type returned by this visitor + using VisitAccessType = Array; + + /// Fragment type of accumulator + using AccumulatorAccessType = Array; + + /// Combination Op TODO: generalize this + using UnaryOp = UnaryOp_; + + static_assert(kElementsPerAccess==VisitAccessTypeVisitor::kElements, "kElementsPerAccess mismatches with Visitor"); + + /// SMEM buffer class required in the epilogue visitor + struct SharedStorage { + typename Visitor::SharedStorage storage_visitor; + + CUTLASS_HOST_DEVICE + SharedStorage() {} + }; + + + /// Host-constructable Arguments structure + struct Arguments { + typename UnaryOp::Arguments unary_arg; + typename Visitor::Arguments visitor_arg; ///< Argument type for visitor + + // + // Methods + // + CUTLASS_HOST_DEVICE + Arguments():unary_arg() { } + + CUTLASS_HOST_DEVICE + Arguments( + typename UnaryOp::Arguments unary_arg, + typename Visitor::Arguments visitor_arg + ): + unary_arg(unary_arg), + visitor_arg(visitor_arg) + { } + }; + + /// Parameter structure + struct Params { + typename UnaryOp::Params unary_param; + typename Visitor::Params visitor_param; ///< Argument type for visitor + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params():unary_param() { } + + CUTLASS_HOST_DEVICE + Params(Arguments const &args): + unary_param(args.unary_arg), + visitor_param(args.visitor_arg) + { } + }; + +private: + // + // Data members + // + UnaryOp unary_op; + + Visitor visitor_op; + +public: + + /// Constructs the function object + CUTLASS_HOST_DEVICE + VisitorOpUnary( + Params const ¶ms, + SharedStorage &shared_storage, + int thread_idx, + MatrixCoord threadblock_offset, + MatrixCoord problem_size + ): + unary_op(params.unary_param), + visitor_op(params.visitor_param, shared_storage.storage_visitor, thread_idx, threadblock_offset, problem_size) + { } + + CUTLASS_DEVICE + void set_batch_index(int batch_idx) { + visitor_op.set_batch_index(batch_idx); + } + + CUTLASS_DEVICE + void begin_epilogue() { + if (unary_op.guard()) visitor_op.begin_epilogue(); + } + + CUTLASS_DEVICE + void begin_step(int step_idx) { + if (unary_op.guard()) visitor_op.begin_step(step_idx); + } + + CUTLASS_DEVICE + void begin_row(int row_idx) { + if (unary_op.guard()) visitor_op.begin_row(row_idx); + } + + CUTLASS_DEVICE + VisitAccessType visit( + int iter_idx, + int row_idx, + int column_idx, + int frag_idx, + AccumulatorAccessType const &accum + ) { + /// Get result from visitor A and visitor B + VisitAccessTypeVisitor result; + + if (unary_op.guard()){ + result = visitor_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum); + } else { + result.clear(); + } + + /// Type conversion + NumericArrayConverter source_converter; + + cutlass::multiplies multiply_op; + + return unary_op(source_converter(result)); + } + + CUTLASS_DEVICE + void end_row(int row_idx) { + if (unary_op.guard()) visitor_op.end_row(row_idx); + } + + CUTLASS_DEVICE + void end_step(int step_idx) { + if (unary_op.guard()) visitor_op.end_step(step_idx); + } + + CUTLASS_DEVICE + void end_epilogue() { + if (unary_op.guard()) visitor_op.end_epilogue(); + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_with_layernorm.h b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_with_layernorm.h new file mode 100644 index 00000000..67ea478c --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_with_layernorm.h @@ -0,0 +1,481 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this layernormware without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief A file contains all functioning classes needed by GemmLayernorm. + + GemmLayernorm example = GEMM0 with partial reduction fused in epilogue (EpilogueVisitorLayerNorm) + + lightweight full reduction kernel (ApplyFinalReduction) + + GEMM1 with elemenwise operations fused in mainloop (GemmLayernormMainloopFusion) + +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/kernel/default_gemm_complex.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ThreadblockShape_, + int ThreadCount, + typename OutputTileIterator_, + typename AccumulatorTile_, + typename ElementAccumulator_, + typename ElementVariance_, + typename ElementMean_, + typename ElementLayernormCompute_, + typename ElementwiseFunctor_, + bool IsShiftedVariance_ = false +> +class EpilogueVisitorLayerNorm { +public: + + using ElementVariance = ElementVariance_; + using ElementMean = ElementMean_; + using ElementLayernormCompute = ElementLayernormCompute_; + + using AccumulatorTile = AccumulatorTile_; + + using ThreadblockShape = ThreadblockShape_; + static int const kThreadCount = ThreadCount; + + using OutputTileIterator = OutputTileIterator_; + using ElementwiseFunctor = ElementwiseFunctor_; + + static int const kIterations = OutputTileIterator::kIterations; + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + static int const kRowIterations = OutputTileIterator::ThreadMap::Iterations::kRow; + + static int const kThreads = OutputTileIterator::ThreadMap::kThreads; + + static bool const kIsShiftedVariance = IsShiftedVariance_; + + using ElementOutput = typename OutputTileIterator::Element; + + static int const kDeltaRow = OutputTileIterator::ThreadMap::Delta::kRow; + + /// Array type used in Shift-K Layernorm + static int const kRowAccessCount = kIterations * kRowIterations; + + using ConvertedShiftFragment = Array; + + // Conducts manual transpose externally (already supported) for column major + using LayoutOutput = cutlass::layout::RowMajor; + + using ElementAccumulator = ElementAccumulator_; + + using AccumulatorFragment = Array; + using LayernormFragment = Array; + using OutputVector = Array; + using TensorRefD = TensorRef; + + static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth; + static int const kThreadsInColumn = kThreads / kThreadsPerRow; + static int const kHalfThreadsPerRow = (kThreadsPerRow >> 1); + + /// Argument structure + struct Arguments { + + typename ElementwiseFunctor::Params elementwise; + ElementVariance *ptr_Variance; + ElementMean *ptr_Mean; + ElementOutput *ptr_Shifted_K; + MatrixCoord extent; + + // + // Methods + // + Arguments(): + ptr_Variance(nullptr), + ptr_Mean(nullptr), + ptr_Shifted_K(nullptr) + { + + } + + Arguments( + typename ElementwiseFunctor::Params elementwise_, + ElementVariance *ptr_Variance, + ElementMean *ptr_Mean_, + ElementOutput *ptr_Shifted_K_ = nullptr, + MatrixCoord extent = MatrixCoord(0, 0) + ): + elementwise(elementwise_), + ptr_Variance(ptr_Variance), + ptr_Mean(ptr_Mean_), + ptr_Shifted_K(ptr_Shifted_K_), + extent(extent) + { + + } + }; + + struct Params { + + typename ElementwiseFunctor::Params elementwise; + ElementVariance *ptr_Variance; + ElementMean *ptr_Mean; + ElementOutput *ptr_Shifted_K; + MatrixCoord extent; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params(): + ptr_Variance(nullptr), + ptr_Mean(nullptr) + { + + } + + CUTLASS_HOST_DEVICE + Params(Arguments const &args): + elementwise(args.elementwise), + ptr_Variance(args.ptr_Variance), + ptr_Mean(args.ptr_Mean), + ptr_Shifted_K(args.ptr_Shifted_K), + extent(args.extent) + { + + } + }; + + /// Shared storage + struct SharedStorage { + + }; + +private: + + Params const & params_; + SharedStorage & shared_storage_; + MatrixCoord extent_; + ElementwiseFunctor elementwise_; + + OutputTileIterator iterator_C_; + OutputTileIterator iterator_D_; + typename OutputTileIterator::Fragment fragment_C_; + typename OutputTileIterator::Fragment fragment_D_; + + ElementAccumulator alpha_; + ElementAccumulator beta_; + ConvertedShiftFragment shift_k_frag_; + + ElementLayernormCompute accum_sum_square_; + ElementLayernormCompute accum_sum_element_; + int thread_idx_; + + MatrixCoord thread_offset_; + + gemm::GemmCoord threadblock_tile_offset_; + +public: + + CUTLASS_DEVICE + EpilogueVisitorLayerNorm( + Params const ¶ms, ///< Parameters routed to the epilogue + SharedStorage &shared_storage, ///< Shared storage needed by the functors here + MatrixCoord threadblock_offset, + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + OutputTileIterator destination_iterator, ///< Tile iterator for destination + OutputTileIterator source_iterator ///< Threadblock tile coordinate in GEMMM + ): + params_(params), + shared_storage_(shared_storage), + elementwise_(params.elementwise), + extent_(params.extent), + iterator_C_(source_iterator), + iterator_D_(destination_iterator), + threadblock_tile_offset_(threadblock_tile_offset), + thread_idx_(thread_idx) + { + alpha_ = (params.elementwise.alpha_ptr ? *params.elementwise.alpha_ptr : params.elementwise.alpha); + beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta); + + if (beta_ == ElementAccumulator()) { + iterator_C_.clear_mask(); + } + } + + /// Helper to indicate split-K behavior + CUTLASS_DEVICE + void set_k_partition( + int split_k_index, ///< Index of this threadblock within split-K partitioned scheme + int split_k_slices) { ///< Total number of split-K slices + + } + + /// Called to set the batch index + CUTLASS_DEVICE + void set_batch_index(int batch_idx) { + + } + + /// Called at the start of the epilogue just before iterating over accumulator slices + CUTLASS_DEVICE + void begin_epilogue() { + + // If shift-K feature is enabled, we load shift-k fragment + // at the very beginning of an epilogue + if (kIsShiftedVariance && params_.ptr_Shifted_K != nullptr) { + shift_k_frag_.clear(); + int thread_offset_row_base = iterator_D_.thread_start_row(); + + CUTLASS_PRAGMA_UNROLL + for (int iter_idx = 0; iter_idx < kIterations; ++iter_idx) { + int step_offset = iter_idx * OutputTileIterator::Shape::kRow; + CUTLASS_PRAGMA_UNROLL + for (int rid = 0; rid < kRowIterations; ++rid) { + int row_step_offset = rid * kDeltaRow; + int row_offset = thread_offset_row_base + step_offset + row_step_offset; + bool is_load = (row_offset < extent_.row()); + shift_k_frag_[iter_idx * kRowIterations + rid] = load_shift_k_(row_offset, is_load); + } + + } + + } + + } + + /// Called at the start of one step before starting accumulator exchange + CUTLASS_DEVICE + void begin_step(int step_idx) { + fragment_D_.clear(); + + if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { + fragment_C_.clear(); + iterator_C_.load(fragment_C_); + ++iterator_C_; + } + } + + /// Called at the start of a row + CUTLASS_DEVICE + void begin_row(int row_idx) { + /// set the accumulator to 0 + accum_sum_element_ = ElementLayernormCompute(0); + accum_sum_square_ = ElementLayernormCompute(0); + } + + /// Called after accumulators have been exchanged for each accumulator vector + CUTLASS_DEVICE + void visit( + int iter_idx, + int row_idx, + int column_idx, + int frag_idx, + AccumulatorFragment const &accum) { + + using Mul = cutlass::multiplies; + using Minus = cutlass::minus; + using Exp = cutlass::fast_exp_op; + + Minus minus; + Mul mul; + Exp exponential; + + LayernormFragment result; + + thread_offset_ = + iterator_D_.thread_start() + + OutputTileIterator::ThreadMap::iteration_offset(frag_idx); + + NumericArrayConverter source_converter; + OutputVector &source_vector = reinterpret_cast(&fragment_C_)[frag_idx]; + + bool column_guard = (thread_offset_.column() < extent_.column()); + + if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { + result = source_converter(elementwise_(accum)); + }else{ + result = source_converter(elementwise_(accum, source_vector)); + } + + + ElementLayernormCompute inv_scalar = cutlass::constants::one() / ElementLayernormCompute(extent_.column()); + + // Fragment is cleared for non-reachable columns so no need to check against column guard + ElementLayernormCompute accum_sum_element_tmp = element_sum_accumulator_(result); + + // Square sum is different. Non-reachable columns should've been computed for shift-k + // Otherwise we will incorrectly have some extra k^2 added into square sum. + ElementLayernormCompute accum_sum_square_tmp = ElementLayernormCompute(0); + + if (column_guard) { + accum_sum_square_tmp = (kIsShiftedVariance) ? \ + square_sum_accumulator_(result, shift_k_frag_[iter_idx * kRowIterations + row_idx]) : \ + square_sum_accumulator_(result); + } + + accum_sum_element_tmp *= inv_scalar; + accum_sum_square_tmp *= inv_scalar; + + // After performing the in-thread reduction, we then perform cross-thread / in-warp reduction + CUTLASS_PRAGMA_UNROLL + for (int i = kHalfThreadsPerRow; i > 0; i >>= 1) { + accum_sum_element_tmp += __shfl_xor_sync(0xFFFFFFFF, accum_sum_element_tmp, i); + accum_sum_square_tmp += __shfl_xor_sync(0xFFFFFFFF, accum_sum_square_tmp, i); + } + accum_sum_element_ += accum_sum_element_tmp; + accum_sum_square_ += accum_sum_square_tmp; + + // Convert to the output + NumericArrayConverter output_converter; + OutputVector &output = reinterpret_cast(&fragment_D_)[frag_idx]; + output = output_converter(result); + } + + /// Called at the start of a row + CUTLASS_DEVICE + void end_row(int row_idx) { + + using ConvertVarianceOutput = cutlass::NumericConverter; + using ConvertMeanOutput = cutlass::NumericConverter; + + ConvertVarianceOutput convert_variance_output; + ConvertMeanOutput convert_mean_output; + + bool is_write_thread = (thread_offset_.row() < extent_.row() && (threadIdx.x % kThreadsPerRow) == 0); + int row_offset = thread_offset_.row() + threadblock_tile_offset_.n() * extent_.row(); + + ElementVariance *curr_ptr_sum_square = params_.ptr_Variance + row_offset; + ElementMean *curr_ptr_element_sum = params_.ptr_Mean + row_offset; + + arch::global_store( + convert_variance_output(accum_sum_square_), + (void *)curr_ptr_sum_square, + is_write_thread); + + arch::global_store( + convert_mean_output(accum_sum_element_), + (void *)curr_ptr_element_sum, + is_write_thread); + } + + /// Called after all accumulator elements have been visited + CUTLASS_DEVICE + void end_step(int step_idx) { + + iterator_D_.store(fragment_D_); + ++iterator_D_; + } + + /// Called after all steps have been completed + CUTLASS_DEVICE + void end_epilogue() { + + } + +private: + + CUTLASS_DEVICE + ElementLayernormCompute load_shift_k_(int row_offset, bool is_load) { + using ConvertShiftK = cutlass::NumericConverter; + ConvertShiftK convert_shift_k; + ElementOutput shift_k_val; + + // Computes the address to load shift_k element + ElementOutput *curr_ptr_shift_k = params_.ptr_Shifted_K + row_offset; + // Conditionally loads from global memory + arch::global_load(shift_k_val, (void *)curr_ptr_shift_k, is_load); + // Converts data type to return + ElementLayernormCompute converted_shift_k_val = convert_shift_k(shift_k_val); + + return converted_shift_k_val; + } + + CUTLASS_DEVICE + ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum) { + ElementLayernormCompute sum_ = ElementLayernormCompute(0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < LayernormFragment::kElements; ++i) { + auto accum_ = accum[i]; + sum_ += accum_ * accum_; + } + + return sum_; + } + + CUTLASS_DEVICE + ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum, ElementLayernormCompute shift_k_val) { + ElementLayernormCompute sum_ = ElementLayernormCompute(0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < LayernormFragment::kElements; ++i) { + auto accum_ = accum[i] - shift_k_val; + sum_ += accum_ * accum_; + } + + return sum_; + } + + CUTLASS_DEVICE + ElementLayernormCompute element_sum_accumulator_(LayernormFragment const &accum) { + ElementLayernormCompute sum_ = ElementLayernormCompute(0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < LayernormFragment::kElements; ++i) { + sum_ += accum[i]; + } + + return sum_; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm_universal_with_visitor.h b/tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm_universal_with_visitor.h new file mode 100644 index 00000000..6a840ce1 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm_universal_with_visitor.h @@ -0,0 +1,692 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" + +#include "cutlass/layout/matrix.h" + +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function +> +struct GemmUniversalwithEpilogueVisitor { +public: + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueVisitor = typename Epilogue::Visitor; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename EpilogueVisitor::ElementOutput; + using LayoutC = typename EpilogueVisitor::OutputTileIterator::Layout; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max( + 128 / sizeof_bits::value, + 128 / sizeof_bits::value + ); + + // + // Structures + // + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + int batch_count; + + typename EpilogueVisitor::Arguments epilogue_visitor; + + void const * ptr_A; + void const * ptr_B; + void const * ptr_C; + void * ptr_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_C; + int64_t batch_stride_D; + + typename LayoutA::Stride stride_a; + typename LayoutB::Stride stride_b; + typename LayoutC::Stride stride_c; + typename LayoutC::Stride stride_d; + + typename LayoutA::Stride::LongIndex lda; + typename LayoutB::Stride::LongIndex ldb; + typename LayoutC::Stride::LongIndex ldc; + typename LayoutC::Stride::LongIndex ldd; + + int const * ptr_gather_A_indices; + int const * ptr_gather_B_indices; + int const * ptr_scatter_D_indices; + + // + // Methods + // + + Arguments(): + mode(GemmUniversalMode::kGemm), + batch_count(1), + ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr), + ptr_gather_A_indices(nullptr), + ptr_gather_B_indices(nullptr), + ptr_scatter_D_indices(nullptr) {} + + /// constructs an arguments structure + Arguments( + GemmUniversalMode mode, + GemmCoord problem_size, + int batch_count, + typename EpilogueVisitor::Arguments epilogue_visitor, + void const * ptr_A, + void const * ptr_B, + void const * ptr_C, + void * ptr_D, + int64_t batch_stride_A, + int64_t batch_stride_B, + int64_t batch_stride_C, + int64_t batch_stride_D, + typename LayoutA::Stride stride_a, + typename LayoutB::Stride stride_b, + typename LayoutC::Stride stride_c, + typename LayoutC::Stride stride_d, + int const *ptr_gather_A_indices = nullptr, + int const *ptr_gather_B_indices = nullptr, + int const *ptr_scatter_D_indices = nullptr + ): + mode(mode), + problem_size(problem_size), + batch_count(batch_count), + epilogue_visitor(epilogue_visitor), + ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), + batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), + stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d), + ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices), + ptr_scatter_D_indices(ptr_scatter_D_indices) { + lda = 0; + ldb = 0; + ldc = 0; + ldd = 0; + CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); + } + + /// constructs an arguments structure + Arguments( + GemmUniversalMode mode, + GemmCoord problem_size, + int batch_count, + typename EpilogueVisitor::Arguments epilogue_visitor, + void const * ptr_A, + void const * ptr_B, + void const * ptr_C, + void * ptr_D, + int64_t batch_stride_A, + int64_t batch_stride_B, + int64_t batch_stride_C, + int64_t batch_stride_D, + typename LayoutA::Stride::LongIndex lda, + typename LayoutB::Stride::LongIndex ldb, + typename LayoutC::Stride::LongIndex ldc, + typename LayoutC::Stride::LongIndex ldd, + int const *ptr_gather_A_indices = nullptr, + int const *ptr_gather_B_indices = nullptr, + int const *ptr_scatter_D_indices = nullptr + ): + mode(mode), + problem_size(problem_size), + batch_count(batch_count), + epilogue_visitor(epilogue_visitor), + ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), + batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), + lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), + ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices), + ptr_scatter_D_indices(ptr_scatter_D_indices) { + stride_a = make_Coord(lda); + stride_b = make_Coord(ldb); + stride_c = make_Coord(ldc); + stride_d = make_Coord(ldd); + CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); + } + + /// Returns arguments for the transposed problem + Arguments transposed_problem() const { + Arguments args(*this); + + std::swap(args.problem_size.m(), args.problem_size.n()); + std::swap(args.ptr_A, args.ptr_B); + std::swap(args.lda, args.ldb); + std::swap(args.stride_a, args.stride_b); + std::swap(args.batch_stride_A, args.batch_stride_B); + std::swap(args.ptr_gather_A_indices, args.ptr_gather_B_indices); + + return args; + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename EpilogueVisitor::OutputTileIterator::Params params_C; + typename EpilogueVisitor::OutputTileIterator::Params params_D; + + typename EpilogueVisitor::Params epilogue_visitor; + + GemmUniversalMode mode; + int batch_count; + int gemm_k_size; + + void * ptr_A; + void * ptr_B; + void * ptr_C; + void * ptr_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_C; + int64_t batch_stride_D; + + int * ptr_gather_A_indices; + int * ptr_gather_B_indices; + int * ptr_scatter_D_indices; + + int *semaphore; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + swizzle_log_tile(0), + params_A(0), + params_B(0), + params_C(0), + params_D(0), + batch_count(0), + gemm_k_size(0), + mode(cutlass::gemm::GemmUniversalMode::kGemm), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + batch_stride_A(0), + batch_stride_B(0), + batch_stride_C(0), + batch_stride_D(0), + ptr_gather_A_indices(nullptr), + ptr_gather_B_indices(nullptr), + ptr_scatter_D_indices(nullptr), + semaphore(nullptr) { } + + CUTLASS_HOST_DEVICE + Params( + Arguments const &args, + cutlass::gemm::GemmCoord const & grid_tiled_shape, + int gemm_k_size, + void *workspace = nullptr + ): + problem_size(args.problem_size), + grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + params_A(args.lda ? make_Coord_with_padding(args.lda) : args.stride_a), + params_B(args.ldb ? make_Coord_with_padding(args.ldb) : args.stride_b), + params_C(args.ldc ? make_Coord_with_padding(args.ldc) : args.stride_c), + params_D(args.ldd ? make_Coord_with_padding(args.ldd) : args.stride_d), + epilogue_visitor(args.epilogue_visitor), + mode(args.mode), + batch_count(args.batch_count), + gemm_k_size(gemm_k_size), + ptr_A(const_cast(args.ptr_A)), + ptr_B(const_cast(args.ptr_B)), + ptr_C(const_cast(args.ptr_C)), + ptr_D(args.ptr_D), + batch_stride_A(args.batch_stride_A), + batch_stride_B(args.batch_stride_B), + batch_stride_C(args.batch_stride_C), + batch_stride_D(args.batch_stride_D), + ptr_gather_A_indices(const_cast(args.ptr_gather_A_indices)), + ptr_gather_B_indices(const_cast(args.ptr_gather_B_indices)), + ptr_scatter_D_indices(const_cast(args.ptr_scatter_D_indices)), + semaphore(static_cast(workspace)) { + + } + + CUTLASS_HOST_DEVICE + void update( + Arguments const &args, + void *workspace = nullptr) { + + ptr_A = const_cast(args.ptr_A); + ptr_B = const_cast(args.ptr_B); + ptr_C = const_cast(args.ptr_C); + ptr_D = args.ptr_D; + + ptr_gather_A_indices = const_cast(args.ptr_gather_A_indices); + ptr_gather_B_indices = const_cast(args.ptr_gather_B_indices); + ptr_scatter_D_indices = const_cast(args.ptr_scatter_D_indices); + + batch_stride_A = args.batch_stride_A; + batch_stride_B = args.batch_stride_B; + batch_stride_C = args.batch_stride_C; + batch_stride_D = args.batch_stride_D; + + epilogue_visitor = args.epilogue_visitor; + + semaphore = static_cast(workspace); + CUTLASS_TRACE_HOST("GemmUniversal::Params::update()"); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + typename EpilogueVisitor::SharedStorage visitor; + }; + +public: + + // + // Methods + // + + CUTLASS_DEVICE + GemmUniversalwithEpilogueVisitor() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size) { + + CUTLASS_TRACE_HOST("GemmUniversalwithEpilogueVisitor::can_implement()"); + + static int const kAlignmentA = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Epilogue::OutputTileIterator::kElementsPerAccess; + + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; + + if (platform::is_same::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } else if (platform::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } + + if (platform::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (platform::is_same::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (platform::is_same::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } else if (platform::is_same::value) { + isCMisaligned = problem_size.m() % kAlignmentC; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return can_implement(args.problem_size); + } + + static size_t get_extra_workspace_size(Arguments const &args, + cutlass::gemm::GemmCoord const &grid_tiled_shape) { + + return 0; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA *ptr_A = static_cast(params.ptr_A); + ElementB *ptr_B = static_cast(params.ptr_B); + + // + // Fetch pointers based on mode. + // + if (params.mode == GemmUniversalMode::kGemm || + params.mode == GemmUniversalMode::kGemmSplitKParallel) { + + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { + + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; + ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; + } + else if (params.mode == GemmUniversalMode::kArray) { + ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; + ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; + } + + __syncthreads(); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B{ + offset_k, + threadblock_tile_offset.n() * Mma::Shape::kN + }; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + ptr_A, + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A, + params.ptr_gather_A_indices); + + typename Mma::IteratorB iterator_B( + params.params_B, + ptr_B, + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B, + params.ptr_gather_B_indices); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma( + gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + accumulators); + + // + // Epilogue + // + + // EpilogueOutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + ElementC *ptr_C = static_cast(params.ptr_C); + ElementC *ptr_D = static_cast(params.ptr_D); + + // + // Fetch pointers based on mode. + // + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // if (params.mode == GemmUniversalMode::kGemm) { + + // // TODO: fix this order + // // If performing a reduction via split-K, fetch the initial synchronization + // if (params.grid_tiled_shape.k() > 1) { + + // // Fetch the synchronization lock initially but do not block. + // semaphore.fetch(); + + // // Indicate which position in a serial reduction the output operator is currently updating + // output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + // } + // } + + // Tile iterator loading from source tensor. + + EpilogueVisitor epilogue_visitor( + params.epilogue_visitor, + shared_storage.visitor, + threadblock_offset, + threadblock_tile_offset, + thread_idx, + params.problem_size.mn() + ); + + // if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { + // ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; + // } + if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) { + epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); + } + + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + // TODO: ??? + // if (threadblock_tile_offset.k()) { + // iterator_C = iterator_D; + // } + + semaphore.wait(threadblock_tile_offset.k()); + } + + + // Execute the epilogue operator to update the destination tensor. + epilogue(epilogue_visitor, accumulators); + + // + // Release the semaphore + // + + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/scripts/pycutlass/src/cpp/include/tensor_coord.h b/tools/library/scripts/pycutlass/src/cpp/include/tensor_coord.h index 231a21d5..605b99e4 100644 --- a/tools/library/scripts/pycutlass/src/cpp/include/tensor_coord.h +++ b/tools/library/scripts/pycutlass/src/cpp/include/tensor_coord.h @@ -50,7 +50,13 @@ void bind_tensor_coord(py::module &m) { R"pbdoc(Defines a canonical 4D coordinate used by tensor operations)pbdoc") .def(py::init(), py::arg("n"), py::arg("h"), py::arg("w"), py::arg("c"), - R"pbdoc(Helper to construct from N, H, W, and C)pbdoc"); + R"pbdoc(Helper to construct from N, H, W, and C)pbdoc") + .def("at", py::overload_cast(&cutlass::Tensor4DCoord::at), + py::arg("dim"), + R"pbdoc(Gets the index of a given Coord element)pbdoc") + .def("size", [](const cutlass::Tensor4DCoord & coord) { + return coord.at(0) * coord.at(1) * coord.at(2) * coord.at(3);}, + R"pbdoc(The size of the tensor coord)pbdoc"); py::class_>(m, "Tensor3DCoord", R"pbdoc(Defines a canonical 3D coordinate used by tensor operations)pbdoc") diff --git a/tools/library/scripts/pycutlass/src/pycutlass/__init__.py b/tools/library/scripts/pycutlass/src/pycutlass/__init__.py index 40a19433..8972c6fa 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/__init__.py @@ -1,7 +1,24 @@ -from pycutlass.type import * +import re + + +def SubstituteTemplate(template, values): + text = template + changed = True + while changed: + changed = False + for key, value in values.items(): + regex = "\\$\\{%s\\}" % key + newtext = re.sub(regex, value, text) + if newtext != text: + changed = True + text = newtext + return text + +from pycutlass.type_hint import * from pycutlass.tensor_ref import * from pycutlass.operation import * from pycutlass.epilogue import * +from pycutlass.parser import * from pycutlass.compiler import ArtifactManager from pycutlass.memory_manager import * from pycutlass.arguments import * diff --git a/tools/library/scripts/pycutlass/src/pycutlass/arguments.py b/tools/library/scripts/pycutlass/src/pycutlass/arguments.py index 9c6bc5d2..c0db206f 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/arguments.py @@ -60,6 +60,13 @@ class ArgumentBase: C: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]', D: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]', **kwargs) -> None: + + # tensor_C can be interpreted as the bias with bias=True in keyword args + if "bias" in kwargs.keys(): + self.bias = kwargs["bias"] + else: + # by default, tensor_C is not bias + self.bias = False # preprocessing input tensors if isinstance(A, np.ndarray): @@ -72,21 +79,28 @@ class ArgumentBase: self.ptr_B = self.buffer_B.ptr self.ptr_C = self.buffer_C.ptr self.ptr_D = self.buffer_D.ptr + # number of elements in C + self.tensor_c_numel = C.size elif torch_available and isinstance(A, torch.Tensor): self.ptr_A = TorchFrontend.argument(A) self.ptr_B = TorchFrontend.argument(B) self.ptr_C = TorchFrontend.argument(C) self.ptr_D = TorchFrontend.argument(D) + # number of elements in C + self.tensor_c_numel = C.numel() elif isinstance(A, cuda.CUdeviceptr): self.ptr_A = A self.ptr_B = B self.ptr_C = C self.ptr_D = D + elif cupy_available and isinstance(A, cp.ndarray): self.ptr_A = CupyFrontend.argument(A) self.ptr_B = CupyFrontend.argument(B) self.ptr_C = CupyFrontend.argument(C) self.ptr_D = CupyFrontend.argument(D) + # number of elements in C + self.tensor_c_numel = C.size else: raise TypeError( "Unsupported Frontend. Only support numpy and torch") diff --git a/tools/library/scripts/pycutlass/src/pycutlass/c_types.py b/tools/library/scripts/pycutlass/src/pycutlass/c_types.py index 1d6abdb2..a7b3af03 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/c_types.py @@ -63,22 +63,9 @@ dtype2ctype = { } -def get_epilogue_output_op(element_compute_): - element_compute = dtype2ctype[element_compute_] +def get_gemm_arguments(epilogue_functor): - class _EpilogueOutputOpParams(ctypes.Structure): - _fields_ = [ - ("alpha", element_compute), - ("beta", element_compute), - ("alpha_ptr", ctypes.c_void_p), - ("beta_ptr", ctypes.c_void_p) - ] - return _EpilogueOutputOpParams - - -def get_gemm_arguments(element_compute_): - - _EpilogueOutputOpParams = get_epilogue_output_op(element_compute_) + _EpilogueOutputOpParams = epilogue_functor.epilogue_type class _GemmArguments(ctypes.Structure): _fields_ = [ @@ -116,8 +103,8 @@ def get_gemm_arguments(element_compute_): # include/cutlass/gemm/kernel/gemm_grouped.h -def get_gemm_grouped_arguments(element_compute_): - _EpilogueOutputOpParams = get_epilogue_output_op(element_compute_) +def get_gemm_grouped_arguments(epilogue_functor): + _EpilogueOutputOpParams = epilogue_functor.epilogue_type class _GEMMGroupedArguments(ctypes.Structure): _fields_ = [ @@ -214,8 +201,8 @@ class TensorRef2D_(ctypes.Structure): # include/cutlass/conv/kernel/implicit_gemm_convolution.h # split_k_mode: kNone: 0, kSerial: 1, kParallel: 2, kParallelSerial: 3, kInvalid: 4 -def get_conv2d_arguments(element_compute_): - _EpilogueOutputOpParams = get_epilogue_output_op(element_compute_) +def get_conv2d_arguments(epilogue_functor): + _EpilogueOutputOpParams = epilogue_functor.epilogue_type class _Conv2dArguments(ctypes.Structure): _fields_ = [ @@ -236,8 +223,8 @@ def get_conv2d_arguments(element_compute_): ############################################################################################ -def get_reduction_params(element_compute_): - _EpilogueOutputParams = get_epilogue_output_op(element_compute_) +def get_reduction_params(epilogue_functor): + _EpilogueOutputParams = epilogue_functor.epilogue_type class _ReductionParams(ctypes.Structure): _fields_ = [ diff --git a/tools/library/scripts/pycutlass/src/pycutlass/cache.py b/tools/library/scripts/pycutlass/src/pycutlass/cache.py deleted file mode 100644 index 322da90f..00000000 --- a/tools/library/scripts/pycutlass/src/pycutlass/cache.py +++ /dev/null @@ -1,366 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# -from pycutlass import * -from pycutlass.library import SubstituteTemplate -import cutlass -from cuda import cuda -from cuda import nvrtc -import tempfile -import os -import ctypes - -# -import json -import sqlite3 - - -IncludeTemplate = r'''#include "${include}" -''' - -# -class CompilationOptions: - ''' - Compilation options. - ''' - - # - def __init__(self, architectures = [80], include_paths = []): - self.includes = [] - self.include_paths = include_paths - self.flags = ['-std=c++11', '-default-device'] - self.architectures = architectures - - # - def get(self): - options = [] - - for flag in self.flags: - options.append(bytes(str.encode(flag))) - - for incl in self.include_paths: - options.append(bytes(str.encode('--include-path=%s' % incl))) - - arch_list = "-arch=" - for idx, arch in enumerate(self.architectures): - if idx: - arch_list += "," - arch_list += "sm_%d" % arch - - options.append(bytes(str.encode(arch_list))) - - return options - -def convertToBinaryData(filename): - with open(filename, 'rb') as file: - blobData = file.read() - return blobData - -def CDLLBin(host_binary): - tempfile.tempdir = "./" - temp_so = tempfile.NamedTemporaryFile(prefix='host_func', suffix='.so', delete=True) - with open(temp_so.name, 'wb') as file: - file.write(host_binary) - host_lib = ctypes.CDLL(temp_so.name) - return host_lib - - -class ArtifactManager: - """ - Artifact manager - """ - def __init__(self) -> None: - try: - connection = sqlite3.connect("./compiled_cache.db") - cursor = connection.cursor() - sqlite_create_table_query = """CREATE TABLE compiled_operations(op_key TEXT NOT NULL UNIQUE, cubin BLOB NOT NULL, hostbin BLOB NOT NULL, op_name TEXT NOT NULL, op_attrs TEXT NOT NULL)""" - cursor.execute(sqlite_create_table_query) - connection.commit() - cursor.close() - except: - pass - - self.compiled_cache_device = cutlass.CompileCache() - self.compiled_cache_host = cutlass.CompileCache() - - def insert_operation(self, op_key, cubin, hostfile, op_name, op_attrs): - connection = sqlite3.connect("./compiled_cache.db") - cursor = connection.cursor() - sqlite_insert_blob_query = """ INSERT OR IGNORE INTO compiled_operations (op_key, cubin, hostbin, op_name, op_attrs) VALUES (?, ?, ?, ?, ?)""" - - hostbin = convertToBinaryData(hostfile) - - data_tuple = (op_key, cubin, hostbin, op_name, json.dumps(op_attrs)) - - cursor.execute(sqlite_insert_blob_query, data_tuple) - connection.commit() - cursor.close() - - def load_operation(self, op_key): - connection = sqlite3.connect("./compiled_cache.db") - cursor = connection.cursor() - sqlite_fetch_blob_query = """SELECT * from compiled_operations where op_key = ?""" - # try: - cursor.execute(sqlite_fetch_blob_query, (op_key, )) - record = cursor.fetchall() - if len(record) == 0: - return False - for row in record: - key, cubin_image, host_binary, operation_name, op_attr = row - op_attr = json.loads(op_attr) - err, module = cuda.cuModuleLoadData(cubin_image) - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError('Cuda Error: {}'.format(err)) - - err, kernel = cuda.cuModuleGetFunction(module, bytes(str.encode(operation_name))) - self.compiled_cache_device.insert(key, kernel) - - compiled_host_fns = {} - host_lib = CDLLBin(host_binary) - - func_name = operation_name + '_get_params' - func = getattr(host_lib, func_name) - func.restype = ctypes.POINTER(ctypes.c_char * op_attr[0]) - compiled_host_fns['get_args'] = func - - func_name = operation_name + '_shared_memory_size' - func = getattr(host_lib, func_name) - compiled_host_fns['shared_memory_capacity'] = func() - - for attr in op_attr: - if isinstance(attr, str): - func_name = operation_name + '_' + attr - func = getattr(host_lib, func_name) - compiled_host_fns[attr] = func - - self.compiled_cache_host.insert(key, compiled_host_fns) - return True - - - def emit_compile_(self, operation_list, compilation_options): - """ - Compile a list of kernels and store them into database - """ - source_buffer_device = "" - source_buffer_host = "" - # 1. include - includes = [] - for operation in operation_list: - for incl in operation.emitter.includes: - if incl not in includes: - includes.append(incl) - - includes_host = [ - "builtin_types.h", "device_launch_parameters.h", "stddef.h"] + includes - for incl in includes: - source_buffer_device += SubstituteTemplate(IncludeTemplate, {'include': incl}) - - for incl in includes_host: - if "/device/" not in incl: - source_buffer_host += SubstituteTemplate(IncludeTemplate, { 'include': incl} ) - - - # 2. Operations - for operation in operation_list: - source_buffer_device += operation.emit() - source_buffer_host += operation.emit() - values = { - 'operation_name': operation.name(), - 'operation_suffix': operation.emitter.operation_suffix - } - source_buffer_device += SubstituteTemplate(operation.KernelTemplate, values) - source_buffer_host += SubstituteTemplate(operation.HostTemplate, values) - - # 3. compile - err, program = nvrtc.nvrtcCreateProgram( - str.encode(source_buffer_device), - bytes(str.encode("module.cu")), - 0, [], []) - - if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise RuntimeError('NVRTC Error: {}'.format(err)) - - # Compile program - options = compilation_options.get() - - err, = nvrtc.nvrtcCompileProgram(program, len(options), options) - if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: - - error_string = 'NVRTC Error: {}\n'.format(err) - - # Get log from compilation - err, logSize = nvrtc.nvrtcGetProgramLogSize(program) - if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise RuntimeError('NVRTC Error: {}'.format(err)) - - log = b' ' * logSize - err, = nvrtc.nvrtcGetProgramLog(program, log) - if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise RuntimeError('NVRTC Error: {}'.format(err)) - - raise RuntimeError(error_string + log.decode() + source_buffer_device) - - # Get data from compilation - err, dataSize = nvrtc.nvrtcGetCUBINSize(program) - if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise RuntimeError('NVRTC Error: {}'.format(err)) - - cubin_image = b' ' * dataSize - err, = nvrtc.nvrtcGetCUBIN(program, cubin_image) - if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise RuntimeError('NVRTC Error: {}'.format(err)) - - # compile the host code - options = compilation_options.get() - cmd = "echo '%s'|g++ -x c++ -fpermissive -w -fPIC" % source_buffer_host - for opt in options: - opt = opt.decode("utf-8") - if opt not in ['-default-device', '-std=c++11', '-arch=sm_80']: - if '--include-path=' in opt: - cmd += " " + opt.replace('--include-path=', '-I') - else: - cmd += " "+ opt - - tempfile.tempdir = "./" - temp = tempfile.NamedTemporaryFile(prefix='host_func', suffix='.so', delete=True) - - cmd += ' - -shared -o %s' % temp.name - os.system(cmd) - host_lib = ctypes.CDLL(temp.name) - - return cubin_image, host_lib, temp - - - def add_module(self, operations, compile_options=None): - """ - Insert a new compiled device module - """ - if compile_options is None: - cutlass_path = os.getenv('CUTLASS_PATH') - assert cutlass_path is not None, "Environment variable 'CUTLASS_PATH' is not defined." - cuda_install_path = os.getenv('CUDA_INSTALL_PATH') - assert cuda_install_path is not None, "Environment variable 'CUDA_INSTALL_PATH' is not defined." - architectures = [] - for operation in operations: - if hasattr(operation, "tile_description"): - cc = operation.tile_description.minimum_compute_capability - if cc not in architectures: - architectures.append(cc) - include_paths = [ - cuda_install_path + '/include', - cutlass_path + '/include', - cutlass_path + '/tools/util/include', - ] - compile_options = CompilationOptions(architectures, include_paths) - # save the cubin - operation_key = [] - operation_list = [] - for operation in operations: - # step 1: get kernel string as key - key = operation.rt_module.emit() + operation.procedural_name() - # step 1: check if the operation is in cache - compiled_kernel = self.compiled_cache_device.at(key) - - if compiled_kernel is None: - hit = self.load_operation(key) - if hit: - compiled_kernel = self.compiled_cache_device.at(key) - assert compiled_kernel is not None - if compiled_kernel is not None: - operation.rt_module.kernel = compiled_kernel - compiled_host_fns = self.compiled_cache_host.at(key) - assert compiled_host_fns is not None - for key in compiled_host_fns.keys(): - setattr(operation.rt_module, key, compiled_host_fns[key]) - operation.rt_module.initialize() - else: - operation_list.append(operation.rt_module) - operation_key.append(key) - if len(operation_list) > 0: - cubin_image, host_lib, host_file = self.emit_compile_(operation_list, compile_options) - - err, module = cuda.cuModuleLoadData(cubin_image) - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError('Cuda Error: {}'.format(err)) - - operation_name = [] - operation_attr = [] - for operation, key in zip(operation_list, operation_key): - # get device kernels - err, operation.kernel = cuda.cuModuleGetFunction( - module, - bytes(str.encode(operation.name())) - ) - operation_name.append(operation.name()) - self.compiled_cache_device.insert(key, operation.kernel) - # get host functions - compiled_host_fns = {} - op_attr = [] - - # get param size - func_name = operation.name() + '_get_param_size' - func = getattr(host_lib, func_name) - param_size = func() - - func_name = operation.name() + '_get_params' - func = getattr(host_lib, func_name) - func.argtype = operation.argtype - func.restype = ctypes.POINTER(ctypes.c_char * param_size) - setattr(operation, 'get_args', func) - compiled_host_fns['get_args'] = func - - # set shared memory size - func_name = operation.name() + '_shared_memory_size' - func = getattr(host_lib, func_name) - setattr(operation, 'shared_memory_capacity', func()) - compiled_host_fns['shared_memory_capacity'] = func() - # set the maximum dynamic shared size - operation.initialize() - - # get extra functions - op_attr.append(param_size) - - if hasattr(operation, "extra_funcs"): - for suffix in operation.extra_funcs: - func_name = operation.name() + '_' + suffix - func = getattr(host_lib, func_name) - setattr(operation, suffix, func) - compiled_host_fns[suffix] = func - op_attr.append(suffix) - - operation_attr.append(op_attr) - self.compiled_cache_host.insert(key, compiled_host_fns) - - for key, operation_name, operation_attr in zip(operation_key, operation_name, operation_attr): - self.insert_operation(key, cubin_image, host_file.name, operation_name, operation_attr) - - -artifact_manager = ArtifactManager() diff --git a/tools/library/scripts/pycutlass/src/pycutlass/compiler.py b/tools/library/scripts/pycutlass/src/pycutlass/compiler.py index 158ff483..5b50c2f7 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/compiler.py @@ -30,7 +30,6 @@ # ################################################################################################# from pycutlass import * -from pycutlass.library import SubstituteTemplate import cutlass from cuda import cuda from cuda import nvrtc @@ -132,13 +131,15 @@ class ArtifactManager: except: pass + self.nvcc() + self.compiled_cache_device = cutlass.CompileCache() + self.compiled_cache_host = cutlass.CompileCache() + + def nvrtc(self): self.backend = "nvrtc" self.default_compile_options = [ '-std=c++11', '-default-device', ] - self.compiled_cache_device = cutlass.CompileCache() - self.compiled_cache_host = cutlass.CompileCache() - def nvcc(self): self.backend = "nvcc" self.default_compile_options = [ @@ -335,13 +336,14 @@ class ArtifactManager: architectures = [] for operation in operations: if hasattr(operation, "tile_description"): - cc = operation.tile_description.minimum_compute_capability + cc = operation.arch if cc not in architectures: architectures.append(cc) include_paths = [ cuda_install_path + '/include', cutlass_path + '/include', cutlass_path + '/tools/util/include', + cutlass_path + '/tools/library/scripts/pycutlass/src/cpp/include' ] compile_options = CompilationOptions( self.default_compile_options, architectures, include_paths) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py b/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py index fed535b6..3915d76e 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py @@ -48,6 +48,9 @@ class Conv2dArguments(ArgumentBase): :param operation: the Conv2d operation to take the argument :type operation: :class:`pycutlass.Conv2dOperation` + :param problem_size: the Conv2d problem size + :type problem_size: :class:`cutlass.conv.Conv2dProblemSize` + :param A: tensor A :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray @@ -78,6 +81,7 @@ class Conv2dArguments(ArgumentBase): split_k_mode: 'cutlass.conv.SplitKMode' = cutlass.conv.SplitKMode.Serial, **kwargs) -> None: + self.operation = operation #: convolution kind self.conv_kind: cutlass.conv.Operator = operation.conv_kind self.layout_A: cutlass.layout = operation.A.layout @@ -93,15 +97,12 @@ class Conv2dArguments(ArgumentBase): super().__init__(A, B, C, D, **kwargs) # preprocessing output ops - if "output_op" in kwargs.keys() and \ + + if 'output_op' in kwargs.keys() and \ split_k_mode != cutlass.conv.SplitKMode.Parallel: - self.alpha = kwargs["output_op"].alpha - self.beta = kwargs["output_op"].beta + self.output_op = kwargs['output_op'] else: - self.alpha = 1.0 - self.beta = 0.0 - - self.element_compute = operation.element_epilogue + self.output_op = self.operation.epilogue_type(1.0, 0.0) if "split_k_slices" in kwargs.keys(): self.split_k_mode = split_k_mode @@ -114,7 +115,12 @@ class Conv2dArguments(ArgumentBase): self.problem_size: cutlass.conv.Conv2dProblemSize = problem_size self.problem_size.split_k_slices = self.split_k_slices - self.operation = operation + if hasattr(self, "tensor_c_numel"): + c_coord = cutlass.conv.implicit_gemm_tensor_c_extent( + self.conv_kind, problem_size) + if (self.tensor_c_numel == c_coord.at(3) and + self.tensor_c_numel < c_coord.size()): + self.bias = True # # initialize the argument @@ -159,6 +165,9 @@ class Conv2dArguments(ArgumentBase): self.conv_kind, problem_size) else: raise ValueError("unknown operand: " + operand) + # Zero stride trick + if operand == "c" and self.bias: + tensor_coord = cutlass.Tensor4DCoord(0, 0, 0, 0) layout = tensor_layout.packed(tensor_coord) @@ -174,24 +183,9 @@ class Conv2dArguments(ArgumentBase): ref_D = TensorRef_(self.get_tensor_ref( self.ptr_D, self.element_C, self.layout_C, self.problem_size, "d")) - if self.element_compute == cutlass.float16: - alpha = cutlass.float16(self.alpha).storage - beta = cutlass.float16(self.beta).storage - elif self.element_compute == cutlass.int32: - alpha = int(self.alpha) - beta = int(self.beta) - else: - alpha = self.alpha - beta = self.beta - - argument_type, epilogue_type = get_conv2d_arguments( - self.operation.element_epilogue) - - output_op = epilogue_type(alpha, beta, 0, 0) - - self.c_arguments = argument_type( + self.c_arguments = self.operation.argument_type( Conv2DProblemSize(self.problem_size), - ref_A, ref_B, ref_C, ref_D, output_op, self.split_k_mode + ref_A, ref_B, ref_C, ref_D, self.output_op, self.split_k_mode ) self.semaphore = semaphore @@ -296,9 +290,8 @@ extern "C" { def __init__(self, operation: 'Conv2dOperation'): super().__init__(operation) - - self.argtype = [ctypes.POINTER(get_conv2d_arguments( - operation.element_epilogue)[0]), ctypes.c_void_p] + self.argument_type, self.epilogue_type = get_conv2d_arguments(operation.epilogue_functor) + self.argtype = [ctypes.POINTER(self.argument_type), ctypes.c_void_p] self.conv_kind = operation.conv_kind self.operation: Conv2dOperation = operation @@ -410,9 +403,7 @@ class Conv2dOperation: iterator_algorithm: cutlass.conv.IteratorAlgorithm, arch: int, tile_description: TileDescription, A: TensorDescription, B: TensorDescription, C: TensorDescription, - element_epilogue: Union[cutlass.int8, cutlass.int32, cutlass.float16, - cutlass.bfloat16, cutlass.float32, cutlass.float64], - stride_support, epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support, epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1): self.operation_kind: OperationKind = OperationKind.Conv2d @@ -422,13 +413,14 @@ class Conv2dOperation: self.A: TensorDescription = A self.B: TensorDescription = B self.C: TensorDescription = C - self.element_epilogue = element_epilogue self.epilogue_functor = epilogue_functor self.iterator_algorithm = iterator_algorithm self.stride_support = stride_support self.swizzling_functor = swizzling_functor() self.rt_module: Conv2dRT = Conv2dRT(self) + self.argument_type = self.rt_module.argument_type + self.epilogue_type = self.rt_module.epilogue_type def run(self, arguments: Conv2dArguments) -> cuda.CUresult: """ @@ -577,12 +569,7 @@ typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}< cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >, cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${epilogue_functor}< - ${element_c}, - ${epilogue_vector_length}, - ${element_accumulator}, - ${element_epilogue} - >, + ${epilogue_functor}, ${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>, ${stages}, ${math_operator}, @@ -629,8 +616,7 @@ struct ${operation_name}${operation_suffix}: 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), 'epilogue_vector_length': str(epilogue_vector_length), - 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], - 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': operation.epilogue_functor.emit(), 'swizzling_functor': operation.swizzling_functor.tag(), 'stages': str(operation.tile_description.stages), 'iterator_algorithm': IteratorAlgorithmTag[operation.iterator_algorithm], diff --git a/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py b/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py index 2eb65797..147fc68c 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py @@ -30,109 +30,997 @@ # ################################################################################ +from ast import Num +from audioop import mul +from pipes import Template import struct +from pycutlass.library import DataTypeTag +from pycutlass import * +import cutlass +from scipy.special import erf +from pycutlass.c_types import MatrixCoord_ +from pycutlass.frontend import NumpyFrontend -def MaxAlignment(fmt): - align = 1 - for x in fmt: - align = max(align, struct.calcsize(x)) - return align +from cuda import cuda +from cuda import cudart +dtype2ctype = { + cutlass.float16: ctypes.c_uint16, + cutlass.float32: ctypes.c_float, + cutlass.float64: ctypes.c_double, + cutlass.int32: ctypes.c_int32 +} -def AlignedOffset(offset, align): - remainder = (offset % align) - if remainder: - offset += (align - remainder) - return offset ################################################################################################# # -# Functors +# Epilogue Functors # ################################################################################################# -# +class EpilogueFunctorBase: + """ + Base class for thread-level epilogue functors + """ + def __init__(self) -> None: + pass + + def emit(self, tag, template_argument): + template = """${tag}<${arguments}>""" + arguments = "" + for idx, arg in enumerate(template_argument): + arguments += arg + if idx < len(template_argument) - 1: + arguments += ", " + values = { + "tag": tag, + "arguments": arguments + } + + return SubstituteTemplate(template, values) + -class Functor: - def __init__(self): - self.decl = '' - self.definition = '' - self.fmt = '' - self.identifier = '' +class LinearCombination(EpilogueFunctorBase): + """ + Apply a linear combination operator to an array of elements + D = alpha * accumulator + beta * source - # - def emit_declaration(self): - return self.decl + :param element_output: data type used to load and store tensors + + :param epilogue_vector_length: number of elements computed per operation. + Usually it is 128/sizeof_bits, but we use 64 and 32 sometimes + when there are not enough data to store - # - def emit_definition(self): - return self.definition + :param element_accumulator: Accumulator data type - # - def size(self): - ''' - Size of the packed Params structure - ''' - return struct.calcsize(self.fmt) - - # - def alignment(self): - return MaxAlignment(self.fmt) - - # - def initialize(self, host_workspace, offset, arguments): - return offset + self.size() - -################################################################################################# - -# - - -class LinearCombinationFunctorArguments: - def __init__(self, alpha=1.0, beta=0.0): - self.alpha = alpha - self.beta = beta - self.alpha_ptr = 0 - self.beta_ptr = 0 - -# - - -class LinearCombinationFunctor(Functor): - def __init__(self): + :param element_epilogue: data type used to compute linear combination + """ + tag = "cutlass::epilogue::thread::LinearCombination" + def __init__( + self, element_output, epilogue_vector_length, + element_accumulator=None, element_epilogue=None) -> None: # TODO bind ScaleType super().__init__() - self.decl = """ - cutlass::epilogue::thread::LinearCombination< - float, - 1, - float, - float - >""" - self.identifier = 'linear_combination' - self.fmt = "ffPP" + if element_accumulator is None: + element_accumulator = element_output + if element_epilogue is None: + element_epilogue = element_output + + self.element_output = element_output + self.element_accumulator = element_accumulator + self.element_epilogue = element_epilogue - # - def size(self): - ''' - Size of the packed Params structure - ''' - return struct.calcsize(self.fmt) + self.template_arguments = [ + DataTypeTag[element_output], str(epilogue_vector_length), + DataTypeTag[element_accumulator], DataTypeTag[element_epilogue] + ] - # - def alignment(self): - return MaxAlignment(self.fmt) + # get epilogue output op type + c_element_epilogue = dtype2ctype[self.element_epilogue] + element_epilogue = self.element_epilogue - # - def initialize(self, host_workspace, offset, arguments): + class _EpilogueOutputOpParams(ctypes.Structure): + _fields_ = [ + ("alpha_data", ctypes.c_longlong*2), + ("beta_data", ctypes.c_longlong*2), + ("alpha", c_element_epilogue), + ("beta", c_element_epilogue), + ("alpha_ptr", ctypes.c_void_p), + ("beta_ptr", ctypes.c_void_p), + ] + def __init__(self, alpha, beta, *args) -> None: + self.alpha = element_epilogue(alpha).storage + self.beta = element_epilogue(beta).storage + self.epilogue_type = _EpilogueOutputOpParams + + def emit(self): + return super().emit(self.tag, self.template_arguments) - offset = AlignedOffset(offset, self.alignment()) - struct.pack_into( - self.fmt, - host_workspace, offset, - arguments.alpha, arguments.beta, arguments.alpha_ptr, arguments.beta_ptr) +class LinearCombinationClamp(LinearCombination): + """ + Applies a linear combination operator to an array of elements then clamps + the output before converting to the output element type. - return offset + self.size() + D = alpha * accumulator + beta * source + uniform + + :param element_output: data type used to load and store tensors + + :param epilogue_vector_length: number of elements computed per operation. + Usually it is 128/sizeof_bits, but we use 64 and 32 sometimes + when there are not enough data to store + + :param element_accumulator: Accumulator data type + + :param element_epilogue: data type used to compute linear combination + """ + tag = "cutlass::epilogue::thread::LinearCombinationClamp" + def __init__( + self, element_output, epilogue_vector_length, + element_accumulator=None, element_epilogue=None) -> None: + # Base constructor + super().__init__( + element_output, epilogue_vector_length, + element_accumulator, element_epilogue) + + c_element_epilogue = dtype2ctype[self.element_epilogue] + element_epilogue = self.element_epilogue + + class _EpilogueOutputOpParams(ctypes.Structure): + _fields_ = [ + ("alpha", c_element_epilogue), + ("beta", c_element_epilogue), + ("alpha_ptr", ctypes.c_void_p), + ("beta_ptr", ctypes.c_void_p), + ] + def __init__(self, alpha, beta, *args) -> None: + self.alpha = element_epilogue(alpha).storage + self.beta = element_epilogue(beta).storage + self.epilogue_type = _EpilogueOutputOpParams + + +class FastLinearCombinationClamp(EpilogueFunctorBase): + """ + Applies a linear combination operator to an array of elements then clamps + the output before converting to the output element type. + + D = alpha * accumulator + beta * source + + Note: The below method only when problem_size_K <= 256 for signed int8 gemm + or problem_size_K <= 128 for unsigned int8 gemm. The default approach is + above. + + :param element_output: data type used to load and store tensors + + :param epilogue_vector_length: number of elements computed per operation. + Usually it is 128/sizeof_bits, but we use 64 and 32 sometimes + when there are not enough data to store + """ + tag = "cutlass::epilogue::thread::FastLinearCombinationClamp" + def __init__(self, element_output, epilogue_vector_length, *args) -> None: + super().__init__() + + self.template_arguments = [ + DataTypeTag[element_output], str(epilogue_vector_length) + ] + + self.element_accumulator = cutlass.int32 + self.element_epilogue = cutlass.float32 + + # get epilogue output op + c_element_epilogue = dtype2ctype[self.element_epilogue] + element_epilogue = self.element_epilogue + + class _EpilogueOutputOpParams(ctypes.Structure): + _fields_ = [ + ("alpha", c_element_epilogue), + ("beta", c_element_epilogue), + ("alpha_ptr", ctypes.c_void_p), + ("beta_ptr", ctypes.c_void_p), + ] + def __init__(self, alpha, beta, *args) -> None: + self.alpha = element_epilogue(alpha).storage + self.beta = element_epilogue(beta).storage + self.epilogue_type = _EpilogueOutputOpParams + + def emit(self): + return super().emit(self.tag, self.template_arguments) + + +class LinearCombinationGeneric(LinearCombination): + """ + Applies a linear combination operator followed by an activation function + to an array of elements. + + D = activation(alpha * accumulator + beta * source) + + :param activation_functor: input activation functor + + :param element_output: data type used to load and store tensors + + :param epilogue_vector_length: number of elements computed per operation. + Usually it is 128/sizeof_bits, but we use 64 and 32 sometimes + when there are not enough data to store + + :param element_accumulator: Accumulator data type + + :param element_epilogue: data type used to compute linear combination + """ + tag = "cutlass::epilogue::thread::LinearCombinationGeneric" + def __init__( + self, activation_functor, + element_output, epilogue_vector_length, + element_accumulator=None, element_epilogue=None) -> None: + super().__init__( + element_output, epilogue_vector_length, + element_accumulator, element_epilogue) + + self.template_arguments = [ + activation_functor.emit(),] + self.template_arguments + + self.activation_functor = activation_functor + self.element_epilogue = element_epilogue + + # get epilogue output op + self.epilogue_type = self.activation_functor.epilogue_output_op(self.element_epilogue) + + +class ActivationFunctor: + """ + Base class for frequently used activation functions + """ + def __init__(self, element_compute) -> None: + pass + @staticmethod + def numpy(x: np.ndarray): + raise NotImplementedError() + + def emit(self): + return self.tag + + @staticmethod + def epilogue_output_op(element_epilogue): + c_element_epilogue = dtype2ctype[element_epilogue] + + class _EpilogueOutputOpParams(ctypes.Structure): + _fields_ = [ + ("alpha", c_element_epilogue), + ("beta", c_element_epilogue), + ("alpha_ptr", ctypes.c_void_p), + ("beta_ptr", ctypes.c_void_p), + ] + def __init__(self, alpha, beta, *args) -> None: + self.alpha = element_epilogue(alpha).storage + self.beta = element_epilogue(beta).storage + return _EpilogueOutputOpParams + +# identity operator +class identity(ActivationFunctor): + def numpy(x: np.ndarray): + return x + +# ReLu operator, +class relu(ActivationFunctor): + tag = "cutlass::epilogue::thread::ReLu" + + def __init__(self, element_compute): + super().__init__(element_compute) + class _Arguments(ctypes.Structure): + _fields_ = [ + ("threshold", dtype2ctype[element_compute]) + ] + def __init__(self, threshold=0.) -> None: + self.threshold = element_compute(threshold).storage + self.argument_type = _Arguments + + def emit_visitor(self): + return "cutlass::ReLUVisitor" + + @staticmethod + def numpy(x: np.ndarray): + return np.maximum(x, 0) + +# Leaky ReLu operator +class leaky_relu(ActivationFunctor): + tag = "cutlass::epilogue::thread::LeakyReLU" + + def __init__(self, element_compute) -> None: + super().__init__(element_compute) + class _Arguments(ctypes.Structure): + _fields_ = [ + ("leaky_alpha", dtype2ctype[element_compute]) + ] + def __init__(self, leaky_alpha) -> None: + self.leaky_alpha = element_compute(leaky_alpha).storage + self.argument_type = _Arguments + + def emit_visitor(self): + return "cutlass::LeakyReLUVisitor" + + @staticmethod + def numpy(x: np.ndarray, leaky_alpha): + return np.maximum(x, 0) + np.minimum(x, 0) * leaky_alpha + + def epilogue_output_op(self, element_epilogue): + c_element_epilogue = dtype2ctype[element_epilogue] + class _EpilogueOutputOpParams(ctypes.Structure): + _fields_ = [ + ("alpha", c_element_epilogue), + ("beta", c_element_epilogue), + ("alpha_ptr", ctypes.c_void_p), + ("beta_ptr", ctypes.c_void_p), + ("leaky_alpha", c_element_epilogue) + ] + def __init__(self, alpha, beta, leaky_alpha=0.2, *args) -> None: + self.alpha = element_epilogue(alpha).storage + self.beta = element_epilogue(beta).storage + self.alpha_ptr = 0 + self.beta_ptr = 0 + self.leaky_alpha = element_epilogue(leaky_alpha).storage + return _EpilogueOutputOpParams + +# Tanh operator +class tanh(ActivationFunctor): + tag = "cutlass::epilogue::thread::Tanh" + + def __init__(self, element_compute) -> None: + super().__init__(element_compute) + class _Arguments(ctypes.Structure): + _fields_ = [ + ("tmp", ctypes.c_int) + ] + def __init__(self, *args) -> None: + self.tmp = 0 + self.argument_type = _Arguments + + def emit_visitor(self): + return "cutlass::TanhVisitor" + + @staticmethod + def numpy(x: np.ndarray): + return np.tanh(x) + +def sigmoid_op(x: np.ndarray): + return 1. / (1. + np.exp(-x)) + +# Sigmoid operator +class sigmoid(ActivationFunctor): + tag = "cutlass::epilogue::thread::Sigmoid" + + @staticmethod + def numpy(x: np.ndarray): + return sigmoid_op(x) + +# SiLu operator +class silu(ActivationFunctor): + tag = "cutlass::epilogue::thread::SiLu" + + @staticmethod + def numpy(x: np.ndarray): + return x * sigmoid_op(x) + +# Hardswish operator +class hardswish(ActivationFunctor): + tag = "cutlass::epilogue::thread::HardSwish" + + @staticmethod + def numpy(x: np.ndarray): + relu6 = np.minimum(np.maximum(x + 3., 0), 6.) + return x * relu6 / 6. + +# GELU operator +class gelu(ActivationFunctor): + tag = "cutlass::epilogue::thread::GELU" + + @staticmethod + def numpy(x: np.ndarray): + return 0.5 * x * (1 + erf(x / np.sqrt(2.))) + +# reduction operator +def reduction_op(tensor, direction, math, factor): + batch, m, n = tensor.shape + if math == "Add": + if direction == "row": + num_cta_n = (n + factor - 1) // factor + reduction = np.transpose( + np.sum(tensor.reshape(batch, m, num_cta_n, factor), axis=-1), + axes=[0, 2, 1]).flatten() + elif direction == "column": + num_cta_m = (m + factor - 1) // factor + reduction = np.sum( + tensor.reshape(batch, num_cta_m, factor, n), axis=-2).flatten() + else: + raise NotImplementedError + return reduction + else: + raise NotImplementedError + +# # GELU operator implemented using the taylor series approximation +# class GELU_taylor(ActivationFunctor): +# tag = "cutlass::epilogue::thread::GELU_taylor" + +# # Computes backwards pass for GELU operator +# class dGELU(ActivationFunctor): +# tag = "cutlass::epilogue::thread::dGELU" + +################################################################################ +# Epilogue Visitor +################################################################################ + + +class LayerNorm(EpilogueFunctorBase): + """ + Apply a linear combination operator to an array of elements + D = alpha * accumulator + beta * source + + :param element_output: data type used to load and store tensors + + :param epilogue_vector_length: number of elements computed per operation. + Usually it is 128/sizeof_bits, but we use 64 and 32 sometimes + when there are not enough data to store + + :param element_accumulator: Accumulator data type + + :param element_epilogue: data type used to compute linear combination + """ + KernelTemplate = """ + +cutlass::epilogue::threadblock::EpilogueVisitorLayerNorm< + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + ${operation_name}_default::kThreadCount, + ${operation_name}_default::Epilogue::OutputTileIterator, + ${operation_name}_default::Epilogue::AccumulatorFragmentIterator::AccumulatorTile, + ${element_compute}, // element_compute + ${element_variance}, // element_variance + ${element_mean}, // element_mean + ${element_layer_norm_compute}, // element_layer_norm_compute + ${epilogue_functor}, + ${shifted_k}>; +""" + headers = ["gemm/gemm_universal_with_visitor.h", + "epilogue/epilogue_visitor_with_layernorm.h"] + def __init__( + self, elementwise_functor, + element_variance=None, element_mean=None, + element_layer_norm_compute=None, shifted_k=True) -> None: # TODO bind ScaleType + super().__init__() + + self.elementwise_functor = elementwise_functor + self.element_compute = elementwise_functor.element_epilogue + self.element_output = elementwise_functor.element_output + + if element_variance is None: + self.element_variance = self.element_output + if element_mean is None: + self.element_mean = self.element_output + if element_layer_norm_compute is None: + self.element_layer_norm_compute = self.element_compute + if shifted_k: + self.shifted_k = "true" + else: + self.shifted_k = "false" + + # get epilogue output op + elementwise_params_type = self.elementwise_functor.epilogue_type + + class _EpilogueVisitorParams(ctypes.Structure): + _fields_ = [ + ("element_wise", elementwise_params_type), + ("ptr_Variance", ctypes.c_void_p), + ("ptr_Mean_", ctypes.c_void_p), + ("ptr_Shifted_K_", ctypes.c_void_p), + ("extent", MatrixCoord_) + ] + def __init__(self, elementwise_params, variance, mean, shift_k, extent) -> None: + self.element_wise = elementwise_params + if isinstance(variance, np.ndarray): + self.buffer_variance = NumpyFrontend.argument(variance, False) + self.buffer_mean = NumpyFrontend.argument(mean, False) + self.buffer_shift_k = NumpyFrontend.argument(shift_k, False) + self.ptr_Variance = int(self.buffer_variance.ptr) + self.ptr_Mean_ = int(self.buffer_mean.ptr) + self.ptr_Shifted_K_ = int(self.buffer_shift_k.ptr) + self.extent = MatrixCoord_(extent[0], extent[1]) + + self.host_variance = variance + self.host_mean = mean + self.host_shift_k = shift_k + + def sync(self, stream_sync=True): + if stream_sync: + err, = cudart.cudaDeviceSynchronize() + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("CUDA Error %s" % str(err)) + + # if hasattr(self, "host_variance"): + err, = cuda.cuMemcpyDtoH( + self.host_variance, cuda.CUdeviceptr(self.ptr_Variance), + self.host_variance.size * self.host_variance.itemsize) + err, = cuda.cuMemcpyDtoH( + self.host_mean, cuda.CUdeviceptr(self.ptr_Mean_), + self.host_mean.size * self.host_mean.itemsize) + err, = cuda.cuMemcpyDtoH( + self.host_shift_k, cuda.CUdeviceptr(self.ptr_Shifted_K_), + self.host_shift_k.size * self.host_shift_k.itemsize) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("CUDA Error %s" % str(err)) + + self.epilogue_type = _EpilogueVisitorParams + + def emit(self, operation): + values = { + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'operation_name': operation.procedural_name(), + 'element_compute': DataTypeTag[self.element_compute], + 'element_variance': DataTypeTag[self.element_variance], + 'element_mean': DataTypeTag[self.element_mean], + 'element_layer_norm_compute': DataTypeTag[self.element_layer_norm_compute], + 'epilogue_functor': self.elementwise_functor.emit(), + 'shifted_k': self.shifted_k + } + return SubstituteTemplate(self.KernelTemplate, values) + + + +class AccumulatorOp: + Template = """ +using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpAccumulator<${element_accumulator}, ${elements_per_access}>; +""" + counter = 0 + def __init__(self, element_accumulator, elements_per_access) -> None: + self.element_accumulator = element_accumulator + self.elements_per_access = elements_per_access + + self.instance_name = "AccumulatorOp%d" % AccumulatorOp.counter + AccumulatorOp.counter += 1 + + + class _Arguments(ctypes.Structure): + _fields_ = [ + ("tmp", ctypes.c_int) + ] + def __init__(self): + self.tmp = 0 + + self.argument_type = _Arguments + + def emit(self, *args): + values = { + "instance_name": self.instance_name, + "element_accumulator": DataTypeTag[self.element_accumulator], + "elements_per_access": str(self.elements_per_access) + } + return SubstituteTemplate(self.Template, values) + + +class LinearCombinationOp: + Template = """ +${visitor_a} + +${visitor_b} + +using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpLinearCombination< + ${element_accumulator}, ${element_compute}, + ${elements_per_access}, ${visitor_a_name}, ${visitor_b_name}>; +""" + counter = 0 + def __init__(self, element_accumulator, element_compute, + elements_per_access, visitor_a, visitor_b) -> None: + # + self.element_accumulator = element_accumulator + self.element_compute = element_compute + self.elements_per_access = elements_per_access + self.visitor_a = visitor_a + self.visitor_b = visitor_b + + self.instance_name = "LinearCombinationOp%d" % LinearCombinationOp.counter + LinearCombinationOp.counter += 1 + + class _Arguments(ctypes.Structure): + _fields_ = [ + ("alpha", dtype2ctype[self.element_compute]), + ("beta", dtype2ctype[self.element_compute]), + ("visitor_a", self.visitor_a.argument_type), + ("visitor_b", self.visitor_b.argument_type) + ] + def __init__(self, alpha, beta, visitor_a_arg, visitor_b_arg) -> None: + self.alpha = element_compute(alpha).storage + self.beta = element_compute(beta).storage + self.visitor_a = visitor_a_arg + self.visitor_b = visitor_b_arg + + self.argument_type = _Arguments + + def emit(self, operation): + values = { + "instance_name": self.instance_name, + "element_accumulator": DataTypeTag[self.element_accumulator], + "element_compute": DataTypeTag[self.element_compute], + "elements_per_access": str(self.elements_per_access), + "visitor_a_name": self.visitor_a.instance_name, + "visitor_b_name": self.visitor_b.instance_name, + "visitor_a": self.visitor_a.emit(operation), + "visitor_b": self.visitor_b.emit(operation) + } + return SubstituteTemplate(self.Template, values) + +class VectorAdd: + def __init__(self, *args) -> None: + class _Arguments(ctypes.Structure): + _fields_ = [ + ("tmp", ctypes.c_int) + ] + def __init__(self, *args) -> None: + self.tmp = 0 + self.argument_type = _Arguments + + def emit(self): + return "cutlass::VectorAdd" + +class VectorMult: + def __init__(self, *args) -> None: + class _Arguments(ctypes.Structure): + _fields_ = [ + ("tmp", ctypes.c_int) + ] + def __init__(self, *args) -> None: + self.tmp = 0 + self.argument_type = _Arguments + + def emit(self): + return "cutlass::VectorMult" + + +class BinaryOp: + Template = """ +${visitor_a} + +${visitor_b} + +using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpBinary< + ${element_accumulator}, ${element_compute}, + ${elements_per_access}, ${visitor_a_name}, ${visitor_b_name}, ${binary_op}>; +""" + counter = 0 + def __init__(self, element_accumulator, element_compute, + elements_per_access, visitor_a, visitor_b, binary_op) -> None: + # + self.element_accumulator = element_accumulator + self.element_compute = element_compute + self.elements_per_access = elements_per_access + self.visitor_a = visitor_a + self.visitor_b = visitor_b + self.binary_op = binary_op + + self.instance_name = "BinaryOp%d" % BinaryOp.counter + BinaryOp.counter += 1 + + class _Arguments(ctypes.Structure): + _fields_ = [ + ("binary_param", binary_op.argument_type), + ("visitor_a", self.visitor_a.argument_type), + ("visitor_b", self.visitor_b.argument_type) + ] + def __init__(self, binary_param, visitor_a_arg, visitor_b_arg) -> None: + self.binary_param = binary_param + self.visitor_a = visitor_a_arg + self.visitor_b = visitor_b_arg + + self.argument_type = _Arguments + def emit(self, operation): + values = { + "instance_name": self.instance_name, + "element_accumulator": DataTypeTag[self.element_accumulator], + "element_compute": DataTypeTag[self.element_compute], + "elements_per_access": str(self.elements_per_access), + "visitor_a_name": self.visitor_a.instance_name, + "visitor_b_name": self.visitor_b.instance_name, + "visitor_a": self.visitor_a.emit(operation), + "visitor_b": self.visitor_b.emit(operation), + "binary_op": self.binary_op.emit() + } + return SubstituteTemplate(self.Template, values) + + +class Mult: + def __init__(self, element_compute) -> None: + class _Arguments(ctypes.Structure): + _fields_ = [ + ("alpha", dtype2ctype[element_compute]) + ] + def __init__(self, alpha) -> None: + self.alpha = element_compute(alpha).storage + + self.argument_type = _Arguments + + def emit_visitor(self): + return "cutlass::Mult" + +class UnaryOp: + Template = """ +${visitor} + +using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpUnary< + ${element_accumulator}, ${element_compute}, + ${elements_per_access}, ${visitor_name}, ${unary_op}>; +""" + counter = 0 + def __init__(self, element_accumulator, element_compute, + elements_per_access, visitor, unary_op) -> None: + # + self.element_accumulator = element_accumulator + self.element_compute = element_compute + self.elements_per_access = elements_per_access + self.visitor = visitor + self.unary_op = unary_op + + self.instance_name = "UnaryOp%d" % UnaryOp.counter + UnaryOp.counter += 1 + + class _Arguments(ctypes.Structure): + _fields_ = [ + ("unary_param", unary_op.argument_type), + ("visitor_arg", self.visitor.argument_type) + ] + def __init__(self, unary_param, visitor_arg) -> None: + self.unary_param = unary_param + self.visitor_arg = visitor_arg + + self.argument_type = _Arguments + + def emit(self, operation): + values = { + "instance_name": self.instance_name, + "element_accumulator": DataTypeTag[self.element_accumulator], + "element_compute": DataTypeTag[self.element_compute], + "elements_per_access": str(self.elements_per_access), + "visitor_name": self.visitor.instance_name, + "unary_op": self.unary_op.emit_visitor(), + "visitor": self.visitor.emit(operation) + } + return SubstituteTemplate(self.Template, values) + + + +class RowBroadcastOp: + Template = """ +using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpRowBroadcast< + ${element_accumulator}, ${element_fragment}, ${input_tile_iterator}>; +""" + counter = 0 + def __init__(self, element_accumulator, element_fragment) -> None: + self.element_accumulator = element_accumulator + self.element_fragment = element_fragment + + self.instance_name = "RowBroadcastOp%d" % RowBroadcastOp.counter + RowBroadcastOp.counter += 1 + + class _Arguments(ctypes.Structure): + _fields_ = [ + ("broadcast_ptr", ctypes.c_void_p), + ("batch_stride", ctypes.c_longlong) + ] + def __init__(self, broadcast_ptr, batch_stride=0): + self.broadcast_ptr = int(broadcast_ptr) + self.batch_stride = batch_stride + + self.argument_type = _Arguments + + def emit(self, operation): + values = { + "instance_name": self.instance_name, + "element_accumulator": DataTypeTag[self.element_accumulator], + "element_fragment": DataTypeTag[self.element_fragment], + "input_tile_iterator": operation.procedural_name() + "_default::Epilogue::OutputTileIterator" + } + return SubstituteTemplate(self.Template, values) + + +class ColumnBroadcastOp: + Template = """ +using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpColumnBroadcast< + ${element_accumulator}, ${element_fragment}, ${input_tile_iterator}>; +""" + counter = 0 + def __init__(self, element_accumulator, element_fragment) -> None: + self.element_accumulator = element_accumulator + self.element_fragment = element_fragment + + self.instance_name = "ColumnBroadcastOp%d" % ColumnBroadcastOp.counter + ColumnBroadcastOp.counter += 1 + + class _Arguments(ctypes.Structure): + _fields_ = [ + ("broadcast_ptr", ctypes.c_void_p), + ("batch_stride", ctypes.c_longlong) + ] + def __init__(self, broadcast_ptr, batch_stride=0): + self.broadcast_ptr = int(broadcast_ptr) + self.batch_stride = batch_stride + + self.argument_type = _Arguments + + def emit(self, operation): + values = { + "instance_name": self.instance_name, + "element_accumulator": DataTypeTag[self.element_accumulator], + "element_fragment": DataTypeTag[self.element_fragment], + "input_tile_iterator": operation.procedural_name() + "_default::Epilogue::OutputTileIterator" + } + return SubstituteTemplate(self.Template, values) + + +class TensorInputOp: + Template = """ +using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpTensorInput< + ${element_accumulator}, ${input_tile_iterator}>; +""" + counter = 0 + def __init__(self, element_accumulator) -> None: + self.element_accumulator = element_accumulator + + self.instance_name = "TensorInputOp%d" % TensorInputOp.counter + TensorInputOp.counter += 1 + + class _Arguments(ctypes.Structure): + _fields_ = [ + ("input_ptr", ctypes.c_void_p), + ("ldt", ctypes.c_int), + ("batch_stride", ctypes.c_longlong) + ] + def __init__(self, input_ptr, ldt, batch_stride=0) -> None: + self.input_ptr = int(input_ptr) + self.ldt = ldt + self.batch_stride = batch_stride + + self.argument_type = _Arguments + + def emit(self, operation): + values = { + "instance_name": self.instance_name, + "element_accumulator": DataTypeTag[self.element_accumulator], + "input_tile_iterator": operation.procedural_name() + "_default::Epilogue::OutputTileIterator" + } + return SubstituteTemplate(self.Template, values) + +class TensorOutputOp: + Template = """ +${visitor} + +using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpTensorOutput< + ${element_accumulator}, ${output_tile_iterator}, ${visitor_name}>; +""" + counter = 0 + def __init__(self, element_accumulator, visitor) -> None: + self.element_accumulator = element_accumulator + self.visitor = visitor + + self.instance_name = "TensorOutputOp%d" % TensorOutputOp.counter + TensorOutputOp.counter += 1 + + class _Arguments(ctypes.Structure): + _fields_ = [ + ("output_ptr", ctypes.c_void_p), + ("ldt", ctypes.c_int), + ("batch_stride", ctypes.c_longlong), + ("visitor_arg", self.visitor.argument_type) + ] + def __init__(self, output_ptr, ldt, visitor_arg, batch_stride=0) -> None: + self.output_ptr = int(output_ptr) + self.ldt = int(ldt) + self.visitor_arg = visitor_arg + self.batch_stride = batch_stride + + self.argument_type = _Arguments + + def emit(self, operation): + values = { + "instance_name": self.instance_name, + "element_accumulator": DataTypeTag[self.element_accumulator], + "output_tile_iterator": operation.procedural_name() + "_default::Epilogue::OutputTileIterator", + "visitor_name": self.visitor.instance_name, + "visitor": self.visitor.emit(operation) + } + return SubstituteTemplate(self.Template, values) + + +class ColumnReductionOp: + Template = """ +${visitor} + +using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpColumnReduction< + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + ${element_accumulator}, ${element_reduction}, ${element_reduction_accumulator}, + ${output_tile_iterator}, ${visitor_name}>; +""" + counter = 0 + def __init__(self, element_accumulator, element_reduction, + element_reduction_accumulator, visitor) -> None: + self.element_accumulator = element_accumulator + self.element_reduction = element_reduction + self.element_reduction_accumulator = element_reduction_accumulator + self.visitor = visitor + + self.instance_name = "ColumnReductionOp%d" % ColumnReductionOp.counter + ColumnReductionOp.counter += 1 + + class _Arguments(ctypes.Structure): + _fields_ = [ + ("reduction_ptr", ctypes.c_void_p), + ("batch_stride", ctypes.c_longlong), + ("visitor_arg", self.visitor.argument_type) + ] + def __init__(self, reduction_ptr, visitor_arg, batch_stride=0) -> None: + self.reduction_ptr = reduction_ptr + self.batch_stride = batch_stride + self.visitor_arg = visitor_arg + + self.argument_type = _Arguments + + def emit(self, operation): + values = { + "instance_name": self.instance_name, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + "element_accumulator": DataTypeTag[self.element_accumulator], + "element_reduction": DataTypeTag[self.element_reduction], + "element_reduction_accumulator": DataTypeTag[self.element_reduction_accumulator], + "output_tile_iterator": operation.procedural_name() + "_default::Epilogue::OutputTileIterator", + "visitor_name": self.visitor.instance_name, + "visitor": self.visitor.emit(operation) + } + return SubstituteTemplate(self.Template, values) + + +class RowReductionOp: + Template = """ +${visitor} + +using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpRowReduction< + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + ${element_accumulator}, ${element_reduction}, ${element_reduction_accumulator}, + ${output_tile_iterator}, ${visitor_name}>; +""" + counter = 0 + def __init__(self, element_accumulator, element_reduction, + element_reduction_accumulator, visitor) -> None: + self.element_accumulator = element_accumulator + self.element_reduction = element_reduction + self.element_reduction_accumulator = element_reduction_accumulator + self.visitor = visitor + + self.instance_name = "RowReductionOp%d" % RowReductionOp.counter + RowReductionOp.counter += 1 + + class _Arguments(ctypes.Structure): + _fields_ = [ + ("reduction_ptr", ctypes.c_void_p), + ("batch_stride", ctypes.c_longlong), + ("visitor_arg", self.visitor.argument_type) + ] + def __init__(self, reduction_ptr, visitor_arg, batch_stride=0) -> None: + self.reduction_ptr = reduction_ptr + self.visitor_arg = visitor_arg + self.batch_stride = batch_stride + + self.argument_type = _Arguments + + def emit(self, operation): + values = { + "instance_name": self.instance_name, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + "element_accumulator": DataTypeTag[self.element_accumulator], + "element_reduction": DataTypeTag[self.element_reduction], + "element_reduction_accumulator": DataTypeTag[self.element_reduction_accumulator], + "output_tile_iterator": operation.procedural_name() + "_default::Epilogue::OutputTileIterator", + "visitor_name": self.visitor.instance_name, + "visitor": self.visitor.emit(operation) + } + return SubstituteTemplate(self.Template, values) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py b/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py index 4361d7ea..51d30ed4 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py @@ -116,6 +116,12 @@ class GemmArguments(ArgumentBase): else: self.problem_size = cutlass.gemm.GemmCoord( problem_size.m(), problem_size.n(), problem_size.k()) + + # if the number of elements in C = problem_size.n + # C is treated as the bias + if hasattr(self, "tensor_c_numel"): + if (self.tensor_c_numel == self.problem_size.n() and + self.problem_size.m() != 1): self.bias = True # get the leading dimension self.lda = operation.A.layout.packed(self.problem_size.mk()).stride() @@ -123,27 +129,69 @@ class GemmArguments(ArgumentBase): self.ldc = operation.C.layout.packed(self.problem_size.mn()).stride() self.ldd = self.ldc + # stride 0 trick + if self.bias: + self.ldc = 0 + if 'output_op' in kwargs.keys() and \ gemm_mode != cutlass.gemm.Mode.GemmSplitKParallel: - self.alpha = kwargs['output_op'].alpha - self.beta = kwargs['output_op'].beta + self.output_op = kwargs['output_op'] else: - self.alpha = 1.0 - self.beta = 0.0 + self.output_op = self.operation.epilogue_type(1.0, 0.0) # get number of slices on k dimension self.gemm_mode = gemm_mode - if 'split_k_slices' in kwargs.keys(): - self.split_k_slices = kwargs['split_k_slices'] - else: - self.split_k_slices = 1 + if gemm_mode in [cutlass.gemm.Mode.Gemm, cutlass.gemm.Mode.GemmSplitKParallel]: + if 'split_k_slices' in kwargs.keys(): + self.batch_count = kwargs['split_k_slices'] + else: + self.batch_count = 1 + self.split_k_slices = self.batch_count - self.batch_count = self.split_k_slices + if gemm_mode in [cutlass.gemm.Mode.Batched, cutlass.gemm.Mode.Array]: + if 'batch' in kwargs.keys(): + self.batch_count = kwargs['batch'] + else: + self.batch_count = 1 self.batched_stride_A = self.problem_size.m() * self.problem_size.k() self.batched_stride_B = self.problem_size.n() * self.problem_size.k() self.batched_stride_C = self.problem_size.m() * self.problem_size.n() self.batched_stride_D = self.problem_size.m() * self.problem_size.n() + if self.bias: + self.batched_stride_C = self.problem_size.n() + + # support GEMM Mode Array + if gemm_mode == cutlass.gemm.Mode.Array: + self.ptr_A_array = [] + self.ptr_B_array = [] + self.ptr_C_array = [] + self.ptr_D_array = [] + + ptr_A_addr = int(self.ptr_A) + ptr_B_addr = int(self.ptr_B) + ptr_C_addr = int(self.ptr_C) + ptr_D_addr = int(self.ptr_D) + + stride_A = self.batched_stride_A * DataTypeSize[self.element_A] // 8 + stride_B = self.batched_stride_B * DataTypeSize[self.element_B] // 8 + stride_C = self.batched_stride_C * DataTypeSize[self.element_C] // 8 + stride_D = self.batched_stride_D * DataTypeSize[self.element_C] // 8 + for _ in range(self.batch_count): + self.ptr_A_array.append(ptr_A_addr) + self.ptr_B_array.append(ptr_B_addr) + self.ptr_C_array.append(ptr_C_addr) + self.ptr_D_array.append(ptr_D_addr) + + ptr_A_addr += stride_A + ptr_B_addr += stride_B + ptr_C_addr += stride_C + ptr_D_addr += stride_D + + self.ptr_A_array_buffer = todevice(self.ptr_A_array, dtype=np.int64) + self.ptr_B_array_buffer = todevice(self.ptr_B_array, dtype=np.int64) + self.ptr_C_array_buffer = todevice(self.ptr_C_array, dtype=np.int64) + self.ptr_D_array_buffer = todevice(self.ptr_D_array, dtype=np.int64) if isinstance(self.operation, GemmOperationUniversal): self.initialize() @@ -195,28 +243,28 @@ class GemmArguments(ArgumentBase): self.grid_tiled_shape.z ) ) - - argument_type, epilogue_type = get_gemm_arguments( - self.operation.element_epilogue) - - if self.operation.element_epilogue == cutlass.float16: - self.alpha = cutlass.float16(self.alpha).storage - self.beta = cutlass.float16(self.beta).storage - elif self.operation.element_epilogue == cutlass.int32: - self.alpha = int(self.alpha) - self.beta = int(self.beta) - - output_op = epilogue_type(self.alpha, self.beta, 0, 0) - - arguments = argument_type( - self.gemm_mode, problem_size_, self.batch_count, output_op, - int(self.ptr_A), int(self.ptr_B), int(self.ptr_C), int(self.ptr_D), - self.batched_stride_A, self.batched_stride_B, self.batched_stride_C, - self.batched_stride_D, - self.lda, self.ldb, self.ldc, self.ldd, - self.lda, self.ldb, self.ldc, self.ldd, - 0, 0, 0 - ) + if self.gemm_mode == cutlass.gemm.Mode.Array: + arguments = self.operation.argument_type( + self.gemm_mode, problem_size_, self.batch_count, self.output_op, + int(self.ptr_A_array_buffer.ptr), + int(self.ptr_B_array_buffer.ptr), + int(self.ptr_C_array_buffer.ptr), + int(self.ptr_D_array_buffer.ptr), + 0, 0, 0, 0, + self.lda, self.ldb, self.ldc, self.ldd, + self.lda, self.ldb, self.ldc, self.ldd, + 0, 0, 0 + ) + else: + arguments = self.operation.argument_type( + self.gemm_mode, problem_size_, self.batch_count, self.output_op, + int(self.ptr_A), int(self.ptr_B), int(self.ptr_C), int(self.ptr_D), + self.batched_stride_A, self.batched_stride_B, self.batched_stride_C, + self.batched_stride_D, + self.lda, self.ldb, self.ldc, self.ldd, + self.lda, self.ldb, self.ldc, self.ldd, + 0, 0, 0 + ) self.arguments = arguments, grid_tiled_shape_, self.gemm_k_size @@ -381,13 +429,12 @@ class GemmGroupedArguments: else: self.alpha = 1.0 self.beta = 0.0 + + if 'output_op' in kwargs.keys(): + self.output_op = kwargs['output_op'] + else: + self.output_op = self.operation.epilogue_type(1.0, 0.0) - if self.operation.element_epilogue == cutlass.float16: - self.alpha = cutlass.float16(self.alpha).storage - self.beta = cutlass.float16(self.beta).storage - elif self.operation.element_epilogue == cutlass.int32: - self.alpha = int(self.alpha) - self.beta = int(self.beta) # get host problem size self.host_problem_size_ptr = np.array( @@ -398,12 +445,7 @@ class GemmGroupedArguments: self.initialize() def get_arguments(self): - - argument_type, epilogue_type = get_gemm_grouped_arguments( - self.operation.element_epilogue) - self.output_op = epilogue_type(self.alpha, self.beta, 0, 0) - - return argument_type( + return self.operation.argument_type( self.problem_size_buffer.ptr, self.problem_count, self.total_tiles, self.output_op, self.ptr_A_buffer.ptr, self.ptr_B_buffer.ptr, self.ptr_C_buffer.ptr, self.ptr_D_buffer.ptr, self.lda_buffer.ptr, @@ -492,16 +534,6 @@ ${operation_name}(${operation_name}${operation_suffix}::Params params) { #: number of threads per threadblock self.threads: int = operation.tile_description.num_threads - if (operation.epilogue_functor in - [ - EpilogueFunctor.LinearCombination, - EpilogueFunctor.FastLinearCombinationClamp, - EpilogueFunctor.LinearCombinationClamp - ]): - self.output_op = LinearCombinationFunctor() - else: - raise ValueError("unknown epilogue functor") - # def emit(self): return self.emitter.emit(self.operation) @@ -568,9 +600,11 @@ extern "C" { def __init__(self, operation: 'GemmOperation'): super(GemmRTUniversal, self).__init__(operation) self.emitter = EmitGemmUniversalInstance( - '_type', operation.direct_store) + '_type', operation.direct_store, operation.visitor) + + self.argument_type, self.epilogue_type = get_gemm_arguments(operation.epilogue_functor) self.argtype = [ - ctypes.POINTER(get_gemm_arguments(operation.element_epilogue)[0]), + ctypes.POINTER(self.argument_type), ctypes.POINTER(GemmCoord_), ctypes.c_int, ctypes.c_void_p ] @@ -673,8 +707,8 @@ class GemmRTGrouped(GemmRTbase): self.extra_funcs = ['precompute'] self.emitter = EmitGemmGroupedInstance('_type') - self.argtype = [ctypes.POINTER(get_gemm_grouped_arguments( - operation.element_epilogue)[0]), ctypes.c_int, ctypes.c_void_p] + self.argument_type, self.epilogue_type = get_gemm_grouped_arguments(operation.epilogue_functor) + self.argtype = [ctypes.POINTER(self.argument_type), ctypes.c_int, ctypes.c_void_p] def host_precompute(self, arguments, workspace_bytes): self.precompute.argtype = [ @@ -717,7 +751,7 @@ class GemmOperationBase: def __init__( self, gemm_kind, arch, tile_description: TileDescription, A: TensorDescription, B: TensorDescription, C: TensorDescription, - element_epilogue, epilogue_functor=EpilogueFunctor.LinearCombination, + epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs): #: operation kind @@ -749,7 +783,7 @@ class GemmOperationBase: #: Operand C self.C: TensorDescription = copy.deepcopy(C) self.switched = False - self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor self.swizzling_functor = swizzling_functor() @@ -757,6 +791,11 @@ class GemmOperationBase: self.direct_store = kwargs["direct_store"] else: self.direct_store = False + + if "visitor" in kwargs: + self.visitor = kwargs["visitor"] + else: + self.visitor = False def run(self, arguments: GemmArguments) -> cuda.CUresult: """ @@ -895,22 +934,26 @@ class GemmOperationBase: class GemmOperationUniversal(GemmOperationBase): - def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C, element_epilogue, - epilogue_functor=EpilogueFunctor.LinearCombination, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs): + def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C, + epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs): super(GemmOperationUniversal, self).__init__(GemmKind.Universal, arch, tile_description, - A, B, C, element_epilogue, epilogue_functor, swizzling_functor, **kwargs) + A, B, C, epilogue_functor, swizzling_functor, **kwargs) self.rt_module = GemmRTUniversal(self) + self.argument_type = self.rt_module.argument_type + self.epilogue_type = self.rt_module.epilogue_type class GemmOperationGrouped(GemmOperationBase): - def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C, element_epilogue, - epilogue_functor=EpilogueFunctor.LinearCombination, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs): + def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C, + epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs): super(GemmOperationGrouped, self).__init__(GemmKind.Grouped, arch, tile_description, - A, B, C, element_epilogue, epilogue_functor, swizzling_functor, **kwargs) + A, B, C, epilogue_functor, swizzling_functor, **kwargs) assert "precompute_mode" in kwargs.keys( ), "missing keyword arguement 'precompute_mode'." self.precompute_mode = kwargs["precompute_mode"] self.rt_module = GemmRTGrouped(self) + self.argument_type = self.rt_module.argument_type + self.epilogue_type = self.rt_module.epilogue_type ################################################################################################### # @@ -918,228 +961,14 @@ class GemmOperationGrouped(GemmOperationBase): # ################################################################################################### -# - - -class EmitGemmInstance: - ''' Responsible for emitting a CUTLASS template definition''' - - def __init__(self, operation_suffix=''): - self.operation_suffix = operation_suffix - self.includes = [] - self.gemm_template = """ - // Gemm operator ${operation_name} - using Operation_${operation_name} = cutlass::gemm::device::Gemm< - ${element_a}, ${layout_a}, - ${element_b}, ${layout_b}, - ${element_c}, ${layout_c}, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${epilogue_functor}< - ${element_c}, - ${epilogue_vector_length}, - ${element_accumulator}, - ${element_epilogue} - >, - ${swizzling_functor}, - ${stages}, - ${align_a}, - ${align_b}, - false, - ${math_operation} - ${residual} - >; -""" - self.gemm_complex_template = """ - // Gemm operator ${operation_name} - using Operation_${operation_name} = cutlass::gemm::device::GemmComplex< - ${element_a}, ${layout_a}, - ${element_b}, ${layout_b}, - ${element_c}, ${layout_c}, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${epilogue_functor}< - ${element_c}, - ${epilogue_vector_length}, - ${element_accumulator}, - ${element_epilogue} - >, - ${swizzling_functor}, - ${stages}, - ${transform_a}, - ${transform_b}, - ${math_operation} - ${residual} - >; -""" - - # - def instance_template(self): - return """ -${compile_guard_start} - manifest.append(new ${gemm_kind}("${operation_name}")); -${compile_guard_end} -""" - - # - def emit(self, operation): - - warp_shape = [operation.tile_description.threadblock_shape[idx] // - operation.tile_description.warp_count[idx] for idx in range(3)] - - epilogue_vector_length = int(min( - operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) - - residual = '' - - values = { - 'operation_name': operation.procedural_name(), - 'element_a': DataTypeTag[operation.A.element], - 'layout_a': LayoutTag[operation.A.layout], - 'element_b': DataTypeTag[operation.B.element], - 'layout_b': LayoutTag[operation.B.layout], - 'element_c': DataTypeTag[operation.C.element], - 'layout_c': LayoutTag[operation.C.layout], - 'element_accumulator': DataTypeTag[operation.accumulator_type()], - 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - 'arch': "cutlass::arch::Sm%d" % operation.arch, - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), - 'warp_shape_m': str(warp_shape[0]), - 'warp_shape_n': str(warp_shape[1]), - 'warp_shape_k': str(warp_shape[2]), - 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), - 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), - 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), - 'epilogue_vector_length': str(epilogue_vector_length), - 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], - 'swizzling_functor': operation.swizzling_functor.tag(), - 'stages': str(operation.tile_description.stages), - 'align_a': str(operation.A.alignment), - 'align_b': str(operation.B.alignment), - 'transform_a': ComplexTransformTag[operation.A.complex_transform], - 'transform_b': ComplexTransformTag[operation.B.complex_transform], - 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], - 'residual': residual - } - - template = self.gemm_complex_template if operation.is_complex() else self.gemm_template - - return SubstituteTemplate(template, values) - -################################################################################################### - - -class EmitSparseGemmInstance: - ''' Responsible for emitting a CUTLASS template definition''' - - def __init__(self, operation_suffix=''): - self.operation_suffix = operation_suffix - self.includes = [] - self.gemm_template = """ - // Gemm operator ${operation_name} - using Operation_${operation_name} = cutlass::gemm::device::SparseGemm< - ${element_a}, ${layout_a}, - ${element_b}, ${layout_b}, - ${element_c}, ${layout_c}, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${epilogue_functor}< - ${element_c}, - ${epilogue_vector_length}, - ${element_accumulator}, - ${element_epilogue} - >, - ${swizzling_functor}, - ${stages}, - ${align_a}, - ${align_b}, - false, - ${math_operation} - ${residual} - >; -""" - - # - def instance_template(self): - return """ -${compile_guard_start} - manifest.append(new ${gemm_kind}("${operation_name}")); -${compile_guard_end} -""" - - # - def emit(self, operation): - - warp_shape = [operation.tile_description.threadblock_shape[idx] // - operation.tile_description.warp_count[idx] for idx in range(3)] - - epilogue_vector_length = int(min( - operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) - - residual = '' - - values = { - 'operation_name': operation.procedural_name(), - 'element_a': DataTypeTag[operation.A.element], - 'layout_a': LayoutTag[operation.A.layout], - 'element_b': DataTypeTag[operation.B.element], - 'layout_b': LayoutTag[operation.B.layout], - 'element_c': DataTypeTag[operation.C.element], - 'layout_c': LayoutTag[operation.C.layout], - 'element_accumulator': DataTypeTag[operation.accumulator_type()], - 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - 'arch': "cutlass::arch::Sm%d" % operation.arch, - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), - 'warp_shape_m': str(warp_shape[0]), - 'warp_shape_n': str(warp_shape[1]), - 'warp_shape_k': str(warp_shape[2]), - 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), - 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), - 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), - 'epilogue_vector_length': str(epilogue_vector_length), - 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], - 'swizzling_functor': operation.swizzling_functor.tag(), - 'stages': str(operation.tile_description.stages), - 'align_a': str(operation.A.alignment), - 'align_b': str(operation.B.alignment), - 'transform_a': ComplexTransformTag[operation.A.complex_transform], - 'transform_b': ComplexTransformTag[operation.B.complex_transform], - 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], - 'residual': residual - } - - template = self.gemm_template - - return SubstituteTemplate(template, values) - -################################################################################################### - - # class EmitGemmUniversalInstance: ''' Responsible for emitting a CUTLASS template definition''' - def __init__(self, operation_suffix='', direct_store=False): + def __init__(self, operation_suffix='', direct_store=False, visitor=False): self.operation_suffix = operation_suffix self.direct_store = direct_store + self.visitor = visitor self.includes = [ "cutlass/cutlass.h", "cutlass/numeric_types.h", @@ -1150,46 +979,15 @@ class EmitGemmUniversalInstance: "cutlass/gemm/device/gemm_universal_adapter.h", "cutlass/gemm/kernel/default_gemm_universal.h", ] + if self.visitor: + self.includes += [ + "gemm/gemm_universal_with_visitor.h", + "epilogue/epilogue_visitor_with_layernorm.h", + "epilogue/epilogue_visitor_generic.h" + ] if self.direct_store: self.includes.append( "cutlass/epilogue/threadblock/default_epilogue_direct_store.h") - self.builtin_epilogue_functor_template = """ - ${epilogue_functor}< - ${element_c}, - ${epilogue_vector_length}, - ${element_accumulator}, - ${element_epilogue} - > -""" - self.builtin_epilogue_functor_template_clamp = """ - ${epilogue_functor}< - ${element_c}, - ${epilogue_vector_length} - > -""" - self.gemm_template = """ -// Gemm operator ${operation_name} -using ${operation_name}_base = - typename cutlass::gemm::kernel::DefaultGemmUniversal< - ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, - ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, - ${element_c}, ${layout_c}, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${epilogue_functor}, - ${swizzling_functor}, - ${stages}, - ${math_operation} ->::GemmKernel; - -// Define named type -struct ${operation_name}${operation_suffix} : - public ${operation_name}_base { }; -""" self.gemm_template_interleaved = """ // Gemm operator ${operation_name} using ${operation_name}_base = @@ -1241,6 +1039,42 @@ using ${operation_name}_base = ${operation_name}_default::ThreadblockSwizzle >; +// Define named type +struct ${operation_name}${operation_suffix} : + public ${operation_name}_base { }; +""" + self.gemm_template_visitor = """ +// Gemm operator ${operation_name} +using ${operation_name}_default = + typename cutlass::gemm::kernel::DefaultGemmUniversal< + ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${elementwise_epilogue_functor}, + ${swizzling_functor}, + ${stages}, + ${math_operation} +>::GemmKernel; + +${epilogue_visitor} + +using ${operation_name}_Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue< + ${operation_name}_EpilogueVisitor, + typename ${operation_name}_default::Epilogue>::Epilogue; + +using ${operation_name}_base = + cutlass::gemm::kernel::GemmUniversalwithEpilogueVisitor< + ${operation_name}_default::Mma, + ${operation_name}_Epilogue, + ${operation_name}_default::ThreadblockSwizzle + >; + // Define named type struct ${operation_name}${operation_suffix} : public ${operation_name}_base { }; @@ -1284,32 +1118,12 @@ ${compile_guard_end} (operation.A.layout, operation.B.layout, operation.C.layout) if self.direct_store: gemm_template = self.gemm_template_direct_store + elif self.visitor: + gemm_template = self.gemm_template_visitor else: gemm_template = self.gemm_template_interleaved # - # Support built-in epilogue functors or user-defined functions - if isinstance(operation.epilogue_functor, enum.Enum): - - epilogue_vector_length = \ - min(operation.C.alignment * - DataTypeSize[operation.C.element], 128) // DataTypeSize[operation.C.element] - - values = { - 'epilogue_vector_length': str(epilogue_vector_length), - 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], - } - if operation.epilogue_functor == EpilogueFunctor.FastLinearCombinationClamp: - epilogue_functor = SubstituteTemplate( - self.builtin_epilogue_functor_template_clamp, values) - else: - epilogue_functor = SubstituteTemplate( - self.builtin_epilogue_functor_template, values) - else: - epilogue_functor = self.epilogue_functor.emit_declaration() - # - values = { 'operation_name': operation.procedural_name(), 'operation_suffix': self.operation_suffix, @@ -1331,7 +1145,6 @@ ${compile_guard_end} 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), - 'epilogue_functor': epilogue_functor, 'swizzling_functor': operation.swizzling_functor.tag(), 'stages': str(operation.tile_description.stages), 'align_a': str(operation.A.alignment), @@ -1341,6 +1154,12 @@ ${compile_guard_end} 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation] } + if self.visitor: + values['epilogue_visitor'] = operation.epilogue_functor.emit(operation) + values['elementwise_epilogue_functor'] = operation.epilogue_functor.elementwise_functor.emit() + else: + values['epilogue_functor'] = operation.epilogue_functor.emit() + return SubstituteTemplate(gemm_template, values) ################################################################################################### @@ -1348,185 +1167,6 @@ ${compile_guard_end} # -class EmitGemmPlanarComplexInstance: - ''' Responsible for emitting a CUTLASS template definition''' - - def __init__(self, operation_suffix=''): - self.operation_suffix = operation_suffix - self.includes = [] - self.template = """ - // Gemm operator ${operation_name} - using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< - ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a}, - ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b}, - ${element_c}, cutlass::layout::RowMajor, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - cutlass::epilogue::thread::LinearCombinationPlanarComplex< - ${element_c}, - ${alignment_c}, - ${element_accumulator}, - ${element_epilogue} - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - ${stages}, - ${math_operator} - >::GemmKernel; - - struct ${operation_name} : - public Operation_${operation_name} { }; -""" - - # - def instance_template(self): - return """ -${compile_guard_start} - manifest.append(new ${gemm_kind}< - cutlass::gemm::device::GemmUniversalAdapter<${operation_name}> - >("${operation_name}")); -${compile_guard_end} -""" - - # - def emit(self, operation): - - warp_shape = [operation.tile_description.threadblock_shape[idx] // - operation.tile_description.warp_count[idx] for idx in range(3)] - - # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major - transposed_layout_A = TransposedLayout[operation.A.layout] - transposed_layout_B = TransposedLayout[operation.B.layout] - - values = { - 'operation_name': operation.procedural_name(), - 'element_a': DataTypeTag[operation.B.element], - 'layout_a': LayoutTag[transposed_layout_B], - 'transform_a': ComplexTransformTag[operation.B.complex_transform], - 'alignment_a': str(operation.B.alignment), - 'element_b': DataTypeTag[operation.A.element], - 'layout_b': LayoutTag[transposed_layout_A], - 'transform_b': ComplexTransformTag[operation.A.complex_transform], - 'alignment_b': str(operation.A.alignment), - 'element_c': DataTypeTag[operation.C.element], - 'layout_c': LayoutTag[operation.C.layout], - 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator], - 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - 'arch': "cutlass::arch::Sm%d" % operation.arch, - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), - 'warp_shape_m': str(warp_shape[0]), - 'warp_shape_n': str(warp_shape[1]), - 'warp_shape_k': str(warp_shape[2]), - 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), - 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), - 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), - 'alignment_c': str(operation.C.alignment), - 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'stages': str(operation.tile_description.stages), - 'math_operator': 'cutlass::arch::OpMultiplyAdd' - } - - return SubstituteTemplate(self.template, values) - -################################################################################################### - -# - - -class EmitGemmPlanarComplexArrayInstance: - ''' Responsible for emitting a CUTLASS template definition''' - - def __init__(self, operation_suffix=''): - self.operation_suffix = operation_suffix - self.includes = [] - self.template = """ - // Gemm operator ${operation_name} - using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< - ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a}, - ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b}, - ${element_c}, cutlass::layout::RowMajor, - ${element_accumulator}, - ${opcode_class}, - ${arch}, - cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, - cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, - cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - cutlass::epilogue::thread::LinearCombinationPlanarComplex< - ${element_c}, - ${alignment_c}, - ${element_accumulator}, - ${element_epilogue} - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - ${stages}, - ${math_operator} - >::GemmArrayKernel; - - struct ${operation_name} : public Operation_${operation_name} { }; -""" - - # - def instance_template(self): - return """ -${compile_guard_start} - manifest.append(new ${gemm_kind}< - cutlass::gemm::device::GemmUniversalAdapter<${operation_name}> - >("${operation_name}")); -${compile_guard_end} -""" - - # - def emit(self, operation): - - warp_shape = [operation.tile_description.threadblock_shape[idx] // - operation.tile_description.warp_count[idx] for idx in range(3)] - - # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major - transposed_layout_A = TransposedLayout[operation.A.layout] - transposed_layout_B = TransposedLayout[operation.B.layout] - - values = { - 'operation_name': operation.procedural_name(), - 'element_a': DataTypeTag[operation.B.element], - 'layout_a': LayoutTag[transposed_layout_B], - 'transform_a': ComplexTransformTag[operation.B.complex_transform], - 'alignment_a': str(operation.B.alignment), - 'element_b': DataTypeTag[operation.A.element], - 'layout_b': LayoutTag[transposed_layout_A], - 'transform_b': ComplexTransformTag[operation.A.complex_transform], - 'alignment_b': str(operation.A.alignment), - 'element_c': DataTypeTag[operation.C.element], - 'layout_c': LayoutTag[operation.C.layout], - 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator], - 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - 'arch': "cutlass::arch::Sm%d" % operation.arch, - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), - 'warp_shape_m': str(warp_shape[0]), - 'warp_shape_n': str(warp_shape[1]), - 'warp_shape_k': str(warp_shape[2]), - 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), - 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), - 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), - 'alignment_c': str(operation.C.alignment), - 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'stages': str(operation.tile_description.stages), - 'math_operator': 'cutlass::arch::OpMultiplyAdd' - } - - return SubstituteTemplate(self.template, values) - -################################################################################################### - -# - - class EmitGemmGroupedInstance: ''' Responsible for emitting a CUTLASS template definition''' @@ -1541,14 +1181,6 @@ class EmitGemmGroupedInstance: "cutlass/gemm/kernel/gemm_grouped.h", "cutlass/gemm/kernel/default_gemm_grouped.h" ] - self.builtin_epilogue_functor_template = """ - ${epilogue_functor}< - ${element_c}, - ${epilogue_vector_length}, - ${element_accumulator}, - ${element_epilogue} - > -""" self.gemm_template = """ // Gemm operator ${operation_name} using ${operation_name}_base = @@ -1598,23 +1230,8 @@ ${compile_guard_end} # # Support built-in epilogue functors or user-defined functions - if isinstance(operation.epilogue_functor, enum.Enum): - - epilogue_vector_length = \ - min(operation.C.alignment * - DataTypeSize[operation.C.element], 128) // DataTypeSize[operation.C.element] - - values = { - 'epilogue_vector_length': str(epilogue_vector_length), - 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], - } - epilogue_functor = SubstituteTemplate( - self.builtin_epilogue_functor_template, values) - else: - epilogue_functor = self.epilogue_functor.emit_declaration() - # - + epilogue_functor = operation.epilogue_functor.emit() + values = { 'operation_name': operation.procedural_name(), 'operation_suffix': self.operation_suffix, diff --git a/tools/library/scripts/pycutlass/src/pycutlass/library.py b/tools/library/scripts/pycutlass/src/pycutlass/library.py index 3ba16752..61b59a6c 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/library.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/library.py @@ -478,27 +478,6 @@ SharedMemPerCC = { ################################################################################################### -# - - -def SubstituteTemplate(template, values): - text = template - changed = True - while changed: - changed = False - for key, value in values.items(): - regex = "\\$\\{%s\\}" % key - newtext = re.sub(regex, value, text) - if newtext != text: - changed = True - text = newtext - return text - -################################################################################################### - -# - - class GemmKind(enum.Enum): Gemm = enum_auto() Sparse = enum_auto() @@ -557,22 +536,6 @@ SymmKindNames = { # -class EpilogueFunctor(enum.Enum): - LinearCombination = enum_auto() - LinearCombinationClamp = enum_auto() - FastLinearCombinationClamp = enum_auto() - - -# -EpilogueFunctorTag = { - EpilogueFunctor.LinearCombination: 'cutlass::epilogue::thread::LinearCombination', - EpilogueFunctor.LinearCombinationClamp: 'cutlass::epilogue::thread::LinearCombinationClamp', - EpilogueFunctor.FastLinearCombinationClamp: 'cutlass::epilogue::thread::FastLinearCombinationClamp' -} - -# - - class SwizzlingFunctor(enum.Enum): Identity1 = enum_auto() Identity2 = enum_auto() @@ -700,7 +663,7 @@ class MathInstruction: class TileDescription: - def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute): + def __init__(self, threadblock_shape, stages, warp_count, math_instruction): self.threadblock_shape = threadblock_shape #: number of pipeline stages @@ -710,11 +673,6 @@ class TileDescription: self.warp_count: list[int] = warp_count self.math_instruction = math_instruction - #: minimum compute capability - self.minimum_compute_capability: int = min_compute - #: maximum compute capability - self.maximum_compute_capability: int = max_compute - #: number threads per threadblock self.num_threads: int = 32 for cnt in self.warp_count: diff --git a/tools/library/scripts/pycutlass/src/pycutlass/parser.py b/tools/library/scripts/pycutlass/src/pycutlass/parser.py new file mode 100644 index 00000000..744149e8 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/pycutlass/parser.py @@ -0,0 +1,619 @@ +################################################################################ +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +from typing import Generic, TypeVar +from treelib import Tree +import numpy as np + +from pycutlass import * +import pycutlass + +import ast +import textwrap +import inspect + +################################################################################ +# Type annotation for input arguments +################################################################################ + +Ttype = TypeVar("Ttype") +Dtype = TypeVar("Dtype") + +class NDArray(np.ndarray, Generic[Ttype, Dtype]): + pass + +################################################################################ +# Operations +################################################################################ + +operators = { + ast.Add: "Add", + ast.Div: "Div", + ast.Eq: "Equal", + ast.Mult: "Mult" +} + +################################################################################ +# AST Node abstractions +################################################################################ +class UnaryNode: + cnt = 0 + # Concept: this is created by the BinOp Node in python ast + def __init__(self, + element_accumulator, element_compute, elements_per_access, + node, args) -> None: + if isinstance(node, BinOpNode): + self.op = node.op + elif isinstance(node, ast.Call): + if isinstance(node.func, ast.Name): + self.op = node.func.id + elif isinstance(node.func, ast.Attribute): + self.op = node.func.value.id + else: + raise TypeError + else: + raise TypeError + self.tag = "Unary" + self.op + str(UnaryNode.cnt) + self.id = self.op + str(UnaryNode.cnt) + self.args = args + UnaryNode.cnt += 1 + + self.type = "tensor" + + self.epilogue_op = getattr(pycutlass, self.op)(element_compute) + + # data types + self.element_accumulator = element_accumulator + self.element_compute = element_compute + self.elements_per_access = elements_per_access + + def get_epilogue_node(self, visitors): + self.epilogue_node = UnaryOp( + self.element_accumulator, self.element_compute, + self.elements_per_access, *visitors, self.epilogue_op) + + def get_argument(self, visitor_args, kwargs): + epilogue_ops = [] + for arg in self.args: + try: + epilogue_ops.append(kwargs[arg]) + except: + epilogue_ops.append(arg) # direct arguments like constant + self.argument = self.epilogue_node.argument_type(self.epilogue_op.argument_type(*epilogue_ops), *visitor_args) + + +class BinOpNode: + cnt = 0 + # Concept: this is created by the BinOp Node in python ast + def __init__(self, + element_accumulator, element_compute, elements_per_access, + node) -> None: + self.op = operators[type(node.op)] + self.tag = "Binary" + self.op + str(BinOpNode.cnt) + self.id = self.op + str(BinOpNode.cnt) + self.args = None + BinOpNode.cnt += 1 + + self.type = "tensor" + + self.epilogue_op = getattr(pycutlass, "Vector"+self.op)(element_compute) + + # data types + self.element_accumulator = element_accumulator + self.element_compute = element_compute + self.elements_per_access = elements_per_access + + def get_epilogue_node(self, visitors): + self.epilogue_node = BinaryOp( + self.element_accumulator, self.element_compute, + self.elements_per_access, *visitors, self.epilogue_op) + + def get_argument(self, visitor_args, kwargs): + self.argument = self.epilogue_node.argument_type(self.epilogue_op.argument_type(self.args), *visitor_args) + + +class NameNode: + # Concept: this is created by the Name Node in python ast + def __init__(self, node) -> None: + try: + self.id = node.id + except: + self.id = node.targets[0].id + self.tag = self.id + +class ScalarInputNode(NameNode): + # Concept: scalar + def __init__(self, node) -> None: + super().__init__(node) + self.tag = "Scalar:" + self.tag + self.type = "scalar" + +class AccumulatorNode(NameNode): + # Concept: VisitorOpAccumulator + def __init__(self, + element_accumulator, elements_per_access, node) -> None: + super().__init__(node) + self.tag = "Accum:" + self.tag + self.type = "tensor" + + self.element_accumulator = element_accumulator + self.elements_per_access = elements_per_access + + def get_epilogue_node(self, visitors): + self.epilogue_node = AccumulatorOp( + self.element_accumulator, self.elements_per_access) + + def get_argument(self, visitor_args, kwargs): + self.argument = self.epilogue_node.argument_type() + +class TensorInputNode(NameNode): + # Concept: VisitorOpTensorInput + def __init__(self, element_accumulator, node) -> None: + super().__init__(node) + self.tag = "TensorInput:" + self.tag + self.type = "tensor" + self.element_accumulator = element_accumulator + + def get_epilogue_node(self, *args): + self.epilogue_node = TensorInputOp(self.element_accumulator) + + def get_argument(self, visitor_args, kwargs): + self.argument = self.epilogue_node.argument_type( + kwargs[self.id + "_ptr"], kwargs["problem_size"][1], + kwargs["problem_size"][0] * kwargs["problem_size"][1]) + +class RowBroadcastNode(NameNode): + # Concept: VisitorOpRowBroadcast + def __init__(self, element_accumulator, element_fragment, node) -> None: + super().__init__(node) + # + self.tag = "RowBroadcast:" + self.tag + self.type = "tensor" + self.element_accumulator = element_accumulator + self.element_fragment = element_fragment + + def get_epilogue_node(self, *args): + self.epilogue_node = RowBroadcastOp( + self.element_accumulator, self.element_fragment) + + def get_argument(self, visitor_args, kwargs): + self.argument = self.epilogue_node.argument_type(kwargs[self.id + "_ptr"], kwargs["problem_size"][1]) + +class ColumnBroadcastNode(NameNode): + # Concept: VisitorOpColumnBroadcast + def __init__(self, element_accumulator, element_fragment, node) -> None: + super().__init__(node) + self.tag = "ColumnBroadcast:" + self.tag + self.type = "tensor" + self.element_accumulator = element_accumulator + self.element_fragment = element_fragment + + def get_epilogue_node(self, *args): + self.epilogue_node = ColumnBroadcastOp( + self.element_accumulator, self.element_fragment) + + def get_argument(self, visitor_args, kwargs): + self.argument = self.epilogue_node.argument_type(kwargs[self.id + "_ptr"], kwargs["problem_size"][0]) + +class TensorOutputNode(NameNode): + # Concept: VisitorOpTensorOutput + def __init__(self, element_accumulator, node) -> None: + super().__init__(node) + self.tag = "TensorOutput:" + self.tag + self.type = "tensor" + self.element_accumulator = element_accumulator + + def get_epilogue_node(self, visitors): + self.epilogue_node = TensorOutputOp(self.element_accumulator, *visitors) + + def get_argument(self, visitor_args, kwargs): + self.argument = self.epilogue_node.argument_type(kwargs[self.id + "_ptr"], kwargs["problem_size"][1], *visitor_args, kwargs["problem_size"][0] * kwargs["problem_size"][1]) + +class RowReductionNode: + # Concept: RowReductionOp + def __init__(self, element_accumulator, element_reduction, + element_reduction_accumulator, id, factor) -> None: + # + self.id = id + self.tag = "RowReduction:" + self.id + self.type = "tensor" + self.element_accumulator = element_accumulator + self.element_reduction = element_reduction + self.element_reduction_accumulator = element_reduction_accumulator + self.factor = factor + + def get_epilogue_node(self, visitors): + self.epilogue_node = RowReductionOp( + self.element_accumulator, self.element_reduction, + self.element_reduction_accumulator, *visitors) + + def get_batch_stride(self, problem_size): + return problem_size[0] * ((problem_size[1] + self.factor - 1) // self.factor) + + def get_argument(self, visitor_args, kwargs): + self.argument = self.epilogue_node.argument_type(kwargs[self.id + "_ptr"], *visitor_args, self.get_batch_stride(kwargs["problem_size"])) + +class ColumnReductionNode: + # Concept: ColumnReductionOp + def __init__(self, element_accumulator, element_reduction, + element_reduction_accumulator, id, factor) -> None: + # + self.id = id + self.tag = "ColumnReduction:" + self.id + self.type = "tensor" + self.element_accumulator = element_accumulator + self.element_reduction = element_reduction + self.element_reduction_accumulator = element_reduction_accumulator + self.factor = factor + + def get_epilogue_node(self, visitors): + self.epilogue_node = ColumnReductionOp( + self.element_accumulator, self.element_reduction, + self.element_reduction_accumulator, *visitors) + + def get_batch_stride(self, problem_size): + return problem_size[1] * ((problem_size[0] + self.factor - 1) // self.factor) + + def get_argument(self, visitor_args, kwargs): + self.argument = self.epilogue_node.argument_type(kwargs[self.id + '_ptr'], *visitor_args, self.get_batch_stride(kwargs["problem_size"])) + +################################################################################ +# Epilogue parser function +################################################################################ +class EpilogueAST(ast.NodeVisitor): + def __init__(self, epilogue, + tile_description, + element_accumulator, elements_per_access, + element_compute, element_output) -> None: + # + + self.tile_description = tile_description + self.element_accumulator = element_accumulator + self.elements_per_access = elements_per_access + self.element_compute = element_compute + self.element_output = element_output + self.epilogue = epilogue + + self.source = textwrap.dedent(inspect.getsource(epilogue.__call__)) + self.ast_tree = ast.parse(self.source) + self.epilogue_tree = Tree() + + + # print(ast.dump(self.ast_tree, indent=4)) # For Debug purpose + + # input arguments + self.input_args = {} + # return nodes + self.returns = [] + # reduction source nodes + self.reduction_source = {} + + # stack used to keep the parent node id + self.stack = [] + + # visit the AST + self.visit(self.ast_tree) + + # visit the name node + def visit_Name(self, node): + # append the return ids into self.returns + if self.stack[-1] == "return": + self.returns.append(node.id) + else: + # accum is produced from accumulator node + if node.id == "accum": + name_node = AccumulatorNode( + self.element_accumulator, self.elements_per_access, node) + else: + # for input nodes + if node.id in self.input_args.keys(): + type = self.input_args[node.id][0] + if type == "tensor": + name_node = TensorInputNode(self.element_accumulator, node) + elif type == "row": + name_node = RowBroadcastNode(self.element_accumulator, self.element_compute, node) + elif type == "column": + name_node = ColumnBroadcastNode(self.element_accumulator, self.element_compute, node) + elif type == "scalar": + name_node = ScalarInputNode(node) + else: + raise ValueError(type) + # for output nodes + else: + name_node = TensorOutputNode(self.element_accumulator, node) + self.epilogue_tree.create_node(name_node.tag, name_node.id, data=name_node, parent=self.stack[-1]) + + def visit_Assign(self, node): + pre_assign_node = self.epilogue_tree.get_node(node.targets[0].id) + if pre_assign_node is None: + # The assign is to a root node + # skip the reduction nodes + if isinstance(node.value, ast.Call): + if isinstance(node.value.func, ast.Name): + func_type = node.value.func.id + elif isinstance(node.value.func, ast.Attribute): + func_type = node.value.func.value.id + else: + raise TypeError + if func_type == 'reduction_op': + self.reduction_source[node.value.args[0].id] = [node.value.args[1].value, node.value.args[2].value, node.targets[0].id] + return + name_node = TensorOutputNode(self.element_accumulator, node) + self.epilogue_tree.create_node(name_node.tag, name_node.id, data=name_node) + self.stack.append(name_node.id) + else: + if node.targets[0].id in self.returns or node.targets[0].id in self.reduction_source.keys(): + self.stack.append(node.targets[0].id) + else: + self.stack.append(pre_assign_node.predecessor(self.epilogue_tree.identifier)) + self.epilogue_tree.remove_node(node.targets[0].id) + + # get child tag + self.visit(node.value) + self.stack.pop() + + def visit_Call(self, node): + if isinstance(node.func, ast.Name): + func_type = node.func.id + elif isinstance(node.func, ast.Attribute): + func_type = node.func.value.id + else: + raise TypeError + if func_type == "reduction_op": + self.visit(node.args[0]) + else: + arg_list = [] + for idx, arg in enumerate(node.args): + if idx == 0: continue + if isinstance(arg, ast.Constant): + arg_list.append(arg.value) + elif isinstance(arg, ast.Name): + arg_list.append(arg.id) + else: + raise TypeError + + unary_node = UnaryNode(self.element_accumulator, self.element_compute, self.elements_per_access, node, arg_list) + self.epilogue_tree.create_node(unary_node.tag, unary_node.id, parent=self.stack[-1], data=unary_node) + self.stack.append(unary_node.id) + self.visit(node.args[0]) + self.stack.pop() + + def visit_BinOp(self, node): + binop = BinOpNode(self.element_accumulator, self.element_compute, + self.elements_per_access, node) + self.epilogue_tree.create_node(binop.tag, binop.id, data=binop, parent=self.stack[-1]) + self.stack.append(binop.id) + self.visit(node.left) + self.visit(node.right) + self.stack.pop() + + def visit_Return(self, node): + self.stack.append("return") + self.visit(node.value) + self.stack.pop() + + # # A function definition + def visit_FunctionDef(self, node: ast.FunctionDef): + # visit args + for arg in node.args.args: + if arg.arg == "self": continue + if isinstance(arg.annotation, ast.Constant): + self.input_args[arg.arg] = [arg.annotation.value, ] + # visit the assign in the reverse order + for idx in range(len(node.body)): + self.visit(node.body[-1-idx]) + + # + # Tree optimization pass + # + + # pass 1: lower Binary to Unary + def pass_binary_2_unary(self, tree, nid): + node = tree.get_node(nid) + if isinstance(node.data, BinOpNode): + lhs_node = tree.get_node(node.successors(tree.identifier)[0]) + left_type = lhs_node.data.type + rhs_node = tree.get_node(node.successors(tree.identifier)[1]) + right_type = rhs_node.data.type + + if left_type == "scalar" and right_type == "tensor": + node.data = UnaryNode( + self.element_accumulator, self.element_compute, + self.elements_per_access, + node.data, [lhs_node.data.id,]) + node.tag = node.data.tag + tree.remove_node(lhs_node.data.id) + self.pass_binary_2_unary(tree, rhs_node.data.id) + + elif left_type == "tensor" and right_type == "scalar": + node.data = UnaryNode( + self.element_accumulator, self.element_compute, + self.elements_per_access, + node.data, [rhs_node.id,]) + node.tag = node.data.tag + tree.remove_node(rhs_node.data.id) + self.pass_binary_2_unary(tree, lhs_node.data.id) + + else: + self.pass_binary_2_unary(tree, lhs_node.data.id) + self.pass_binary_2_unary(tree, rhs_node.data.id) + else: + for child in node.successors(tree.identifier): + self.pass_binary_2_unary(tree, child) + + # pass 2: inject reduction nodes + def pass_inject_reduction(self, tree, nid): + node = tree.get_node(nid) + if isinstance(node.data, TensorOutputNode): + if node.data.id in self.reduction_source.keys(): + direction = self.reduction_source[node.data.id][0] + target = self.reduction_source[node.data.id][-1] + if direction == 'row': + reduction_node = RowReductionNode( + self.element_accumulator, self.element_output, + self.element_accumulator, target, self.tile_description.threadblock_shape[1]) + elif direction == "column": + reduction_node = ColumnReductionNode( + self.element_accumulator, self.element_output, + self.element_accumulator, target, self.tile_description.threadblock_shape[0]) + else: + raise ValueError(direction) + child_nid = node.successors(tree.identifier)[0] + # if this output node is injected only for reduction + if node.data.id not in self.returns: + # get reduction config from disc + node.data = reduction_node + node.tag = reduction_node.tag + self.pass_inject_reduction(tree, child_nid) + # if this output node is also a tensor output, inject reduction as its children + else: + # get child node + tree.create_node(reduction_node.tag, reduction_node.id, data=reduction_node, parent=node.data.id) + tree.move_node(child_nid, reduction_node.id) + child = tree.get_node(child_nid) + for grand_child in child.successors(tree.identifier): + self.pass_inject_reduction(tree, grand_child) + else: + for child in node.successors(tree.identifier): + self.pass_inject_reduction(tree, child) + else: + for child in node.successors(tree.identifier): + self.pass_inject_reduction(tree, child) + + def pass_inject_epilogue_op(self, tree, nid): + node = tree.get_node(nid) + visitors = [] + for child in node.successors(tree.identifier): + visitors.append(self.pass_inject_epilogue_op(tree, child)) + + node.data.get_epilogue_node(visitors) + return node.data.epilogue_node + + def get_arguments(self, tree, nid, kwargs): + node = tree.get_node(nid) + visitor_args = [] + for child in node.successors(tree.identifier): + visitor_args.append(self.get_arguments(tree, child, kwargs)) + + node.data.get_argument(visitor_args, kwargs) + return node.data.argument + +class EpilogueVisitTree: + KernelTemplate = """ +${visitor} + +using ${operation_name}_EpilogueVisitor = cutlass::epilogue::threadblock::EpilogueVisitorGeneric<${visitor_name}>; +""" + def __init__(self, elementwise_functor, tile_description, + element_accumulator, elements_per_access, + element_compute, element_output) -> None: + # + # data types + self.tile_description = tile_description + self.element_accumulator = element_accumulator + self.elements_per_access = elements_per_access + self.element_compute = element_compute + self.element_output = element_output + # TODO: deprecate this + self.elementwise_functor = elementwise_functor + pass + + def initialize(self): + function = EpilogueAST(self, self.tile_description, + self.element_accumulator, self.elements_per_access, + self.element_compute, self.element_output) + # + tree = function.epilogue_tree + self.tree = tree + # self.tree.show() # for debug + function.pass_binary_2_unary(self.tree, self.tree.root) + # self.tree.show() # for debug + function.pass_inject_reduction(self.tree, self.tree.root) + # self.tree.show() # for debug + function.pass_inject_epilogue_op(self.tree,self.tree.root) + + visitor = self.tree.get_node(self.tree.root).data.epilogue_node + self.visitor = visitor + + class _Argument(ctypes.Structure): + _fields_ = [ + ("visitor_arg", visitor.argument_type) + ] + def __init__(self, **kwargs) -> None: + # process input args + _kwargs = {} + for input_key in function.input_args.keys(): + if input_key == "accum": + continue + if function.input_args[input_key][0] == "scalar": + # _kwargs[input_key] = kwargs[input_key] + continue + # tensor input + else: + setattr(self, "buffer_tensor_" + input_key, NumpyFrontend.argument(kwargs[input_key], False)) + setattr(self, input_key + "_ptr", int(getattr(self, "buffer_tensor_" + input_key).ptr)) + _kwargs[input_key+"_ptr"] = getattr(self, input_key + "_ptr") + # process the return args + for ret in function.returns: + setattr(self, "buffer_tensor_" + ret, NumpyFrontend.argument(kwargs[ret], True)) + setattr(self, ret + "_ptr", int(getattr(self, "buffer_tensor_" + ret).ptr)) + _kwargs[ret+"_ptr"] = getattr(self, ret + "_ptr") + setattr(self, "host_tensor_" + ret, kwargs[ret]) + + _kwargs.update(kwargs) + function.get_arguments(tree, tree.root, _kwargs) + self.visitor_arg = tree.get_node(tree.root).data.argument + + def sync(self, stream_sync=True): + if stream_sync: + err, = cudart.cudaDeviceSynchronize() + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("CUDA Error %s" % str(err)) + + for ret in function.returns: + err, = cuda.cuMemcpyDtoH( + getattr(self, "host_tensor_" + ret), cuda.CUdeviceptr(getattr(self, ret + "_ptr")), + getattr(self, "host_tensor_" + ret).size * getattr(self, "host_tensor_" + ret).itemsize + ) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("CUDA Error %s" % str(err)) + pass + + self.epilogue_type = _Argument + + def emit(self, operation): + values = { + 'visitor': self.visitor.emit(operation), + 'operation_name': operation.procedural_name(), + 'visitor_name': self.visitor.instance_name + } + return SubstituteTemplate(self.KernelTemplate, values) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py b/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py index a5f7217a..9c5db3a1 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py @@ -58,6 +58,13 @@ class ReductionArguments: destination: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]', source: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]', **kwargs) -> None: + # tensor_C can be interpreted as the bias with bias=True in keyword args + if "bias" in kwargs.keys(): + self.bias = kwargs["bias"] + else: + # by default, tensor_C is not bias + self.bias = False + self.operation = operation #: pointer to the workspace self.ptr_workspace = workspace @@ -89,11 +96,9 @@ class ReductionArguments: problem_size[1] * DataTypeSize[operation.C.element] // 8 if "output_op" in kwargs.keys(): - self.alpha = kwargs["output_op"].alpha - self.beta = kwargs["output_op"].beta + self.output_op = kwargs['output_op'] else: - self.alpha = 1.0 - self.beta = 0.0 + self.output_op = self.operation.epilogue_type(1.0, 0.0) # get arguments self.get_arguments() @@ -109,31 +114,25 @@ class ReductionArguments: ref_workspace = ReductionArguments.get_tensor_ref( extent=[self.problem_size.row, self.problem_size.column], device_ptr=self.ptr_workspace, layout=cutlass.RowMajor) - - ref_source = ReductionArguments.get_tensor_ref( - extent=[self.problem_size.row, self.problem_size.column], - device_ptr=self.ptr_source, layout=cutlass.RowMajor) + if self.bias: + ref_source = ReductionArguments.get_tensor_ref( + extent=[0, 0], + device_ptr=self.ptr_source, layout=cutlass.RowMajor) + else: + ref_source = ReductionArguments.get_tensor_ref( + extent=[self.problem_size.row, self.problem_size.column], + device_ptr=self.ptr_source, layout=cutlass.RowMajor) ref_destination = ReductionArguments.get_tensor_ref( extent=[self.problem_size.row, self.problem_size.column], device_ptr=self.ptr_destination, layout=cutlass.RowMajor) - argument_type, epilogue_type = get_reduction_params( - self.operation.element_compute) - if self.operation.element_compute == cutlass.float16: - self.alpha = cutlass.float16(self.alpha).storage - self.beta = cutlass.float16(self.beta).storage - elif self.operation.element_compute == cutlass.int32: - self.alpha = int(self.alpha) - self.beta = int(self.beta) - - output_op = epilogue_type(self.alpha, self.beta, 0, 0) - self.c_arguments = argument_type( + self.c_arguments = self.operation.argument_type( self.problem_size, self.partitions, self.partition_stride, ref_workspace, ref_destination, ref_source, - output_op + self.output_op ) params_ = self.operation.rt_module.get_args( @@ -210,8 +209,8 @@ extern "C" { self.emitter = EmitReductionInstance('_type') self.elements_per_access = self.operation.count - self.argtype = [ctypes.POINTER( - get_reduction_params(operation.element_compute)[0])] + self.argument_type, self.epilogue_type = get_reduction_params(operation.epilogue_functor) + self.argtype = [ctypes.POINTER(self.argument_type)] def emit(self): return self.emitter.emit(self.operation) @@ -247,14 +246,14 @@ class ReductionOperation: def __init__(self, shape: cutlass.MatrixCoord, C: TensorDescription, element_accumulator, element_workspace=None, - element_compute=None, epilogue_functor: EpilogueFunctor = EpilogueFunctor.LinearCombination, + element_compute=None, epilogue_functor=None, count: int = 1, partitions_per_stage: int = 4) -> None: """ Constructor """ self.shape = shape #: epilogue functor (default: LinearCombination) - self.epilogue_functor: EpilogueFunctor = epilogue_functor + self.epilogue_functor = epilogue_functor #: datatype of accumulator self.element_accumulator = element_accumulator @@ -285,6 +284,8 @@ class ReductionOperation: self.partitions_per_stage: int = partitions_per_stage self.rt_module: ReductionRT = ReductionRT(self) + self.argument_type = self.rt_module.argument_type + self.epilogue_type = self.rt_module.epilogue_type # def extended_name(self): @@ -363,12 +364,7 @@ class EmitReductionInstance: using ${operation_name}_base = typename cutlass::reduction::kernel::ReduceSplitK< cutlass::MatrixShape<${shape_row}, ${shape_column}>, - ${epilogue_functor}< - ${element_output}, - ${epilogue_vector_length}, - ${element_accumulator}, - ${element_compute} - >, + ${epilogue_functor}, cutlass::reduction::thread::ReduceAdd< ${element_accumulator}, ${element_output}, @@ -389,7 +385,7 @@ struct ${operation_name}${operation_suffix}: 'operation_suffix': self.operation_suffix, 'shape_row': str(operation.shape.row()), 'shape_column': str(operation.shape.column()), - 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'epilogue_functor': operation.epilogue_functor.emit(), 'element_output': DataTypeTag[operation.element_output], 'epilogue_vector_length': str(epilogue_vector_length), 'element_accumulator': DataTypeTag[operation.element_accumulator], diff --git a/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py b/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py index 4d2b89e6..bd863ffd 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py @@ -68,4 +68,3 @@ class TensorRef: # the dtype(0) is used to overload between different data types # with the same layout self.tensor_ref = cutlass.get_tensor_ref(int(ptr), dtype(0), layout) - diff --git a/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py b/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py index 5f9cd3c1..6f733b7b 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py @@ -124,7 +124,7 @@ class Conv2dLauncher: self.reduction_operation = ReductionOperation( shape=cutlass.MatrixCoord(4, 32 * operation.C.alignment), C=operation.C, element_accumulator=operation.tile_description.math_instruction.element_accumulator, - element_compute=operation.element_epilogue, + element_compute=operation.epilogue_functor.element_epilogue, epilogue_functor=operation.epilogue_functor, count=operation.C.alignment ) @@ -183,7 +183,7 @@ class Conv2dLauncher: # Get the host reference function # - self.element_compute = operation.element_epilogue + self.element_compute = operation.epilogue_functor.element_epilogue self.host_conv2d = cutlass.test.conv.host.conv2d @@ -441,7 +441,7 @@ class Conv2dLauncher: arguments = Conv2dArguments( operation=self.operation, problem_size=problem_size, A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, - output_op = LinearCombinationFunctorArguments(alpha, beta), + output_op = self.operation.epilogue_type(alpha, beta), split_k_slices=problem_size.split_k_slices, split_k_mode=split_k_mode ) @@ -454,7 +454,7 @@ class Conv2dLauncher: workspace=arguments.ptr_D, destination=tensor_D, source=tensor_C, - output_op = LinearCombinationFunctorArguments(alpha, beta) + output_op = self.reduction_operation.epilogue_type(alpha, beta) ) self.operation.run(arguments) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py b/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py index 467d965f..a58cd46b 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py @@ -68,7 +68,7 @@ class TestbedGrouped: self.scope_min = -8 #: compute type - self.compute_type = operation.element_epilogue + self.compute_type = operation.epilogue_functor.element_epilogue self.accumulator_type = operation.tile_description.math_instruction.element_accumulator @@ -176,7 +176,7 @@ class TestbedGrouped: arguments = GemmGroupedArguments( operation=self.operation, problem_sizes=problem_sizes, A=tensor_As, B=tensor_Bs, C=tensor_Cs, D=tensor_Ds, - output_op=LinearCombinationFunctorArguments(alpha, beta) + output_op=self.operation.epilogue_type(alpha, beta) ) self.operation.run(arguments) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py b/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py index 344f20ec..f5c0d78e 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py @@ -143,7 +143,7 @@ class GemmUniversalLauncher: self.reduction_operation: ReductionOperation = ReductionOperation( shape=cutlass.MatrixCoord(4, 32 * operation.C.alignment), C=operation.C, element_accumulator=operation.tile_description.math_instruction.element_accumulator, - element_compute=operation.element_epilogue, + element_compute=operation.epilogue_functor.element_epilogue, epilogue_functor=operation.epilogue_functor, count=operation.C.alignment ) @@ -200,7 +200,7 @@ class GemmUniversalLauncher: self.interleaved = interleaved #: compute type - self.compute_type = operation.element_epilogue + self.compute_type = operation.epilogue_functor.element_epilogue self.accumulator_type = operation.tile_description.math_instruction.element_accumulator def print_problem_size(self, p, mode, batch_count): @@ -391,7 +391,7 @@ class GemmUniversalLauncher: arguments = GemmArguments( operation=self.operation, problem_size=problem_size, A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, - output_op=LinearCombinationFunctorArguments(alpha, beta), + output_op=self.operation.epilogue_type(alpha, beta), gemm_mode=mode, split_k_slices=batch_count ) @@ -403,7 +403,7 @@ class GemmUniversalLauncher: workspace=arguments.ptr_D, destination=tensor_D, source=tensor_C, - output_op=LinearCombinationFunctorArguments(alpha, beta) + output_op=self.reduction_operation.epilogue_type(alpha, beta) ) self.operation.run(arguments) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/type.py b/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py similarity index 100% rename from tools/library/scripts/pycutlass/src/pycutlass/type.py rename to tools/library/scripts/pycutlass/src/pycutlass/type_hint.py diff --git a/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py b/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py index 809fcf99..edb483a9 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py @@ -34,6 +34,7 @@ import numpy as np import cutlass from pycutlass.library import TensorDescription from typing import Union +from bfloat16 import bfloat16 try: import torch torch_available = True @@ -46,7 +47,7 @@ class ReferenceModule: self.layout_B = B.layout self.layout_C = C.layout - def run(self, A: np.ndarray, B: np.ndarray, C: np.ndarray, problem_size: cutlass.gemm.GemmCoord, alpha: float=1.0, beta: float=0.0): + def run(self, A: np.ndarray, B: np.ndarray, C: np.ndarray, problem_size: cutlass.gemm.GemmCoord, alpha: float=1.0, beta: float=0.0, bias=False, batch=1): """ Compute the reference result on CPU Args: @@ -57,27 +58,38 @@ class ReferenceModule: M, N, K = problem_size.m(), problem_size.n(), problem_size.k() if isinstance(A, np.ndarray): if self.layout_A == cutlass.RowMajor: - A_row = np.reshape(A, newshape=(M, K)) + A_row = np.reshape(A, newshape=(batch, M, K)) else: - A_col = np.reshape(A, newshape=(K, M)) - A_row = np.transpose(A_col, axes=(1, 0)) + A_col = np.reshape(A, newshape=(batch, K, M)) + A_row = np.transpose(A_col, axes=(0, 2, 1)) if self.layout_B == cutlass.RowMajor: - B_row = np.reshape(B, newshape=(K, N)) + B_row = np.reshape(B, newshape=(batch, K, N)) else: - B_col = np.reshape(B, newshape=(N, K)) - B_row = np.transpose(B_col, axes=(1, 0)) + B_col = np.reshape(B, newshape=(batch, N, K)) + B_row = np.transpose(B_col, axes=(0, 2, 1)) if self.layout_C == cutlass.RowMajor: - C_row = np.reshape(C, newshape=(M, N)) + if bias: + C_row = np.reshape(C, newshape=(batch, 1, N)) + else: + C_row = np.reshape(C, newshape=(batch, M, N)) else: - C_col = np.reshape(C, newshape=(N, M)) - C_row = np.transpose(C_col, axes=(1, 0)) + if bias: + C_row = np.reshape(C, newshape=(batch, M, 1)) + else: + C_col = np.reshape(C, newshape=(batch, N, M)) + C_row = np.transpose(C_col, axes=(0, 2, 1)) - out_row = np.matmul(A_row, B_row) * alpha + C_row * beta + if A_row.dtype == bfloat16: + # numpy's einsum doesn't support bfloat16 + out_row = np.einsum("bik,bkj->bij", A_row.astype(np.float32), B_row.astype(np.float32)) * alpha + C_row * beta + out_row = out_row.astype(C_row.dtype) + else: + out_row = np.einsum("bik,bkj->bij", A_row, B_row) * alpha + C_row * beta if self.layout_C == cutlass.ColumnMajor: - out = np.transpose(out_row, axes=(1, 0)) + out = np.transpose(out_row, axes=(0, 2, 1)) else: out = out_row @@ -128,7 +140,7 @@ if torch_available: def run(self, A: Union[np.ndarray, torch.Tensor], B: Union[np.ndarray, torch.Tensor], - C: Union[np.ndarray, torch.Tensor], problem_size, alpha=1.0, beta=0.0) -> np.ndarray: + C: Union[np.ndarray, torch.Tensor], problem_size, alpha=1.0, beta=0.0, bias=False) -> np.ndarray: """ Compute the reference result on CPU """ @@ -184,7 +196,10 @@ if torch_available: B_torch_nchw = torch.permute(B_nhwc, (0, 3, 1, 2)) if self.layout_C == cutlass.TensorNHWC: - C_nhwc = C.view((k, r, s, c)) + if bias: + C_nhwc = C.view((1, 1, 1, c)) + else: + C_nhwc = C.view((k, r, s, c)) C_torch_nchw = torch.permute(C_nhwc, (0, 3, 1, 2)) elif self.kind == cutlass.conv.Operator.dgrad: if self.layout_A == cutlass.TensorNHWC: @@ -196,7 +211,10 @@ if torch_available: B_torch_nchw = torch.permute(B_nhwc, (0, 3, 1, 2)) if self.layout_C == cutlass.TensorNHWC: - C_nhwc = C.view((n, h, w, c)) + if bias: + C_nhwc = C.view((1, 1, 1, c)) + else: + C_nhwc = C.view((n, h, w, c)) C_torch_nchw = torch.permute(C_nhwc, (0, 3, 1, 2)) else: if self.layout_A == cutlass.TensorNHWC: @@ -208,7 +226,10 @@ if torch_available: B_torch_nchw = torch.permute(B_nhwc, (0, 3, 1, 2)) if self.layout_C == cutlass.TensorNHWC: - C_nhwc = C.view((n, p, q, k)) + if bias: + C_nhwc = C.view((1, 1, 1, k)) + else: + C_nhwc = C.view((n, p, q, k)) C_torch_nchw = torch.permute(C_nhwc, (0, 3, 1, 2)) if self.kind == cutlass.conv.Operator.fprop: diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py index fd3309bc..24d70376 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py @@ -30,15 +30,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC tile_description = TileDescription( threadblock_shape=[128, 128, 64], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float16) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float16, stride_support=StrideSupport.Unity, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Unity, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) @@ -68,15 +71,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC tile_description = TileDescription( threadblock_shape=[128, 128, 64], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float16) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float16, stride_support=StrideSupport.Unity, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Unity, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) @@ -106,15 +112,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC tile_description = TileDescription( threadblock_shape=[128, 128, 64], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float16) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float16, stride_support=StrideSupport.Unity, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Unity, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) @@ -156,15 +165,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC tile_description = TileDescription( threadblock_shape=[128, 128, 64], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float16) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float16, stride_support=StrideSupport.Unity, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Unity, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py index bb8eff46..e9ce7460 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -29,15 +29,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC tile_description = TileDescription( threadblock_shape=[128, 128, 32], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Unity, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) @@ -67,15 +70,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC tile_description = TileDescription( threadblock_shape=[128, 128, 32], stages=4, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Unity, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) @@ -105,15 +111,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC tile_description = TileDescription( threadblock_shape=[128, 128, 64], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Unity, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) @@ -143,15 +152,18 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC tile_description = TileDescription( threadblock_shape=[128, 128, 64], stages=4, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Unity, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py index 50cb2598..0351bc5c 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py @@ -30,15 +30,18 @@ class Conv2dDgradImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase) tile_description = TileDescription( threadblock_shape=[128, 128, 8], stages=4, warp_count=[4, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Unity, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) @@ -68,15 +71,18 @@ class Conv2dDgradImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase) tile_description = TileDescription( threadblock_shape=[128, 128, 8], stages=4, warp_count=[2, 4, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Unity, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py index ea3ba2b0..21061729 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -29,15 +29,18 @@ class Conv2dDgradImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.Te tile_description = TileDescription( threadblock_shape=[128, 128, 16], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Unity, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) @@ -67,15 +70,18 @@ class Conv2dDgradImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.Te tile_description = TileDescription( threadblock_shape=[128, 128, 16], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Unity, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py index 7a8e8ba3..fb4f2434 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py @@ -97,15 +97,18 @@ class Conv2dFpropFewChannelsF16NHWCF16NHWCF16HNWCTensorOpF32SM80(unittest.TestCa tile_description = TileDescription( threadblock_shape=[128, 128, 64], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.few_channels, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) @@ -135,15 +138,18 @@ class Conv2dFpropFewChannelsF16NHWCF16NHWCF16HNWCTensorOpF32SM80(unittest.TestCa tile_description = TileDescription( threadblock_shape=[128, 128, 32], stages=2, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.few_channels, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py index 43c38c81..cf46d0b5 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py @@ -79,15 +79,18 @@ class Conv2dFpropFixedChannelsF16NHWCF16NHWCF16HNWCTensorOpF32SM80(unittest.Test tile_description = TileDescription( threadblock_shape=[128, 128, 64], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.fixed_channels, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) @@ -117,15 +120,18 @@ class Conv2dFpropFixedChannelsF16NHWCF16NHWCF16HNWCTensorOpF32SM80(unittest.Test tile_description = TileDescription( threadblock_shape=[128, 128, 64], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.fixed_channels, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) @@ -155,15 +161,18 @@ class Conv2dFpropFixedChannelsF16NHWCF16NHWCF16HNWCTensorOpF32SM80(unittest.Test tile_description = TileDescription( threadblock_shape=[128, 128, 64], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.fixed_channels, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py index 36640794..e2a4ccc3 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py @@ -29,15 +29,18 @@ class Conv2dFpropImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC tile_description = TileDescription( threadblock_shape=[128, 128, 64], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float16) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) @@ -67,15 +70,18 @@ class Conv2dFpropImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC tile_description = TileDescription( threadblock_shape=[128, 128, 64], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float16) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) @@ -105,15 +111,18 @@ class Conv2dFpropImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC tile_description = TileDescription( threadblock_shape=[128, 128, 64], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float16) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) @@ -173,15 +182,18 @@ class Conv2dFpropImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC tile_description = TileDescription( threadblock_shape=[128, 128, 64], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float16) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) @@ -241,15 +253,18 @@ class Conv2dFpropImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC tile_description = TileDescription( threadblock_shape=[128, 128, 64], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float16) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py index a48cc22c..101aaa1d 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -29,15 +29,18 @@ class Conv2dFpropImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC tile_description = TileDescription( threadblock_shape=[128, 128, 64], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py index 05d77052..412e199a 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py @@ -30,15 +30,18 @@ class Conv2dFpropImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase) tile_description = TileDescription( threadblock_shape=[128, 128, 8], stages=4, warp_count=[4, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle2 ) @@ -68,15 +71,18 @@ class Conv2dFpropImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase) tile_description = TileDescription( threadblock_shape=[128, 128, 8], stages=4, warp_count=[2, 4, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py index 4e1570c5..4585d66c 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -29,15 +29,18 @@ class Conv2dFpropImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.Te tile_description = TileDescription( threadblock_shape=[128, 128, 16], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) @@ -67,15 +70,18 @@ class Conv2dFpropImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.Te tile_description = TileDescription( threadblock_shape=[128, 128, 16], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py index 4c69340f..4ce627c4 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -29,15 +29,18 @@ class Conv2dStridedDgradImplicitGemmF16NHWCF16NHWCF32NHWCTensorOpF32SM80(unittes tile_description = TileDescription( threadblock_shape=[128, 128, 32], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.StridedDgradIdentitySwizzle1 ) @@ -67,15 +70,18 @@ class Conv2dStridedDgradImplicitGemmF16NHWCF16NHWCF32NHWCTensorOpF32SM80(unittes tile_description = TileDescription( threadblock_shape=[128, 256, 64], stages=3, warp_count=[2, 4, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.StridedDgradIdentitySwizzle1 ) @@ -105,15 +111,18 @@ class Conv2dStridedDgradImplicitGemmF16NHWCF16NHWCF32NHWCTensorOpF32SM80(unittes tile_description = TileDescription( threadblock_shape=[128, 128, 32], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.StridedDgradIdentitySwizzle1 ) @@ -155,15 +164,18 @@ class Conv2dStridedDgradImplicitGemmF16NHWCF16NHWCF32NHWCTensorOpF32SM80(unittes tile_description = TileDescription( threadblock_shape=[128, 128, 32], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.StridedDgradIdentitySwizzle1 ) @@ -193,15 +205,18 @@ class Conv2dStridedDgradImplicitGemmF16NHWCF16NHWCF32NHWCTensorOpF32SM80(unittes tile_description = TileDescription( threadblock_shape=[128, 128, 32], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.StridedDgradIdentitySwizzle1 ) diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py index 370abcc5..533e66c4 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py @@ -29,15 +29,19 @@ class Conv2dWgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC tile_description = TileDescription( threadblock_shape=[128, 128, 64], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst + ) + + epilogue_functor = LinearCombination( + C.element, C.alignment, math_inst.element_accumulator, + cutlass.float16 ) operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) @@ -67,15 +71,19 @@ class Conv2dWgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC tile_description = TileDescription( threadblock_shape=[128, 128, 64], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst + ) + + epilogue_functor = LinearCombination( + C.element, C.alignment, math_inst.element_accumulator, + cutlass.float16 ) operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py index 6e9ed6c7..2399a1e1 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -29,15 +29,18 @@ class Conv2dWgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC tile_description = TileDescription( threadblock_shape=[128, 128, 16], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) @@ -67,15 +70,18 @@ class Conv2dWgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC tile_description = TileDescription( threadblock_shape=[128, 128, 16], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) @@ -105,15 +111,18 @@ class Conv2dWgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC tile_description = TileDescription( threadblock_shape=[64, 256, 32], stages=3, warp_count=[1, 4, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) @@ -143,15 +152,18 @@ class Conv2dWgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC tile_description = TileDescription( threadblock_shape=[128, 128, 16], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) @@ -193,15 +205,18 @@ class Conv2dWgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC tile_description = TileDescription( threadblock_shape=[128, 128, 16], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py index f92fdfb1..c932d808 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py @@ -30,15 +30,18 @@ class Conv2dWgradImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase) tile_description = TileDescription( threadblock_shape=[128, 128, 8], stages=4, warp_count=[2, 4, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) @@ -68,15 +71,18 @@ class Conv2dWgradImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase) tile_description = TileDescription( threadblock_shape=[128, 128, 8], stages=4, warp_count=[2, 4, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py index e5520715..a69274fc 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -29,15 +29,18 @@ class Conv2dWgradImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.Te tile_description = TileDescription( threadblock_shape=[128, 128, 16], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) @@ -67,15 +70,18 @@ class Conv2dWgradImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.Te tile_description = TileDescription( threadblock_shape=[128, 128, 32], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, - min_compute=80, max_compute=80 + math_instruction=math_inst ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + operation = Conv2dOperation( conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, - epilogue_functor=EpilogueFunctor.LinearCombination, + stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) diff --git a/tools/library/scripts/pycutlass/test/example/run_all_example.sh b/tools/library/scripts/pycutlass/test/example/run_all_example.sh new file mode 100644 index 00000000..8f68bc30 --- /dev/null +++ b/tools/library/scripts/pycutlass/test/example/run_all_example.sh @@ -0,0 +1,80 @@ +pushd $CUTLASS_PATH/examples/40_cutlass_py/ + +python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 + +python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2 + +python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 + +python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2 + +python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_f32 -op TensorOp -b 64 64 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 4 + +python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 + +python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2 + +python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 256 128 64 -s 3 -w 4 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 3 + +python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 5 + +python gemm.py -i 16 8 32 -ta int8 -tb int8 -tc int8 -tacc int32 -m multiply_add -op TensorOp -b 128 128 128 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 16 -lb ColumnMajor -ab 16 -lc RowMajor -ac 16 -te float32 -ep FastLinearCombinationClamp -sw IdentitySwizzle2 -p 512 512 512 -alpha 1.0 -beta 0.0 -gm Gemm -k 1 + +python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 0.0 -pm Device + +python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 1.0 -pm Host + +python gemm_grouped.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 64 8 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device + +python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device + +python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 13 17 8 -krsc 24 3 3 8 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0 + +python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 1.0 -beta 1.0 + +python conv2d.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 128 8 -s 4 -w 4 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -co fprop -st Strided -ia analytic -sm Parallel -k 3 -nhwc 1 71 80 32 -krsc 64 5 5 32 -pad 2 2 2 2 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 1.0 + +python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 1 -lb TensorNHWC -ab 1 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co wgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 8 8 1 -krsc 1 3 3 1 -pad 1 1 1 1 -stride 1 1 -dilation 1 1 -alpha 1.0 -beta 0.0 + +python conv2d.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 128 8 -s 4 -w 2 4 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co wgrad -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0 + +python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0 + +python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 8 -lb TensorNHWC -ab 8 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0 + +python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0 + +python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 8 -lb TensorNHWC -ab 8 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia fixed_channels -sm Serial -k 1 -nhwc 1 8 8 8 -krsc 16 3 3 8 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0 + +python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 56 56 12 -krsc 8 1 1 12 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0 + +python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 -bias -activ relu + +python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2 -bias -activ leaky_relu -activ_arg 0.2 + +python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2 -bias -activ tanh + +python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 0.0 -beta 0.5 -pm Host -bias -activ sigmoid -bias -activ sigmoid + +python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ silu + +python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ hardswish + +python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 0.0 -beta 0.5 -gm GemmSplitKParallel -k 5 -bias -activ gelu + +python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3 + +python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 2 + +python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 + +python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -epv ColumnBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 + +python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 + +python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnReduction -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 + +python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3 + +python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnBroadcast -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 3 +popd diff --git a/tools/library/scripts/pycutlass/test/frontend/test_frontend.py b/tools/library/scripts/pycutlass/test/frontend/test_frontend.py index 6e2ee256..59b5549a 100644 --- a/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +++ b/tools/library/scripts/pycutlass/test/frontend/test_frontend.py @@ -49,7 +49,7 @@ class Test_Frontend(unittest.TestCase): tile_description = TileDescription( [128, 128, 8], 4, [2, 4, 1], - math_inst, 80, 80 + math_inst ) A = TensorDescription( @@ -64,10 +64,14 @@ class Test_Frontend(unittest.TestCase): cutlass.float32, cutlass.RowMajor, 1 ) + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) + self.operation = GemmOperationUniversal( arch=80, tile_description=tile_description, - A=A, B=B, C=C, element_epilogue=cutlass.float32, - epilogue_functor=EpilogueFunctor.LinearCombination, + A=A, B=B, C=C, + epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 ) @@ -89,7 +93,7 @@ class Test_Frontend(unittest.TestCase): arguments = GemmArguments( operation=self.operation, problem_size=problem_size, A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, - output_op=LinearCombinationFunctorArguments(alpha, beta), + output_op=self.operation.epilogue_type(alpha, beta), gemm_mode=cutlass.gemm.Mode.Gemm, split_k_splices=1 ) @@ -119,7 +123,7 @@ class Test_Frontend(unittest.TestCase): arguments = GemmArguments( operation=self.operation, problem_size=problem_size, A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, - output_op=LinearCombinationFunctorArguments(alpha, beta), + output_op=self.operation.epilogue_type(alpha, beta), gemm_mode=cutlass.gemm.Mode.Gemm, split_k_splices=1 ) diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py b/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py index 59bf9bb3..b4505c65 100644 --- a/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +++ b/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py @@ -17,7 +17,7 @@ class GemmBF16TensorOpSm80(unittest.TestCase): tile_description = TileDescription( threadblock_shape=[64, 128, 64], stages=4, warp_count=[2, 2, 1], - math_instruction=math_inst, min_compute=80, max_compute=80 + math_instruction=math_inst ) A = TensorDescription( @@ -33,15 +33,15 @@ class GemmBF16TensorOpSm80(unittest.TestCase): alignment=4 ) - element_epilogue = cutlass.float32 - - epilogue_functor = EpilogueFunctor.LinearCombination + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) swizzling_functor = cutlass.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, - A=A, B=B, C=C, element_epilogue=element_epilogue, + A=A, B=B, C=C, epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor ) @@ -58,7 +58,7 @@ class GemmBF16TensorOpSm80(unittest.TestCase): tile_description = TileDescription( threadblock_shape=[64, 128, 32], stages=6, warp_count=[2, 2, 1], - math_instruction=math_inst, min_compute=80, max_compute=80 + math_instruction=math_inst ) A = TensorDescription( @@ -74,15 +74,15 @@ class GemmBF16TensorOpSm80(unittest.TestCase): alignment=8 ) - element_epilogue = cutlass.float32 - - epilogue_functor = EpilogueFunctor.LinearCombination + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, cutlass.float32) swizzling_functor = cutlass.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, - A=A, B=B, C=C, element_epilogue=element_epilogue, + A=A, B=B, C=C, epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor ) diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py b/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py index 284ac928..5bef482d 100644 --- a/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +++ b/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py @@ -18,7 +18,7 @@ class GemmF16Sm80(unittest.TestCase): tile_description = TileDescription( threadblock_shape=[128, 128, 32], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, min_compute=80, max_compute=80 + math_instruction=math_inst ) A = TensorDescription( @@ -36,13 +36,15 @@ class GemmF16Sm80(unittest.TestCase): element_epilogue = cutlass.float32 - epilogue_functor = EpilogueFunctor.LinearCombination + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, element_epilogue) swizzling_functor = cutlass.BatchedIdentitySwizzle operation = GemmOperationUniversal( arch=80, tile_description=tile_description, - A=A, B=B, C=C, element_epilogue=element_epilogue, + A=A, B=B, C=C, epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor, direct_store=True ) @@ -60,7 +62,7 @@ class GemmF16Sm80(unittest.TestCase): tile_description = TileDescription( threadblock_shape=[128, 128, 64], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, min_compute=80, max_compute=80 + math_instruction=math_inst ) A = TensorDescription( @@ -78,13 +80,15 @@ class GemmF16Sm80(unittest.TestCase): element_epilogue = cutlass.float32 - epilogue_functor = EpilogueFunctor.LinearCombination + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, element_epilogue) swizzling_functor = cutlass.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, - A=A, B=B, C=C, element_epilogue=element_epilogue, + A=A, B=B, C=C, epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor ) @@ -101,7 +105,7 @@ class GemmF16Sm80(unittest.TestCase): tile_description = TileDescription( threadblock_shape=[128, 256, 64], stages=3, warp_count=[2, 4, 1], - math_instruction=math_inst, min_compute=80, max_compute=80 + math_instruction=math_inst ) A = TensorDescription( @@ -119,13 +123,15 @@ class GemmF16Sm80(unittest.TestCase): element_epilogue = cutlass.float32 - epilogue_functor = EpilogueFunctor.LinearCombination + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, element_epilogue) swizzling_functor = cutlass.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, - A=A, B=B, C=C, element_epilogue=element_epilogue, + A=A, B=B, C=C, epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor ) @@ -142,7 +148,7 @@ class GemmF16Sm80(unittest.TestCase): tile_description = TileDescription( threadblock_shape=[256, 128, 64], stages=3, warp_count=[4, 2, 1], - math_instruction=math_inst, min_compute=80, max_compute=80 + math_instruction=math_inst ) A = TensorDescription( @@ -160,13 +166,15 @@ class GemmF16Sm80(unittest.TestCase): element_epilogue = cutlass.float32 - epilogue_functor = EpilogueFunctor.LinearCombination + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, element_epilogue) swizzling_functor = cutlass.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, - A=A, B=B, C=C, element_epilogue=element_epilogue, + A=A, B=B, C=C, epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor ) @@ -183,7 +191,7 @@ class GemmF16Sm80(unittest.TestCase): tile_description = TileDescription( threadblock_shape=[128, 64, 64], stages=3, warp_count=[2, 1, 1], - math_instruction=math_inst, min_compute=80, max_compute=80 + math_instruction=math_inst ) A = TensorDescription( @@ -201,13 +209,15 @@ class GemmF16Sm80(unittest.TestCase): element_epilogue = cutlass.float16 - epilogue_functor = EpilogueFunctor.LinearCombination + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, element_epilogue) swizzling_functor = cutlass.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, - A=A, B=B, C=C, element_epilogue=element_epilogue, + A=A, B=B, C=C, epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor ) @@ -224,7 +234,7 @@ class GemmF16Sm80(unittest.TestCase): tile_description = TileDescription( threadblock_shape=[64, 64, 32], stages=10, warp_count=[2, 2, 1], - math_instruction=math_inst, min_compute=80, max_compute=80 + math_instruction=math_inst ) A = TensorDescription( @@ -242,13 +252,15 @@ class GemmF16Sm80(unittest.TestCase): element_epilogue = cutlass.float16 - epilogue_functor = EpilogueFunctor.LinearCombination + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, element_epilogue) swizzling_functor = cutlass.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, - A=A, B=B, C=C, element_epilogue=element_epilogue, + A=A, B=B, C=C, epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor ) @@ -265,7 +277,7 @@ class GemmF16Sm80(unittest.TestCase): tile_description = TileDescription( threadblock_shape=[256, 128, 64], stages=3, warp_count=[4, 2, 1], - math_instruction=math_inst, min_compute=80, max_compute=80 + math_instruction=math_inst ) A = TensorDescription( @@ -283,13 +295,15 @@ class GemmF16Sm80(unittest.TestCase): element_epilogue = cutlass.float32 - epilogue_functor = EpilogueFunctor.LinearCombination + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, element_epilogue) swizzling_functor = cutlass.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, - A=A, B=B, C=C, element_epilogue=element_epilogue, + A=A, B=B, C=C, epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor ) @@ -306,7 +320,7 @@ class GemmF16Sm80(unittest.TestCase): tile_description = TileDescription( threadblock_shape=[128, 64, 64], stages=3, warp_count=[2, 1, 1], - math_instruction=math_inst, min_compute=80, max_compute=80 + math_instruction=math_inst ) A = TensorDescription( @@ -324,13 +338,15 @@ class GemmF16Sm80(unittest.TestCase): element_epilogue = cutlass.float32 - epilogue_functor = EpilogueFunctor.LinearCombination + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, element_epilogue) swizzling_functor = cutlass.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, - A=A, B=B, C=C, element_epilogue=element_epilogue, + A=A, B=B, C=C, epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor ) @@ -347,7 +363,7 @@ class GemmF16Sm80(unittest.TestCase): tile_description = TileDescription( threadblock_shape=[128, 256, 64], stages=3, warp_count=[2, 4, 1], - math_instruction=math_inst, min_compute=80, max_compute=80 + math_instruction=math_inst ) A = TensorDescription( @@ -365,13 +381,15 @@ class GemmF16Sm80(unittest.TestCase): element_epilogue = cutlass.float32 - epilogue_functor = EpilogueFunctor.LinearCombination + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, element_epilogue) swizzling_functor = cutlass.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, - A=A, B=B, C=C, element_epilogue=element_epilogue, + A=A, B=B, C=C, epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor ) @@ -388,7 +406,7 @@ class GemmF16Sm80(unittest.TestCase): tile_description = TileDescription( threadblock_shape=[128, 256, 64], stages=3, warp_count=[2, 4, 1], - math_instruction=math_inst, min_compute=80, max_compute=80 + math_instruction=math_inst ) A = TensorDescription( @@ -406,13 +424,15 @@ class GemmF16Sm80(unittest.TestCase): element_epilogue = cutlass.float32 - epilogue_functor = EpilogueFunctor.LinearCombination + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, element_epilogue) swizzling_functor = cutlass.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, - A=A, B=B, C=C, element_epilogue=element_epilogue, + A=A, B=B, C=C, epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor ) diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py b/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py index 20c39be3..960bdd39 100644 --- a/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py @@ -19,7 +19,7 @@ class GemmF32nF32nF32nTensorOpF32Sm80(unittest.TestCase): tile_description = TileDescription( threadblock_shape=[128, 128, 32], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, min_compute=80, max_compute=80 + math_instruction=math_inst ) A = TensorDescription( @@ -37,13 +37,15 @@ class GemmF32nF32nF32nTensorOpF32Sm80(unittest.TestCase): element_epilogue = cutlass.float32 - epilogue_functor = EpilogueFunctor.LinearCombination + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, element_epilogue) swizzling_functor = cutlass.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, - A=A, B=B, C=C, element_epilogue=element_epilogue, + A=A, B=B, C=C, epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor ) @@ -61,7 +63,7 @@ class GemmF32nF32nF32nTensorOpF32Sm80(unittest.TestCase): tile_description = TileDescription( threadblock_shape=[128, 128, 32], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, min_compute=80, max_compute=80 + math_instruction=math_inst ) A = TensorDescription( @@ -79,13 +81,15 @@ class GemmF32nF32nF32nTensorOpF32Sm80(unittest.TestCase): element_epilogue = cutlass.float32 - epilogue_functor = EpilogueFunctor.LinearCombination + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, element_epilogue) swizzling_functor = cutlass.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, - A=A, B=B, C=C, element_epilogue=element_epilogue, + A=A, B=B, C=C, epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor ) @@ -102,7 +106,7 @@ class GemmF32nF32nF32nTensorOpF32Sm80(unittest.TestCase): tile_description = TileDescription( threadblock_shape=[64, 64, 32], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, min_compute=80, max_compute=80 + math_instruction=math_inst ) A = TensorDescription( @@ -120,13 +124,15 @@ class GemmF32nF32nF32nTensorOpF32Sm80(unittest.TestCase): element_epilogue = cutlass.float32 - epilogue_functor = EpilogueFunctor.LinearCombination + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, element_epilogue) swizzling_functor = cutlass.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, - A=A, B=B, C=C, element_epilogue=element_epilogue, + A=A, B=B, C=C, epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor ) diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py b/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py index 04591ab2..1e1778a9 100644 --- a/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +++ b/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py @@ -17,7 +17,7 @@ class GemmF64TensorOpSm80(unittest.TestCase): tile_description = TileDescription( threadblock_shape=[32, 32, 16], stages=4, warp_count=[2, 2, 1], - math_instruction=math_inst, min_compute=80, max_compute=80 + math_instruction=math_inst ) # alignment 1 restricted for double @@ -36,13 +36,15 @@ class GemmF64TensorOpSm80(unittest.TestCase): element_epilogue = cutlass.float64 - epilogue_functor = EpilogueFunctor.LinearCombination + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, element_epilogue) swizzling_functor = cutlass.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, - A=A, B=B, C=C, element_epilogue=element_epilogue, + A=A, B=B, C=C, epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor ) @@ -59,7 +61,7 @@ class GemmF64TensorOpSm80(unittest.TestCase): tile_description = TileDescription( threadblock_shape=[64, 64, 16], stages=4, warp_count=[2, 2, 1], - math_instruction=math_inst, min_compute=80, max_compute=80 + math_instruction=math_inst ) # alignment 1 restricted for double @@ -78,13 +80,15 @@ class GemmF64TensorOpSm80(unittest.TestCase): element_epilogue = cutlass.float64 - epilogue_functor = EpilogueFunctor.LinearCombination + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, element_epilogue) swizzling_functor = cutlass.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, - A=A, B=B, C=C, element_epilogue=element_epilogue, + A=A, B=B, C=C, epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor ) diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py b/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py index 6024f83c..451a91ac 100644 --- a/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +++ b/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py @@ -18,7 +18,7 @@ class GemmGroupedSm80(unittest.TestCase): tile_description = TileDescription( threadblock_shape=[128, 128, 32], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, min_compute=80, max_compute=80 + math_instruction=math_inst ) A = TensorDescription( @@ -37,14 +37,15 @@ class GemmGroupedSm80(unittest.TestCase): ) element_epilogue = cutlass.float32 - epilogue_functor = EpilogueFunctor.LinearCombination + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, element_epilogue) swizzling_functor = cutlass.BatchedIdentitySwizzle for precompute_mode in [SchedulerMode.Device, SchedulerMode.Host]: operation = GemmOperationGrouped( - tile_description.minimum_compute_capability, + 80, tile_description, A, B, C, - element_epilogue, epilogue_functor, swizzling_functor, precompute_mode=precompute_mode ) @@ -64,7 +65,7 @@ class GemmGroupedSm80(unittest.TestCase): tile_description = TileDescription( threadblock_shape=[64, 64, 16], stages=4, warp_count=[2, 2, 1], - math_instruction=math_inst, min_compute=80, max_compute=80 + math_instruction=math_inst ) A = TensorDescription( @@ -83,14 +84,15 @@ class GemmGroupedSm80(unittest.TestCase): ) element_epilogue = cutlass.float64 - epilogue_functor = EpilogueFunctor.LinearCombination + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, element_epilogue) swizzling_functor = cutlass.BatchedIdentitySwizzle for precompute_mode in [SchedulerMode.Device, SchedulerMode.Host]: operation = GemmOperationGrouped( - tile_description.minimum_compute_capability, + 80, tile_description, A, B, C, - element_epilogue, epilogue_functor, swizzling_functor, precompute_mode=precompute_mode ) @@ -110,7 +112,7 @@ class GemmGroupedSm80(unittest.TestCase): tile_description = TileDescription( threadblock_shape=[128, 64, 8], stages=4, warp_count=[2, 2, 1], - math_instruction=math_inst, min_compute=80, max_compute=80 + math_instruction=math_inst ) A = TensorDescription( @@ -129,14 +131,15 @@ class GemmGroupedSm80(unittest.TestCase): ) element_epilogue = cutlass.float32 - epilogue_functor = EpilogueFunctor.LinearCombination + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, element_epilogue) swizzling_functor = cutlass.BatchedIdentitySwizzle for precompute_mode in [SchedulerMode.Device, SchedulerMode.Host]: operation = GemmOperationGrouped( - tile_description.minimum_compute_capability, + 80, tile_description, A, B, C, - element_epilogue, epilogue_functor, swizzling_functor, precompute_mode=precompute_mode ) @@ -156,7 +159,7 @@ class GemmGroupedSm80(unittest.TestCase): tile_description = TileDescription( threadblock_shape=[128, 128, 32], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, min_compute=80, max_compute=80 + math_instruction=math_inst ) A = TensorDescription( @@ -175,14 +178,15 @@ class GemmGroupedSm80(unittest.TestCase): ) element_epilogue = cutlass.float32 - epilogue_functor = EpilogueFunctor.LinearCombination + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, element_epilogue) swizzling_functor = cutlass.BatchedIdentitySwizzle for precompute_mode in [SchedulerMode.Device, SchedulerMode.Host]: operation = GemmOperationGrouped( - tile_description.minimum_compute_capability, + 80, tile_description, A, B, C, - element_epilogue, epilogue_functor, swizzling_functor, precompute_mode=precompute_mode ) diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py b/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py index b41b78fd..0a76198a 100644 --- a/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +++ b/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py @@ -1,5 +1,6 @@ import pycutlass from pycutlass import * +from pycutlass.epilogue import LinearCombinationClamp from pycutlass.test import * import unittest @@ -17,7 +18,7 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase): tile_description = TileDescription( threadblock_shape=[64, 64, 64], stages=6, warp_count=[2, 2, 1], - math_instruction=math_inst, min_compute=80, max_compute=80 + math_instruction=math_inst ) A = TensorDescription( @@ -33,15 +34,15 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase): alignment=8 ) - element_epilogue = cutlass.float32 - - epilogue_functor = EpilogueFunctor.FastLinearCombinationClamp + epilogue_functor = FastLinearCombinationClamp( + C.element, C.alignment + ) swizzling_functor = cutlass.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, - A=A, B=B, C=C, element_epilogue=element_epilogue, + A=A, B=B, C=C, epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor ) @@ -58,7 +59,7 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase): tile_description = TileDescription( threadblock_shape=[128, 128, 128], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, min_compute=80, max_compute=80 + math_instruction=math_inst ) A = TensorDescription( @@ -74,15 +75,15 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase): alignment=16 ) - element_epilogue = cutlass.float32 - - epilogue_functor = EpilogueFunctor.FastLinearCombinationClamp + epilogue_functor = FastLinearCombinationClamp( + C.element, C.alignment + ) swizzling_functor = cutlass.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, - A=A, B=B, C=C, element_epilogue=element_epilogue, + A=A, B=B, C=C, epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor ) @@ -99,7 +100,7 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase): tile_description = TileDescription( threadblock_shape=[128, 128, 128], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, min_compute=80, max_compute=80 + math_instruction=math_inst ) A = TensorDescription( @@ -115,15 +116,15 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase): alignment=16 ) - element_epilogue = cutlass.float32 - - epilogue_functor = EpilogueFunctor.FastLinearCombinationClamp + epilogue_functor = FastLinearCombinationClamp( + C.element, C.alignment + ) swizzling_functor = cutlass.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, - A=A, B=B, C=C, element_epilogue=element_epilogue, + A=A, B=B, C=C, epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor ) @@ -140,7 +141,7 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase): tile_description = TileDescription( threadblock_shape=[128, 128, 128], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, min_compute=80, max_compute=80 + math_instruction=math_inst ) A = TensorDescription( @@ -158,13 +159,16 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase): element_epilogue = cutlass.int32 - epilogue_functor = EpilogueFunctor.LinearCombinationClamp + epilogue_functor = LinearCombinationClamp( + C.element, C.alignment, math_inst.element_accumulator, + element_epilogue + ) swizzling_functor = cutlass.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, - A=A, B=B, C=C, element_epilogue=element_epilogue, + A=A, B=B, C=C, epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor ) @@ -181,7 +185,7 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase): tile_description = TileDescription( threadblock_shape=[128, 128, 128], stages=3, warp_count=[2, 2, 1], - math_instruction=math_inst, min_compute=80, max_compute=80 + math_instruction=math_inst ) A = TensorDescription( @@ -199,13 +203,16 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase): element_epilogue = cutlass.int32 - epilogue_functor = EpilogueFunctor.LinearCombinationClamp + epilogue_functor = LinearCombinationClamp( + C.element, C.alignment, math_inst.element_accumulator, + element_epilogue + ) swizzling_functor = cutlass.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, - A=A, B=B, C=C, element_epilogue=element_epilogue, + A=A, B=B, C=C, epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor ) diff --git a/tools/library/scripts/pycutlass/test/unit/cached_results_SM80_2080.txt b/tools/library/scripts/pycutlass/test/unit/cached_results_SM80_2080.txt index c5e51d9f..c026860a 100644 --- a/tools/library/scripts/pycutlass/test/unit/cached_results_SM80_2080.txt +++ b/tools/library/scripts/pycutlass/test/unit/cached_results_SM80_2080.txt @@ -348,3 +348,16 @@ conv2d fprop_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1 conv2d fprop_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 164942943 4259285988 984016853 888753301 conv2d fprop_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 2823094147 1681845497 4242738907 3244428635 conv2d fprop_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 s8nhwc_s8nhwc_inhwc_i_i 4060010502 2881035321 3927119619 3311661122 +conv2d dgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4110991321 3464637181 1030377090 3211227145 +conv2d dgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4110991321 1479940693 2379046159 2482639965 +conv2d dgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 832653836 1871463331 2718290800 1797658305 +conv2d dgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3484040069 664160900 3954982568 985899371 +conv2d dgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1513864544 1924855848 1728786974 3821277575 +conv2d dgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 868180534 1764715518 3998637379 2782670608 +conv2d dgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3437976747 666906244 2107859856 831363691 +conv2d dgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4195072693 1575210381 2486552517 3268706408 +conv2d dgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3457330201 2316839359 1729888024 2308314800 +conv2d dgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 754609939 2469024119 464378888 544154978 +conv2d dgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 754609939 2469024119 464378888 3191247524 +conv2d dgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1690216859 554790212 956712535 1281779197 +conv2d dgrad_1x56x56x8_28x28_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3184127693 835105643 4011933753 3207244654 diff --git a/tools/library/scripts/pycutlass/test/unit/test_sm80.py b/tools/library/scripts/pycutlass/test/unit/test_sm80.py index bedb3a3a..0dd685de 100644 --- a/tools/library/scripts/pycutlass/test/unit/test_sm80.py +++ b/tools/library/scripts/pycutlass/test/unit/test_sm80.py @@ -42,8 +42,7 @@ import unittest # def TestGemmOperator(gemm_kind, math_inst, layout, alignment, tiling, arch, mixed=False, - epilogue_functor = EpilogueFunctor.LinearCombination, - swizzling_functor=cutlass.IdentitySwizzle1, **kwargs): + epilogue_functor=None, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs): """ Test GEMM Operation based on configuration """ @@ -68,7 +67,7 @@ def TestGemmOperator(gemm_kind, math_inst, layout, alignment, tiling, arch, mixe tile_description = TileDescription( tiling[0], tiling[1], tiling[2], - math_inst, arch, arch + math_inst ) A = TensorDescription( @@ -84,11 +83,15 @@ def TestGemmOperator(gemm_kind, math_inst, layout, alignment, tiling, arch, mixe ) element_epilogue = data_type[3] + if epilogue_functor is None: + epilogue_functor = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, element_epilogue) if gemm_kind == GemmKind.Universal: operation = GemmOperationUniversal( arch=arch, tile_description=tile_description, - A=A, B=B, C=C, element_epilogue=element_epilogue, + A=A, B=B, C=C, epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor ) if A.layout in [cutlass.ColumnMajorInterleaved32, cutlass.RowMajorInterleaved32]: @@ -99,7 +102,7 @@ def TestGemmOperator(gemm_kind, math_inst, layout, alignment, tiling, arch, mixe elif gemm_kind == GemmKind.Grouped: operation = GemmOperationGrouped( arch, tile_description, A, B, C, - element_epilogue, epilogue_functor, swizzling_functor, + epilogue_functor, swizzling_functor, precompute_mode=kwargs["precompute_mode"] ) testbed = TestbedGrouped(operation=operation) @@ -110,7 +113,7 @@ def TestGemmOperator(gemm_kind, math_inst, layout, alignment, tiling, arch, mixe def TestConv2dOperator(math_inst, alignment, tiling, arch, stride_supports=[StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided], - epilogue_functor=EpilogueFunctor.LinearCombination, + epilogue_functor=None, swizzling_functor=cutlass.IdentitySwizzle1, interleaved=False, **kwargs): """ Test Conv2d Operation based on configurations @@ -167,20 +170,24 @@ def TestConv2dOperator(math_inst, alignment, tiling, arch, tile_description = TileDescription( threadblock_shape=tiling[0], stages=tiling[1], warp_count=tiling[2], - math_instruction=math_inst, - min_compute=arch, max_compute=arch + math_instruction=math_inst ) if conv_kind == cutlass.conv.Operator.dgrad and stride_support == StrideSupport.Strided: swizzling_functor = cutlass.StridedDgradIdentitySwizzle1 else: swizzling_functor = default_swizzling_functor + + if epilogue_functor is None: + epilogue_functor_ = LinearCombination( + C.element, C.alignment, + math_inst.element_accumulator, data_type[3]) operation = Conv2dOperation( conv_kind=conv_kind, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, arch=arch, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=data_type[3], stride_support=stride_support, - epilogue_functor=epilogue_functor, + stride_support=stride_support, + epilogue_functor=epilogue_functor_, swizzling_functor=swizzling_functor ) @@ -369,7 +376,11 @@ class Test_SM80(unittest.TestCase): tiling = ([256, 64, 64], 4, [4, 1, 1]) data_type_mixed = [cutlass.int8, cutlass.int8, cutlass.int8, cutlass.float32] - self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment_mixed, tiling, 80, False, data_type=data_type_mixed, epilogue_functor=EpilogueFunctor.FastLinearCombinationClamp)) + epilogue_functor = FastLinearCombinationClamp( + data_type_mixed[2], alignment_mixed[2] + ) + + self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment_mixed, tiling, 80, False, data_type=data_type_mixed, epilogue_functor=epilogue_functor)) stride_supports = [StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided] layout = [cutlass.TensorNC32HW32, cutlass.TensorC32RSK32, cutlass.TensorNC32HW32] results = TestConv2dOperator(math_inst, alignment_mixed, tiling, 80, stride_supports=stride_supports, data_type=data_type_mixed, layout=layout, interleaved=True) @@ -378,59 +389,59 @@ class Test_SM80(unittest.TestCase): def SM80_SparseTensorOp_16832(self): pass - def test_SM80_PlanarComplexTensorOp_16816(self): + def SM80_PlanarComplexTensorOp_16816(self): pass - def test_SM80_SparseTensorOp_16816_fast_math(self): + def SM80_SparseTensorOp_16816_fast_math(self): pass - def test_SM80_TensorOp_1688_complex(self): + def SM80_TensorOp_1688_complex(self): pass - def test_SM80_TensorOp_1688_fast_fp32_math_complex(self): + def SM80_TensorOp_1688_fast_fp32_math_complex(self): pass - def test_SM80_TensorOp_1688_rank_k(self): + def SM80_TensorOp_1688_rank_k(self): pass - def test_SM80_TensorOp_1688_rank_k_complex(self): + def SM80_TensorOp_1688_rank_k_complex(self): pass - def test_SM80_TensorOp_1688_trmm(self): + def SM80_TensorOp_1688_trmm(self): pass - def test_SM80_TensorOp_1688_trmm_complex(self): + def SM80_TensorOp_1688_trmm_complex(self): pass - def test_SM80_TensorOp_1688_symm(self): + def SM80_TensorOp_1688_symm(self): pass - def test_SM80_TensorOp_1688_symm_complex(self): + def SM80_TensorOp_1688_symm_complex(self): pass - def test_SM80_TensorOp_884_complex(self): + def SM80_TensorOp_884_complex(self): pass - def test_SM80_TensorOp_884_complex_gaussian(self): + def SM80_TensorOp_884_complex_gaussian(self): pass - def test_SM80_TensorOp_884_rank_k(self): + def SM80_TensorOp_884_rank_k(self): pass - def test_SM80_TensorOp_884_rank_k_complex(self): + def SM80_TensorOp_884_rank_k_complex(self): pass - def test_SM80_TensorOp_884_rank_k_complex_gaussian(self): + def SM80_TensorOp_884_rank_k_complex_gaussian(self): pass - def test_SM80_TensorOp_884_trmm(self): + def SM80_TensorOp_884_trmm(self): pass - def test_SM80_TensorOp_884_trmm_complex(self): + def SM80_TensorOp_884_trmm_complex(self): pass - def test_SM80_TensorOp_884_trmm_complex_gaussian(self): + def SM80_TensorOp_884_trmm_complex_gaussian(self): pass - def test_SM80_TensorOp_884_symm(self): + def SM80_TensorOp_884_symm(self): pass - def test_SM80_TensorOp_884_symm_complex(self): + def SM80_TensorOp_884_symm_complex(self): pass - def test_SM80_TensorOp_884_symm_complex_gaussian(self): + def SM80_TensorOp_884_symm_complex_gaussian(self): pass - def test_SM80_SparseTensorOp_16864_TN(self): + def SM80_SparseTensorOp_16864_TN(self): pass - def test_SM80_TensorOp_16864_TN(self): + def SM80_TensorOp_16864_TN(self): pass - def test_SM80_SparseTensorOp_168128_TN(self): + def SM80_SparseTensorOp_168128_TN(self): pass - def test_SM80_TensorOp_16864_Interleaved(self): + def SM80_TensorOp_16864_Interleaved(self): pass - def test_SM80_TensorOp_168256(self): + def SM80_TensorOp_168256(self): pass - def test_SM80_Simt_complex(self): + def SM80_Simt_complex(self): pass diff --git a/tools/library/src/util.cu b/tools/library/src/util.cu index 4636fd36..27636661 100644 --- a/tools/library/src/util.cu +++ b/tools/library/src/util.cu @@ -1195,13 +1195,13 @@ std::string lexical_cast(std::vector &bytes, NumericTypeID type) { break; case NumericTypeID::kBF16: { - float tmp = *reinterpret_cast(bytes.data());; + float tmp = *reinterpret_cast(bytes.data()); ss << tmp; } break; case NumericTypeID::kTF32: { - float tmp = *reinterpret_cast(bytes.data());; + float tmp = *reinterpret_cast(bytes.data()); ss << tmp; } break; diff --git a/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h b/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h index 7a387eaa..762486fe 100644 --- a/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h +++ b/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h @@ -183,7 +183,7 @@ __global__ void GemmPlanarComplex( ComplexC d_ij; d_ij.real() = convert_op(result.real()); - d_ij.imag() = convert_op(result.imag());; + d_ij.imag() = convert_op(result.imag()); tensor_d.at(coord) = d_ij; } diff --git a/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h b/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h index c1e5c094..a7477f17 100644 --- a/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h +++ b/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h @@ -172,7 +172,7 @@ void GemmPlanarComplex( complex result = alpha * acc + beta * src; d_ij.real() = convert_op(result.real()); - d_ij.imag() = convert_op(result.imag());; + d_ij.imag() = convert_op(result.imag()); tensor_d.at(coord) = d_ij; }